【CCF BDCI 2023】多模態多方對話場景下的發言人識別 Baseline 0.71 NLP 部分

【CCF BDCI 2023】多模態多方對話場景下的發言人識別 Baseline 0.71 NLP 部分

  • 概述
  • NLP 簡介
    • 文本處理
    • 詞嵌入
    • 上下文理解
  • 文本數據加載
    • to_device 函數
    • 構造
    • 數據加載
    • 樣本數量 len
    • 獲取樣本 getitem
  • 分詞
    • 構造函數
    • 調用函數
    • 輪次嵌入
  • Roberta
    • Roberta 創新點
    • NSP (Next Sentence Prediction)
    • Roberta 構造函數
    • Roberta 前向傳播
    • 計算發言人相似度
      • 推理模式 (Inference Mode)
      • 訓練模式 (Training Mode)
  • Deberta
    • Deberta 創新點
    • Deberta 構造函數
    • Deberta 前向傳播
  • 訓練
  • 驗證
  • 參考文獻

概述

現今技術日新月異, Artificial Intelligence 的發展正在迅速的改變我們的生活和工作方式. 尤其是在自然語言處理 (Natural Linguistic Processing) 和計算機視覺 (Computer Vision) 等領域.

傳統的多模態對話研究主要集中在單一用戶與系統之間的交互, 而忽視了多用戶場景的復雜性. 視覺信息 (Visual Info) 往往會被邊緣化, 僅作為維嘉信息而非對話的核心部分. 在實際應用中, 算法需要 “觀察” 并與多個用戶的交互, 這些用戶有可能不是當前的發言人.

【CCF BDCI 2023】多模態多方對話場景下的發言人識別, 核心思想是通過多輪連續對話的內容和每輪對應的幀, 以及對應的人臉 bbox 和 name label, 從每輪對話中識別出發言人 (speaker).

NLP 簡介

書接上文, 在上一篇博客中小白帶大家詳解了 Baseline 中的 CNN 模型部分. 今天我們來詳解一下 NLP 部分. 包括 Roberta 和 Deberta 模型及其應用.

NLP 部分

文本處理

文本處理是 NLP 任務的第一步. 我們需要將原始文本轉化成模型可以處理的格式.

步驟包含:

  1. 清洗 (Cleaning): 去除無用信息, 常見的有標點符號, 特殊字符, html, 停用詞等
  2. 分詞 (Tokenization): 將文本按詞 (Word) 為單位進行分割, 并轉換為數字數據.
    • 常見單詞, 例如數據中的人名:
      • Rachel對應 token id 5586
      • Chandler對應 token id 13814
      • Phoebe對應 token id 18188
      • 上述 token id 對應 bert 的 vocab 中, roberta 的 vocab 表在服務器上, 懶得找了
    • 特殊字符:
      • [CLS]: token id 101, 表示句子的開始
      • [SEP]: token id 102, 表示分隔句子或文本片段
      • [PAD]: token id 0, 表示填充 (Padding), 當文本為達到指定長度時, 例如 512, 會用[PAD]進行填充
      • [MASK]: token id 0, 表示填充 (Padding), 當文本為達到指定長度時, 例如 512, 會用[PAD]進行填充

上述字符在 Bert & Bert-like 模型中扮演著至關重要的角色, 在不同的任務重, 這些 Token ID 都是固定的, 例如 Bert 為 30522 個.

FYI: 上面的超鏈接是 jieba 分詞的一個簡單示例.

詞嵌入

詞嵌入 (Word Embedding) 是將文本中的詞匯映射到向量空間的過程. 詞向量 (Word Vector) 對應為詞匯的語義信息, 具有相似含義的詞匯在向量空間中距離接近.

詞嵌入

常見的詞嵌入技術包含:

  • Word2Vec: 通過神經網絡模型學習詞匯的分布式
  • GloVe: 基于全局詞共現統計信息構建詞向量 (Word Vector)
  • Bert Embedding: 使用 Bert 模型生成上下文相關的詞嵌入

FYI: 想要了解詞向量和 Word2Vec 的具體原理, 參考我上面超鏈接的博客.

上下文理解

在多方對話中, 上下文的理解至關重要, 包括對話的語境, 參與者之間的關系和對話的流程.

具體技術:

  • Transformers 模型, 如 Bert, Roberta, Deberta 等, 通過捕捉長距離依賴關系, 理解整個句子 / 對話的上下文
  • 注意力機制 (Attention Mechanism): 模型在處理一個單詞 / 短語時, 考慮到其他相關部分的信息.

文本數據加載

SpeakerIdentificationDataset是用于加載多模態多方對話場景下的發言人識別任務中的數據的一個類. 下面小白帶大家來逐行解析.

數據加載

to_device 函數

to_device 函數左右為將數據移動到指定設備, 例如 GPU:0.

