import json, time import numpy as np from tqdm import tqdm from torch.utils.data import Dataset, DataLoader from pytorch_pretrained_bert import BertModel, BertTokenizer import logging import torch import torch.nn as nn import torch.nn.functional as F import os # import paddle # import paddle.nn.functional as F import unicodedata from pyhanlp import * from torch_geometric.nn import RGCNConv from gcn import * from graphModule import * from einops import rearrange from config import args from biaffine import * os.environ["CUDA_VISIBLE_DEVICES"] = "1" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # DEVICE = torch.device("cpu") # BERT_PATH = "./SpanBERT/Spanbert-base-cased" # BERT_PATH = "./chinese_roberta_wwm_ext_pytorch" BERT_PATH = "./bert" maxlen = 256 ####256 def load_data(filename): D = [] with open(filename) as data_file: data = data_file.read() # print(data) data = json.loads(data) for item in data: d = {'text': item['text'], 'triple_list': []} for sub_item in item['triple_list']: d['triple_list'].append( (sub_item[0], sub_item[1], sub_item[2]) ) D.append(d) return D # 加载数据集 train_data = load_data('./data/CMED/train_triples.json') valid_data = load_data('./data/CMED/dev_triples.json') def search(pattern, sequence): """从sequence中寻找子串pattern 如果找到,返回第一个下标;否则返回-1。 """ n = len(pattern) for i in range(len(sequence)): if sequence[i:i + n] == pattern: return i return -1 train_data_new = [] # 创建新的训练集,把结束位置超过250的文本去除,可见并没有去除多少 for data in tqdm(train_data): # print (data) flag = 1 for s, p, o in data['triple_list']: s_begin = search(s, data['text']) o_begin = search(o, data['text']) if s_begin == -1 or o_begin == -1 or s_begin + len(s) > 256 or o_begin + len(o) > 256: flag = 0 break if flag == 1: train_data_new.append(data) print("去除大于250的文本:\t", len(train_data_new)) # 读取schema ''' with open('RE/data/schema.json', encoding='utf-8') as f: id2predicate, predicate2id, n = {}, {}, 0 predicate2type = {} for l in f: l = json.loads(l) predicate2type[l['predicate']] = (l['subject_type'], l['object_type']) for k, _ in sorted(l['object_type'].items()): key = l['predicate'] + '_' + k id2predicate[n] = key predicate2id[key] = n n += 1 print(len(predicate2id)) ''' with open('./data/CMED/rel2id.json', encoding='utf-8') as f: # id2predicate, predicate2id, n = {}, {}, 0 l = json.load(f) id2predicate = l[0] predicate2id = l[1] print("关系类型数量:\t", len(predicate2id)) class OurTokenizer(BertTokenizer): def tokenize(self, text): R = [] for c in text: if c in self.vocab: R.append(c) elif self._is_whitespace(c): R.append('[unused1]') else: R.append('[UNK]') return R def _is_whitespace(self, char): if char == " " or char == "\t" or char == "\n" or char == "\r": return True cat = unicodedata.category(char) if cat == "Zs": return True return False # 初始化分词器 tokenizer = OurTokenizer(vocab_file="./chinese_roberta_wwm_ext_pytorch/vocab.txt") ######依存句法树+分词 def seg_pos(text): head, seg_word, Dep_rel, str_le = [], [], [], [] # tree = HanLP.parseDependency(text) parser = JClass('com.hankcs.hanlp.dependency.nnparser.NeuralNetworkDependencyParser')() parser.enableDeprelTranslator(False) tree = parser.parse(text) for word in tree.iterator(): # 通过dir()可以查看sentence的方法 head.append(word.HEAD.ID) for i in word.LEMMA.split(): str_le.append(i) seg_word.append(word.LEMMA) Dep_rel.append(word.DEPREL) return head, seg_word, Dep_rel, str_le def out_list_word(seg_word): temp = "" for word in seg_word: temp += " " + word text_out = temp.lstrip(" ") return text_out def map_to_ids(tokens, vocab): ids = [vocab[t] if t in vocab.keys() else 0 for t in tokens] return ids def vocab_json(): vocab_out = json.load(open("./vacab.json")) return vocab_out def dep_json(): dep_out = json.load(open("./dep.json")) return dep_out class TorchDataset(Dataset): def __init__(self, data): self.data = data def __getitem__(self, i): t = self.data[i] # print ('t!!!',t) ######{'text': '齐志江,男,汉族,中共党员,大学学历', 'triple_list': [('齐志江', '民族', '汉族')]} x = tokenizer.tokenize(t['text']) # print (x) x = ["[CLS]"] + x + ["[SEP]"] token_ids = tokenizer.convert_tokens_to_ids(x) seg_ids = [0] * len(token_ids) assert len(token_ids) == len(t['text']) + 2 spoes = {} for s, p, o in t['triple_list']: s = tokenizer.tokenize(s) s = tokenizer.convert_tokens_to_ids(s) p = predicate2id[p] o = tokenizer.tokenize(o) o = tokenizer.convert_tokens_to_ids(o) s_idx = search(s, token_ids) o_idx = search(o, token_ids) if s_idx != -1 and o_idx != -1: s = (s_idx, s_idx + len(s) - 1) o = (o_idx, o_idx + len(o) - 1, p) # 同时预测o和p if s not in spoes: spoes[s] = [] spoes[s].append(o) # print(spoes) {(2, 5): [(13, 15, 31), (19, 21, 38), (29, 31, 45)]} if spoes: sub_labels = np.zeros((len(token_ids), 2)) # print (sub_labels) for s in spoes: # print (s) #(2, 5) # print (sub_labels) # print(s[0]) sub_labels[s[0], 0] = 1 sub_labels[s[1], 1] = 1 # 随机选一个subject start, end = np.array(list(spoes.keys())).T start = np.random.choice(start) # print (start) end = sorted(end[end >= start])[0] sub_ids = (start, end) obj_labels = np.zeros((len(token_ids), len(predicate2id), 2)) for o in spoes.get(sub_ids, []): # print (o) obj_labels[o[0], o[2], 0] = 1 obj_labels[o[1], o[2], 1] = 1 token_ids = self.sequence_padding(token_ids, maxlen=maxlen) seg_ids = self.sequence_padding(seg_ids, maxlen=maxlen) sub_labels = self.sequence_padding(sub_labels, maxlen=maxlen, padding=np.zeros(2)) sub_ids = np.array(sub_ids) obj_labels = self.sequence_padding(obj_labels, maxlen=maxlen, padding=np.zeros((len(predicate2id), 2))) return (torch.LongTensor(token_ids), torch.LongTensor(seg_ids), torch.LongTensor(sub_ids), torch.LongTensor(sub_labels), torch.LongTensor(obj_labels)) def __len__(self): data_len = len(self.data) return data_len def sequence_padding(self, x, maxlen, padding=0): output = np.concatenate([x, [padding] * (maxlen - len(x))]) if len(x) < maxlen else np.array(x[:maxlen]) return output train_dataset = TorchDataset(train_data_new) train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch1, shuffle=True, drop_last=True) # for i, x in enumerate(train_loader): # print([_.shape for _ in x]) # if i == 10: # break class GRUnet(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, layer_dim, output_dim): """ vocab_size: 词典长度,也就是嵌入矩阵的行数 embedding_dim: 词向量的维度,也就是嵌入矩阵的列数,也是W的列数,也是输入GRU的x_t的维度 hidden_dim: GRU神经元的个数,也就是W的行数 layer_dim: GRU的层数 output_dim: 隐藏层输出的维度 """ super(GRUnet, self).__init__() # 嵌入层 self.embedding = nn.Embedding(vocab_size, embedding_dim) # GRU + 全连接 self.gru = nn.GRU(embedding_dim, hidden_dim, layer_dim, batch_first=True) self.fc1 = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.Dropout(0.5), torch.nn.ReLU(), nn.Linear(hidden_dim, output_dim) ) def forward(self, x): # x : [bacth, time_step, vocab_size] embeds = self.embedding(x) # print(embeds.shape) # embeds : [batch, time_step, embedding_dim] r_out, h_n = self.gru(embeds, None) # print (r_out.shape) # r_out : [batch, time_step, hidden_dim] # out = self.fc1(r_out[:, -1, :]) out = self.fc1(r_out) # out : [batch, time_step, output_dim] return out class GCN(nn.Module): def __init__(self, hidden_size=768): super(GCN, self).__init__() self.hidden_size = hidden_size # self.fc = nn.Linear(self.hidden_size, self.hidden_size // 2) def forward(self, x, adj, is_relu=True): out = x # Make permutations for matrix multiplication # Assuming batch_first = False # print (out.shape) # out = out.permute(1, 0, 2) # to: batch, seq_len, hidden # adj = adj.permute(2, 0, 1) # to: batch, seq_len, seq_len out = torch.bmm(adj, out) # .permute(1, 0, 2) # to: seq_len, batch, hidden if is_relu == True: out = F.relu(out) return out class RGCN(torch.nn.Module): def __init__(self,in_channels,hideden_channels,out_channels,n_layers=2,dropout=0.5): super().__init__() self.convs = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() self.relu = F.relu self.dropout = dropout self.convs.append(RGCNConv(in_channels,hideden_channels,num_relations=24,num_bases=1)) for i in range(n_layers-2): self.convs.append(RGCNConv(hideden_channels,hideden_channels,num_relations=24,num_bases=1)) self.norms.append(torch.nn.BatchNormld(hideden_channels)) self.convs.append(RGCNConv(hideden_channels,out_channels,num_relations=24,num_bases=1)) def forward(self, x, edge_index=2561,edge_type=24): for conv ,norm in zip(self.convs, self.norms): x = norm(conv(x,2561,24)) x = F.relu(x) x = F.dropout(x,p=self.dropout,training=self.training) return x class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) # [bs, maxlen, 1] s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class Attention2(nn.Module): """ 1.输入 [batch_size,time_step,hidden_dim] -> Linear、Tanh 2.[batch_size,time_step,hidden_dim] -> transpose 3.[batch_size,hidden_dim,time_step] -> Softmax 4.[batch_size,hidden_dim,time_step] -> mean 5.[batch_size,time_step] -> unsqueeze 5.[batch_size,1,time_step] -> expand 6.[batch_size,hidden_dim,time_step] -> transpose 7.[batch_size,time_step,hidden_dim] """ def __init__(self, hidden_dim): super(Attention2, self).__init__() self.hidden_dim = hidden_dim self.dense = nn.Linear(hidden_dim, hidden_dim) def forward(self, features, mean=True): batch_size, time_step, hidden_dim = features.size() # weight = nn.Tanh()(self.dense(features)) weight = nn.ReLU()(self.dense(features)) # mask给负无穷使得权重为0 mask_idx = torch.sign(torch.abs(features).sum(dim=-1)) mask_idx = mask_idx.unsqueeze(-1).expand(batch_size, time_step, hidden_dim) paddings = torch.ones_like(mask_idx) * (-2 ** 32 + 1) weight = torch.where(torch.eq(mask_idx, 1), weight, paddings) weight = weight.transpose(2, 1) # weight = nn.Softmax(dim=2)(weight) # weight = nn.Sigmoid(weight) if mean: weight = weight.mean(dim=1) weight = weight.unsqueeze(1) weight = weight.expand(batch_size, hidden_dim, time_step) weight = weight.transpose(2, 1) features_attention = weight * features return features_attention class KeyValueMemoryNetwork(nn.Module): def __init__(self, vocab_size, feature_vocab_size, emb_size): super(KeyValueMemoryNetwork, self).__init__() self.key_embedding = nn.Embedding(vocab_size, emb_size, padding_idx=0) self.value_embedding = nn.Embedding(feature_vocab_size, emb_size, padding_idx=0) self.scale = np.power(emb_size, 0.5) def forward(self, key_embed, value_embed, hidden, mask_matrix): # key_embed = self.key_embedding(key_seq) # print (key_embed.shape) # value_embed = self.value_embedding(value_seq) # print (value_embed.shape) # hidden = self.key_embedding(hidden) u = torch.bmm(hidden.float(), key_embed.transpose(1, 2)) u = u / self.scale exp_u = torch.exp(u) # print ('exp_u',exp_u.shape) delta_exp_u = torch.mul(exp_u.float(), mask_matrix.float()) sum_delta_exp_u = torch.stack([torch.sum(delta_exp_u, 2)] * delta_exp_u.shape[2], 2) p = torch.div(delta_exp_u, sum_delta_exp_u + 1e-10) # print ('exp_u',p.shape)(9,256,256) # embedding_val = value_embed.permute(3, 0, 1, 2) o = torch.mul(p.float(), value_embed.float()) # print (o.shape) # o = o.permute(1, 2, 3, 0) # o = torch.sum(o, 2) # aspect_len = (o != 0).sum(dim=1) # o = o.float().sum(dim=1) # avg_o = torch.div(o, aspect_len) return o # avg_o.type_as(hidden) class REModel(nn.Module): def __init__(self): super(REModel, self).__init__() self.bert = BertModel.from_pretrained(BERT_PATH) for param in self.bert.parameters(): param.requires_grad = True self.linear = nn.Linear(768, 768) self.relu = nn.ReLU() self.sub_output = nn.Linear(768, 2) self.suopand = nn.Linear(1024, 768) self.cat_output = nn.Linear(1024, 768) self.obj_output = nn.Linear(768, len(predicate2id) * 2) self.sub_pos_emb = nn.Embedding(256, 768) # subject位置embedding self.layernorm = BertLayerNorm(768, eps=1e-12) # self.GCN_model = GCNClassifier(opt, emb_matrix=None) self.GRU = GRUnet(23923, 768, 1024, 6, 768) # self.CRF_S = CRF_S(768, 16, if_bias=True) # self.LSTM_CRF = LSTM_CRF(23922, 16, 768, 768, 1, 0.5, large_CRF=True) self.biaffine = BiaffineTagger(768, 2) # self.GCN = GCN(hidden_size=768) self.attention2 = Attention2(hidden_dim=768) self.gcu1 = GraphConv1(batch=args.batch1, h=[16, 32, 64, 128, 256], w=[16, 32, 64, 128, 256], d=[768, 512], V=[2, 4, 8, 32], outfeatures=[256, 128]) # self.gcu2 = GraphConv2(batch = args.batch2, h=[16,32,64,128,256], w=[16,32,64,128,256], d=[768,512], V=[2,4,8,32],outfeatures=[256,128]) self.cov = nn.Conv2d(768, 768, 1) self.GCN_model = GCNClassifier(opt, emb_matrix=None) self.emb = nn.Embedding(23923, 768) self.emb1 = nn.Embedding(37, 256) self.keyvalue = KeyValueMemoryNetwork(23923, 23923, 768) # self.apnb = APNB(in_channels=768, out_channels=768, key_channels=256, value_channels=256,dropout=0.05, sizes=([1])) def forward(self, token_ids, seg_ids, sub_ids=None): out, _ = self.bert(token_ids, token_type_ids=seg_ids, output_all_encoded_layers=False) # [batch_size, maxlen, size] # print ("1",out.shape) out = self.attention2(out) # print("1", out.shape) sub_preds = self.sub_output(out) # [batch_size, maxlen, 2] sub_preds = torch.sigmoid(sub_preds) # sub_preds = sub_preds ** 2 if sub_ids is None: return sub_preds # print(sub_ids) # print(sub_ids[:, :1]) # 融入subject特征信息 sub_pos_start = self.sub_pos_emb(sub_ids[:, :1]) # 取主实体首位置 sub_pos_end = self.sub_pos_emb(sub_ids[:, 1:]) # [batch_size, 1, size] #取主实体尾位置 # print(sub_pos_start) sub_id1 = sub_ids[:, :1].unsqueeze(-1).repeat(1, 1, out.shape[-1]) # subject开始的位置id 重复字编码次数 # print (sub_id1) sub_id2 = sub_ids[:, 1:].unsqueeze(-1).repeat(1, 1, out.shape[-1]) # [batch_size, 1, size] sub_start = torch.gather(out, 1, sub_id1) # 按照sub_id1位置索引去找bert编码后的值,在列维度进行索引 # print(sub_start.shape) sub_end = torch.gather(out, 1, sub_id2) # [batch_size, 1, size] sub_start = sub_pos_start + sub_start # 位置编码向量+bert字编码向量 sub_end = sub_pos_end + sub_end out1 = out + sub_start + sub_end out1 = torch.reshape(out1, (-1, 16, 16, 768)) # print ('out1:',out1.shape) out1 = out1.permute(0, 3, 1, 2) # print(out1.shape) # out1 = HGT(in_channels=1, hidden_channels=5, out_channels=2, n_layers=2, n_heads=3)(out1) out1 = RGCN(in_channels=1, hideden_channels=5, out_channels=2, n_layers=2, dropout=0.5)(out1) # print(1)out1 = RGCN(in_channels=1, hideden_channels=5, out_channels=2, n_layers=2, dropout=0.5)(out1) # print(out1.shape) # print(1) # print(out1.shape) # if out1.shape[0] == args.batch1: # out1 = self.gcu1(out1) # # word_re_embed,_ = self.LSTM_CRF(inputs[0],hidden=None,t = True) # else: # out1 = GraphConv2(batch=out1.shape[0], h=[16, 32, 64, 128, 256], w=[16, 32, 64, 128, 256], d=[768, 512], # V=[2, 4, 8, 32], outfeatures=[256, 128])(out1) # # word_re_embed,_ = LSTM_CRF1(23922, 16, 768, 768, 1, 0.5, large_CRF=True, t = out1.shape[0]).to(DEVICE)(inputs[0],hidden=None) # print ('out1_',out1.shape) out1 = self.cov(out1) # out1 = self.apnb(out1) # out = out.permute(0,2,3,1) # print (out.shape) # b, c, h, w = out1.shape out1 = rearrange(out1, 'b c h w -> b c (h w)') out1 = out1.permute(0, 2, 1) # out1 = torch.cat((out1,pooling_output),dim=1) out1 = self.layernorm(out1) out1 = F.dropout(out1, p=0.5, training=self.training) # print(2) # print(out1.shape) output = self.relu(self.linear(out1)) output = F.dropout(output, p=0.4, training=self.training) output = self.obj_output(output) # [batch_size, maxlen, 2*plen] # print(3) # print(output.shape) ###### # logits_output = torch.unsqueeze(logits, dim = 1) # final_output = logits_output + output output = torch.sigmoid(output) # output = output ** 2 obj_preds = output.view(-1, output.shape[1], len(predicate2id), 2) return sub_preds, obj_preds net = REModel().to(DEVICE) print(DEVICE) optimizer = torch.optim.Adam(net.parameters(), lr=1e-5) def get_long_tensor(tokens_list, batch_size): """ Convert list of list of tokens to a padded LongTensor. """ token_len = max(len(x) for x in tokens_list) tokens = torch.LongTensor(batch_size, token_len).fill_(0) for i, s in enumerate(tokens_list): tokens[i, :len(s)] = torch.LongTensor(s) return tokens class ValidDataset(Dataset): def __init__(self, data): self.data = data def __getitem__(self, i): t = self.data[i] # word_input, center_word = [],[] # print (t['triple_list']) if len(t['text']) > 254: t['text'] = t['text'][:254] x = tokenizer.tokenize(t['text']) x = ["[CLS]"] + x + ["[SEP]"] token_ids = tokenizer.convert_tokens_to_ids(x) seg_ids = [0] * len(token_ids) assert len(token_ids) == len(t['text']) + 2 token_ids = torch.LongTensor(self.sequence_padding(token_ids, maxlen=maxlen)) seg_ids = torch.LongTensor(self.sequence_padding(seg_ids, maxlen=maxlen)) # tri = t['triple_list'] # print('tri',tri) ''' return {'token_ids':token_ids, 'seg_ids':seg_ids, 'text':t['text'], 'triple_list':t['triple_list']} ''' # return token_ids, seg_ids, list(t['text']), list(t['triple_list']) return token_ids, seg_ids, t def __len__(self): data_len = len(self.data) return data_len def sequence_padding(self, x, maxlen, padding=0): output = np.concatenate([x, [padding] * (maxlen - len(x))]) if len(x) < maxlen else np.array(x[:maxlen]) return output valid_dataset = ValidDataset(valid_data) valid_loader = DataLoader(dataset=valid_dataset, batch_size=args.batch2, shuffle=False, drop_last=True) def extract_spoes(data, model, device): ''' """抽取三元组""" if len(text) > 254: text = text[:254] tokens = tokenizer.tokenize(text) tokens = ["[CLS]"] + tokens + ["[SEP]"] token_ids = tokenizer.convert_tokens_to_ids(tokens) assert len(token_ids) == len(text) + 2 seg_ids = [0] * len(token_ids) ''' # print (data[2]) # print (data['text']) # token_ids = data['token_ids'] token_ids = data[0] # seg_ids = data['seg_ids'] seg_ids = data[1] # import pdb # pdb.set_trace() sub_preds = model(token_ids.to(device), seg_ids.to(device)) sub_preds = sub_preds.detach().cpu().numpy() # [1, maxlen, 2] # print(sub_preds[0,]) start = np.where(sub_preds[0, :, 0] > 0.5)[0] end = np.where(sub_preds[0, :, 1] > 0.5)[0] # print(start, end) tmp_print = [] subjects = [] for i in start: j = end[end >= i] if len(j) > 0: j = j[0] subjects.append((i, j)) tmp_print.append(data[2][i - 1: j]) if subjects: spoes = [] # print (len(subjects)) #只有2 token_ids = np.repeat(token_ids, len(subjects), 0) # [len_subjects, seqlen] # print(token_ids.shape) seg_ids = np.repeat(seg_ids, len(subjects), 0) subjects = np.array(subjects) # [len_subjects, 2] # 传入subject 抽取object和predicate _, object_preds = model(token_ids.to(device), seg_ids.to(device), torch.LongTensor(subjects).to(device)) object_preds = object_preds.detach().cpu().numpy() # print(object_preds.shape) for sub, obj_pred in zip(subjects, object_preds): # obj_pred [maxlen, 55, 2] start = np.where(obj_pred[:, :, 0] > 0.3) end = np.where(obj_pred[:, :, 1] > 0.3) for _start, predicate1 in zip(*start): for _end, predicate2 in zip(*end): if _start <= _end and predicate1 == predicate2: spoes.append( ((sub[0] - 1, sub[1] - 1), predicate1, (_start - 1, _end - 1)) ) break # print (spoes) return [(data[2][s[0]:s[1] + 1], id2predicate[str(p)], data[2][o[0]:o[1] + 1]) for s, p, o in spoes] else: return [] def evaluate(valid_data, valid_load, model, device): """评估函数,计算f1、precision、recall """ # F1 = [] # P = [] # Re = [] X, Y, Z = 1e-10, 1e-10, 1e-10 f = open("./data/CMED/dev_pred.json", 'w', encoding='utf-8') pbar = tqdm() # for d in data: # with torch.no_grad: # print (type(valid_load)) # return for idx, data in tqdm(enumerate(valid_load)): input = data[0], data[1], data[2]['text'][0] # print(input) # input = data[0], data[1], valid_data[idx]['text'], valid_data[idx]['triple_list'] R = extract_spoes(input, model, device) # print ('R:',R) T = valid_data[idx]['triple_list'] ''' tri = data[3] #tri = tuple(tri) T = [] for tris in tri: temp = tuple() for i in tris: temp += i T.append(temp) ''' # print ('tri:',tri) # print ('tri:',temp_tri) R = set(R) # print ('R',R) T = set(T) # print('T', R) X += len(R & T) Y += len(R) Z += len(T) f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z # F1.append(f1) # P.append(precision) # Re.append(recall) pbar.update() pbar.set_description( 'F1: %.5f, \tPrecision: %.5f, \tRecall: %.5f' % (f1, precision, recall) ) if f1 > 0.5: s = json.dumps({ 'text': valid_data[idx]['text'], 'triple_list': list(T), 'triple_list_pred': list(R), 'new': list(R - T), 'lack': list(T - R), }, ensure_ascii=False, indent=4) f.write(s + '\n') pbar.close() f.close() return f1, precision, recall ''' def evaluate(data, model, device): """评估函数,计算f1、precision、recall """ X, Y, Z = 1e-10, 1e-10, 1e-10 f = open("/home/jason/EXP/NLP/triple_test/data/CMED/dev_pred.json", 'w', encoding='utf-8') pbar = tqdm() for d in data: R = extract_spoes(d['text'], model, device) T = d['triple_list'] #print (T) R = set(R) #print ('R',R) T = set(T) #T = set() #for item in T1: # for i in item: # T.add(i) #print ('T',T) X += len(R & T) Y += len(R) Z += len(T) f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z pbar.update() pbar.set_description( 'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall) ) if f1 > 0.5: s = json.dumps({ 'text': d['text'], 'triple_list': list(T), 'triple_list_pred': list(R), 'new': list(R - T), 'lack': list(T - R), }, ensure_ascii=False, indent=4) f.write(s + '\n') pbar.close() f.close() return f1, precision, recall ''' import sys import os class Logger(object): def __init__(self,fileN="default.log"): self.terminal = sys.stdout self.log = open(fileN,"a") def write(self,message): self.terminal.write(message) self.log.write(message) def flush(self): pass # def FocalLoss(input, target ,gamma=2,weight=None,reduction='mean'): # # def __init__(self,gamma=2,weight=None,reduction='mean'): # # super(FocalLoss, self).__init__() # # self.gamma = gamma # # self.weight = weight # # self.reduction = reduction # # def forward(self, output, target): # out_target = torch.stack([input[i,t] for i.type(torch.bool),t.type(torch.bool) in enumerate(target)]) # probs = torch.sigmoid(out_target) # focal_weight = torch.pow(1-probs,gamma=2) # # ce_loss = F.cross_entropy(input,target,weight=None,reduction='none') # focal_loss = focal_weight*ce_loss # # if reduction == 'mean': # focal_loss = (focal_loss/focal_weight.sum()).sum() # elif reduction == 'sum': # focal_loss = focal_loss.sum() # # return focal_loss # class FocalLoss(nn.Module): # # def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7): # super(FocalLoss, self).__init__() # self.gamma = gamma # self.eps = eps # self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction) # # def forward(self, input, target): # logp = self.ce(input, target) # print('logp',logp) # p = torch.exp(-logp) # loss = (1 - p) ** self.gamma * logp # return loss.mean() # def Dice_loss(inputs,target,beta=1,smooth=1e-5): # n,c,h = inputs.size() # nt,ht,wt = target.size() # if n!= nt and h!=wt: # inputs = F.interpolate(inputs,size=(ht,wt),mode="bilinear",align_corners=True) # temp_imputs = torch.softmax(inputs.transpose(1,2).transpose(2,3).contiguous().view(n,-1,c),-1) # temp_target = target.view(n,-1,ct) # # #...................... # #ice loss # #...................... # tp = torch.sum(temp_target[...,:-1]*temp_imputs,axis=[0,1]) # fp = torch.sum(temp_imputs,axis=[0,1])-tp # fn = torch.sum(temp_target[...,:-1],axis=[0,1])-tp # # score = ((1+beta**2)*tp+smooth)/((1+beta**2)*tp+beta**2*fn+fp+smooth) # dice_loss = 1-torch.mean(score) # return dice_loss # def dice_coeff(pred, target): # smooth = 1. # num = pred.size(0) # m1 = pred.view(num, -1) # Flatten # m2 = target.view(num, -1) # Flatten # intersection = (m1 * m2).sum() # # return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth) # def train(model, train_loader, optimizer, epoches, device): # # model.train() # torch.backends.cudnn.enabled = False # for _ in range(epoches): # print('epoch: ', _ + 1) # start = time.time() # train_loss_sum = 0.0 # for batch_idx, x in tqdm(enumerate(train_loader)): # # token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) # token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) # # tokens_words, masks_out, head = x[5].to(device), x[6].to(device), x[7].to(device) # # print (token_ids.shape) # # mask = (token_ids > 0).float() # mask = mask.to(device) # zero-mask # sub_labels, obj_labels = x[3].float().to(device), x[4].float().to(device) # sub_preds, obj_preds = model(token_ids, seg_ids, sub_ids) # # (batch_size, maxlen, 2), (batch_size, maxlen, 55, 2) # # #计算loss # smooth = 1 # intersection = sub_labels * sub_preds # sub_dice_eff = (2 * intersection.sum(1) + smooth) / (sub_preds.sum(1) + sub_labels.sum(1) + smooth) # # print(sub_dice_eff) # smooth = 1 # # intersection2 = obj_labels * obj_preds # # obj_dice_eff = (2 * intersection2.sum(1) + smooth) / (obj_preds.sum(1) + obj_labels.sum(1) + smooth) # # # print(obj_dice_eff) # # beta = 1 # # smooth = 1e-5 # # p = torch.sigmoid(sub_preds) # # tp = torch.sum(sub_labels[..., :-1] * p, axis=[0, 1]) # # # print(tp) # # fp = torch.sum(p, axis=[0, 1]) - tp # # # print(fp) # # fn = torch.sum(sub_labels[..., :-1], axis=[0, 1]) - tp # # score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) # # sub_dice_loss = 1-torch.mean(score) # # # print(sub_dice_loss) # # ce_loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') # [bs, ml, 2] # # p_t = p*sub_labels + (1-p)*(1-sub_labels) # # gamma = 2 # # loss_sub= ce_loss_sub*((1-p_t)**gamma) # # q = torch.sigmoid(obj_preds) # # print(q) # tp = torch.sum(obj_labels[..., :-1] * q, axis=[0, 1]) # # print(tp) # fp = torch.sum(q, axis=[0, 1]) - tp # # print(fp) # fn = torch.sum(obj_labels[..., :-1], axis=[0, 1]) - tp # score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) # obj_dice_loss = 1 - torch.mean(score) # # print(obj_dice_loss) # loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) # loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) # # print('loss_sub:',loss_sub) # q = torch.sigmoid(obj_preds) # ce_loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 55, 2] # q_t = q * obj_labels + (1 - q) * (1 - obj_labels) # gamma = 2 # loss_obj = ce_loss_obj * ((1 - q_t) ** gamma) # loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) # loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) # loss = loss_sub + loss_obj # loss_sub = dice_coeff(sub_preds, sub_labels) # loss_obj = dice_coeff(obj_preds, obj_labels) # loss = loss_sub+ loss_obj # # 计算loss # loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') # [bs, ml, 2] # loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) # loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) # loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 55, 2] # loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) # loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) # loss = loss_sub + loss_obj # optimizer.zero_grad() # # loss.backward() # optimizer.step() # train_loss_sum += loss.cpu().item() # if (batch_idx + 1) % 31 == 0: # print('loss: ', train_loss_sum / (batch_idx + 1), 'time: ', time.time() - start) # # torch.save(net.state_dict(), "./checkpoints/best_re.pth") # # with torch.no_grad(): # # model.eval() # # print (valid_data[:5]) # val_f1, pre, rec = evaluate(valid_data, valid_loader, net, device) # # print('F1_score: %.5f, Precision: %.5f, Recall: %.5f' % (val_f1, pre, rec)) # # sys.stdout = Logger('./datalog.txt') # re = tuple((val_f1, pre, rec)) # with open("./result_Dice_loss.json","a",encoding='utf-8') as f: # json.dump(re,f,indent=4,ensure_ascii=True) # # print("f1, pre, rec: ", val_f1, pre, rec) class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2, size_average=True, ignore_index=255): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.ignore_index = ignore_index self.size_average = size_average def forward(self, inputs, targets): ce_loss = F.class_entropy(inputs, targets, reduction='none',ignore_index=self.ignore_index) pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss return focal_loss.sum() def train(model, train_loader, optimizer, epoches, device): # model.train() torch.backends.cudnn.enabled = False list = [] for _ in range(epoches): # f = open("./test.txt", 'w+', encoding='utf-8') print('epoch: ', _ + 1) start = time.time() train_loss_sum = 0.0 for batch_idx, x in tqdm(enumerate(train_loader)): # token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) # tokens_words, masks_out, head = x[5].to(device), x[6].to(device), x[7].to(device) # print (token_ids.shape) mask = (token_ids > 0).float() mask = mask.to(device) # zero-mask sub_labels, obj_labels = x[3].float().to(device), x[4].float().to(device) sub_preds, obj_preds = model(token_ids, seg_ids, sub_ids) # (batch_size, maxlen, 2), (batch_size, maxlen, 55, 2) # 计算loss smooth = 1 intersection2 = obj_labels * obj_preds obj_dice_eff = (2 * intersection2.sum(1) + smooth) / (obj_preds.sum(1) + obj_labels.sum(1) + smooth) # print(obj_dice_eff) beta = 1 smooth = 1e-5 p = torch.sigmoid(sub_preds) tp = torch.sum(sub_labels[..., :-1] * p, axis=[0, 1]) # print(tp) fp = torch.sum(p, axis=[0, 1]) - tp # print(fp) fn = torch.sum(sub_labels[..., :-1], axis=[0, 1]) - tp score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) sub_dice_loss = 1-torch.mean(score) # print(sub_dice_loss) ce_loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') # [bs, ml, 2] p_t = p*sub_labels + (1-p)*(1-sub_labels) gamma = 2 loss_sub= ce_loss_sub*((1-p_t)**gamma) q = torch.sigmoid(obj_preds) # print(q) tp = torch.sum(obj_labels[..., :-1] * q, axis=[0, 1]) # print(tp) fp = torch.sum(q, axis=[0, 1]) - tp # print(fp) fn = torch.sum(obj_labels[..., :-1], axis=[0, 1]) - tp score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) obj_dice_loss = 1 - torch.mean(score) # print(obj_dice_loss) loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) # print('loss_sub:',loss_sub) q = torch.sigmoid(obj_preds) ce_loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 55, 2] q_t = q * obj_labels + (1 - q) * (1 - obj_labels) gamma = 2 loss_obj = ce_loss_obj * ((1 - q_t) ** gamma) loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) # jiaochashang # loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') # [bs, ml, 2] # loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) # loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) # loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 55, 2] # loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) # loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) loss = loss_sub + loss_obj optimizer.zero_grad() loss.backward() optimizer.step() train_loss_sum += loss.cpu().item() if (batch_idx + 1) % 31 == 0: print('loss: ', train_loss_sum / (batch_idx + 1), 'time: ', time.time() - start) list.append(train_loss_sum / (batch_idx + 1)) torch.save(net.state_dict(), "./checkpoints/best_re.pth") with torch.no_grad(): # model.eval() # print (valid_data[:5]) val_f1, pre, rec = evaluate(valid_data, valid_loader, net, device) print('F1_score: %.5f, Precision: %.5f, Recall: %.5f' % (val_f1, pre, rec)) # print("f1, pre, rec: ", val_f1, pre, rec) print(list) # LOGGER = set_logging(name='test', level=logging.INFO, verbose=True) if __name__ == '__main__': # net.load_state_dict(torch.load("RE/data/bert_re.pth")) train(net, train_loader, optimizer, 600, DEVICE)