Spaces:
Runtime error
Runtime error
| import json, time | |
| import numpy as np | |
| from tqdm import tqdm | |
| from torch.utils.data import Dataset, DataLoader | |
| from pytorch_pretrained_bert import BertModel, BertTokenizer | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import os | |
| # import paddle | |
| # import paddle.nn.functional as F | |
| import unicodedata | |
| from pyhanlp import * | |
| from torch_geometric.nn import RGCNConv | |
| from gcn import * | |
| from graphModule import * | |
| from einops import rearrange | |
| from config import args | |
| from biaffine import * | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "1" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # DEVICE = torch.device("cpu") | |
| # BERT_PATH = "./SpanBERT/Spanbert-base-cased" | |
| # BERT_PATH = "./chinese_roberta_wwm_ext_pytorch" | |
| BERT_PATH = "./bert" | |
| maxlen = 256 ####256 | |
| def load_data(filename): | |
| D = [] | |
| with open(filename) as data_file: | |
| data = data_file.read() | |
| # print(data) | |
| data = json.loads(data) | |
| for item in data: | |
| d = {'text': item['text'], 'triple_list': []} | |
| for sub_item in item['triple_list']: | |
| d['triple_list'].append( | |
| (sub_item[0], sub_item[1], sub_item[2]) | |
| ) | |
| D.append(d) | |
| return D | |
| # 加载数据集 | |
| train_data = load_data('./data/CMED/train_triples.json') | |
| valid_data = load_data('./data/CMED/dev_triples.json') | |
| def search(pattern, sequence): | |
| """从sequence中寻找子串pattern | |
| 如果找到,返回第一个下标;否则返回-1。 | |
| """ | |
| n = len(pattern) | |
| for i in range(len(sequence)): | |
| if sequence[i:i + n] == pattern: | |
| return i | |
| return -1 | |
| train_data_new = [] # 创建新的训练集,把结束位置超过250的文本去除,可见并没有去除多少 | |
| for data in tqdm(train_data): | |
| # print (data) | |
| flag = 1 | |
| for s, p, o in data['triple_list']: | |
| s_begin = search(s, data['text']) | |
| o_begin = search(o, data['text']) | |
| if s_begin == -1 or o_begin == -1 or s_begin + len(s) > 256 or o_begin + len(o) > 256: | |
| flag = 0 | |
| break | |
| if flag == 1: | |
| train_data_new.append(data) | |
| print("去除大于250的文本:\t", len(train_data_new)) | |
| # 读取schema | |
| ''' | |
| with open('RE/data/schema.json', encoding='utf-8') as f: | |
| id2predicate, predicate2id, n = {}, {}, 0 | |
| predicate2type = {} | |
| for l in f: | |
| l = json.loads(l) | |
| predicate2type[l['predicate']] = (l['subject_type'], l['object_type']) | |
| for k, _ in sorted(l['object_type'].items()): | |
| key = l['predicate'] + '_' + k | |
| id2predicate[n] = key | |
| predicate2id[key] = n | |
| n += 1 | |
| print(len(predicate2id)) | |
| ''' | |
| with open('./data/CMED/rel2id.json', encoding='utf-8') as f: | |
| # id2predicate, predicate2id, n = {}, {}, 0 | |
| l = json.load(f) | |
| id2predicate = l[0] | |
| predicate2id = l[1] | |
| print("关系类型数量:\t", len(predicate2id)) | |
| class OurTokenizer(BertTokenizer): | |
| def tokenize(self, text): | |
| R = [] | |
| for c in text: | |
| if c in self.vocab: | |
| R.append(c) | |
| elif self._is_whitespace(c): | |
| R.append('[unused1]') | |
| else: | |
| R.append('[UNK]') | |
| return R | |
| def _is_whitespace(self, char): | |
| if char == " " or char == "\t" or char == "\n" or char == "\r": | |
| return True | |
| cat = unicodedata.category(char) | |
| if cat == "Zs": | |
| return True | |
| return False | |
| # 初始化分词器 | |
| tokenizer = OurTokenizer(vocab_file="./chinese_roberta_wwm_ext_pytorch/vocab.txt") | |
| ######依存句法树+分词 | |
| def seg_pos(text): | |
| head, seg_word, Dep_rel, str_le = [], [], [], [] | |
| # tree = HanLP.parseDependency(text) | |
| parser = JClass('com.hankcs.hanlp.dependency.nnparser.NeuralNetworkDependencyParser')() | |
| parser.enableDeprelTranslator(False) | |
| tree = parser.parse(text) | |
| for word in tree.iterator(): # 通过dir()可以查看sentence的方法 | |
| head.append(word.HEAD.ID) | |
| for i in word.LEMMA.split(): | |
| str_le.append(i) | |
| seg_word.append(word.LEMMA) | |
| Dep_rel.append(word.DEPREL) | |
| return head, seg_word, Dep_rel, str_le | |
| def out_list_word(seg_word): | |
| temp = "" | |
| for word in seg_word: | |
| temp += " " + word | |
| text_out = temp.lstrip(" ") | |
| return text_out | |
| def map_to_ids(tokens, vocab): | |
| ids = [vocab[t] if t in vocab.keys() else 0 for t in tokens] | |
| return ids | |
| def vocab_json(): | |
| vocab_out = json.load(open("./vacab.json")) | |
| return vocab_out | |
| def dep_json(): | |
| dep_out = json.load(open("./dep.json")) | |
| return dep_out | |
| class TorchDataset(Dataset): | |
| def __init__(self, data): | |
| self.data = data | |
| def __getitem__(self, i): | |
| t = self.data[i] | |
| # print ('t!!!',t) ######{'text': '齐志江,男,汉族,中共党员,大学学历', 'triple_list': [('齐志江', '民族', '汉族')]} | |
| x = tokenizer.tokenize(t['text']) | |
| # print (x) | |
| x = ["[CLS]"] + x + ["[SEP]"] | |
| token_ids = tokenizer.convert_tokens_to_ids(x) | |
| seg_ids = [0] * len(token_ids) | |
| assert len(token_ids) == len(t['text']) + 2 | |
| spoes = {} | |
| for s, p, o in t['triple_list']: | |
| s = tokenizer.tokenize(s) | |
| s = tokenizer.convert_tokens_to_ids(s) | |
| p = predicate2id[p] | |
| o = tokenizer.tokenize(o) | |
| o = tokenizer.convert_tokens_to_ids(o) | |
| s_idx = search(s, token_ids) | |
| o_idx = search(o, token_ids) | |
| if s_idx != -1 and o_idx != -1: | |
| s = (s_idx, s_idx + len(s) - 1) | |
| o = (o_idx, o_idx + len(o) - 1, p) # 同时预测o和p | |
| if s not in spoes: | |
| spoes[s] = [] | |
| spoes[s].append(o) | |
| # print(spoes) {(2, 5): [(13, 15, 31), (19, 21, 38), (29, 31, 45)]} | |
| if spoes: | |
| sub_labels = np.zeros((len(token_ids), 2)) | |
| # print (sub_labels) | |
| for s in spoes: | |
| # print (s) #(2, 5) | |
| # print (sub_labels) | |
| # print(s[0]) | |
| sub_labels[s[0], 0] = 1 | |
| sub_labels[s[1], 1] = 1 | |
| # 随机选一个subject | |
| start, end = np.array(list(spoes.keys())).T | |
| start = np.random.choice(start) | |
| # print (start) | |
| end = sorted(end[end >= start])[0] | |
| sub_ids = (start, end) | |
| obj_labels = np.zeros((len(token_ids), len(predicate2id), 2)) | |
| for o in spoes.get(sub_ids, []): | |
| # print (o) | |
| obj_labels[o[0], o[2], 0] = 1 | |
| obj_labels[o[1], o[2], 1] = 1 | |
| token_ids = self.sequence_padding(token_ids, maxlen=maxlen) | |
| seg_ids = self.sequence_padding(seg_ids, maxlen=maxlen) | |
| sub_labels = self.sequence_padding(sub_labels, maxlen=maxlen, padding=np.zeros(2)) | |
| sub_ids = np.array(sub_ids) | |
| obj_labels = self.sequence_padding(obj_labels, maxlen=maxlen, | |
| padding=np.zeros((len(predicate2id), 2))) | |
| return (torch.LongTensor(token_ids), torch.LongTensor(seg_ids), torch.LongTensor(sub_ids), | |
| torch.LongTensor(sub_labels), torch.LongTensor(obj_labels)) | |
| def __len__(self): | |
| data_len = len(self.data) | |
| return data_len | |
| def sequence_padding(self, x, maxlen, padding=0): | |
| output = np.concatenate([x, [padding] * (maxlen - len(x))]) if len(x) < maxlen else np.array(x[:maxlen]) | |
| return output | |
| train_dataset = TorchDataset(train_data_new) | |
| train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch1, shuffle=True, drop_last=True) | |
| # for i, x in enumerate(train_loader): | |
| # print([_.shape for _ in x]) | |
| # if i == 10: | |
| # break | |
| class GRUnet(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim, hidden_dim, layer_dim, output_dim): | |
| """ | |
| vocab_size: 词典长度,也就是嵌入矩阵的行数 | |
| embedding_dim: 词向量的维度,也就是嵌入矩阵的列数,也是W的列数,也是输入GRU的x_t的维度 | |
| hidden_dim: GRU神经元的个数,也就是W的行数 | |
| layer_dim: GRU的层数 | |
| output_dim: 隐藏层输出的维度 | |
| """ | |
| super(GRUnet, self).__init__() | |
| # 嵌入层 | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| # GRU + 全连接 | |
| self.gru = nn.GRU(embedding_dim, hidden_dim, layer_dim, | |
| batch_first=True) | |
| self.fc1 = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.Dropout(0.5), | |
| torch.nn.ReLU(), | |
| nn.Linear(hidden_dim, output_dim) | |
| ) | |
| def forward(self, x): | |
| # x : [bacth, time_step, vocab_size] | |
| embeds = self.embedding(x) | |
| # print(embeds.shape) | |
| # embeds : [batch, time_step, embedding_dim] | |
| r_out, h_n = self.gru(embeds, None) | |
| # print (r_out.shape) | |
| # r_out : [batch, time_step, hidden_dim] | |
| # out = self.fc1(r_out[:, -1, :]) | |
| out = self.fc1(r_out) | |
| # out : [batch, time_step, output_dim] | |
| return out | |
| class GCN(nn.Module): | |
| def __init__(self, hidden_size=768): | |
| super(GCN, self).__init__() | |
| self.hidden_size = hidden_size | |
| # self.fc = nn.Linear(self.hidden_size, self.hidden_size // 2) | |
| def forward(self, x, adj, is_relu=True): | |
| out = x | |
| # Make permutations for matrix multiplication | |
| # Assuming batch_first = False | |
| # print (out.shape) | |
| # out = out.permute(1, 0, 2) # to: batch, seq_len, hidden | |
| # adj = adj.permute(2, 0, 1) # to: batch, seq_len, seq_len | |
| out = torch.bmm(adj, out) # .permute(1, 0, 2) # to: seq_len, batch, hidden | |
| if is_relu == True: | |
| out = F.relu(out) | |
| return out | |
| class RGCN(torch.nn.Module): | |
| def __init__(self,in_channels,hideden_channels,out_channels,n_layers=2,dropout=0.5): | |
| super().__init__() | |
| self.convs = torch.nn.ModuleList() | |
| self.norms = torch.nn.ModuleList() | |
| self.relu = F.relu | |
| self.dropout = dropout | |
| self.convs.append(RGCNConv(in_channels,hideden_channels,num_relations=24,num_bases=1)) | |
| for i in range(n_layers-2): | |
| self.convs.append(RGCNConv(hideden_channels,hideden_channels,num_relations=24,num_bases=1)) | |
| self.norms.append(torch.nn.BatchNormld(hideden_channels)) | |
| self.convs.append(RGCNConv(hideden_channels,out_channels,num_relations=24,num_bases=1)) | |
| def forward(self, x, edge_index=2561,edge_type=24): | |
| for conv ,norm in zip(self.convs, self.norms): | |
| x = norm(conv(x,2561,24)) | |
| x = F.relu(x) | |
| x = F.dropout(x,p=self.dropout,training=self.training) | |
| return x | |
| class BertLayerNorm(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-12): | |
| super(BertLayerNorm, self).__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, x): | |
| u = x.mean(-1, keepdim=True) # [bs, maxlen, 1] | |
| s = (x - u).pow(2).mean(-1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
| return self.weight * x + self.bias | |
| class Attention2(nn.Module): | |
| """ | |
| 1.输入 [batch_size,time_step,hidden_dim] -> Linear、Tanh | |
| 2.[batch_size,time_step,hidden_dim] -> transpose | |
| 3.[batch_size,hidden_dim,time_step] -> Softmax | |
| 4.[batch_size,hidden_dim,time_step] -> mean | |
| 5.[batch_size,time_step] -> unsqueeze | |
| 5.[batch_size,1,time_step] -> expand | |
| 6.[batch_size,hidden_dim,time_step] -> transpose | |
| 7.[batch_size,time_step,hidden_dim] | |
| """ | |
| def __init__(self, hidden_dim): | |
| super(Attention2, self).__init__() | |
| self.hidden_dim = hidden_dim | |
| self.dense = nn.Linear(hidden_dim, hidden_dim) | |
| def forward(self, features, mean=True): | |
| batch_size, time_step, hidden_dim = features.size() | |
| # weight = nn.Tanh()(self.dense(features)) | |
| weight = nn.ReLU()(self.dense(features)) | |
| # mask给负无穷使得权重为0 | |
| mask_idx = torch.sign(torch.abs(features).sum(dim=-1)) | |
| mask_idx = mask_idx.unsqueeze(-1).expand(batch_size, time_step, hidden_dim) | |
| paddings = torch.ones_like(mask_idx) * (-2 ** 32 + 1) | |
| weight = torch.where(torch.eq(mask_idx, 1), weight, paddings) | |
| weight = weight.transpose(2, 1) | |
| # weight = nn.Softmax(dim=2)(weight) | |
| # weight = nn.Sigmoid(weight) | |
| if mean: | |
| weight = weight.mean(dim=1) | |
| weight = weight.unsqueeze(1) | |
| weight = weight.expand(batch_size, hidden_dim, time_step) | |
| weight = weight.transpose(2, 1) | |
| features_attention = weight * features | |
| return features_attention | |
| class KeyValueMemoryNetwork(nn.Module): | |
| def __init__(self, vocab_size, feature_vocab_size, emb_size): | |
| super(KeyValueMemoryNetwork, self).__init__() | |
| self.key_embedding = nn.Embedding(vocab_size, emb_size, padding_idx=0) | |
| self.value_embedding = nn.Embedding(feature_vocab_size, emb_size, padding_idx=0) | |
| self.scale = np.power(emb_size, 0.5) | |
| def forward(self, key_embed, value_embed, hidden, mask_matrix): | |
| # key_embed = self.key_embedding(key_seq) | |
| # print (key_embed.shape) | |
| # value_embed = self.value_embedding(value_seq) | |
| # print (value_embed.shape) | |
| # hidden = self.key_embedding(hidden) | |
| u = torch.bmm(hidden.float(), key_embed.transpose(1, 2)) | |
| u = u / self.scale | |
| exp_u = torch.exp(u) | |
| # print ('exp_u',exp_u.shape) | |
| delta_exp_u = torch.mul(exp_u.float(), mask_matrix.float()) | |
| sum_delta_exp_u = torch.stack([torch.sum(delta_exp_u, 2)] * delta_exp_u.shape[2], 2) | |
| p = torch.div(delta_exp_u, sum_delta_exp_u + 1e-10) | |
| # print ('exp_u',p.shape)(9,256,256) | |
| # embedding_val = value_embed.permute(3, 0, 1, 2) | |
| o = torch.mul(p.float(), value_embed.float()) | |
| # print (o.shape) | |
| # o = o.permute(1, 2, 3, 0) | |
| # o = torch.sum(o, 2) | |
| # aspect_len = (o != 0).sum(dim=1) | |
| # o = o.float().sum(dim=1) | |
| # avg_o = torch.div(o, aspect_len) | |
| return o # avg_o.type_as(hidden) | |
| class REModel(nn.Module): | |
| def __init__(self): | |
| super(REModel, self).__init__() | |
| self.bert = BertModel.from_pretrained(BERT_PATH) | |
| for param in self.bert.parameters(): | |
| param.requires_grad = True | |
| self.linear = nn.Linear(768, 768) | |
| self.relu = nn.ReLU() | |
| self.sub_output = nn.Linear(768, 2) | |
| self.suopand = nn.Linear(1024, 768) | |
| self.cat_output = nn.Linear(1024, 768) | |
| self.obj_output = nn.Linear(768, len(predicate2id) * 2) | |
| self.sub_pos_emb = nn.Embedding(256, 768) # subject位置embedding | |
| self.layernorm = BertLayerNorm(768, eps=1e-12) | |
| # self.GCN_model = GCNClassifier(opt, emb_matrix=None) | |
| self.GRU = GRUnet(23923, 768, 1024, 6, 768) | |
| # self.CRF_S = CRF_S(768, 16, if_bias=True) | |
| # self.LSTM_CRF = LSTM_CRF(23922, 16, 768, 768, 1, 0.5, large_CRF=True) | |
| self.biaffine = BiaffineTagger(768, 2) | |
| # self.GCN = GCN(hidden_size=768) | |
| self.attention2 = Attention2(hidden_dim=768) | |
| self.gcu1 = GraphConv1(batch=args.batch1, h=[16, 32, 64, 128, 256], w=[16, 32, 64, 128, 256], d=[768, 512], | |
| V=[2, 4, 8, 32], outfeatures=[256, 128]) | |
| # self.gcu2 = GraphConv2(batch = args.batch2, h=[16,32,64,128,256], w=[16,32,64,128,256], d=[768,512], V=[2,4,8,32],outfeatures=[256,128]) | |
| self.cov = nn.Conv2d(768, 768, 1) | |
| self.GCN_model = GCNClassifier(opt, emb_matrix=None) | |
| self.emb = nn.Embedding(23923, 768) | |
| self.emb1 = nn.Embedding(37, 256) | |
| self.keyvalue = KeyValueMemoryNetwork(23923, 23923, 768) | |
| # self.apnb = APNB(in_channels=768, out_channels=768, key_channels=256, value_channels=256,dropout=0.05, sizes=([1])) | |
| def forward(self, token_ids, seg_ids, sub_ids=None): | |
| out, _ = self.bert(token_ids, token_type_ids=seg_ids, | |
| output_all_encoded_layers=False) # [batch_size, maxlen, size] | |
| # print ("1",out.shape) | |
| out = self.attention2(out) | |
| # print("1", out.shape) | |
| sub_preds = self.sub_output(out) # [batch_size, maxlen, 2] | |
| sub_preds = torch.sigmoid(sub_preds) | |
| # sub_preds = sub_preds ** 2 | |
| if sub_ids is None: | |
| return sub_preds | |
| # print(sub_ids) | |
| # print(sub_ids[:, :1]) | |
| # 融入subject特征信息 | |
| sub_pos_start = self.sub_pos_emb(sub_ids[:, :1]) # 取主实体首位置 | |
| sub_pos_end = self.sub_pos_emb(sub_ids[:, 1:]) # [batch_size, 1, size] #取主实体尾位置 | |
| # print(sub_pos_start) | |
| sub_id1 = sub_ids[:, :1].unsqueeze(-1).repeat(1, 1, out.shape[-1]) # subject开始的位置id 重复字编码次数 | |
| # print (sub_id1) | |
| sub_id2 = sub_ids[:, 1:].unsqueeze(-1).repeat(1, 1, out.shape[-1]) # [batch_size, 1, size] | |
| sub_start = torch.gather(out, 1, sub_id1) # 按照sub_id1位置索引去找bert编码后的值,在列维度进行索引 | |
| # print(sub_start.shape) | |
| sub_end = torch.gather(out, 1, sub_id2) # [batch_size, 1, size] | |
| sub_start = sub_pos_start + sub_start # 位置编码向量+bert字编码向量 | |
| sub_end = sub_pos_end + sub_end | |
| out1 = out + sub_start + sub_end | |
| out1 = torch.reshape(out1, (-1, 16, 16, 768)) | |
| # print ('out1:',out1.shape) | |
| out1 = out1.permute(0, 3, 1, 2) | |
| # print(out1.shape) | |
| # out1 = HGT(in_channels=1, hidden_channels=5, out_channels=2, n_layers=2, n_heads=3)(out1) | |
| out1 = RGCN(in_channels=1, hideden_channels=5, out_channels=2, n_layers=2, dropout=0.5)(out1) | |
| # print(1)out1 = RGCN(in_channels=1, hideden_channels=5, out_channels=2, n_layers=2, dropout=0.5)(out1) | |
| # print(out1.shape) | |
| # print(1) | |
| # print(out1.shape) | |
| # if out1.shape[0] == args.batch1: | |
| # out1 = self.gcu1(out1) | |
| # # word_re_embed,_ = self.LSTM_CRF(inputs[0],hidden=None,t = True) | |
| # else: | |
| # out1 = GraphConv2(batch=out1.shape[0], h=[16, 32, 64, 128, 256], w=[16, 32, 64, 128, 256], d=[768, 512], | |
| # V=[2, 4, 8, 32], outfeatures=[256, 128])(out1) | |
| # # word_re_embed,_ = LSTM_CRF1(23922, 16, 768, 768, 1, 0.5, large_CRF=True, t = out1.shape[0]).to(DEVICE)(inputs[0],hidden=None) | |
| # print ('out1_',out1.shape) | |
| out1 = self.cov(out1) | |
| # out1 = self.apnb(out1) | |
| # out = out.permute(0,2,3,1) | |
| # print (out.shape) | |
| # b, c, h, w = out1.shape | |
| out1 = rearrange(out1, 'b c h w -> b c (h w)') | |
| out1 = out1.permute(0, 2, 1) | |
| # out1 = torch.cat((out1,pooling_output),dim=1) | |
| out1 = self.layernorm(out1) | |
| out1 = F.dropout(out1, p=0.5, training=self.training) | |
| # print(2) | |
| # print(out1.shape) | |
| output = self.relu(self.linear(out1)) | |
| output = F.dropout(output, p=0.4, training=self.training) | |
| output = self.obj_output(output) # [batch_size, maxlen, 2*plen] | |
| # print(3) | |
| # print(output.shape) | |
| ###### | |
| # logits_output = torch.unsqueeze(logits, dim = 1) | |
| # final_output = logits_output + output | |
| output = torch.sigmoid(output) | |
| # output = output ** 2 | |
| obj_preds = output.view(-1, output.shape[1], len(predicate2id), 2) | |
| return sub_preds, obj_preds | |
| net = REModel().to(DEVICE) | |
| print(DEVICE) | |
| optimizer = torch.optim.Adam(net.parameters(), lr=1e-5) | |
| def get_long_tensor(tokens_list, batch_size): | |
| """ Convert list of list of tokens to a padded LongTensor. """ | |
| token_len = max(len(x) for x in tokens_list) | |
| tokens = torch.LongTensor(batch_size, token_len).fill_(0) | |
| for i, s in enumerate(tokens_list): | |
| tokens[i, :len(s)] = torch.LongTensor(s) | |
| return tokens | |
| class ValidDataset(Dataset): | |
| def __init__(self, data): | |
| self.data = data | |
| def __getitem__(self, i): | |
| t = self.data[i] | |
| # word_input, center_word = [],[] | |
| # print (t['triple_list']) | |
| if len(t['text']) > 254: | |
| t['text'] = t['text'][:254] | |
| x = tokenizer.tokenize(t['text']) | |
| x = ["[CLS]"] + x + ["[SEP]"] | |
| token_ids = tokenizer.convert_tokens_to_ids(x) | |
| seg_ids = [0] * len(token_ids) | |
| assert len(token_ids) == len(t['text']) + 2 | |
| token_ids = torch.LongTensor(self.sequence_padding(token_ids, maxlen=maxlen)) | |
| seg_ids = torch.LongTensor(self.sequence_padding(seg_ids, maxlen=maxlen)) | |
| # tri = t['triple_list'] | |
| # print('tri',tri) | |
| ''' | |
| return {'token_ids':token_ids, | |
| 'seg_ids':seg_ids, | |
| 'text':t['text'], | |
| 'triple_list':t['triple_list']} | |
| ''' | |
| # return token_ids, seg_ids, list(t['text']), list(t['triple_list']) | |
| return token_ids, seg_ids, t | |
| def __len__(self): | |
| data_len = len(self.data) | |
| return data_len | |
| def sequence_padding(self, x, maxlen, padding=0): | |
| output = np.concatenate([x, [padding] * (maxlen - len(x))]) if len(x) < maxlen else np.array(x[:maxlen]) | |
| return output | |
| valid_dataset = ValidDataset(valid_data) | |
| valid_loader = DataLoader(dataset=valid_dataset, batch_size=args.batch2, shuffle=False, drop_last=True) | |
| def extract_spoes(data, model, device): | |
| ''' | |
| """抽取三元组""" | |
| if len(text) > 254: | |
| text = text[:254] | |
| tokens = tokenizer.tokenize(text) | |
| tokens = ["[CLS]"] + tokens + ["[SEP]"] | |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
| assert len(token_ids) == len(text) + 2 | |
| seg_ids = [0] * len(token_ids) | |
| ''' | |
| # print (data[2]) | |
| # print (data['text']) | |
| # token_ids = data['token_ids'] | |
| token_ids = data[0] | |
| # seg_ids = data['seg_ids'] | |
| seg_ids = data[1] | |
| # import pdb | |
| # pdb.set_trace() | |
| sub_preds = model(token_ids.to(device), | |
| seg_ids.to(device)) | |
| sub_preds = sub_preds.detach().cpu().numpy() # [1, maxlen, 2] | |
| # print(sub_preds[0,]) | |
| start = np.where(sub_preds[0, :, 0] > 0.5)[0] | |
| end = np.where(sub_preds[0, :, 1] > 0.5)[0] | |
| # print(start, end) | |
| tmp_print = [] | |
| subjects = [] | |
| for i in start: | |
| j = end[end >= i] | |
| if len(j) > 0: | |
| j = j[0] | |
| subjects.append((i, j)) | |
| tmp_print.append(data[2][i - 1: j]) | |
| if subjects: | |
| spoes = [] | |
| # print (len(subjects)) #只有2 | |
| token_ids = np.repeat(token_ids, len(subjects), 0) # [len_subjects, seqlen] | |
| # print(token_ids.shape) | |
| seg_ids = np.repeat(seg_ids, len(subjects), 0) | |
| subjects = np.array(subjects) # [len_subjects, 2] | |
| # 传入subject 抽取object和predicate | |
| _, object_preds = model(token_ids.to(device), | |
| seg_ids.to(device), | |
| torch.LongTensor(subjects).to(device)) | |
| object_preds = object_preds.detach().cpu().numpy() | |
| # print(object_preds.shape) | |
| for sub, obj_pred in zip(subjects, object_preds): | |
| # obj_pred [maxlen, 55, 2] | |
| start = np.where(obj_pred[:, :, 0] > 0.3) | |
| end = np.where(obj_pred[:, :, 1] > 0.3) | |
| for _start, predicate1 in zip(*start): | |
| for _end, predicate2 in zip(*end): | |
| if _start <= _end and predicate1 == predicate2: | |
| spoes.append( | |
| ((sub[0] - 1, sub[1] - 1), predicate1, (_start - 1, _end - 1)) | |
| ) | |
| break | |
| # print (spoes) | |
| return [(data[2][s[0]:s[1] + 1], id2predicate[str(p)], data[2][o[0]:o[1] + 1]) for s, p, o in spoes] | |
| else: | |
| return [] | |
| def evaluate(valid_data, valid_load, model, device): | |
| """评估函数,计算f1、precision、recall | |
| """ | |
| # F1 = [] | |
| # P = [] | |
| # Re = [] | |
| X, Y, Z = 1e-10, 1e-10, 1e-10 | |
| f = open("./data/CMED/dev_pred.json", 'w', encoding='utf-8') | |
| pbar = tqdm() | |
| # for d in data: | |
| # with torch.no_grad: | |
| # print (type(valid_load)) | |
| # return | |
| for idx, data in tqdm(enumerate(valid_load)): | |
| input = data[0], data[1], data[2]['text'][0] | |
| # print(input) | |
| # input = data[0], data[1], valid_data[idx]['text'], valid_data[idx]['triple_list'] | |
| R = extract_spoes(input, model, device) | |
| # print ('R:',R) | |
| T = valid_data[idx]['triple_list'] | |
| ''' | |
| tri = data[3] | |
| #tri = tuple(tri) | |
| T = [] | |
| for tris in tri: | |
| temp = tuple() | |
| for i in tris: | |
| temp += i | |
| T.append(temp) | |
| ''' | |
| # print ('tri:',tri) | |
| # print ('tri:',temp_tri) | |
| R = set(R) | |
| # print ('R',R) | |
| T = set(T) | |
| # print('T', R) | |
| X += len(R & T) | |
| Y += len(R) | |
| Z += len(T) | |
| f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z | |
| # F1.append(f1) | |
| # P.append(precision) | |
| # Re.append(recall) | |
| pbar.update() | |
| pbar.set_description( | |
| 'F1: %.5f, \tPrecision: %.5f, \tRecall: %.5f' % (f1, precision, recall) | |
| ) | |
| if f1 > 0.5: | |
| s = json.dumps({ | |
| 'text': valid_data[idx]['text'], | |
| 'triple_list': list(T), | |
| 'triple_list_pred': list(R), | |
| 'new': list(R - T), | |
| 'lack': list(T - R), | |
| }, ensure_ascii=False, indent=4) | |
| f.write(s + '\n') | |
| pbar.close() | |
| f.close() | |
| return f1, precision, recall | |
| ''' | |
| def evaluate(data, model, device): | |
| """评估函数,计算f1、precision、recall | |
| """ | |
| X, Y, Z = 1e-10, 1e-10, 1e-10 | |
| f = open("/home/jason/EXP/NLP/triple_test/data/CMED/dev_pred.json", 'w', encoding='utf-8') | |
| pbar = tqdm() | |
| for d in data: | |
| R = extract_spoes(d['text'], model, device) | |
| T = d['triple_list'] | |
| #print (T) | |
| R = set(R) | |
| #print ('R',R) | |
| T = set(T) | |
| #T = set() | |
| #for item in T1: | |
| # for i in item: | |
| # T.add(i) | |
| #print ('T',T) | |
| X += len(R & T) | |
| Y += len(R) | |
| Z += len(T) | |
| f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z | |
| pbar.update() | |
| pbar.set_description( | |
| 'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall) | |
| ) | |
| if f1 > 0.5: | |
| s = json.dumps({ | |
| 'text': d['text'], | |
| 'triple_list': list(T), | |
| 'triple_list_pred': list(R), | |
| 'new': list(R - T), | |
| 'lack': list(T - R), | |
| }, ensure_ascii=False, indent=4) | |
| f.write(s + '\n') | |
| pbar.close() | |
| f.close() | |
| return f1, precision, recall | |
| ''' | |
| import sys | |
| import os | |
| class Logger(object): | |
| def __init__(self,fileN="default.log"): | |
| self.terminal = sys.stdout | |
| self.log = open(fileN,"a") | |
| def write(self,message): | |
| self.terminal.write(message) | |
| self.log.write(message) | |
| def flush(self): | |
| pass | |
| # def FocalLoss(input, target ,gamma=2,weight=None,reduction='mean'): | |
| # # def __init__(self,gamma=2,weight=None,reduction='mean'): | |
| # # super(FocalLoss, self).__init__() | |
| # # self.gamma = gamma | |
| # # self.weight = weight | |
| # # self.reduction = reduction | |
| # # def forward(self, output, target): | |
| # out_target = torch.stack([input[i,t] for i.type(torch.bool),t.type(torch.bool) in enumerate(target)]) | |
| # probs = torch.sigmoid(out_target) | |
| # focal_weight = torch.pow(1-probs,gamma=2) | |
| # | |
| # ce_loss = F.cross_entropy(input,target,weight=None,reduction='none') | |
| # focal_loss = focal_weight*ce_loss | |
| # | |
| # if reduction == 'mean': | |
| # focal_loss = (focal_loss/focal_weight.sum()).sum() | |
| # elif reduction == 'sum': | |
| # focal_loss = focal_loss.sum() | |
| # | |
| # return focal_loss | |
| # class FocalLoss(nn.Module): | |
| # | |
| # def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7): | |
| # super(FocalLoss, self).__init__() | |
| # self.gamma = gamma | |
| # self.eps = eps | |
| # self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction) | |
| # | |
| # def forward(self, input, target): | |
| # logp = self.ce(input, target) | |
| # print('logp',logp) | |
| # p = torch.exp(-logp) | |
| # loss = (1 - p) ** self.gamma * logp | |
| # return loss.mean() | |
| # def Dice_loss(inputs,target,beta=1,smooth=1e-5): | |
| # n,c,h = inputs.size() | |
| # nt,ht,wt = target.size() | |
| # if n!= nt and h!=wt: | |
| # inputs = F.interpolate(inputs,size=(ht,wt),mode="bilinear",align_corners=True) | |
| # temp_imputs = torch.softmax(inputs.transpose(1,2).transpose(2,3).contiguous().view(n,-1,c),-1) | |
| # temp_target = target.view(n,-1,ct) | |
| # | |
| # #...................... | |
| # #ice loss | |
| # #...................... | |
| # tp = torch.sum(temp_target[...,:-1]*temp_imputs,axis=[0,1]) | |
| # fp = torch.sum(temp_imputs,axis=[0,1])-tp | |
| # fn = torch.sum(temp_target[...,:-1],axis=[0,1])-tp | |
| # | |
| # score = ((1+beta**2)*tp+smooth)/((1+beta**2)*tp+beta**2*fn+fp+smooth) | |
| # dice_loss = 1-torch.mean(score) | |
| # return dice_loss | |
| # def dice_coeff(pred, target): | |
| # smooth = 1. | |
| # num = pred.size(0) | |
| # m1 = pred.view(num, -1) # Flatten | |
| # m2 = target.view(num, -1) # Flatten | |
| # intersection = (m1 * m2).sum() | |
| # | |
| # return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth) | |
| # def train(model, train_loader, optimizer, epoches, device): | |
| # # model.train() | |
| # torch.backends.cudnn.enabled = False | |
| # for _ in range(epoches): | |
| # print('epoch: ', _ + 1) | |
| # start = time.time() | |
| # train_loss_sum = 0.0 | |
| # for batch_idx, x in tqdm(enumerate(train_loader)): | |
| # # token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) | |
| # token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) | |
| # # tokens_words, masks_out, head = x[5].to(device), x[6].to(device), x[7].to(device) | |
| # # print (token_ids.shape) | |
| # | |
| # mask = (token_ids > 0).float() | |
| # mask = mask.to(device) # zero-mask | |
| # sub_labels, obj_labels = x[3].float().to(device), x[4].float().to(device) | |
| # sub_preds, obj_preds = model(token_ids, seg_ids, sub_ids) | |
| # # (batch_size, maxlen, 2), (batch_size, maxlen, 55, 2) | |
| # | |
| # #计算loss | |
| # smooth = 1 | |
| # intersection = sub_labels * sub_preds | |
| # sub_dice_eff = (2 * intersection.sum(1) + smooth) / (sub_preds.sum(1) + sub_labels.sum(1) + smooth) | |
| # # print(sub_dice_eff) | |
| # smooth = 1 | |
| # # intersection2 = obj_labels * obj_preds | |
| # # obj_dice_eff = (2 * intersection2.sum(1) + smooth) / (obj_preds.sum(1) + obj_labels.sum(1) + smooth) | |
| # # # print(obj_dice_eff) | |
| # # beta = 1 | |
| # # smooth = 1e-5 | |
| # # p = torch.sigmoid(sub_preds) | |
| # # tp = torch.sum(sub_labels[..., :-1] * p, axis=[0, 1]) | |
| # # # print(tp) | |
| # # fp = torch.sum(p, axis=[0, 1]) - tp | |
| # # # print(fp) | |
| # # fn = torch.sum(sub_labels[..., :-1], axis=[0, 1]) - tp | |
| # # score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) | |
| # # sub_dice_loss = 1-torch.mean(score) | |
| # # # print(sub_dice_loss) | |
| # # ce_loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') # [bs, ml, 2] | |
| # # p_t = p*sub_labels + (1-p)*(1-sub_labels) | |
| # # gamma = 2 | |
| # # loss_sub= ce_loss_sub*((1-p_t)**gamma) | |
| # | |
| # q = torch.sigmoid(obj_preds) | |
| # # print(q) | |
| # tp = torch.sum(obj_labels[..., :-1] * q, axis=[0, 1]) | |
| # # print(tp) | |
| # fp = torch.sum(q, axis=[0, 1]) - tp | |
| # # print(fp) | |
| # fn = torch.sum(obj_labels[..., :-1], axis=[0, 1]) - tp | |
| # score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) | |
| # obj_dice_loss = 1 - torch.mean(score) | |
| # # print(obj_dice_loss) | |
| # loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) | |
| # loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) | |
| # # print('loss_sub:',loss_sub) | |
| # q = torch.sigmoid(obj_preds) | |
| # ce_loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 55, 2] | |
| # q_t = q * obj_labels + (1 - q) * (1 - obj_labels) | |
| # gamma = 2 | |
| # loss_obj = ce_loss_obj * ((1 - q_t) ** gamma) | |
| # loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) | |
| # loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) | |
| # loss = loss_sub + loss_obj | |
| # loss_sub = dice_coeff(sub_preds, sub_labels) | |
| # loss_obj = dice_coeff(obj_preds, obj_labels) | |
| # loss = loss_sub+ loss_obj | |
| # # 计算loss | |
| # loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') # [bs, ml, 2] | |
| # loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) | |
| # loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) | |
| # loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 55, 2] | |
| # loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) | |
| # loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) | |
| # loss = loss_sub + loss_obj | |
| # optimizer.zero_grad() | |
| # | |
| # loss.backward() | |
| # optimizer.step() | |
| # train_loss_sum += loss.cpu().item() | |
| # if (batch_idx + 1) % 31 == 0: | |
| # print('loss: ', train_loss_sum / (batch_idx + 1), 'time: ', time.time() - start) | |
| # | |
| # torch.save(net.state_dict(), "./checkpoints/best_re.pth") | |
| # | |
| # with torch.no_grad(): | |
| # # model.eval() | |
| # # print (valid_data[:5]) | |
| # val_f1, pre, rec = evaluate(valid_data, valid_loader, net, device) | |
| # | |
| # print('F1_score: %.5f, Precision: %.5f, Recall: %.5f' % (val_f1, pre, rec)) | |
| # # sys.stdout = Logger('./datalog.txt') | |
| # re = tuple((val_f1, pre, rec)) | |
| # with open("./result_Dice_loss.json","a",encoding='utf-8') as f: | |
| # json.dump(re,f,indent=4,ensure_ascii=True) | |
| # # print("f1, pre, rec: ", val_f1, pre, rec) | |
| class FocalLoss(nn.Module): | |
| def __init__(self, alpha=1, gamma=2, size_average=True, ignore_index=255): | |
| super(FocalLoss, self).__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.ignore_index = ignore_index | |
| self.size_average = size_average | |
| def forward(self, inputs, targets): | |
| ce_loss = F.class_entropy(inputs, targets, reduction='none',ignore_index=self.ignore_index) | |
| pt = torch.exp(-ce_loss) | |
| focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss | |
| return focal_loss.sum() | |
| def train(model, train_loader, optimizer, epoches, device): | |
| # model.train() | |
| torch.backends.cudnn.enabled = False | |
| list = [] | |
| for _ in range(epoches): | |
| # f = open("./test.txt", 'w+', encoding='utf-8') | |
| print('epoch: ', _ + 1) | |
| start = time.time() | |
| train_loss_sum = 0.0 | |
| for batch_idx, x in tqdm(enumerate(train_loader)): | |
| # token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) | |
| token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) | |
| # tokens_words, masks_out, head = x[5].to(device), x[6].to(device), x[7].to(device) | |
| # print (token_ids.shape) | |
| mask = (token_ids > 0).float() | |
| mask = mask.to(device) # zero-mask | |
| sub_labels, obj_labels = x[3].float().to(device), x[4].float().to(device) | |
| sub_preds, obj_preds = model(token_ids, seg_ids, sub_ids) | |
| # (batch_size, maxlen, 2), (batch_size, maxlen, 55, 2) | |
| # 计算loss | |
| smooth = 1 | |
| intersection2 = obj_labels * obj_preds | |
| obj_dice_eff = (2 * intersection2.sum(1) + smooth) / (obj_preds.sum(1) + obj_labels.sum(1) + smooth) | |
| # print(obj_dice_eff) | |
| beta = 1 | |
| smooth = 1e-5 | |
| p = torch.sigmoid(sub_preds) | |
| tp = torch.sum(sub_labels[..., :-1] * p, axis=[0, 1]) | |
| # print(tp) | |
| fp = torch.sum(p, axis=[0, 1]) - tp | |
| # print(fp) | |
| fn = torch.sum(sub_labels[..., :-1], axis=[0, 1]) - tp | |
| score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) | |
| sub_dice_loss = 1-torch.mean(score) | |
| # print(sub_dice_loss) | |
| ce_loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') # [bs, ml, 2] | |
| p_t = p*sub_labels + (1-p)*(1-sub_labels) | |
| gamma = 2 | |
| loss_sub= ce_loss_sub*((1-p_t)**gamma) | |
| q = torch.sigmoid(obj_preds) | |
| # print(q) | |
| tp = torch.sum(obj_labels[..., :-1] * q, axis=[0, 1]) | |
| # print(tp) | |
| fp = torch.sum(q, axis=[0, 1]) - tp | |
| # print(fp) | |
| fn = torch.sum(obj_labels[..., :-1], axis=[0, 1]) - tp | |
| score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) | |
| obj_dice_loss = 1 - torch.mean(score) | |
| # print(obj_dice_loss) | |
| loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) | |
| loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) | |
| # print('loss_sub:',loss_sub) | |
| q = torch.sigmoid(obj_preds) | |
| ce_loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 55, 2] | |
| q_t = q * obj_labels + (1 - q) * (1 - obj_labels) | |
| gamma = 2 | |
| loss_obj = ce_loss_obj * ((1 - q_t) ** gamma) | |
| loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) | |
| loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) | |
| # jiaochashang | |
| # loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') # [bs, ml, 2] | |
| # loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) | |
| # loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) | |
| # loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 55, 2] | |
| # loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) | |
| # loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) | |
| loss = loss_sub + loss_obj | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| train_loss_sum += loss.cpu().item() | |
| if (batch_idx + 1) % 31 == 0: | |
| print('loss: ', train_loss_sum / (batch_idx + 1), 'time: ', time.time() - start) | |
| list.append(train_loss_sum / (batch_idx + 1)) | |
| torch.save(net.state_dict(), "./checkpoints/best_re.pth") | |
| with torch.no_grad(): | |
| # model.eval() | |
| # print (valid_data[:5]) | |
| val_f1, pre, rec = evaluate(valid_data, valid_loader, net, device) | |
| print('F1_score: %.5f, Precision: %.5f, Recall: %.5f' % (val_f1, pre, rec)) | |
| # print("f1, pre, rec: ", val_f1, pre, rec) | |
| print(list) | |
| # LOGGER = set_logging(name='test', level=logging.INFO, verbose=True) | |
| if __name__ == '__main__': | |
| # net.load_state_dict(torch.load("RE/data/bert_re.pth")) | |
| train(net, train_loader, optimizer, 600, DEVICE) |