def to_device(obj, dev):if isinstance(obj, dict):return {k: to_device(v, dev) for k, v in obj.items()}if isinstance(obj, list):return [to_device(v, dev) for v in obj]if isinstance(obj, tuple):return tuple([to_device(v, dev) for v in obj])if isinstance(obj, torch.Tensor):return obj.to(dev)return obj
  • 如果傳入對象為 obj, dict則遞歸的對這個字典的每個值進行to_device操作, 將結果匯總在一個新的字典上, key 不變, value.to(device)
  • 如果傳入對象為 obj, list則遞歸對列表的每個元素鏡像``to_device```操作, 將結果匯總在一個新的列表上
  • 如果傳入對象為obj, tuple, 同理, 返回元組
  • 如果傳入對象為obj, torch.tensor, 將張量移動到指定的設備, 如: CPU->GPU

構造

class SpeakerIdentificationDataset:def __init__(self, base_folder, bos_token='<bos>', split='train', dataset='friends', data_aug=False, debug=False):self.base_folder = base_folderself.debug = debugself.dataset = datasetself.split = splitself.bos_token = bos_token
  • base_folder: 數據集存放路徑
  • bos_token: 句子開始時的特殊字符
  • split: 分割 (train, valid, test)
  • dataset: 默認 friends

數據加載

if dataset == 'friends':if split == 'test':metadata = json.load(open(os.path.join(base_folder, 'test-metadata.json')))else:if data_aug:metadata = json.load(open(os.path.join(base_folder, 'train-metadata-aug.json')))else:metadata = json.load(open(os.path.join(base_folder, 'train-metadata.json')))self.examples = list()for dialog_data in metadata:# 我們選擇s01作為驗證集好了if split == 'valid' and not dialog_data[0]['frame'].startswith('s01'):continueif split == 'train' and dialog_data[0]['frame'].startswith('s01'):continueself.examples.append(dialog_data)
else:if dataset == 'ijcai2019':self.examples = [json.loads(line) for line in open(os.path.join(base_folder, '%s.json' % (split.replace('valid', 'dev'))))]if dataset == 'emnlp2016':self.examples = [json.loads(line) for line in open(os.path.join(base_folder, '10_%s.json' % (split.replace('valid', 'dev'))))]self.examples = [example for example in self.examples if len(example['ctx_spk']) != len(set(example['ctx_spk']))]

和前面的 CNN Dataset 一樣, 還是使用 s01 的 dialog 數據做為 valid, 剩下的作為 train.

樣本數量 len

def __len__(self):return len(self.examples) if not self.debug else 32
  • 和 CNN 的 Dataset 一樣, 非 Debug 模式下返回范本數量, Debug 模型下返回 32

獲取樣本 getitem

def __getitem__(self, index):example = self.examples[index]if self.dataset == 'friends':speakers, contents, frame_names = [i['speaker'] for i in example], [i['content'] for i in example], [i['frame'] for i in example]else:speakers, contents = example['ctx_spk'], example['context']frame_names = ['%d-%d' % (index, i) for i in range(len(speakers))]labels = list()for i, speaker_i in enumerate(speakers):for j, speaker_j in enumerate(speakers):if i != j and speaker_i == speaker_j:labels.append([i, j])input_text = self.bos_token + self.bos_token.join(contents)return input_text, labels, frame_names
  • 從數據集提取單個 Sample
  • 提取發言人, 對話內容和幀名
  • 生成標簽, 并標記發言人的位置
  • 將對話內容拼接成一個長文本, 用于模型輸入

這么說可能大家有點暈, 我來大大家拿 train 的第一個 dialog 演示一下.

Dialog[0] (sample), 5 句話組成:

[{"frame": "s06e07-000377", "speaker": "phoebe", "content": "Yeah, I know because you have all the good words. What do I get? I get \"it\u2019s,\" \"and\" oh I'm sorry, I have \"A.\" Forget it.", "start": 297, "end": 491, "faces": [[[752, 135, 881, 336], "rachel"], [[395, 111, 510, 329], "leslie"]]}, {"frame": "s06e07-000504", "speaker": "rachel", "content": "Phoebe, come on that's silly.", "start": 498, "end": 535, "faces": [[[466, 129, 615, 328], "phoebe"]]}, {"frame": "s06e07-000552", "speaker": "phoebe", "content": "All right, so let's switch.", "start": 535, "end": 569, "faces": [[[426, 120, 577, 320], "phoebe"]]}, {"frame": "s06e07-000629", "speaker": "rachel", "content": "No, I have all of the good words. OK, fine, fine, we can switch.", "start": 569, "end": 689, "faces": [[[420, 125, 559, 328], "phoebe"], [[652, 274, 771, 483], "rachel"]]}, {"frame": "s06e07-000892", "speaker": "phoebe", "content": "Please...wait, how did you do that?", "start": 879, "end": 906, "faces": [[[424, 133, 573, 334], "phoebe"], [[816, 197, 925, 399], "bonnie"]]}]

得到的 input_test:

<bos>Yeah, I know because you have all the good words. What do I get? I get "it’s," "and" oh I'm sorry, I have "A." Forget it.<bos>Phoebe, come on that's silly.<bos>All right, so let's switch.<bos>No, I have all of the good words. OK, fine, fine, we can switch.<bos>Please...wait, how did you do that?

得到的 labels:

[[0, 2], [0, 4], [1, 3], [2, 0], [2, 4], [3, 1], [4, 0], [4, 2]]

得到的 frame_names:

['s06e07-000377', 's06e07-000504', 's06e07-000552', 's06e07-000629', 's06e07-000892']

具體說明一下 labels 部分, 上述的 dialog 中的 5 個發言人, 依次為:

  1. Phoebe
  2. Rachel
  3. Phoebe
  4. Rachel
  5. Phoebe

其中:

  • Phoebe 在 1, 3, 5 句子中發言
  • Rachel 在 2, 4 句子中發言

所以我們可以得到:

  • [0, 2]: 1, 3 句子都是同一個人發言 (Phoebe)
  • [0, 4[: 1, 5 句子都是同一個人發言 (Phoebe)
  • [1, 3]: 2, 4 句子都素同一個人發言 (Rachel)
  • [2, 0]: 3, 1 句子都是通一個人發言 (Phoebe)
  • [2, 4]: 3, 5 句子都是同一個人發言 (Phoebe)
  • [3, 1]: 4, 2 句子都是同一個人發言 (Rachel)
  • [4, 0]: 5, 1 句子都是同一個人發言 (Phoebe)
  • [4, 2]: 5, 3 句子都是同一個人發言 (Phoebe)

然后補充一下 input_text 部分:

  • 在上面我們提到了一些特殊Token ID, <bos>就是一個特殊的 Token ID, 用于表示句子的開始, 幫助模型在生成文本和處理序列時確定起始點
  • 在處理對話時, <bos>可以用來分隔不同的語句

補充, <sep><bos>區別:

  • <bos>用于標記句子的開始, <sep>用于分隔句子的不同部分

在這里插入圖片描述

分詞

Collator 類的主要作用是將批次 (Batch) 樣本, tokenize 后轉換為模型需要的輸入格式.

構造函數

def __init__(self, tokenizer, max_length=512, temperature=1.0, use_turn_emb=False):self.tokenizer = tokenizerself.max_length = max_lengthself.temperature = temperatureself.use_turn_emb = use_turn_embself.print_debug = True
  • tokenizer: 用于文本 tokenize, 例如: RobertaTokenizer
  • max_length: 最大長度限制, 默認為 512
  • temperature: 模型溫度參數, 默認為 1
  • use_turn_emb: 是否使用輪次嵌入

調用函數

def __call__(self, examples):input_texts = [i[0] for i in examples]labels = [i[1] for i in examples]frame_names = [i[2] for i in examples]model_inputs = self.tokenizer(input_texts, add_special_tokens=False, truncation=True, padding='longest', max_length=self.max_length, return_tensors='pt')model_inputs = dict(model_inputs)
  • 獲取 input_texts, labels, frame_names
  • tokenize 文本
new_labels = list()
for input_id, label in zip(model_inputs['input_ids'], labels):num_bos_tokens = torch.sum(input_id == self.tokenizer.bos_token_id).item()label = [l for l in label if l[0] < num_bos_tokens and l[1] < num_bos_tokens]      # 如果遇到了truncation,將被truncate掉的turn刪除new_labels.append(torch.tensor(label))
model_inputs['labels'] = new_labels
  • 創建空列表存放標簽
  • 遍歷每個樣本
  • 計算 bos 標記數量
  • 更新標簽

舉個例子:

input_text:

Yeah, I know because you have all the good words. What do I get? I get "it’s," "and" oh I'm sorry, I have "A." Forget it.[CLS]Phoebe, come on that's silly.[CLS]All right, so let's switch.[CLS]No, I have all of the good words. OK, fine, fine, we can switch.[CLS]Please...wait, how did you do that?

tokenize 后:

[101, 3398, 1010, 1045, 2113, 2138, 2017, 2031, 2035, 1996, 2204, 2616, 1012, 2054, 2079, 1045, 2131, 1029, 1045, 2131, 1000, 2009, 1521, 1055, 1010, 1000, 1000, 1998, 1000, 2821, 1045, 1005, 1049, 3374, 1010, 1045, 2031, 1000, 1037, 1012, 1000, 5293, 2009, 1012, 101, 18188, 1010, 2272, 2006, 2008, 1005, 1055, 10021, 1012, 101, 2035, 2157, 1010, 2061, 2292, 1005, 1055, 6942, 1012, 101, 2053, 1010, 1045, 2031, 2035, 1997, 1996, 2204, 2616, 1012, 7929, 1010, 2986, 1010, 2986, 1010, 2057, 2064, 6942, 1012, 101, 3531, 1012, 1012, 1012, 3524, 1010, 2129, 2106, 2017, 2079, 2008, 1029, 102]

注: 這邊我用的是 Bert [CLS], 等同于<bos>

new_labels:

[[0, 2], [0, 4], [1, 3], [2, 0], [2, 4], [3, 1], [4, 0], [4, 2]]

因為上面的 5 個句子加起來并沒有達到 512 個詞, 所以 label 并沒有進行刪減. 如果 比如<bos只有4 個, 即最后一個句子被裁剪 (truncation) 了, 此時就要去掉所有包括句子 5 的 label.

假設上面句子只有三句半, new_labels 為:

[[0, 2], [1, 3], [2, 0], [3, 1]]

輪次嵌入

if self.use_turn_emb:model_inputs['token_type_ids'] = torch.cumsum(model_inputs['input_ids'] == self.tokenizer.bos_token_id, dim=1)
model_inputs['frame_names'] = frame_names
model_inputs['temperature'] = self.temperature

計算輪次嵌入: 使用torch.cumsum函數計算累積和. model_inputs['input_ids'] == self.tokenizer.bos_token_id創建了一個布爾張亮, 每個句子開始<bos>標記的位置為 True, 其他位置為 False

輪次嵌入

對話:

Yeah, I know because you have all the good words. What do I get? I get "it’s," "and" oh I'm sorry, I have "A." Forget it.
[CLS]
Phoebe, come on that's silly.
[CLS]
All right, so let's switch.
[CLS]
No, I have all of the good words. OK, fine, fine, we can switch.
[CLS]
Please...wait, how did you do that?

輪次嵌入前的 token_type_ids:

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

輪次嵌入后的 token_type_ids:

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]

類似:

1: ["Yeah", ",", "I", "know", "because", "you", "have", "all", "the", "good", "words", ".", "What", "do", "I", "get", "?", "I", "get", "it’s", ",", "and", "oh", "I", "'m", "sorry", ",", "I", "have", "A", ".", "Forget", "it", "."]
2: ["Phoebe", ",", "come", "on", "that", "'s", "silly", "."]
3: ["All", "right", ",", "so", "let", "'s", "switch", "."]
4: ["No", ",", "I", "have", "all", "of", "the", "good", "words", ".", "OK", ",", "fine", ",", "fine", ",", "we", "can", "switch", "."]
5: ["Please", "...", "wait", ",", "how", "did", "you", "do", "that", "?"]

輪次嵌入的作用:

輪次嵌入對處理對話和交互文本時至關重要. 輪次嵌入為模型提供了關于每個單詞屬于哪個對話, 對于模型理解對話結構和上下文非常重要.

在上面的例子中, 我們有 5 句話組成的 dialog, 經過輪次嵌入, 第一句的單詞會被標記為 1, 第二句為 2, 第三句為 3, 第四句為 4, 第五句為 5. 通過 1, 2, 3, 4, 5 的標記, 可以幫助模型區分不同句子的語境, 以更好的處理每個對話.

Roberta

Roberta (Robustly Optimized BERT Approach) 是一種基于 BERT (Bidirectional Encoder Representations from Transformers) 的 NLP 模型.

Roberta 創新點

Roberta 在 Bert 的基礎上創新了訓練過程和數據處理方式. First, Roberta 使用的語料庫更大, 數據更難多, 模型更更好的理解和處理復雜的語言模式. Second, Roberta 取消了 Bert 中的下句預測 (Next Sentence Prediction). Third, Roberta 對輸入數據的處理方式也進行了優化, 具體表現為更長序列進行的訓練, 因此 Roberta 的長文本處理能力也更為優秀.

NSP (Next Sentence Prediction)

  • NSP (Next Sentence Prediction) 目的是改善模型 (Bert) 對句子關系的理解, 特別是在理解段落或文檔中句子之間的關系方面
  • NSP 任務重, 模型唄訓練來預測兩個句子是否在原始文本中相鄰. 舉個栗子: A & B 倆句子, 模型需要判斷 B 是否是緊跟在 A 后面的下一句. 在 Training 過沖中, Half time B 確實是 A 的下一句, 另一半時間 B 則是從語料庫中隨機選取的與 A 無關的句子. NSP 就是基于這些句子判斷他們是否是連續的
    • 句子 A: “我是小白呀今年才 18 歲”
    • 句子 B: “真年輕”
      • NSP: 連續, B 是對 A 的回應 (年齡), 表達了作者 “我” 十分年輕
    • 句子 A: “意大利面要拌”
    • 句子 B: “42 號混凝土”
      • NSP: 不連續, B 和 A 內容完全無關
  • NSP 對諸如系統問答, 文本摘要等任務十分重要, 但是 Roberta 發現去除也一樣, 因為 Bert 底層的雙向結構十分強大. 后續的新模型, Roberta, Xlnet, Deberta 都去除了 NSP

Roberta 構造函數

構造函數:

def __init__(self, config):super().__init__(config)self.bos_token_id = config.bos_token_idself.loss_fct = CrossEntropyLoss(reduction='none')...以下省略
  • bos_token_id: 句子起始標記
  • los_fct: 損失函數, 這邊為交叉熵損失 (CrossEntropyLoss)

Roberta 前向傳播

def forward(...):...以上省略outputs = self.roberta(...)last_hidden_state = outputs[0]...以下省略
  • last_hidden_state: 獲取 Roberta 輸出的隱層狀態

計算發言人相似度

這邊的計算發言人相似度分為兩個模式, 分別為推理模式 (Inference Mode) 和訓練模式 (Training Mode).

推理模式 (Inference Mode)

在 labels == None 的時候, 模型進行推理模式 (Inference Mode). 在這種模式下, 模型的主要任務是計算并返回每個句子的隱層狀態和相似度得分, 而不是進行模型的訓練. 用于 valid 和 test.

if labels is None:# inference modeselected_hidden_state_list, logits_list = list(), list()for i, (hidden_state, input_id) in enumerate(zip(last_hidden_state, input_ids)):indices = input_id == self.bos_token_idselected_hidden_state = hidden_state[indices]if not self.linear_sim:selected_hidden_state = F.normalize(selected_hidden_state, p=2, dim=-1)logits = torch.matmul(selected_hidden_state, selected_hidden_state.t())logits += torch.eye(len(logits), device=logits.device) * -100000.0        # set elements on the diag to -infelse:num_sents, hidden_size = selected_hidden_state.size()# concatenated_hidden_state = torch.zeros(num_sents, num_sents, hidden_size*4, device=selected_hidden_state.device)concatenated_hidden_state = torch.zeros(num_sents, num_sents, hidden_size*3, device=selected_hidden_state.device)concatenated_hidden_state[:, :, :hidden_size] = selected_hidden_state.unsqueeze(1)concatenated_hidden_state[:, :, hidden_size:hidden_size*2] = selected_hidden_state.unsqueeze(0)concatenated_hidden_state[:, :, hidden_size*2:hidden_size*3] = torch.abs(selected_hidden_state.unsqueeze(0) - selected_hidden_state.unsqueeze(1))# concatenated_hidden_state[:, :, hidden_size*3:hidden_size*4] = selected_hidden_state.unsqueeze(0) + selected_hidden_state.unsqueeze(1)logits = self.sim_head(self.dropout(concatenated_hidden_state)).squeeze()     # 但要注意,這里的logits就不能保證是在0-1之間了。需要過sigmoid才能應用在之后的任務中selected_hidden_state_list.append(selected_hidden_state)        # 不同對話的輪數可能不一樣,所以結果可能不能stack起來。logits_list.append(logits)

推理模式下的步驟:

  1. 提取隱層狀態 (Hidden State): 從前面的 Roberta 模型提取每個句子的 hidden state
  2. 計算相似度得分: 線性相似度頭sim_head來計算不同句子之間的相似度得分. 這些得分表示句子間的相似性, 用于判斷是否是同一個發言人 (Speaker)

訓練模式 (Training Mode)

當 label != None, 模型進行訓練模式 (Training Mode). 在這種模式下, 模型的主要任務是通過損失函數來優化模型.

else:# training modeselected_hidden_state_list = list()batch_size = len(labels)for hidden_state, input_id in zip(last_hidden_state, input_ids):indices = input_id == self.bos_token_idselected_hidden_state = hidden_state[indices]if not self.linear_sim:selected_hidden_state = F.normalize(selected_hidden_state, p=2, dim=-1)selected_hidden_state_list.append(selected_hidden_state)losses, logits_list = list(), list()for i, (selected_hidden_state, label) in enumerate(zip(selected_hidden_state_list, labels)):if not self.linear_sim:other_selected_hidden_states = torch.cat([selected_hidden_state_list[j] for j in range(batch_size) if j != i])all_selected_hidden_states = torch.cat([selected_hidden_state, other_selected_hidden_states])logits = torch.matmul(selected_hidden_state, all_selected_hidden_states.t())logits += torch.cat([torch.eye(len(logits), device=logits.device) * -100000.0, torch.zeros(len(logits), len(other_selected_hidden_states), device=logits.device)], dim=-1)if label.numel():losses.append(self.loss_fct(logits[label[:, 0]] / temperature, label[:, 1]))else:num_sents, hidden_size = selected_hidden_state.size()# concatenated_hidden_state = torch.zeros(num_sents, num_sents, hidden_size*4, device=selected_hidden_state.device)concatenated_hidden_state = torch.zeros(num_sents, num_sents, hidden_size*3, device=selected_hidden_state.device)concatenated_hidden_state[:, :, :hidden_size] = selected_hidden_state.unsqueeze(1)concatenated_hidden_state[:, :, hidden_size:hidden_size*2] = selected_hidden_state.unsqueeze(0)concatenated_hidden_state[:, :, hidden_size*2:hidden_size*3] = torch.abs(selected_hidden_state.unsqueeze(0) - selected_hidden_state.unsqueeze(1))# concatenated_hidden_state[:, :, hidden_size*3:hidden_size*4] = selected_hidden_state.unsqueeze(0) + selected_hidden_state.unsqueeze(1)logits = self.sim_head(self.dropout(concatenated_hidden_state)).squeeze()     # 但要注意,這里的logits就不能保證是在0-1之間了。需要過sigmoid才能應用在之后的任務中# 使用mse作為loss。loss包括兩部分,一個是和gold的,一個是和自己的轉置的logits = nn.Sigmoid()(logits)real_labels = torch.zeros_like(logits)if label.numel():real_labels[label[:, 0], label[:, 1]] = 1real_labels += torch.eye(len(logits), device=logits.device)loss = nn.MSELoss()(real_labels, logits) + nn.MSELoss()(logits, logits.transpose(0, 1))losses.append(loss)logits_list.append(logits)loss = torch.mean(torch.stack(losses))return MaskedLMOutput(loss=loss, logits=logits_list, hidden_states=selected_hidden_state_list)

訓練模式下的具體步驟:

  1. 提取隱層狀態 (Hidden State): 從前面的 Roberta 模型提取每個句子的 hidden state
  2. 計算相似度得分: 使用模型輸出的相似度 logits
  3. 計算損失函數: 通過計算 logits 和 real_label 之間的差異
  4. 優化模型: 根據 loss 進行梯度下降, 反向傳播 (Backpropagation)

以防大家沒看懂, 下面我們來逐行解析:

提取隱層狀態:

selected_hidden_state_list = list()
for hidden_state, input_id in zip(last_hidden_state, input_ids):indices = input_id == self.bos_token_idselected_hidden_state = hidden_state[indices]if not self.linear_sim:selected_hidden_state = F.normalize(selected_hidden_state, p=2, dim=-1)selected_hidden_state_list.append(selected_hidden_state)
  • 通過<bos>標注每個句子開始, 并選取對應句子的隱藏狀態

線性層計算相似度:

losses, logits_list = list(), list()
for i, (selected_hidden_state, label) in enumerate(zip(selected_hidden_state_list, labels)):# 根據配置選擇相似度計算方法if not self.linear_sim:# 非線性相似度計算...else:# 線性相似度計算concatenated_hidden_state = ...logits = self.sim_head(self.dropout(concatenated_hidden_state)).squeeze()logits = nn.Sigmoid()(logits)real_labels = torch.zeros_like(logits)if label.numel():real_labels[label[:, 0], label[:, 1]] = 1real_labels += torch.eye(len(logits), device=logits.device)loss = nn.MSELoss()(real_labels, logits) + nn.MSELoss()(logits, logits.transpose(0, 1))losses.append(loss)logits_list.append(logits)
  • 在線相似度計算中, 使用sim_head來計算句子間的相似度得分
  • 計算損失函數. 損失函數的計算分為兩部分:
    • 第一部分: y_predict 和 y_true 之間的差異. 具體為loss_similarity = nn.MSELoss()(real_labels, logits)
    • 第二部分: 計算矩陣的對稱性. 因為句子的相似度是雙向的 (A -> B & B -> A 的相似度應該相同) 所以這邊有一個對稱項來確保 loss 矩陣的對稱性: loss_symmetry = nn.MSELoss()(logits, logits.transpose(0, 1))
    • 相加: loss = loss_similarity + loss_symmetry

注: Baseline 代碼為loss = nn.MSELoss()(real_labels, logits) + nn.MSELoss()(logits, logits.transpose(0, 1)), 我就是拆開了而已, 勿噴.

Deberta

Deberta (Decoding-enhanced Bert with Disentangled Attention) 也是一種 NLP 模型. Deberta 在 Bert (Bidirectional Encoder Representations from Transformers) 和 Roberta (Robustly Optimized Bert Approach) 的基礎上進行了創新和改進, 主要為獨特的注意力機制 (Attention) 和編碼策略, 使得 Deberta 在 NLP 任務重表現出色.

Deberta 創新點

Deberta 的主要創新點:

  1. 解耦注意力機制 (Disentangled Attention Mechanism): Deberta 的解耦注意力機制, 將內容和位置信息分開處理. 在傳統 Bert 和 Roberta 模型重, 注意力機制 (Attention) 同時考慮了內容和位置信息. Deberta 將這兩種信息分離, 允許模型更靈活的學習單詞之間的以來關系
  2. 增強的位置編碼 (Positional Encoding). Deberta 的位置編碼方案不僅考慮了單詞之間相對位置, 還考慮他們在序列中的絕對位置. 這種雙重位置編碼使得 Deberta 能夠更準確的捕捉文本中的結構信息
  3. 動態卷積 (Dynamic Convolution): 相較于 CNN 中的標準卷積, 動態卷積具有更高的靈活性和適應性:
    • 權重的動態生成: 標準卷積中, 權重 (W) 在整個測試集上是固定不變的, 而動態卷積是動態生成的, 根據輸入數據不同而改變
    • 適應性強: 由于卷積核的權重是針對每個輸入樣本動態生成的, 能更好的適應不同的語言模式和上下文環境
    • 捕獲局部依賴: 動態卷積特別刪除捕捉文本中的局部依賴關系, 如短語或局部語義結構, 對于理解復雜的語言表達至關重要

Deberta 構造函數

Deberta 構造函數:

def __init__(self, config):super().__init__(config)self.bos_token_id = config.bos_token_idself.loss_fct = CrossEntropyLoss(reduction='none')

Deberta 前向傳播

同 Roberta

訓練

同 cnn

驗證

同 cnn

參考文獻

比賽鏈接

Baseline 完整代碼

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/215703.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/215703.shtml
英文地址,請注明出處:http://en.pswp.cn/news/215703.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

23種設計模式之裝飾者模式(被裝飾者,接口層,裝飾抽象層,具體裝飾者)

23種設計模式之裝飾者模式 文章目錄 23種設計模式之裝飾者模式設計思想裝飾者模式的優點裝飾者模式的缺點裝飾者模式的優化方法UML 解析預設場景 代碼釋義總結 設計思想 原文:裝飾器模式&#xff08;Decorator Pattern&#xff09;允許向一個現有的對象添加新的功能&#xff0…

應用在LED燈光控制觸摸屏中的觸摸芯片

LED燈光控制觸摸屏方法&#xff0c;包括&#xff1a;建立觸摸屏的觸摸軌跡信息與LED燈光驅動程序的映射關系&#xff1b;檢測用戶施加在觸摸屏上的觸摸軌跡&#xff0c;生成觸摸軌跡信息&#xff1b;根據生成的觸摸軌跡信息&#xff0c;調用對應的LED燈光驅動程序&#xff0c;控…

HJ14 字符串排序

一、題目 描述 給定 n 個字符串&#xff0c;請對 n 個字符串按照字典序排列。數據范圍&#xff1a; 1 \le n \le 1000 \1≤n≤1000 &#xff0c;字符串長度滿足 1 \le len \le 100 \1≤len≤100 輸入描述&#xff1a; 輸入第一行為一個正整數n(1≤n≤1000),下面n行為n個字符…

智能優化算法應用:基于頭腦風暴算法3D無線傳感器網絡(WSN)覆蓋優化 - 附代碼

智能優化算法應用&#xff1a;基于頭腦風暴算法3D無線傳感器網絡(WSN)覆蓋優化 - 附代碼 文章目錄 智能優化算法應用&#xff1a;基于頭腦風暴算法3D無線傳感器網絡(WSN)覆蓋優化 - 附代碼1.無線傳感網絡節點模型2.覆蓋數學模型及分析3.頭腦風暴算法4.實驗參數設定5.算法結果6.…

說說React中的虛擬dom?在虛擬dom計算的時候diff和key之間有什么關系?

虛擬 DOM&#xff08;Virtual DOM&#xff09;是 React 中的一種機制&#xff0c;通過在內存中構建一棵輕量級的虛擬 DOM 樹來代替操作瀏覽器 DOM&#xff0c;從而提高組件的渲染性能和用戶體驗。 在 React 中&#xff0c;當組件的 Props 或 State 發生變化時&#xff0c;Reac…

智能優化算法應用:基于蝙蝠算法3D無線傳感器網絡(WSN)覆蓋優化 - 附代碼

智能優化算法應用&#xff1a;基于蝙蝠算法3D無線傳感器網絡(WSN)覆蓋優化 - 附代碼 文章目錄 智能優化算法應用&#xff1a;基于蝙蝠算法3D無線傳感器網絡(WSN)覆蓋優化 - 附代碼1.無線傳感網絡節點模型2.覆蓋數學模型及分析3.蝙蝠算法4.實驗參數設定5.算法結果6.參考文獻7.MA…

酷開科技多維度賦能營銷,實力斬獲三項大獎

在數智化新階段、廣告新生態、傳播新業態的背景下&#xff0c;“第30屆中國國際廣告節廣告主盛典暨網易傳媒態度營銷峰會”于11月18日在廈門國際會展中心盛大舉行。來自全國的品牌方、戰略決策者、媒體平臺和品牌服務機構等匯聚一堂。在50000&#xff0b;現場觀眾和數千萬線上觀…

openssl的x509命令工具

X509命令是一個多用途的證書工具。它可以顯示證書信息、轉換證書格式、簽名證書請求以及改變證書的信任設置等。 用法&#xff1a; openssl x509 [-inform DER|PEM|NET] [-outform DER|PEM|NET] [-keyform DER|PEM] [-CAform DER|PEM] [-CAkeyform DER|PEM] [-in filename…

vue elementui點擊按鈕新增輸入框(點多少次就新增多少個輸入框,無限新增)

效果如圖&#xff1a; 核心代碼&#xff1a; <div v-for"(item,index) in arrayData" :key"item.id">//上面這個是關鍵代碼&#xff0c;所有思路靠這個打通<el-inputtype"input" //除了輸入框&#xff0c;還有textarea等placeholder&…

SequentialChain

以下是使用SequentialChain創建Java代碼的單元測試代碼的示例&#xff1a; import sequentialchain.SequentialChain; import static org.junit.Assert.assertEquals; import org.junit.Test;public class SequentialChainTest {Testpublic void testAdd() {SequentialChain&l…

k8s詳細教程(一)

—————————————————————————————————————————————— 博主介紹&#xff1a;Java領域優質創作者,博客之星城市賽道TOP20、專注于前端流行技術框架、Java后端技術領域、項目實戰運維以及GIS地理信息領域。 &#x1f345;文末獲取源碼…

《Spring Cloud Alibaba 從入門到實戰》分布式消息(事件)驅動

分布式消息&#xff08;事件&#xff09;驅動 1、簡介 事件驅動架構(Event-driven 架構&#xff0c;簡稱 EDA)是軟件設計領域內的一套程序設計模型。 這套模型的意義是所有的操作通過事件的發送/接收來完成。 傳統軟件設計 舉個例子&#xff0c;比如一個訂單的創建在傳統軟…

「差生文具多系列」推薦兩個好看的 Redis 客戶端

&#x1f4e2;?聲明&#xff1a; &#x1f344; 大家好&#xff0c;我是風箏 &#x1f30d; 作者主頁&#xff1a;【古時的風箏CSDN主頁】。 ?? 本文目的為個人學習記錄及知識分享。如果有什么不正確、不嚴謹的地方請及時指正&#xff0c;不勝感激。 直達博主&#xff1a;「…

ModuleNotFoundError: No module named ‘huggingface_hub.snapshot_download‘

ModuleNotFoundError: No module named ‘huggingface_hub.snapshot_download’ 的解決方法 根據提示顯示XXX模塊不存在&#xff0c;一般會直接安裝XXX模塊&#xff0c;但是這個不需要顯式安裝huggingface-hub。 只需要升級sentence-transformers即可。 pip install -U sente…

Innosetup 安裝包 在安裝前判斷是否有其他安裝程序正在安裝...

方法有&#xff1a; 1.使用系統服務WinMgmts 系統信息通過 "winmgmts:\\.\root\CIMV2" 遍歷進程列表。 var FSWbemLocator: Variant; FWMIService : Variant; FWbemObjectSet: Variant; begin Result : false; FSWbemLocator : CreateOleObject(WBEMScripti…

Fabric使用自己的鏈碼進行測試-go語言

書接前文 Fabric鏈碼部署-go語言 通過上面這篇文章&#xff0c;你可以部署好自己的鏈碼 &#xff08;后面很多命令是否需要修改&#xff0c;都是根據上面這篇文章來的&#xff0c;如果零基礎的話建議先看上面這篇&#xff09; 就進行下一步 在測試網絡上運行自己的鏈碼 目…

PDF文件的限制編輯,如何設置?

想要給PDF文件設置一個密碼防止他人對文件進行編輯&#xff0c;那么我們可以對PDF文件設置限制編輯&#xff0c;設置方法很簡單&#xff0c;我們在PDF編輯器中點擊文件 – 屬性 – 安全&#xff0c;在權限下拉框中選中【密碼保護】 然后在密碼保護界面中&#xff0c;我們勾選【…

系列十、SpringBoot + MyBatis + Redis實現分布式緩存(基于注解方式)

一、概述 上篇文章 系列九、SpringBoot MyBatis Redis實現分布式緩存 介紹了基于xml方式實現分布式緩存的效果&#xff0c;當前大家使用的技術棧基本是springboot各種框架的組合&#xff0c;而springboot顯著的一個特點就是去xml配置&#xff0c;那么在無xml配置的情形下&…

CStdioFile

CStdioFile 文件創建、數據寫入、寫入路徑 void StdReferenceDWG::RefDrawCrvt(StdOneReference& ref) {char* old_locale _strdup(setlocale(LC_CTYPE, NULL));setlocale(LC_CTYPE, "chs");//設定CString strPath StdTool::GetCurPath() _T("襯圖\\Re…

界面控件DevExpress中文教程 - 如何用Office File API組件填充PDF表單

DevExpress Office File API是一個專為C#, VB.NET 和 ASP.NET等開發人員提供的非可視化.NET庫。有了這個庫&#xff0c;不用安裝Microsoft Office&#xff0c;就可以完全自動處理Excel、Word等文檔。開發人員使用一個非常易于操作的API就可以生成XLS, XLSx, DOC, DOCx, RTF, CS…