文章結尾有CSDN官方提供的學長的聯系方式!!
歡迎關注B站
從零開始構建一個基于GraphRAG的紅樓夢項目 第三集
01 搭建后端服務
創建一個python文件server.py
完整源碼放到文章最后了。
1.1 graphrag 相關導入
# GraphRAG 相關導入
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (read_indexer_covariates,read_indexer_entities,read_indexer_relationships,read_indexer_reports,read_indexer_text_units,
)
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.question_gen.local_gen import LocalQuestionGen
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
1.2 相關配置
# 設置日志模版
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 設置常量和配置 INPUT_DIR根據自己的建立graphrag的文件夾路徑進行修改
INPUT_DIR = "/Volumes/tesla/dev/rag/GraphRAG001/output/20250820-212616/artifacts"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
COVARIATE_TABLE = "create_final_covariates"
TEXT_UNIT_TABLE = "create_final_text_units"
1.3 大模型相關配置
# 從環境變量中獲取OpenAI API密鑰
api_key = os.environ["OPENAI_API_KEY"]
API_BASE_URL = "https://api.siliconflow.cn/v1"
CHAT_MODEL = "Qwen/Qwen3-32B"
EMBEDDING_MODEL = "BAAI/bge-m3"
1.4 FastAPI 搭建服務接口
# GET請求接口,獲取可用模型列表
@app.get("/v1/models")
async def list_models():logger.info("收到模型列表請求")current_time = int(time.time())models = [{"id": "graphrag-local-search:latest", "object": "model", "created": current_time - 100000, "owned_by": "graphrag"},{"id": "graphrag-global-search:latest", "object": "model", "created": current_time - 95000, "owned_by": "graphrag"},{"id": "full-model:latest", "object": "model", "created": current_time - 80000, "owned_by": "combined"}]response = {"object": "list","data": models}logger.info(f"發送模型列表: {response}")return JSONResponse(content=response)if __name__ == "__main__":logger.info(f"在端口 {PORT} 上啟動服務器")# uvicorn是一個用于運行ASGI應用的輕量級、超快速的ASGI服務器實現# 用于部署基于FastAPI框架的異步PythonWeb應用程序uvicorn.run(app, host="0.0.0.0", port=PORT)
02 測試問答
創建test.py文件
import requests
import jsonurl = "http://localhost:8012/v1/chat/completions"
headers = {"Content-Type": "application/json"}# 1、測試全局搜索 graphrag-global-search:latest
global_data = {"model": "graphrag-global-search:latest","messages": [{"role": "user", "content": "這個故事的首要主題是什么?記住請使用中文進行回答,不要用英文。"}],"temperature": 0.7,# "stream": True,#True or False
}# 2、測試本地搜索 graphrag-local-search:latest
local_data = {"model": "graphrag-local-search:latest","messages": [{"role": "user", "content": "賈政是誰,他的主要關系是什么?記住請使用中文進行回答,不要用英文。"}],"temperature": 0.7,"stream": True, #True or False
}response = requests.post(url, headers=headers, data=json.dumps(global_data))
print(response.json()['choices'][0]['message']['content'])
2.1 測試結果
03 測試問答(流式)
創建test_stream.py文件
import requests
import jsonurl = "http://localhost:8012/v1/chat/completions"
headers = {"Content-Type": "application/json"}# 1、測試全局搜索 graphrag-global-search:latest
global_data = {"model": "graphrag-global-search:latest","messages": [{"role": "user", "content": "這個故事的首要主題是什么?記住請使用中文進行回答,不要用英文。"}],"temperature": 0.7,"stream": True
}# 2、測試本地搜索 graphrag-local-search:latest
local_data = {"model": "graphrag-local-search:latest","messages": [{"role": "user", "content": "賈政是誰,他的主要關系是什么?記住請使用中文進行回答,不要用英文。"}],"temperature": 0.7,"stream": False, #True or False
}# 3、測試全局和本地搜索 full-model:latest
full_data = {"model": "full-model:latest","messages": [{"role": "user", "content": "林黛玉是誰,他的主要關系是什么?記住請使用中文進行回答,不要用英文。"}],"temperature": 0.7,# "stream": True,#True or False
}# # 接收流式輸出
try:with requests.post(url, stream=True, headers=headers, data=json.dumps(local_data)) as response:for line in response.iter_lines():# for line in response.iter_content(chunk_size=16):if line:json_str = line.decode('utf-8').strip("data: ")# 檢查是否為空或不合法的字符串if not json_str:print("接收到空字符串,跳過...")continue# 處理流結束標記if json_str.strip() == '[DONE]':print("流式傳輸已成功完成")break# 確保字符串是有效的JSON格式if json_str.startswith('{') and json_str.endswith('}'):try:data = json.loads(json_str)# 檢查content是否存在delta = data['choices'][0]['delta']if 'content' in delta:print(f"接收到JSON數據: {delta['content']}")else:print("接收到不包含內容的delta")# print(f"{data['choices'][0]['delta']['content']}")except json.JSONDecodeError as e:print(f"JSON解碼失敗: {e}")else:print(f"無效的JSON格式: {json_str}")
except Exception as e:print(f"發生錯誤: {e}")
3.1 測試結果
后端的完整代碼
server.py
import os
import asyncio
import time
import uuid
import json
import re
import pandas as pd
import tiktoken
import logging
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Union
from contextlib import asynccontextmanager
import uvicorn# GraphRAG 相關導入
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (read_indexer_covariates,read_indexer_entities,read_indexer_relationships,read_indexer_reports,read_indexer_text_units,
)
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.question_gen.local_gen import LocalQuestionGen
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore# 設置日志模版
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 設置常量和配置 INPUT_DIR根據自己的建立graphrag的文件夾路徑進行修改
INPUT_DIR = "/Volumes/tesla/dev/rag/GraphRAG001/output/20250820-212616/artifacts"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
COVARIATE_TABLE = "create_final_covariates"
TEXT_UNIT_TABLE = "create_final_text_units"# community level in the Leiden community hierarchy from which we will load the community reports
# higher value means we use reports from more fine-grained communities (at the cost of higher computation cost)
COMMUNITY_LEVEL = 2
PORT = 8012# 從環境變量中獲取OpenAI API密鑰
api_key = os.environ["OPENAI_API_KEY"]
API_BASE_URL = "https://api.siliconflow.cn/v1"
CHAT_MODEL = "Qwen/Qwen3-32B"
EMBEDDING_MODEL = "BAAI/bge-m3"# 全局變量,用于存儲搜索引擎和問題生成器
local_search_engine = None
global_search_engine = None
question_generator = None# 定義Message類型
class Message(BaseModel):role: strcontent: str# 定義ChatCompletionRequest類
class ChatCompletionRequest(BaseModel):model: strmessages: List[Message]temperature: Optional[float] = 1.0top_p: Optional[float] = 1.0n: Optional[int] = 1stream: Optional[bool] = Falsestop: Optional[Union[str, List[str]]] = Nonemax_tokens: Optional[int] = Nonepresence_penalty: Optional[float] = 0frequency_penalty: Optional[float] = 0logit_bias: Optional[Dict[str, float]] = Noneuser: Optional[str] = None# 定義ChatCompletionResponseChoice類
class ChatCompletionResponseChoice(BaseModel):index: intmessage: Messagefinish_reason: Optional[str] = None# 定義Usage類
class Usage(BaseModel):prompt_tokens: intcompletion_tokens: inttotal_tokens: int# 定義ChatCompletionResponse類
class ChatCompletionResponse(BaseModel):id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")object: str = "chat.completion"created: int = Field(default_factory=lambda: int(time.time()))model: strchoices: List[ChatCompletionResponseChoice]usage: Usagesystem_fingerprint: Optional[str] = None# 設置語言模型(LLM)、token編碼器(TokenEncoder)和文本嵌入向量生成器(TextEmbedder)
async def setup_llm_and_embedder():logger.info("正在設置LLM和嵌入器")# 實例化一個ChatOpenAI客戶端對象llm = ChatOpenAI(api_base=API_BASE_URL, # 請求的API服務地址api_key=api_key, # API Keymodel=CHAT_MODEL, # 本次使用的模型api_type=OpenaiApiType.OpenAI,)# 初始化token編碼器token_encoder = tiktoken.get_encoding("cl100k_base")# 實例化OpenAIEmbeddings處理模型text_embedder = OpenAIEmbedding(# 調用本地大模型 通過Ollamaapi_base=API_BASE_URL, # 請求的API服務地址api_key=api_key, # API Keymodel=EMBEDDING_MODEL,deployment_name=EMBEDDING_MODEL,api_type=OpenaiApiType.OpenAI,max_retries=20,)logger.info("LLM和嵌入器設置完成")return llm, token_encoder, text_embedder# 加載上下文數據,包括實體、關系、報告、文本單元和協變量
async def load_context():logger.info("正在加載上下文數據")try:# 使用pandas庫從指定的路徑讀取實體數據表ENTITY_TABLE,文件格式為Parquet,并將其加載為DataFrame,存儲在變量entity_df中entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")# 讀取實體嵌入向量數據表ENTITY_EMBEDDING_TABLE,并將其加載為DataFrame,存儲在變量entity_embedding_df中entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")# 將entity_df和entity_embedding_df傳入,并基于COMMUNITY_LEVEL(社區級別)處理這些數據,返回處理后的實體數據entitiesentities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)# 創建一個LanceDBVectorStore的實例description_embedding_store,用于存儲實體的描述嵌入向量# 這個實例與一個名為"entity_description_embeddings_xiyoujiqwen"的集合(collection)相關聯description_embedding_store = LanceDBVectorStore(collection_name="entity_description_embeddings")# 通過調用connect方法,連接到指定的LanceDB數據庫,使用的URI存儲在LANCEDB_URI變量中description_embedding_store.connect(db_uri=LANCEDB_URI)# 將已處理的實體數據entities存儲到description_embedding_store中,用于語義搜索或其他用途store_entity_semantic_embeddings(entities=entities, vectorstore=description_embedding_store)relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")relationships = read_indexer_relationships(relationship_df)report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")text_units = read_indexer_text_units(text_unit_df)covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet")claims = read_indexer_covariates(covariate_df)logger.info(f"聲明記錄數: {len(claims)}")covariates = {"claims": claims}logger.info("上下文數據加載完成")return entities, relationships, reports, text_units, description_embedding_store, covariatesexcept Exception as e:logger.error(f"加載上下文數據時出錯: {str(e)}")raise# 設置本地和全局搜索引擎、上下文構建器(ContextBuilder)、以及相關參數
async def setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units,description_embedding_store, covariates):logger.info("正在設置搜索引擎")# 設置本地搜索引擎local_context_builder = LocalSearchMixedContext(community_reports=reports,text_units=text_units,entities=entities,relationships=relationships,covariates=covariates,entity_text_embeddings=description_embedding_store,embedding_vectorstore_key=EntityVectorStoreKey.ID,text_embedder=text_embedder,token_encoder=token_encoder,)local_context_params = {"text_unit_prop": 0.5,"community_prop": 0.1,"conversation_history_max_turns": 5,"conversation_history_user_turns_only": True,"top_k_mapped_entities": 10,"top_k_relationships": 10,"include_entity_rank": True,"include_relationship_weight": True,"include_community_rank": False,"return_candidate_context": False,"embedding_vectorstore_key": EntityVectorStoreKey.ID,# "max_tokens": 12_000,"max_tokens": 4096,}local_llm_params = {# "max_tokens": 2_000,"max_tokens": 4096,"temperature": 0.0,}local_search_engine = LocalSearch(llm=llm,context_builder=local_context_builder,token_encoder=token_encoder,llm_params=local_llm_params,context_builder_params=local_context_params,response_type="multiple paragraphs",)# 設置全局搜索引擎global_context_builder = GlobalCommunityContext(community_reports=reports,entities=entities,token_encoder=token_encoder,)global_context_builder_params = {"use_community_summary": False,"shuffle_data": True,"include_community_rank": True,"min_community_rank": 0,"community_rank_name": "rank","include_community_weight": True,"community_weight_name": "occurrence weight","normalize_community_weight": True,# "max_tokens": 12_000,"max_tokens": 4096,"context_name": "Reports",}map_llm_params = {"max_tokens": 1000,"temperature": 0.0,"response_format": {"type": "json_object"},}reduce_llm_params = {"max_tokens": 2000,"temperature": 0.0,}global_search_engine = GlobalSearch(llm=llm,context_builder=global_context_builder,token_encoder=token_encoder,# max_data_tokens=12_000,max_data_tokens=4096,map_llm_params=map_llm_params,reduce_llm_params=reduce_llm_params,allow_general_knowledge=False,json_mode=True,context_builder_params=global_context_builder_params,concurrent_coroutines=32,response_type="multiple paragraphs",)logger.info("搜索引擎設置完成")return local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params# 格式化響應,對輸入的文本進行段落分隔、添加適當的換行符,以及在代碼塊中增加標記,以便生成更具可讀性的輸出
def format_response(response):# 使用正則表達式 \n{2, }將輸入的response按照兩個或更多的連續換行符進行分割。這樣可以將文本分割成多個段落,每個段落由連續的非空行組成paragraphs = re.split(r'\n{2,}', response)# 空列表,用于存儲格式化后的段落formatted_paragraphs = []# 遍歷每個段落進行處理for para in paragraphs:# 檢查段落中是否包含代碼塊標記if '```' in para:# 將段落按照```分割成多個部分,代碼塊和普通文本交替出現parts = para.split('```')for i, part in enumerate(parts):# 檢查當前部分的索引是否為奇數,奇數部分代表代碼塊if i % 2 == 1: # 這是代碼塊# 將代碼塊部分用換行符和```包圍,并去除多余的空白字符parts[i] = f"\n```\n{part.strip()}\n```\n"# 將分割后的部分重新組合成一個字符串para = ''.join(parts)else:# 否則,將句子中的句點后面的空格替換為換行符,以便句子之間有明確的分隔para = para.replace('. ', '.\n')# 將格式化后的段落添加到formatted_paragraphs列表# strip()方法用于移除字符串開頭和結尾的空白字符(包括空格、制表符 \t、換行符 \n等)formatted_paragraphs.append(para.strip())# 將所有格式化后的段落用兩個換行符連接起來,以形成一個具有清晰段落分隔的文本return '\n\n'.join(formatted_paragraphs)# 定義了一個異步函數 lifespan,它接收一個 FastAPI 應用實例 app 作為參數。這個函數將管理應用的生命周期,包括啟動和關閉時的操作
# 函數在應用啟動時執行一些初始化操作,如設置搜索引擎、加載上下文數據、以及初始化問題生成器
# 函數在應用關閉時執行一些清理操作
# @asynccontextmanager 裝飾器用于創建一個異步上下文管理器,它允許你在 yield 之前和之后執行特定的代碼塊,分別表示啟動和關閉時的操作
@asynccontextmanager
async def lifespan(app: FastAPI):# 啟動時執行# 申明引用全局變量,在函數中被初始化,并在整個應用中使用global local_search_engine, global_search_engine, question_generatortry:logger.info("正在初始化搜索引擎和問題生成器...")# 調用setup_llm_and_embedder()函數以設置語言模型(LLM)、token編碼器(TokenEncoder)和文本嵌入向量生成器(TextEmbedder)# await 關鍵字表示此調用是異步的,函數將在這個操作完成后繼續執行llm, token_encoder, text_embedder = await setup_llm_and_embedder()# 調用load_context()函數加載實體、關系、報告、文本單元、描述嵌入存儲和協變量等數據,這些數據將用于構建搜索引擎和問題生成器entities, relationships, reports, text_units, description_embedding_store, covariates = await load_context()# 調用setup_search_engines()函數設置本地和全局搜索引擎、上下文構建器(ContextBuilder)、以及相關參數local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params = await setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units,description_embedding_store, covariates)# 使用LocalQuestionGen類創建一個本地問題生成器question_generator,將前面初始化的各種組件傳遞給它question_generator = LocalQuestionGen(llm=llm,context_builder=local_context_builder,token_encoder=token_encoder,llm_params=local_llm_params,context_builder_params=local_context_params,)logger.info("初始化完成")except Exception as e:logger.error(f"初始化過程中出錯: {str(e)}")# raise 關鍵字重新拋出異常,以確保程序不會在錯誤狀態下繼續運行raise# yield 關鍵字將控制權交還給FastAPI框架,使應用開始運行# 分隔了啟動和關閉的邏輯。在yield 之前的代碼在應用啟動時運行,yield 之后的代碼在應用關閉時運行yield# 關閉時執行logger.info("正在關閉...")# lifespan 參數用于在應用程序生命周期的開始和結束時執行一些初始化或清理工作
app = FastAPI(lifespan=lifespan)# 執行全模型搜索,包括本地檢索、全局檢索
async def full_model_search(prompt: str):local_result = await local_search_engine.asearch(prompt)global_result = await global_search_engine.asearch(prompt)# 格式化結果formatted_result = "#綜合搜索結果:\n\n"formatted_result += "##本地檢索結果:\n"formatted_result += format_response(local_result.response) + "\n\n"formatted_result += "##全局檢索結果:\n"formatted_result += format_response(global_result.response) + "\n\n"return formatted_result# POST請求接口,與大模型進行知識問答
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):if not local_search_engine or not global_search_engine:logger.error("搜索引擎未初始化")raise HTTPException(status_code=500, detail="搜索引擎未初始化")try:logger.info(f"收到聊天完成請求: {request}")prompt = request.messages[-1].contentlogger.info(f"處理提示: {prompt}")# 根據模型選擇使用不同的搜索方法if request.model == "graphrag-global-search:latest":result = await global_search_engine.asearch(prompt)formatted_response = format_response(result.response)elif request.model == "full-model:latest":formatted_response = await full_model_search(prompt)elif request.model == "graphrag-local-search:latest": # 默認使用本地搜索result = await local_search_engine.asearch(prompt)formatted_response = format_response(result.response)logger.info(f"格式化的搜索結果:\n {formatted_response}")# 流式響應和非流式響應的處理保持不變if request.stream:# 定義一個異步生成器函數,用于生成流式數據async def generate_stream():# 為每個流式數據片段生成一個唯一的chunk_idchunk_id = f"chatcmpl-{uuid.uuid4().hex}"# 將格式化后的響應按行分割lines = formatted_response.split('\n')# 歷每一行,并構建響應片段for i, line in enumerate(lines):# 創建一個字典,表示流式數據的一個片段chunk = {"id": chunk_id,"object": "chat.completion.chunk","created": int(time.time()),"model": request.model,"choices": [{"index": 0,"delta": {"content": line + '\n'}, # if i > 0 else {"role": "assistant", "content": ""},"finish_reason": None}]}# 將片段轉換為JSON格式并生成yield f"data: {json.dumps(chunk)}\n"# 每次生成數據后,異步等待0.5秒await asyncio.sleep(0.5)# 生成最后一個片段,表示流式響應的結束final_chunk = {"id": chunk_id,"object": "chat.completion.chunk","created": int(time.time()),"model": request.model,"choices": [{"index": 0,"delta": {},"finish_reason": "stop"}]}yield f"data: {json.dumps(final_chunk)}\n"yield "data: [DONE]\n"# 返回StreamingResponse對象,流式傳輸數據,media_type設置為text/event-stream以符合SSE(Server-SentEvents) 格式return StreamingResponse(generate_stream(), media_type="text/event-stream")# 非流式響應處理else:response = ChatCompletionResponse(model=request.model,choices=[ChatCompletionResponseChoice(index=0,message=Message(role="assistant", content=formatted_response),finish_reason="stop")],# 使用情況usage=Usage(# 提示文本的tokens數量prompt_tokens=len(prompt.split()),# 完成文本的tokens數量completion_tokens=len(formatted_response.split()),# 總tokens數量total_tokens=len(prompt.split()) + len(formatted_response.split())))logger.info(f"發送響應: \n\n{response}")# 返回JSONResponse對象,其中content是將response對象轉換為字典的結果return JSONResponse(content=response.dict())except Exception as e:logger.error(f"處理聊天完成時出錯:\n\n {str(e)}")raise HTTPException(status_code=500, detail=str(e))# GET請求接口,獲取可用模型列表
@app.get("/v1/models")
async def list_models():logger.info("收到模型列表請求")current_time = int(time.time())models = [{"id": "graphrag-local-search:latest", "object": "model", "created": current_time - 100000, "owned_by": "graphrag"},{"id": "graphrag-global-search:latest", "object": "model", "created": current_time - 95000, "owned_by": "graphrag"},{"id": "full-model:latest", "object": "model", "created": current_time - 80000, "owned_by": "combined"}]response = {"object": "list","data": models}logger.info(f"發送模型列表: {response}")return JSONResponse(content=response)if __name__ == "__main__":logger.info(f"在端口 {PORT} 上啟動服務器")# uvicorn是一個用于運行ASGI應用的輕量級、超快速的ASGI服務器實現# 用于部署基于FastAPI框架的異步PythonWeb應用程序uvicorn.run(app, host="0.0.0.0", port=PORT)