本程序是一個基于 Gradio 和 Ollama API 構建的支持多輪對話的寫作助手。相較于上一版本,本版本新增了對話歷史記錄、Token 計數、參數調節和清空對話功能,顯著提升了用戶體驗和交互靈活性。
程序通過抽象基類 LLMAgent
實現模塊化設計,當前使用 OllamaAgent
作為具體實現,調用本地部署的 Ollama 大語言模型(如 qwen3:8b
)生成寫作建議,并提供一個交互式的 Web 界面供用戶操作。
設計支持未來擴展到其他 LLM 平臺(如 OpenAI、HuggingFace),只需實現新的 LLMAgent
子類即可。
環境配置
依賴安裝
需要以下 Python 庫:
gradio
:用于創建交互式 Web 界面。requests
:向 Ollama API 發送 HTTP 請求。json
:解析 API 響應數據(Python 內置)。logging
:記錄運行日志(Python 內置)。abc
:定義抽象基類(Python 內置)。tiktoken
:精確計算 Token 數量以管理輸入和歷史長度。
安裝命令:
pip install gradio requests tiktoken
建議使用 Python 3.8 或更高版本。
Ollama 服務配置
-
安裝 Ollama
從 https://ollama.ai/ 下載并安裝。 -
啟動 Ollama 服務
ollama serve
- 默認監聽地址:
http://localhost:11434
。
- 默認監聽地址:
-
下載模型
ollama pull qwen3:8b
-
驗證模型
ollama list
運行程序
- 將代碼保存為
writing_assistant.py
。 - 確保 Ollama 服務正在運行。
- 執行程序:
python writing_assistant.py
- 打開瀏覽器訪問界面(通常為
http://127.0.0.1:7860
)。 - 輸入寫作提示,調整參數后點擊“獲取寫作建議”,查看結果和對話歷史。
代碼說明
1. 依賴導入
import gradio as gr
import requests
import json
import logging
from abc import ABC, abstractmethod
import tiktoken
tiktoken
:精確計算 Token 數量,優化輸入控制。
2. 日志配置
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
- 配置日志級別為
INFO
,記錄 API 調用和錯誤信息,便于調試。
3. 抽象基類:LLMAgent
class LLMAgent(ABC):@abstractmethoddef generate_response(self, prompt):pass
- 定義通用 LLM 代理接口,要求實現
generate_response
方法。 - 支持未來擴展到其他 LLM 平臺(如 OpenAI、Anthropic)。
4. 具體實現:OllamaAgent
class OllamaAgent(LLMAgent):def __init__(self, config): ...def set_max_history_length(self, max_rounds): ...def set_parameters(self, max_tokens, temperature): ...def generate_response(self, prompt): ...def clear_history(self): ...
- 對話歷史:維護
history
列表,支持多輪對話。 - 參數調節:動態調整
max_tokens
和temperature
。 - Token 管理:自動截斷歷史記錄,防止超出模型上下文限制。
- 錯誤處理:捕獲網絡請求失敗和 JSON 解析錯誤,返回用戶友好的提示。
5. Token 計數函數
def calculate_tokens(text): ...
def calculate_history_tokens(history): ...
- 使用
tiktoken
精確估算 Token 數量,提升輸入長度控制能力。
6. 歷史格式化
def format_history_for_chatbot(history): ...
- 將內部
history
結構轉換為 Gradio 的Chatbot
格式[user_msg, assistant_msg]
。
7. 核心邏輯:generate_assistance
def generate_assistance(prompt, agent, max_rounds, max_tokens, temperature): ...
- 設置最大對話輪數和生成參數。
- 調用
agent.generate_response
獲取響應。 - 返回格式化的對話歷史、最新回復和 Token 計數。
8. 輔助函數
def update_token_count(prompt): ...
def clear_conversation(agent): ...
- 實時更新輸入 Token 數量。
- 清空對話歷史并重置狀態。
9. 主函數:main
def main():config = { ... }agent = OllamaAgent(config)with gr.Blocks(...) as demo:...demo.launch()
- 增強的 UI:包含輸入框、Token 顯示、參數調節滑塊和清空按鈕。
- Gradio 事件綁定:
prompt_input.change()
:動態更新 Token 計數。submit_button.click()
:觸發寫作建議生成。clear_button.click()
:重置對話歷史。
運行流程圖
graph TDA[用戶輸入提示] --> B[點擊 submit_button]B --> C[調用 generate_assistance(prompt, agent, 參數)]C --> D[調用 agent.set_* 設置參數]D --> E[調用 agent.generate_response(prompt)]E --> F[向 Ollama API 發送 POST 請求]F --> G[接收并解析 JSON 響應]G --> H[更新聊天歷史和輸出結果]
注意事項
- Ollama 服務:確保服務運行并監聽在
http://localhost:11434/v1
。 - 模型可用性:確認
qwen3:8b
已下載。 - Token 上限:注意模型的最大上下文長度(如 4096 Tokens),避免歷史過長導致超限。
- 參數影響:
temperature
:控制生成隨機性(較低值更確定,較高值更具創造性)。max_tokens
:限制輸出長度。
- 調試信息:查看終端日志,確認 API 響應是否正常或是否有錯誤。
未來改進建議
- 多模型支持:添加
OpenAIAgent
等子類,通過下拉菜單切換模型。 - 配置文件化:將硬編碼配置移至 JSON/YAML 文件。
- 異步請求:使用
aiohttp
替換requests
,提升并發性能。 - 對話持久化:將歷史對話保存到本地文件或數據庫。
- 用戶認證:區分不同用戶的對話記錄。
- 移動端適配:優化界面布局以適配手機端。
示例使用
- 啟動程序后訪問
http://127.0.0.1:7860
。 - 輸入提示:“幫我寫一段關于環保的文章。”
- 調整參數(如
max_tokens=1000
,temperature=0.2
)。 - 點擊“獲取寫作建議”,查看類似以下輸出:
在這里插入圖片描述
代碼
import gradio as gr
import requests
import json
import logging
from abc import ABC, abstractmethod
import tiktoken# 設置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)# 抽象基類:定義通用的 LLM Agent 接口
class LLMAgent(ABC):@abstractmethoddef generate_response(self, prompt):pass# Ollama 特定的實現
class OllamaAgent(LLMAgent):def __init__(self, config):self.model = config["model"]self.base_url = config["base_url"]self.api_key = config["api_key"]self.max_tokens = config["max_tokens"]self.temperature = config["temperature"]self.history = []self.max_history_length = 10def set_max_history_length(self, max_rounds):self.max_history_length = int(max_rounds * 2)if len(self.history) > self.max_history_length:self.history = self.history[-self.max_history_length:]def set_parameters(self, max_tokens, temperature):self.max_tokens = int(max_tokens)self.temperature = float(temperature)def generate_response(self, prompt):self.history.append({"role": "user", "content": prompt})if len(self.history) > self.max_history_length:self.history = self.history[-self.max_history_length:]url = f"{self.base_url}/chat/completions"headers = {"Authorization": f"Bearer {self.api_key}","Content-Type": "application/json"}payload = {"model": self.model,"messages": self.history,"max_tokens": self.max_tokens,"temperature": self.temperature}try:response = requests.post(url, headers=headers, json=payload)response.raise_for_status()result = response.json()content = result['choices'][0]['message']['content']logger.info(f"API 響應: {content}")self.history.append({"role": "assistant", "content": content})return contentexcept requests.exceptions.RequestException as e:logger.error(f"API 請求失敗: {str(e)}")return f"錯誤:無法連接到 Ollama API: {str(e)}"except KeyError as e:logger.error(f"解析響應失敗: {str(e)}")return f"錯誤:解析響應失敗: {str(e)}"def clear_history(self):self.history = []def calculate_tokens(text):if not text:return 0cleaned_text = text.strip().replace('\n', '')try:encoding = tiktoken.get_encoding("cl100k_base")tokens = encoding.encode(cleaned_text)return len(tokens)except Exception as e:logger.error(f"Token 計算失敗: {str(e)}")return len(cleaned_text)def calculate_history_tokens(history):total_tokens = 0try:encoding = tiktoken.get_encoding("cl100k_base")for message in history:content = message["content"].strip()tokens = encoding.encode(content)total_tokens += len(tokens)return total_tokensexcept Exception as e:logger.error(f"歷史 Token 計算失敗: {str(e)}")return sum(len(msg["content"].strip()) for msg in history)def format_history_for_chatbot(history):"""將 agent.history 轉換為 gr.Chatbot 所需格式:List[List[str, str]]"""messages = []for i in range(0, len(history) - 1, 2):if history[i]["role"] == "user" and history[i+1]["role"] == "assistant":messages.append([history[i]["content"], history[i+1]["content"]])return messagesdef generate_assistance(prompt, agent, max_rounds, max_tokens, temperature):agent.set_max_history_length(max_rounds)agent.set_parameters(max_tokens, temperature)response = agent.generate_response(prompt)history_tokens = calculate_history_tokens(agent.history)chatbot_format_history = format_history_for_chatbot(agent.history)return chatbot_format_history, response, f"歷史總 token 數(估算):{history_tokens}"def update_token_count(prompt):return f"當前輸入 token 數(精確):{calculate_tokens(prompt)}"def clear_conversation(agent):agent.clear_history()return [], "對話已清空", "歷史總 token 數(估算):0"def main():config = {"api_type": "ollama","model": "qwen3:8b","base_url": "http://localhost:11434/v1","api_key": "ollama","max_tokens": 1000,"temperature": 0.2}agent = OllamaAgent(config)with gr.Blocks(title="寫作助手") as demo:gr.Markdown("# 寫作助手(支持多輪對話)")gr.Markdown("輸入您的寫作提示,獲取建議和指導!支持連續對話,調整對話輪數、max_tokens 和 temperature,或點擊“清空對話”重置。")with gr.Row():with gr.Column():prompt_input = gr.Textbox(label="請輸入您的提示",placeholder="例如:幫我寫一段關于環保的文章",lines=3)token_count = gr.Textbox(label="輸入 token 數",value="當前輸入 token 數(精確):0",interactive=False)history_token_count = gr.Textbox(label="歷史 token 數",value="歷史總 token 數(估算):0",interactive=False)max_rounds = gr.Slider(minimum=1,maximum=10,value=5,step=1,label="最大對話輪數",info="設置保留的對話輪數(每輪包含用戶和模型消息)")max_tokens = gr.Slider(minimum=100,maximum=2000,value=1000,step=100,label="最大生成 token 數",info="控制單次生成的最大 token 數")temperature = gr.Slider(minimum=0.0,maximum=1.0,value=0.2,step=0.1,label="Temperature",info="控制生成隨機性,0.0 為確定性,1.0 為較隨機")submit_button = gr.Button("獲取寫作建議")clear_button = gr.Button("清空對話")with gr.Column():chatbot = gr.Chatbot(label="對話歷史")output = gr.Textbox(label="最新生成結果", lines=5)prompt_input.change(fn=update_token_count,inputs=prompt_input,outputs=token_count)submit_button.click(fn=generate_assistance,inputs=[prompt_input, gr.State(value=agent), max_rounds, max_tokens, temperature],outputs=[chatbot, output, history_token_count])clear_button.click(fn=clear_conversation,inputs=gr.State(value=agent),outputs=[chatbot, output, history_token_count])demo.launch()if __name__ == "__main__":main()