bert加载数据代码

时间:2022-07-23
本文章向大家介绍bert加载数据代码,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
from torch.utils.data import Dataset
import tqdm
import json
import torch
import random
import numpy as np
from sklearn.utils import shuffle


class BERTDataset(Dataset):
    def __init__(self, corpus_path, word2idx_path, seq_len, hidden_dim=384, on_memory=True):
        # hidden dimension for positional encoding
        self.hidden_dim = hidden_dim
        # define path of dicts
        self.word2idx_path = word2idx_path
        # define max length
        self.seq_len = seq_len
        # load whole corpus at once or not
        self.on_memory = on_memory
        # directory of corpus dataset
        self.corpus_path = corpus_path
        # define special symbols
        self.pad_index = 0
        self.unk_index = 1
        self.cls_index = 2
        self.sep_index = 3
        self.mask_index = 4
        self.num_index = 5

        # 加载字典
        with open(word2idx_path, "r", encoding="utf-8") as f:
            self.word2idx = json.load(f)

        # 加载语料
        with open(corpus_path, "r", encoding="utf-8") as f:
            if not on_memory:
                # 如果不将数据集直接加载到内存, 则需先确定语料行数
                self.corpus_lines = 0
                for _ in tqdm.tqdm(f, desc="Loading Dataset"):
                    self.corpus_lines += 1

            if on_memory:
                # 将数据集全部加载到内存
                self.lines = [eval(line) for line in tqdm.tqdm(f, desc="Loading Dataset")]
                self.corpus_lines = len(self.lines)

        if not on_memory:
            # 如果不全部加载到内存, 首先打开语料
            self.file = open(corpus_path, "r", encoding="utf-8")
            # 然后再打开同样的语料, 用来抽取负样本
            self.random_file = open(corpus_path, "r", encoding="utf-8")
            # 下面是为了错位抽取负样本
            for _ in range(np.random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
                self.random_file.__next__()

    def __len__(self):
        return self.corpus_lines

    def __getitem__(self, item):
        t1, t2, is_next_label = self.random_sent(item)

        t1_random, t1_label = self.random_char(t1)
        t2_random, t2_label = self.random_char(t2)

        t1 = [self.cls_index] + t1_random + [self.sep_index]
        t2 = t2_random + [self.sep_index]

        t1_label = [self.pad_index] + t1_label + [self.pad_index]
        t2_label = t2_label + [self.pad_index]

        segment_label = ([0 for _ in range(len(t1))] + [1 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]

        output = {"bert_input": torch.tensor(bert_input),
                  "bert_label": torch.tensor(bert_label),
                  "segment_label": torch.tensor(segment_label),
                  "is_next": torch.tensor([is_next_label])}

        return output

    def tokenize_char(self, segments):
        return [self.word2idx.get(char, self.unk_index) for char in segments]

    def random_char(self, sentence):
        char_tokens_ = list(sentence)
        char_tokens = self.tokenize_char(char_tokens_)

        output_label = []
        for i, token in enumerate(char_tokens):
            prob = random.random()
            if prob < 0.30:
                prob /= 0.30
                output_label.append(char_tokens[i])
                # 80% randomly change token to mask token
                if prob < 0.8:
                    char_tokens[i] = self.mask_index
                # 10% randomly change token to random token
                elif prob < 0.9:
                    char_tokens[i] = random.randrange(len(self.word2idx))
            else:
                output_label.append(0)
        return char_tokens, output_label


    def random_sent(self, index):
        t1, t2 = self.get_corpus_line(index)

        # output_text, label(isNotNext:0, isNext:1)
        if random.random() > 0.5:
            return t1, t2, 1
        else:
            return t1, self.get_random_line(), 0

    def get_corpus_line(self, item):
        if self.on_memory:
            return self.lines[item]["text1"], self.lines[item]["text2"]
        else:
            line = self.file.__next__()
            if line is None:
                self.file.close()
                self.file = open(self.corpus_path, "r", encoding="utf-8")
                line = self.file.__next__()
            line = eval(line)
            t1, t2 = line["text1"], line["text2"]
            return t1, t2

    def get_random_line(self):
        if self.on_memory:
            return self.lines[random.randrange(len(self.lines))]["text2"]

        line = self.random_file.__next__()
        if line is None:
            self.random_file.close()
            self.random_file = open(self.corpus_path, "r", encoding="utf-8")
            for _ in range(np.random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
                self.random_file.__next__()
            line = self.random_file.__next__()
        return eval(line)["text2"]