Kaggle知识点:BERT Prompt文本分类

Prompt介绍

Prompt是NLP的一个新领域。在Prompt中任务的描述被嵌入到输入中,提供了一种新的方式来控制机器学习模型输出。

Prompt是利用预训练语言模型在大量文本数据上获得的知识,来解决各种下游任务。Prompt的优势是它可以减少或者避免对预训练模型进行微调,节省计算资源和时间,同时保持或者提高模型的性能和泛化能力。

Prompt的方法是根据不同的任务和数据,设计合适的输入格式,包括问题,上下文,前缀,后缀,分隔符等。

BERT与Prompt使用

Prompt可用于提高BERT的句子表示能力,通过在BERT的输入中加入一些特定的词语作为Prompt,引导BERT生成更好的句子向量

  • 方法1:在句子的开头或结尾加入Prompt
  • 方法2:在句子的中间加入Prompt

Prompt搜索方法

Prompt的搜索方法找到最优的Prompt,能最大化BERT表示能力的Prompt。目前有三种主要的搜索方法:

  • 随机搜索:随机生成一些Prompt,然后用它们作为BERT的输入,计算BERT的输出向量与目标向量的相似度,选择相似度最高的Prompt作为最优的Prompt。
  • 贪心搜索:从一个空的Prompt开始,每次在Prompt的末尾加入一个词,然后用它作为BERT的输入,计算BERT的输出向量与目标向量的相似度,选择相似度最高的词作为Prompt的一部分,直到达到一个预设的长度或者相似度阈值。
  • 强化学习搜索:将Prompt的生成视为一个序列决策问题,使用强化学习的算法,来优化一个策略网络,根据一个奖励函数来更新网络的参数。

Prompt方法局限性

BERT + Prompt的优势是能够利用Prompt来引导BERT生成更好的句子向量,从而提高句子表示的质量和多样性

句子相似度,文本分类,文本检索等,BERT + Prompt可能会比原始BERT模型有效。文本生成的任务,如文本摘要,文本复述,文本续写等,BERT + Prompt可能不一定比原始BERT模型有效。

Prompt适合进行多任务进行建模,比如多个文本任务一起进行训练。因此在单个任务中,Prompt并不会增加模型精度。在现有文本分类比赛中暂时还没看到Prompt的使用案例。

案例:Prompt文本分类

输入文本:

It was [mask]. 文本输入样例

将[MASK]输出接全连接层,进行分类。

步骤1:定义模型

class Bert_Model(nn.Module):
    def __init__(self,  bert_path ,config_file ):
        super(Bert_Model, self).__init__()
        self.bert = BertForMaskedLM.from_pretrained(bert_path,config=config_file) # 加载预训练模型权重 def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask, token_type_ids) #masked LM 输出的是 mask的值 对应的ids的概率 ,输出 会是词表大小,里面是概率  logit = outputs[0] # 池化后的输出 [bs, config.hidden_size]

return logit


步骤2:定义数据集

class MyDataSet(Data.Dataset):
    def __init__(self, sen , mask , typ ,label ):
        super(MyDataSet, self).__init__()
        self.sen = torch.tensor(sen,dtype=torch.long)
        self.mask = torch.tensor(mask,dtype=torch.long)
        self.typ =torch.tensor( typ,dtype=torch.long)
        self.label = torch.tensor(label,dtype=torch.long)
 
    def __len__(self): return self.sen.shape[0]
 
    def __getitem__(self, idx): return self.sen[idx], self.mask[idx],self.typ[idx],self.label[idx]

步骤3:对文本加入Prompt

prefix = 'It was [mask]. '

for i in range(len(x_train)):
text_ = prefix+x_train[i][0]
encode_dict = tokenizer.encode_plus(text_,max_length=60,padding="max_length",truncation=True)


步骤4:模型训练与预测

optimizer = AdamW(model.parameters(),lr=2e-5,weight_decay=1e-4) #使用Adam优化器 loss_func = nn.CrossEntropyLoss(ignore_index=-1)

for idx,(ids,att_mask,type,y) in enumerate(train_dataset):
ids,att_mask,type,y = ids.to(device),att_mask.to(device),type.to(device),y.to(device)
out_train = model(ids,att_mask,type)
loss = loss_func(out_train.view(-1, tokenizer.vocab_size),y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss_sum += loss.item()


参考文献

PromptBERT: Improving BERT Sentence Embeddings with Prompts

【竞赛报名/项目咨询请加微信:mollywei007】

上一篇

在美国读商科需要具备什么条件?

下一篇

求真书院普特南数学竞赛模拟赛成绩力压美名校

你也可能喜欢

  • 暂无相关文章!

评论已经被关闭。

插入图片
返回顶部
Baidu
map