這段代碼定義了一個簡單的對話生成系統,包括模型加載、詞匯表加載、以及基于給定提示生成文本的功能。下面是對代碼的解析:
-
load_model_and_voc(device="cpu")
:- 該函數用于加載預訓練的模型和詞匯表(vocabulary)。它首先從文件
total_voc.pkl
中加載詞匯表,并創建一個名為SamOut
的神經網絡實例。 - 模型參數的數量被打印出來以供參考。
- 然后嘗試加載指定路徑下的預訓練權重到模型中,并將模型移動到指定的設備(CPU 或 GPU)上。
- 最后設置模型為評估模式(
.eval()
),并返回模型和詞匯表。
- 該函數用于加載預訓練的模型和詞匯表(vocabulary)。它首先從文件
-
gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cpu")
:- 這個函數負責根據提供的提示(prompt)生成新的文本序列。
- 它接受多個參數,包括詞匯表、模型、初始提示、最大生成長度等。
- 函數內部實現了重復抑制、溫度調整和top-k采樣等技術來控制生成文本的質量。
- 使用softmax函數對模型輸出進行處理,并通過多類別抽樣選擇下一個token。
- 如果生成了特殊的開始標記
<|sos|>
,則停止生成過程。 - 生成的每個token會立即打印在屏幕上,形成即時響應的效果。
-
t_infre()
:- 此函數是交互式推理循環,允許用戶輸入文本,然后調用
gen_token
函數來生成回應。 - 它是一個無限循環,持續等待用戶的輸入直到程序被手動終止。
- 此函數是交互式推理循環,允許用戶輸入文本,然后調用
-
if __name__ == '__main__':
- 這部分代碼確保當腳本作為主程序運行時,會執行某些特定的操作或測試。
- 注釋掉的代碼可能是之前用于數據預處理、訓練或其他實驗的部分。
- 最終調用了
t_infre()
函數來啟動交互式推理。
需要注意的是,這里使用的 SamOut
類并沒有在給出的代碼片段中定義,因此你可能需要確保這個類已經被正確實現并在其他地方導入。此外,為了使代碼能夠正常工作,你需要確保所有依賴庫(如 PyTorch 和 pandas)已經安裝,并且所有提及的數據文件和模型權重文件都存在于正確的路徑下。
def load_model_and_voc(device="cpu"):voc = pd.read_pickle("total_voc.pkl")net = SamOut(len(voc["voc"]), 1024 + 512, 64, 16)# net = SamOut(len(voc["voc"]), 512, 32, 8)print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum([i.shape[0] for i in net.parameters() if len(i.shape) == 1]))# net.load_state_dict(torch.load("pretrain_768.pth", map_location=device))# net.load_state_dict(torch.load("pretrain_sft_single.pth", map_location=device))net.load_state_dict(torch.load("pretrain_sft_single_1024.pth", map_location=device))# net.load_state_dict(torch.load("pretrain.pth", map_location=device))net.to(device)net.eval()return net, vocdef gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cpu"):print("agent:", end="", flush=True)for _ in range(max_len):prompt_list = []for i in prompt:if i not in voc["voc"]:prompt_list += [voc["voc"].index(ii) for ii in voc["voc0"].get(i)]else:prompt_list.append(voc["voc"].index(i))out, _ = model(torch.Tensor([prompt_list]).to(device).long())out = out[:, -1:]# 重復抑制for token_id in enumerate(prompt_list):out[:, :, token_id] /= rpscore = torch.softmax(out, -1)[0, 0]score, score_index = torch.sort(score,descending=True)score=score.detach().numpy()score_sum = np.cumsum(score)score_index = score_index.detach().numpy()score1=score[score_sum<0.8]if score1.size==0:score=score[:1]else:score=score1score_index=score_index[:score.size]out = score / tempv= out[:min(top_k, score.size)]idx_next = torch.multinomial(torch.Tensor(v), num_samples=1, generator=None)if voc["voc"][score_index[idx_next.item()]] == "<|sos|>":breakprompt += [voc["voc"][score_index[idx_next.item()]]]print(prompt[-1], end="", flush=True)def t_infre():model, voc = load_model_and_voc()while True:text = input("user:")gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 64)print()if __name__ == '__main__':# print(pd.read_pickle("loss916"))# gen_one_voc()# gen_voc()# for i in range(17,18):# gen_pre_data_align(i, 16)# train()# gen_sft_single_data_align()# train_single()# sft 推理 一本正經的胡說八道已練成t_infre()