文章目錄
- Toolkit 作用
- Toolkit 逐函數解析
- 1. 獲取默認配置
- 2. update_config
- 3. config
- 4. `__init__`
- 5. get_reddit_news
- 6. get_finnhub_news
- 7. get_reddit_stock_info
- 8. get_chinese_social_sentiment
- 9. get_finnhub_company_insider_sentiment
- 10. get_YFin_data
- 11. get_YFin_data_online
- 12. get_stockstats_indicators_report
- 13. get_stockstats_indicators_report_online
- 14. get_simfin_balance_sheet 資產負債表
- 15. get_simfin_income_statement利潤表
- 16. get_simfin_cashflow_statement 現金流量表
- 17. get_simfin_ratios 財務指標比率
- 18. get_simfin_company_info 公司基本信息
- 19. get_simfin_shareprices 歷史股價數據
- 20. get_fundamentals_openai 財務+估值
- 21. get_china_fundamentals A 股財務數據
- 22. get_stock_fundamentals_unified 整合財務+估值
- create_msg_delete函數(非Toolkit類)
- ChromaDBManager
- 類說明
- 逐函數解析
- 類屬性
- `__new__(cls)`
- `__init__(self)`
- get_or_create_collection(self, name: str)
- FinancialSituationMemory類
- 類說明
- 簡化后代碼
- Example
- 逐函數解析
- `__init__(self, config)`
- `_smart_text_truncation(self, text, max_length)`
- get_embedding(self, text, max_length=1000)
- get_embedding_config_status(self)
- get_last_text_info(self)
- add_situations(self, situations)
- get_cache_info(self)
- get_memories(self, situation, n_results=5)
Toolkit 作用
-
Toolkit(tradingagents/agents/utils/agent_utils.py)
是 TradingAgents 框架里的工具集,封裝了各種外部數據源的調用接口(新聞、財務、行情、情緒、指標等),并通過@tool
裝飾器暴露給 LLM 使用。 -
TradingAgents Toolkit 功能總覽表
模塊類別 | 主要函數/工具 | 功能描述 | 數據來源 | 狀態 |
---|---|---|---|---|
市場行情 (Market) | get_market_data_unified | 獲取股票/指數的市場行情數據(價格、成交量、技術指標) | Yahoo Finance, Tushare, 東方財富等 | |
get_price_history | 獲取歷史 K 線數據 | 同上 | ||
get_technical_indicators | 生成技術指標(MA, RSI, MACD 等) | 本地計算 | ||
新聞 (News) | get_news_articles | 抓取金融新聞 | Google News, 東方財富新聞等 | |
summarize_news | 對新聞做摘要 | LLM 處理 | ||
社交媒體 (Social) | get_social_sentiment | 獲取社交媒體情緒(Twitter, 微博) | API / 爬蟲 | 數據源可能受限 |
analyze_sentiment | 使用 LLM 對評論、帖子做情緒分析 | LLM | ||
基本面 (Fundamentals) | get_stock_fundamentals_unified | 統一接口:獲取美股/A股/港股的財務數據與估值 | Yahoo Finance, SimFin, Tushare, AKShare, 東方財富 | |
get_fundamentals_openai | 使用 OpenAI Agent 調用財務數據 | OpenAI Agent | 廢棄 | |
get_china_fundamentals | 獲取中國 A 股財務數據(舊接口) | Tushare, AKShare | 廢棄 | |
風險分析 (Risk) | get_risk_metrics | 計算風險指標(波動率、夏普比率、回撤) | 本地計算 | |
stress_test | 壓力測試(不同市場情景下的資產表現) | 本地模擬 | ||
portfolio_risk_analysis | 投資組合風險分析 | 本地計算 + 歷史行情 |
Toolkit 逐函數解析
- 大部分函數都是通過
tradingagents.dataflows.interface
獲取數據。
1. 獲取默認配置
from tradingagents.default_config import DEFAULT_CONFIG
_config = DEFAULT_CONFIG.copy()
- 獲取默認配置,默認配置如下:
import osDEFAULT_CONFIG = {"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),"data_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data"),"data_cache_dir": os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),"dataflows/data_cache",),# LLM settings"llm_provider": "openai","deep_think_llm": "o4-mini","quick_think_llm": "gpt-4o-mini","backend_url": "https://api.openai.com/v1",# Debate and discussion settings"max_debate_rounds": 1,"max_risk_discuss_rounds": 1,"max_recur_limit": 100,# Tool settings"online_tools": True,# Note: Database and cache configuration is now managed by .env file and config.database_manager# No database/cache settings in default config to avoid configuration conflicts
}
2. update_config
@classmethod
def update_config(cls, config):"""Update the class-level configuration."""cls._config.update(config)
- 更新
Toolkit
的全局配置(類級別),比如數據源 API key、默認參數等。
3. config
@property
def config(self):"""Access the configuration."""return self._config
- 返回當前配置。
4. __init__
def __init__(self, config=None):if config:self.update_config(config)
- 構造函數,可傳入配置并更新默認配置。
5. get_reddit_news
@staticmethod
@tool
def get_reddit_news(curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
) -> str:"""Retrieve global news from Reddit within a specified time frame.Args:curr_date (str): Date you want to get news for in yyyy-mm-dd formatReturns:str: A formatted dataframe containing the latest global news from Reddit in the specified time frame."""global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)return global_news_result
- 從 Reddit 獲取某天起過去 7 天內的 全球新聞(最多 5 條),主要用于宏觀輿情分析。
6. get_finnhub_news
@staticmethod
@tool
def get_finnhub_news(ticker: Annotated[str, "Search query of a company, e.g. 'AAPL, TSM, etc."],start_date: Annotated[str, "Start date in yyyy-mm-dd format"],end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):"""Retrieve the latest news about a given stock from Finnhub within a date rangeArgs:ticker (str): Ticker of a company. e.g. AAPL, TSMstart_date (str): Start date in yyyy-mm-dd formatend_date (str): End date in yyyy-mm-dd formatReturns:str: A formatted dataframe containing news about the company within the date range from start_date to end_date"""end_date_str = end_dateend_date = datetime.strptime(end_date, "%Y-%m-%d")start_date = datetime.strptime(start_date, "%Y-%m-%d")look_back_days = (end_date - start_date).daysfinnhub_news_result = interface.get_finnhub_news(ticker, end_date_str, look_back_days)return finnhub_news_result
- 調用 Finnhub API 獲取指定股票在
start_date ~ end_date
的新聞。
7. get_reddit_stock_info
@staticmethod
@tool
def get_reddit_stock_info(ticker: Annotated[str, "Ticker of a company. e.g. AAPL, TSM"],curr_date: Annotated[str, "Current date you want to get news for"],
) -> str:"""Retrieve the latest news about a given stock from Reddit, given the current date.Args:ticker (str): Ticker of a company. e.g. AAPL, TSMcurr_date (str): current date in yyyy-mm-dd format to get news forReturns:str: A formatted dataframe containing the latest news about the company on the given date"""stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)return stock_news_results
- 從 Reddit 獲取某只股票的近期新聞和討論。
8. get_chinese_social_sentiment
@staticmethod
@tool
def get_chinese_social_sentiment(ticker: Annotated[str, "股票代碼,例如 '600519.SS' 或 '000001.SZ'"],curr_date: Annotated[str, "要獲取的日期,yyyy-mm-dd 格式"],
) -> str:"""獲取某只股票在中國社交媒體上的情緒分析。Args:ticker (str): 股票代碼 (A股格式)curr_date (str): 要獲取的日期Returns:str: 包含中國社交媒體情緒分析的格式化 dataframe"""try:sentiment_result = interface.get_chinese_sentiment(ticker, curr_date, 7)except Exception as e:logger.warning(f"中國輿情數據獲取失敗,回退到 Reddit: {e}")sentiment_result = interface.get_reddit_company_news(ticker, curr_date, 7, 5)return sentiment_result
- 獲取 A 股股票在 中國本土社交媒體 上的輿情/情緒數據。
- 如果中國數據源失敗 → 自動回退到 Reddit。
- 常用于國內公司投資情緒分析。
9. get_finnhub_company_insider_sentiment
@staticmethod
@tool
def get_finnhub_company_insider_sentiment(ticker: Annotated[str, "股票代碼,例如 'AAPL', 'TSM'"],
) -> str:"""獲取公司內部人買賣行為的情緒數據。Args:ticker (str): 公司代碼Returns:str: 格式化的內部人情緒數據"""insider_sentiment = interface.get_finnhub_insider_sentiment(ticker)return insider_sentiment
- 調用 Finnhub API 獲取 內部人買賣股票的情緒指標(比如 CEO、CFO 買入/賣出, 屬于“聰明錢”指標)。
10. get_YFin_data
@staticmethod
@tool
def get_YFin_data(ticker: Annotated[str, "公司股票代碼,例如 'AAPL', 'TSM'"],period: Annotated[str, "數據周期,例如 '1mo', '6mo', '1y'"],
) -> str:"""獲取 Yahoo Finance 歷史行情數據。Args:ticker (str): 股票代碼period (str): 時間范圍Returns:str: 格式化的歷史行情 dataframe"""df = interface.get_yahoo_finance_history(ticker, period)return df.to_string()
- 從 Yahoo Finance 獲取某只股票的歷史行情(K 線數據)。
11. get_YFin_data_online
@staticmethod
@tool
def get_YFin_data_online(ticker: Annotated[str, "公司股票代碼"],start_date: Annotated[str, "開始日期 yyyy-mm-dd"],end_date: Annotated[str, "結束日期 yyyy-mm-dd"],
) -> str:"""獲取 Yahoo Finance 區間行情。"""df = interface.get_yahoo_finance_range(ticker, start_date, end_date)return df.to_string()
- 類似
get_YFin_data
,但支持 指定開始和結束日期。
12. get_stockstats_indicators_report
@staticmethod
@tool
def get_stockstats_indicators_report(ticker: Annotated[str, "公司股票代碼"],period: Annotated[str, "分析周期,例如 '6mo'"],
) -> str:"""獲取技術指標分析報告(基于 stockstats)。"""df = interface.get_stockstats_indicators(ticker, period)return df.to_string()
- 基于 stockstats 庫計算技術指標(均線、RSI、MACD 等)。
13. get_stockstats_indicators_report_online
@staticmethod
@tool
def get_stockstats_indicators_report_online(ticker: Annotated[str, "公司股票代碼"],start_date: Annotated[str, "開始日期"],end_date: Annotated[str, "結束日期"],
) -> str:"""獲取指定區間的技術指標分析。"""df = interface.get_stockstats_indicators_range(ticker, start_date, end_date)return df.to_string()
get_stockstats_indicators_report
類似,但支持 自定義區間。
14. get_simfin_balance_sheet 資產負債表
@staticmethod
@tool
def get_simfin_balance_sheet(ticker: Annotated[str, "公司股票代碼,例如 'AAPL', 'TSM'"],report_type: Annotated[str, "報告類型,例如 'annual', 'quarterly'"],
) -> str:"""獲取公司資產負債表 (Balance Sheet)。"""df = interface.get_simfin_balance_sheet(ticker, report_type)return df.to_string()
- 從 SimFin API 獲取 資產負債表。
- 可選年度 (annual) 或季度 (quarterly)。
- 返回格式化的 dataframe。
15. get_simfin_income_statement利潤表
@staticmethod
@tool
def get_simfin_income_statement(ticker: Annotated[str, "公司股票代碼"],report_type: Annotated[str, "annual 或 quarterly"],
) -> str:"""獲取公司利潤表 (Income Statement)。"""df = interface.get_simfin_income_statement(ticker, report_type)return df.to_string()
- 獲取 利潤表(營業收入、凈利潤、毛利率等)。
- 主要用于盈利能力分析。
16. get_simfin_cashflow_statement 現金流量表
@staticmethod
@tool
def get_simfin_cashflow_statement(ticker: Annotated[str, "公司股票代碼"],report_type: Annotated[str, "annual 或 quarterly"],
) -> str:"""獲取公司現金流量表 (Cashflow Statement)。"""df = interface.get_simfin_cashflow_statement(ticker, report_type)return df.to_string()
- 獲取 現金流量表(經營活動、投資活動、融資活動現金流)。
- 用于衡量公司“造血能力”和資金鏈穩定性。
17. get_simfin_ratios 財務指標比率
@staticmethod
@tool
def get_simfin_ratios(ticker: Annotated[str, "公司股票代碼"],report_type: Annotated[str, "annual 或 quarterly"],
) -> str:"""獲取公司財務比率 (Ratios),例如 PE、ROE、負債率。"""df = interface.get_simfin_ratios(ticker, report_type)return df.to_string()
- 獲取 財務指標比率(PE, PB, ROE, 負債率, 流動比率)。
- 適合做跨公司對比。
18. get_simfin_company_info 公司基本信息
@staticmethod
@tool
def get_simfin_company_info(ticker: Annotated[str, "公司股票代碼"]
) -> str:"""獲取公司基本信息(行業、地區、規模等)。"""df = interface.get_simfin_company_info(ticker)return df.to_string()
- 獲取公司的 基本信息(行業分類、上市地、公司規模)。
- 在做行業對比或聚類分析時很有用。
19. get_simfin_shareprices 歷史股價數據
@staticmethod
@tool
def get_simfin_shareprices(ticker: Annotated[str, "公司股票代碼"],start_date: Annotated[str, "開始日期"],end_date: Annotated[str, "結束日期"],
) -> str:"""獲取公司歷史股價 (Share Prices)。"""df = interface.get_simfin_shareprices(ticker, start_date, end_date)return df.to_string()
- 獲取 SimFin 的歷史股價數據。
- 類似 Yahoo Finance,但數據源不同。
20. get_fundamentals_openai 財務+估值
@staticmethod
@tool
def get_fundamentals_openai(ticker: Annotated[str, "公司代碼,例如 AAPL, TSLA"],report_type: Annotated[str, "annual 或 quarterly"] = "annual",
) -> str:"""使用 OpenAI Agent 獲取公司財務和估值信息。(已廢棄,推薦使用 get_stock_fundamentals_unified)"""logger.warning("?? [DEPRECATED] 推薦使用 get_stock_fundamentals_unified() 代替")return interface.get_fundamentals_openai(ticker, report_type)
- 原始版本,用 OpenAI Agent 來獲取財務+估值。
- 已經 廢棄,現在統一整合進
get_stock_fundamentals_unified
。
- 已經 廢棄,現在統一整合進
21. get_china_fundamentals A 股財務數據
@staticmethod
@tool
def get_china_fundamentals(ticker: Annotated[str, "中國股票代碼,例如 600519"],report_type: Annotated[str, "年度/季度"] = "annual",
) -> str:"""獲取中國 A 股財務數據(通過 Tushare 或 AKShare)。(已廢棄,推薦使用 get_stock_fundamentals_unified)"""logger.warning("?? [DEPRECATED] 推薦使用 get_stock_fundamentals_unified() 代替")return interface.get_china_fundamentals(ticker, report_type)
- 早期用于獲取 A 股財務數據。
- 數據源:Tushare / AKShare。
- 已 廢棄,功能已被統一接口替代。
22. get_stock_fundamentals_unified 整合財務+估值
@staticmethod
@tool
def get_stock_fundamentals_unified(ticker: Annotated[str, "股票代碼,例如 AAPL, 600519, 00700.HK"],market: Annotated[str, "市場類型:us / cn / hk"] = "us",report_type: Annotated[str, "annual 或 quarterly"] = "annual",
) -> str:"""統一接口:自動識別市場并獲取財務數據與估值。"""return interface.get_stock_fundamentals_unified(ticker, market, report_type)
- 核心統一入口,整合了所有財務+估值工具:
- 美股 → Yahoo Finance + SimFin
- A 股 → Tushare + AKShare
- 港股 → Yahoo Finance + 東方財富
- 自動根據
market
參數(us/cn/hk)選擇合適數據源。 - 返回內容包括:
- 財報(資產負債表、利潤表、現金流)
- 估值指標(PE, PB, PEG, ROE 等)
- 行業對比
create_msg_delete函數(非Toolkit類)
def create_msg_delete():def delete_messages(state):"""Clear messages and add placeholder for Anthropic compatibility"""messages = state["messages"]# Remove all messagesremoval_operations = [RemoveMessage(id=m.id) for m in messages]# Add a minimal placeholder messageplaceholder = HumanMessage(content="Continue")return {"messages": removal_operations + [placeholder]}return delete_messages
- 清空消息歷史,并插入一個占位的
HumanMessage("Continue")
,保證在像 Anthropic 這類模型里保持對話兼容性。
ChromaDBManager
類說明
- 目標:保證整個項目里只存在一個 ChromaDB 客戶端,避免多線程或多進程同時初始化帶來的沖突。
- 關鍵點:用了 單例模式 + 線程鎖 來保證全局唯一性。
class ChromaDBManager:"""單例ChromaDB管理器,避免并發創建集合的沖突"""_instance = None_lock = threading.Lock()_collections: Dict[str, any] = {}_client = Nonedef __new__(cls):if cls._instance is None:...cls._instance = super(ChromaDBManager, cls).__new__(cls)cls._instance._initialized = Falsereturn cls._instancedef __init__(self):if not self._initialized:try:...self._initialized = Truedef get_or_create_collection(self, name: str):"""線程安全地獲取或創建集合"""with self._lock:if name in self._collections:logger.info(f"📚 [ChromaDB] 使用緩存集合: {name}")return self._collections[name]try:# 嘗試獲取現有集合collection = self._client.get_collection(name=name)logger.info(f"📚 [ChromaDB] 獲取現有集合: {name}")except Exception:try:# 創建新集合...# 緩存集合self._collections[name] = collectionreturn collection
逐函數解析
類屬性
_instance = None # 存儲類的唯一實例
_lock = threading.Lock() # 線程鎖,保證并發安全
_collections: Dict[str, any] = {} # 緩存已經創建/獲取的集合,避免重復創建
_client = None # 底層 ChromaDB 客戶端實例
__new__(cls)
def __new__(cls):if cls._instance is None:with cls._lock:if cls._instance is None:cls._instance = super(ChromaDBManager, cls).__new__(cls)cls._instance._initialized = Falsereturn cls._instance
-
作用:實現單例模式,確保只創建一個實例。
-
邏輯:
- 如果
_instance
還沒創建 → 加鎖。 - 再次檢查
_instance
(雙重檢查鎖 DCL,避免競態)。 - 創建實例,并標記
_initialized=False
,表示還沒初始化。
- 如果
-
返回值:類的唯一實例。
__init__(self)
def __init__(self):if not self._initialized:try:# 自動檢測操作系統版本并使用最優配置import platformsystem = platform.system()if system == "Windows":# 使用改進的Windows 11檢測from .chromadb_win11_config import is_windows_11if is_windows_11():# Windows 11 或更新版本,使用優化配置from .chromadb_win11_config import get_win11_chromadb_clientself._client = get_win11_chromadb_client()logger.info(f"📚 [ChromaDB] Windows 11優化配置初始化完成 (構建號: {platform.version()})")else:# Windows 10 或更老版本,使用兼容配置from .chromadb_win10_config import get_win10_chromadb_clientself._client = get_win10_chromadb_client()logger.info(f"📚 [ChromaDB] Windows 10兼容配置初始化完成")else:# 非Windows系統,使用標準配置settings = Settings(allow_reset=True,anonymized_telemetry=False,is_persistent=False)self._client = chromadb.Client(settings)logger.info(f"📚 [ChromaDB] {system}標準配置初始化完成")self._initialized = Trueexcept Exception as e:logger.error(f"? [ChromaDB] 初始化失敗: {e}")# 使用最簡單的配置作為備用try:settings = Settings(allow_reset=True,anonymized_telemetry=False, # 關鍵:禁用遙測is_persistent=False)self._client = chromadb.Client(settings)logger.info(f"📚 [ChromaDB] 使用備用配置初始化完成")except Exception as backup_error:# 最后的備用方案self._client = chromadb.Client()logger.warning(f"?? [ChromaDB] 使用最簡配置初始化: {backup_error}")self._initialized = True
-
作用:在第一次創建時初始化
ChromaDB
客戶端。 -
邏輯:
- 檢測操作系統(Windows / Linux / Mac)。
- Windows:進一步區分 Windows 11 優化配置 和 Windows 10 兼容配置。
- 其它系統:用標準配置
Settings(...)
。 - 初始化失敗 → 嘗試 備用配置(禁用遙測、非持久化)。
- 如果還失敗 → 用 最簡配置
chromadb.Client()
。
-
輸出:無(初始化
_client
)。
get_or_create_collection(self, name: str)
def get_or_create_collection(self, name: str):"""線程安全地獲取或創建集合(輸出ChromaDB 集合對象)"""with self._lock:if name in self._collections:logger.info(f"📚 [ChromaDB] 使用緩存集合: {name}")return self._collections[name]try:# 嘗試獲取現有集合collection = self._client.get_collection(name=name)logger.info(f"📚 [ChromaDB] 獲取現有集合: {name}")except Exception:try:# 創建新集合collection = self._client.create_collection(name=name)logger.info(f"📚 [ChromaDB] 創建新集合: {name}")except Exception as e:# 可能是并發創建,再次嘗試獲取try:collection = self._client.get_collection(name=name)logger.info(f"📚 [ChromaDB] 并發創建后獲取集合: {name}")except Exception as final_error:logger.error(f"? [ChromaDB] 集合操作失敗: {name}, 錯誤: {final_error}")raise final_error# 緩存集合self._collections[name] = collectionreturn collection
-
作用:獲取或創建一個集合(類似數據庫里的表)。
-
線程安全:加鎖保證多個線程不會同時創建同一個集合。
-
邏輯:
- 如果緩存
_collections
里已有 → 直接返回。 - 否則嘗試
get_collection(name)
。 - 如果失敗(說明不存在) →
create_collection(name)
。 - 如果創建也失敗(可能是并發競爭) → 再嘗試
get_collection(name)
。 - 如果還是失敗 → 拋出異常。
- 成功則緩存集合,并返回。
- 如果緩存
FinancialSituationMemory類
類說明
- 類可以視為一個 “財務情況記憶庫”:
- 寫入:新情況 + 建議 → 生成 embedding → 存入 ChromaDB
- 讀取:新情況 → 生成 embedding → 相似度搜索 → 找到最相關的歷史建議
簡化后代碼
class FinancialSituationMemory:def __init__(self, name, config):self.config = configself.llm_provider = config.get("llm_provider", "openai").lower()# 配置向量緩存的長度限制(向量緩存默認啟用長度檢查)self.max_embedding_length = int(os.getenv('MAX_EMBEDDING_CONTENT_LENGTH', '50000')) # 默認50K字符self.enable_embedding_length_check = os.getenv('ENABLE_EMBEDDING_LENGTH_CHECK', 'true').lower() == 'true' # 向量緩存默認啟用# 根據LLM提供商選擇嵌入模型和客戶端self.fallback_available = False # 初始化降級選項標志if self.llm_provider == "dashscope" or self.llm_provider == "alibaba":self.embedding = "text-embedding-v3"self.client = None # DashScope不需要OpenAI客戶端# 設置DashScope API密鑰dashscope_key = os.getenv('DASHSCOPE_API_KEY')if dashscope_key:try:# 嘗試導入和初始化DashScopeimport dashscopefrom dashscope import TextEmbeddingdashscope.api_key = dashscope_keylogger.info(f"? DashScope API密鑰已配置,啟用記憶功能")except ImportError as e:# DashScope包未安裝 ...except Exception as e:# 其他初始化錯誤 ...else:# 沒有DashScope密鑰,禁用記憶功能 ...elif self.llm_provider == "deepseek":...# 使用單例ChromaDB管理器self.chroma_manager = ChromaDBManager()self.situation_collection = self.chroma_manager.get_or_create_collection(name)def _smart_text_truncation(self, text, max_length=8192):"""智能文本截斷,保持語義完整性和緩存兼容性"""if len(text) <= max_length:return text, False # 返回原文本和是否截斷的標志# 嘗試在句子邊界截斷sentences = text.split('。')...# 嘗試在段落邊界截斷paragraphs = text.split('\n')...return truncated, Truedef get_embedding(self, text):"""Get embedding for a text using the configured provider"""# 檢查記憶功能是否被禁用if self.client == "DISABLED":# 內存功能已禁用,返回空向量 ...if len(text) == 0: ... # 輸入文本長度為0,返回空向量# 檢查是否啟用長度限制if self.enable_embedding_length_check and text_length > self.max_embedding_length: # 文本過長跳過向量化并存儲跳過信息 ...return [0.0] * 1024# 存儲文本處理信息self._last_text_info = { ... }if (self.llm_provider == "dashscope" orself.llm_provider == "alibaba" or(self.llm_provider == "google" and self.client is None) or(self.llm_provider == "deepseek" and self.client is None) or(self.llm_provider == "openrouter" and self.client is None)):# 使用阿里百煉的嵌入模型try: ...return embeddingelse:...return [0.0] * 1024def get_embedding_config_status(self):"""獲取向量緩存配置狀態"""return {'enabled': self.enable_embedding_length_check,'max_embedding_length': self.max_embedding_length,'max_embedding_length_formatted': f"{self.max_embedding_length:,}字符",'provider': self.llm_provider,'client_status': 'DISABLED' if self.client == "DISABLED" else 'ENABLED'}def add_situations(self, situations_and_advice):"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""...self.situation_collection.add(documents=situations,metadatas=[{"recommendation": rec} for rec in advice],embeddings=embeddings,ids=ids,)
Example
# Example usagematcher = FinancialSituationMemory()# Example dataexample_data = [("High inflation rate with rising interest rates and declining consumer spending","Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",),("Tech sector showing high volatility with increasing institutional selling pressure","Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",),("Strong dollar affecting emerging markets with increasing forex volatility","Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",),("Market showing signs of sector rotation with rising yields","Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",),]# Add the example situations and recommendationsmatcher.add_situations(example_data)# Example querycurrent_situation = """Market showing increased volatility in tech sector, with institutional investors reducing positions and rising interest rates affecting growth stock valuations"""try:recommendations = matcher.get_memories(current_situation, n_matches=2)for i, rec in enumerate(recommendations, 1):logger.info(f"\nMatch {i}:")logger.info(f"Similarity Score: {rec.get('similarity', 0):.2f}")logger.info(f"Matched Situation: {rec.get('situation', '')}")logger.info(f"Recommendation: {rec.get('recommendation', '')}")except Exception as e:logger.error(f"Error during recommendation: {str(e)}")
逐函數解析
__init__(self, config)
-
作用
初始化類,設置向量數據庫(ChromaDB),并根據config
和環境變量自動選擇合適的 Embedding 服務。 -
主要邏輯
- 保存
config
。 - 初始化一些內部狀態(provider、model、status、last_text_info 等)。
- 根據
config["provider"]
來選擇 embedding 服務:- DashScope / Alibaba → 用阿里云的 embedding API
- DeepSeek → 優先用 DashScope,其次 OpenAI,最后 DeepSeek 自己的 embedding
- Google → 優先 DashScope,如果配置里有
openai_api_key
就啟用 fallback - OpenRouter → DashScope embedding
- 本地 Ollama (localhost:11434) → 使用 nomic-embed-text
- 默認 → 嘗試 OpenAI embedding,失敗則禁用
- 保存
- 創建一個 ChromaDB 客戶端,并建立/獲取一個集合
financial_situations
。
- 返回值 : 無(構造函數)。
def __init__(self, name, config):self.config = configself.llm_provider = config.get("llm_provider", "openai").lower()# 配置向量緩存的長度限制(向量緩存默認啟用長度檢查)self.max_embedding_length = int(os.getenv('MAX_EMBEDDING_CONTENT_LENGTH', '50000')) # 默認50K字符self.enable_embedding_length_check = os.getenv('ENABLE_EMBEDDING_LENGTH_CHECK', 'true').lower() == 'true' # 向量緩存默認啟用# 根據LLM提供商選擇嵌入模型和客戶端# 初始化降級選項標志self.fallback_available = Falseif self.llm_provider == "dashscope" or self.llm_provider == "alibaba":self.embedding = "text-embedding-v3"self.client = None # DashScope不需要OpenAI客戶端# 設置DashScope API密鑰dashscope_key = os.getenv('DASHSCOPE_API_KEY')if dashscope_key:try:# 嘗試導入和初始化DashScopeimport dashscopefrom dashscope import TextEmbeddingdashscope.api_key = dashscope_keylogger.info(f"? DashScope API密鑰已配置,啟用記憶功能")# 可選:測試API連接(簡單驗證)# 這里不做實際調用,只驗證導入和密鑰設置except ImportError as e:# DashScope包未安裝logger.error(f"? DashScope包未安裝: {e}")self.client = "DISABLED"logger.warning(f"?? 記憶功能已禁用")except Exception as e:# 其他初始化錯誤logger.error(f"? DashScope初始化失敗: {e}")self.client = "DISABLED"logger.warning(f"?? 記憶功能已禁用")else:# 沒有DashScope密鑰,禁用記憶功能self.client = "DISABLED"logger.warning(f"?? 未找到DASHSCOPE_API_KEY,記憶功能已禁用")logger.info(f"💡 系統將繼續運行,但不會保存或檢索歷史記憶")elif self.llm_provider == "deepseek":# 檢查是否強制使用OpenAI嵌入force_openai = os.getenv('FORCE_OPENAI_EMBEDDING', 'false').lower() == 'true'if not force_openai:# 嘗試使用阿里百煉嵌入dashscope_key = os.getenv('DASHSCOPE_API_KEY')if dashscope_key:try:# 測試阿里百煉是否可用import dashscopefrom dashscope import TextEmbeddingdashscope.api_key = dashscope_key# 驗證TextEmbedding可用性(不需要實際調用)self.embedding = "text-embedding-v3"self.client = Nonelogger.info(f"💡 DeepSeek使用阿里百煉嵌入服務")except ImportError as e:logger.error(f"?? DashScope包未安裝: {e}")dashscope_key = None # 強制降級except Exception as e:logger.error(f"?? 阿里百煉嵌入初始化失敗: {e}")dashscope_key = None # 強制降級else:dashscope_key = None # 跳過阿里百煉if not dashscope_key or force_openai:# 降級到OpenAI嵌入self.embedding = "text-embedding-3-small"openai_key = os.getenv('OPENAI_API_KEY')if openai_key:self.client = OpenAI(api_key=openai_key,base_url=config.get("backend_url", "https://api.openai.com/v1"))logger.warning(f"?? DeepSeek回退到OpenAI嵌入服務")else:# 最后嘗試DeepSeek自己的嵌入deepseek_key = os.getenv('DEEPSEEK_API_KEY')if deepseek_key:try:self.client = OpenAI(api_key=deepseek_key,base_url="https://api.deepseek.com")logger.info(f"💡 DeepSeek使用自己的嵌入服務")except Exception as e:logger.error(f"? DeepSeek嵌入服務不可用: {e}")# 禁用內存功能self.client = "DISABLED"logger.info(f"🚨 內存功能已禁用,系統將繼續運行但不保存歷史記憶")else:# 禁用內存功能而不是拋出異常self.client = "DISABLED"logger.info(f"🚨 未找到可用的嵌入服務,內存功能已禁用")elif self.llm_provider == "google":# Google AI使用阿里百煉嵌入(如果可用),否則禁用記憶功能dashscope_key = os.getenv('DASHSCOPE_API_KEY')openai_key = os.getenv('OPENAI_API_KEY')if dashscope_key:try:# 嘗試初始化DashScopeimport dashscopefrom dashscope import TextEmbeddingself.embedding = "text-embedding-v3"self.client = Nonedashscope.api_key = dashscope_key# 檢查是否有OpenAI密鑰作為降級選項if openai_key:logger.info(f"💡 Google AI使用阿里百煉嵌入服務(OpenAI作為降級選項)")self.fallback_available = Trueself.fallback_client = OpenAI(api_key=openai_key, base_url=config["backend_url"])self.fallback_embedding = "text-embedding-3-small"else:logger.info(f"💡 Google AI使用阿里百煉嵌入服務(無降級選項)")self.fallback_available = Falseexcept ImportError as e:logger.error(f"? DashScope包未安裝: {e}")self.client = "DISABLED"logger.warning(f"?? Google AI記憶功能已禁用")except Exception as e:logger.error(f"? DashScope初始化失敗: {e}")self.client = "DISABLED"logger.warning(f"?? Google AI記憶功能已禁用")else:# 沒有DashScope密鑰,禁用記憶功能self.client = "DISABLED"self.fallback_available = Falselogger.warning(f"?? Google AI未找到DASHSCOPE_API_KEY,記憶功能已禁用")logger.info(f"💡 系統將繼續運行,但不會保存或檢索歷史記憶")elif self.llm_provider == "openrouter":# OpenRouter支持:優先使用阿里百煉嵌入,否則禁用記憶功能dashscope_key = os.getenv('DASHSCOPE_API_KEY')if dashscope_key:try:# 嘗試使用阿里百煉嵌入import dashscopefrom dashscope import TextEmbeddingself.embedding = "text-embedding-v3"self.client = Nonedashscope.api_key = dashscope_keylogger.info(f"💡 OpenRouter使用阿里百煉嵌入服務")except ImportError as e:logger.error(f"? DashScope包未安裝: {e}")self.client = "DISABLED"logger.warning(f"?? OpenRouter記憶功能已禁用")except Exception as e:logger.error(f"? DashScope初始化失敗: {e}")self.client = "DISABLED"logger.warning(f"?? OpenRouter記憶功能已禁用")else:# 沒有DashScope密鑰,禁用記憶功能self.client = "DISABLED"logger.warning(f"?? OpenRouter未找到DASHSCOPE_API_KEY,記憶功能已禁用")logger.info(f"💡 系統將繼續運行,但不會保存或檢索歷史記憶")elif config["backend_url"] == "http://localhost:11434/v1":self.embedding = "nomic-embed-text"self.client = OpenAI(base_url=config["backend_url"])else:self.embedding = "text-embedding-3-small"openai_key = os.getenv('OPENAI_API_KEY')if openai_key:self.client = OpenAI(api_key=openai_key,base_url=config["backend_url"])else:self.client = "DISABLED"logger.warning(f"?? 未找到OPENAI_API_KEY,記憶功能已禁用")# 使用單例ChromaDB管理器self.chroma_manager = ChromaDBManager()self.situation_collection = self.chroma_manager.get_or_create_collection(name)
_smart_text_truncation(self, text, max_length)
def _smart_text_truncation(self, text, max_length=8192):"""智能文本截斷,保持語義完整性和緩存兼容性"""if len(text) <= max_length:return text, False # 返回原文本和是否截斷的標志# 嘗試在句子邊界截斷sentences = text.split('。')if len(sentences) > 1:truncated = ""for sentence in sentences:if len(truncated + sentence + '。') <= max_length - 50: # 留50字符余量truncated += sentence + '。'else:breakif len(truncated) > max_length // 2: # 至少保留一半內容logger.info(f"📝 智能截斷:在句子邊界截斷,保留{len(truncated)}/{len(text)}字符")return truncated, True# 嘗試在段落邊界截斷paragraphs = text.split('\n')if len(paragraphs) > 1:truncated = ""for paragraph in paragraphs:if len(truncated + paragraph + '\n') <= max_length - 50:truncated += paragraph + '\n'else:breakif len(truncated) > max_length // 2:logger.info(f"📝 智能截斷:在段落邊界截斷,保留{len(truncated)}/{len(text)}字符")return truncated, True# 最后選擇:保留前半部分和后半部分的關鍵信息front_part = text[:max_length//2]back_part = text[-(max_length//2-100):] # 留100字符給連接符truncated = front_part + "\n...[內容截斷]...\n" + back_partlogger.warning(f"?? 強制截斷:保留首尾關鍵信息,{len(text)}字符截斷為{len(truncated)}字符")return truncated, True
-
作用
保證輸入文本不會超過max_length
,但盡量保持語義完整(比如按句子、段落截斷)。 -
主要邏輯
- 如果文本長度 ≤
max_length
→ 原樣返回。 - 如果太長:
- 嘗試在句號(。.!?)之后截斷,保留前面一段完整句子。
- 否則嘗試按段落
\n
截斷。 - 如果都不行,就取前
max_length//2
和后max_length//2
拼接,中間插入...
。
- 把截斷后的文本保存到
self.last_text_info
,記錄是否截斷、采用哪種策略、原始/最終長度等。
- 如果文本長度 ≤
-
輸入
text
(str):原始文本max_length
(int):允許的最大長度
-
輸出
- 截斷后的文本 (str)
get_embedding(self, text, max_length=1000)
def get_embedding(self, text):"""Get embedding for a text using the configured provider"""# 檢查記憶功能是否被禁用if self.client == "DISABLED":# 內存功能已禁用,返回空向量logger.debug(f"?? 記憶功能已禁用,返回空向量")return [0.0] * 1024 # 返回1024維的零向量# 驗證輸入文本if not text or not isinstance(text, str):logger.warning(f"?? 輸入文本為空或無效,返回空向量")return [0.0] * 1024text_length = len(text)if text_length == 0:logger.warning(f"?? 輸入文本長度為0,返回空向量")return [0.0] * 1024# 檢查是否啟用長度限制if self.enable_embedding_length_check and text_length > self.max_embedding_length:logger.warning(f"?? 文本過長({text_length:,}字符 > {self.max_embedding_length:,}字符),跳過向量化")# 存儲跳過信息self._last_text_info = {'original_length': text_length,'processed_length': 0,'was_truncated': False,'was_skipped': True,'provider': self.llm_provider,'strategy': 'length_limit_skip','max_length': self.max_embedding_length}return [0.0] * 1024# 記錄文本信息(不進行任何截斷)if text_length > 8192:logger.info(f"📝 處理長文本: {text_length}字符,提供商: {self.llm_provider}")# 存儲文本處理信息self._last_text_info = {'original_length': text_length,'processed_length': text_length, # 不截斷,保持原長度'was_truncated': False, # 永不截斷'was_skipped': False,'provider': self.llm_provider,'strategy': 'no_truncation_with_fallback' # 標記策略}if (self.llm_provider == "dashscope" orself.llm_provider == "alibaba" or(self.llm_provider == "google" and self.client is None) or(self.llm_provider == "deepseek" and self.client is None) or(self.llm_provider == "openrouter" and self.client is None)):# 使用阿里百煉的嵌入模型try:# 導入DashScope模塊import dashscopefrom dashscope import TextEmbedding# 檢查DashScope API密鑰是否可用if not hasattr(dashscope, 'api_key') or not dashscope.api_key:logger.warning(f"?? DashScope API密鑰未設置,記憶功能降級")return [0.0] * 1024 # 返回空向量# 嘗試調用DashScope APIresponse = TextEmbedding.call(model=self.embedding,input=text)# 檢查響應狀態if response.status_code == 200:# 成功獲取embeddingembedding = response.output['embeddings'][0]['embedding']logger.debug(f"? DashScope embedding成功,維度: {len(embedding)}")return embeddingelse:# API返回錯誤狀態碼error_msg = f"{response.code} - {response.message}"# 檢查是否為長度限制錯誤if any(keyword in error_msg.lower() for keyword in ['length', 'token', 'limit', 'exceed']):logger.warning(f"?? DashScope長度限制: {error_msg}")# 檢查是否有降級選項if hasattr(self, 'fallback_available') and self.fallback_available:logger.info(f"💡 嘗試使用OpenAI降級處理長文本")try:response = self.fallback_client.embeddings.create(model=self.fallback_embedding,input=text)embedding = response.data[0].embeddinglogger.info(f"? OpenAI降級成功,維度: {len(embedding)}")return embeddingexcept Exception as fallback_error:logger.error(f"? OpenAI降級失敗: {str(fallback_error)}")logger.info(f"💡 所有降級選項失敗,記憶功能降級")return [0.0] * 1024else:logger.info(f"💡 無可用降級選項,記憶功能降級")return [0.0] * 1024else:logger.error(f"? DashScope API錯誤: {error_msg}")return [0.0] * 1024 # 返回空向量而不是拋出異常except Exception as e:error_str = str(e).lower()# 檢查是否為長度限制錯誤if any(keyword in error_str for keyword in ['length', 'token', 'limit', 'exceed', 'too long']):logger.warning(f"?? DashScope長度限制異常: {str(e)}")# 檢查是否有降級選項if hasattr(self, 'fallback_available') and self.fallback_available:logger.info(f"💡 嘗試使用OpenAI降級處理長文本")try:response = self.fallback_client.embeddings.create(model=self.fallback_embedding,input=text)embedding = response.data[0].embeddinglogger.info(f"? OpenAI降級成功,維度: {len(embedding)}")return embeddingexcept Exception as fallback_error:logger.error(f"? OpenAI降級失敗: {str(fallback_error)}")logger.info(f"💡 所有降級選項失敗,記憶功能降級")return [0.0] * 1024else:logger.info(f"💡 無可用降級選項,記憶功能降級")return [0.0] * 1024elif 'import' in error_str:logger.error(f"? DashScope包未安裝: {str(e)}")elif 'connection' in error_str:logger.error(f"? DashScope網絡連接錯誤: {str(e)}")elif 'timeout' in error_str:logger.error(f"? DashScope請求超時: {str(e)}")else:logger.error(f"? DashScope embedding異常: {str(e)}")logger.warning(f"?? 記憶功能降級,返回空向量")return [0.0] * 1024else:# 使用OpenAI兼容的嵌入模型if self.client is None:logger.warning(f"?? 嵌入客戶端未初始化,返回空向量")return [0.0] * 1024 # 返回空向量elif self.client == "DISABLED":# 內存功能已禁用,返回空向量logger.debug(f"?? 內存功能已禁用,返回空向量")return [0.0] * 1024 # 返回1024維的零向量# 嘗試調用OpenAI兼容的embedding APItry:response = self.client.embeddings.create(model=self.embedding,input=text)embedding = response.data[0].embeddinglogger.debug(f"? {self.llm_provider} embedding成功,維度: {len(embedding)}")return embeddingexcept Exception as e:error_str = str(e).lower()# 檢查是否為長度限制錯誤length_error_keywords = ['token', 'length', 'too long', 'exceed', 'maximum', 'limit','context', 'input too large', 'request too large']is_length_error = any(keyword in error_str for keyword in length_error_keywords)if is_length_error:# 長度限制錯誤:直接降級,不截斷重試logger.warning(f"?? {self.llm_provider}長度限制: {str(e)}")logger.info(f"💡 為保證分析準確性,不截斷文本,記憶功能降級")else:# 其他類型的錯誤if 'attributeerror' in error_str:logger.error(f"? {self.llm_provider} API調用錯誤: {str(e)}")elif 'connectionerror' in error_str or 'connection' in error_str:logger.error(f"? {self.llm_provider}網絡連接錯誤: {str(e)}")elif 'timeout' in error_str:logger.error(f"? {self.llm_provider}請求超時: {str(e)}")elif 'keyerror' in error_str:logger.error(f"? {self.llm_provider}響應格式錯誤: {str(e)}")else:logger.error(f"? {self.llm_provider} embedding異常: {str(e)}")logger.warning(f"?? 記憶功能降級,返回空向量")return [0.0] * 1024
-
作用 : 把文本轉成向量 embedding。
-
主要邏輯
- 調用
_smart_text_truncation
保證長度安全。 - 根據 provider 選擇對應的 API:
- DashScope → 調用
dashscope.TextEmbedding.call
,模型是text-embedding-v2
- OpenAI → 調用
client.embeddings.create(model="text-embedding-3-small")
- DeepSeek → 可能走 DeepSeek 自己的 embedding API
- Ollama → 調用
client.embeddings.create(model="nomic-embed-text")
- 其它情況 → 返回
[0.0] * 1024
(代表禁用)
- DashScope → 調用
- 如果調用失敗,捕獲異常并返回零向量。
- 調用
-
輸入
text
(str):輸入文本max_length
(int):最大允許長度(默認 1000)
-
輸出
- 向量 embedding (list[float])
get_embedding_config_status(self)
def get_embedding_config_status(self):"""獲取向量緩存配置狀態"""return {'enabled': self.enable_embedding_length_check,'max_embedding_length': self.max_embedding_length,'max_embedding_length_formatted': f"{self.max_embedding_length:,}字符",'provider': self.llm_provider,'client_status': 'DISABLED' if self.client == "DISABLED" else 'ENABLED'}
- 作用 返回當前 embedding 的配置信息,主要用于調試/檢查狀態。
get_last_text_info(self)
def get_last_text_info(self):"""獲取最后處理的文本信息"""return getattr(self, '_last_text_info', None)
- 作用:返回最近一次
_smart_text_truncation
的信息(調試用)。
add_situations(self, situations)
def add_situations(self, situations_and_advice):"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""situations = []advice = []ids = []embeddings = []offset = self.situation_collection.count()for i, (situation, recommendation) in enumerate(situations_and_advice):situations.append(situation)advice.append(recommendation)ids.append(str(offset + i))embeddings.append(self.get_embedding(situation))self.situation_collection.add(documents=situations,metadatas=[{"recommendation": rec} for rec in advice],embeddings=embeddings,ids=ids,)
- 作用:把一組新的“財務情況 + 建議”存入數據庫。
主要邏輯
-
遍歷
situations
,每個元素是(situation, recommendation)
。 -
對
situation
文本生成 embedding。 -
構造數據項:
{"id": str(uuid.uuid4()),"embedding": 向量,"metadata": {"recommendation": 建議} }
-
批量寫入
self.collection
(ChromaDB)。
get_cache_info(self)
def get_cache_info(self):"""獲取緩存相關信息,用于調試和監控"""info = {'collection_count': self.situation_collection.count(),'client_status': 'enabled' if self.client != "DISABLED" else 'disabled','embedding_model': self.embedding,'provider': self.llm_provider}# 添加最后一次文本處理信息if hasattr(self, '_last_text_info'):info['last_text_processing'] = self._last_text_inforeturn info
- 作用
返回當前 ChromaDB 集合的一些元信息(比如名稱、模型、provider 等),幫助確認系統運行狀態。
get_memories(self, situation, n_results=5)
def get_memories(self, current_situation, n_matches=1):"""Find matching recommendations using embeddings with smart truncation handling"""# 獲取當前情況的embeddingquery_embedding = self.get_embedding(current_situation)# 檢查是否為空向量(記憶功能被禁用或出錯)if all(x == 0.0 for x in query_embedding):logger.debug(f"?? 查詢embedding為空向量,返回空結果")return []# 檢查是否有足夠的數據進行查詢collection_count = self.situation_collection.count()if collection_count == 0:logger.debug(f"📭 記憶庫為空,返回空結果")return []# 調整查詢數量,不能超過集合中的文檔數量actual_n_matches = min(n_matches, collection_count)try:# 執行相似度查詢results = self.situation_collection.query(query_embeddings=[query_embedding],n_results=actual_n_matches)# 處理查詢結果memories = []if results and 'documents' in results and results['documents']:documents = results['documents'][0]metadatas = results.get('metadatas', [[]])[0]distances = results.get('distances', [[]])[0]for i, doc in enumerate(documents):metadata = metadatas[i] if i < len(metadatas) else {}distance = distances[i] if i < len(distances) else 1.0memory_item = {'situation': doc,'recommendation': metadata.get('recommendation', ''),'similarity': 1.0 - distance, # 轉換為相似度分數'distance': distance}memories.append(memory_item)# 記錄查詢信息if hasattr(self, '_last_text_info') and self._last_text_info.get('was_truncated'):logger.info(f"🔍 截斷文本查詢完成,找到{len(memories)}個相關記憶")logger.debug(f"📊 原文長度: {self._last_text_info['original_length']}, "f"處理后長度: {self._last_text_info['processed_length']}")else:logger.debug(f"🔍 記憶查詢完成,找到{len(memories)}個相關記憶")return memoriesexcept Exception as e:logger.error(f"? 記憶查詢失敗: {str(e)}")return []
-
作用:檢索最相似的歷史情況,返回對應的建議。
-
主要邏輯
- 對輸入
situation
生成 embedding。 - 調用
self.collection.query
,檢索最相似的n_results
條記錄。 - 解析返回結果:提取文本、embedding 相似度/距離、推薦建議。
- 返回一個結果列表。
- 對輸入