SpringAI整合了大多數大模型,而且對于大模型開發的三種技術架構都有比較好的封裝和支持,開發起來非常方便。不同的模型能夠接收的輸入類型、輸出類型不一定相同。SpringAI根據模型的輸入和輸出類型不同對模型進行了分類:
大模型應用開發大多數情況下使用的都是基于對話模型(Chat Model),也就是輸出結果為自然語言或代碼的模型。SpringAI支持的大模型中最完整的就是OpenAI和Ollama平臺的大模型。
1.SpringAI入門實例
1.1 工程創建
創建SpringBoot項目并引入SpringAI基礎依賴:
SpringAI完全適配了SpringBoot的自動裝配功能,而且給不同的大模型提供了不同的starter,比如:
模型/平臺 | starter |
---|---|
Anthropic | |
Azure OpenAI | |
DeepSeek | |
Hugging Face | |
Ollama | |
OpenAI | |
我們可以根據自己選擇的平臺來選擇引入不同的依賴。這里我們先以Ollama為例。
首先,在項目pom.xml中添加spring-ai的版本信息:
<spring-ai.version>1.0.0-M6</spring-ai.version>
然后,添加spring-ai的依賴管理項:
<dependencyManagement> <dependencies> <dependency><groupId>org.springframework.ai</groupId> <artifactId>spring-ai-bom</artifactId><version>${spring-ai.version}</version> <type>pom</type> <scope>import</scope></dependency></dependencies>
</dependencyManagement>
最后,引入spring-ai-ollama的依賴:
<dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
為了方便后續開發,我們再手動引入一個Lombok依賴:
<dependency><groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId><version>1.18.22</version>
</dependency>
注意: 千萬不要用start.spring.io提供的lombok,有bug!!
1.2 配置模型信息
在application.yml中配置模型參數
spring:application:name: ai-demoai:ollama:base-url: http://localhost:11434 # ollama服務地址, 這就是默認值chat:model: deepseek-r1:7b # 模型名稱options:temperature: 0.8 # 模型溫度,影響模型生成結果的隨機性,越小越穩定
1.3 封裝ChatClient
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;@Configuration
public class CommonConfiguration {// 注意參數中的model就是使用的模型,這里用了Ollama,也可以選擇OpenAIChatModel@Beanpublic ChatClient chatClient(OllamaChatModel model) {return ChatClient.builder(model) // 創建ChatClient工廠.build(); // 構建ChatClient實例}
}
-
ChatClient.builder
:會得到一個ChatClient.Builder
工廠對象,利用它可以自由選擇模型、添加各種自定義配置 -
OllamaChatModel
:如果你引入了ollama的starter,這里就可以自動注入OllamaChatModel
對象。同理,OpenAI
也是一樣的用法。 -
Spring 會默認將方法名作為 Bean 的名稱,默認生成名稱為chatClient的Bean
1.4 同步調用
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;@RequiredArgsConstructor
@RestController
@RequestMapping("/ai")
public class ChatController {private final ChatClient chatClient;// 請求方式和路徑不要改動,將來要與前端聯調@RequestMapping("/chat")public String chat(@RequestParam(defaultValue = "講個笑話") String prompt) {return chatClient.prompt(prompt) // 傳入user提示詞.call() // 同步請求,會等待AI全部輸出完才返回結果.content(); //返回響應內容}
}
注意,基于call()方法的調用屬于同步調用,需要所有響應結果全部返回后才能返回給前端。
1.5 流式調用
同步調用需要等待很長時間頁面才能看到結果,用戶體驗不好。為了解決這個問題,我們可以改進調用方式為流式調用。在SpringAI中使用了WebFlux技術實現流式調用。
// 注意看返回值,是Flux<String>,也就是流式結果,另外需要設定響應類型和編碼,不然前端會亂碼
@RequestMapping(value = "/chat", produces = "text/html;charset=UTF-8")
public Flux<String> chat(@RequestParam(defaultValue = "講個笑話") String prompt) {return chatClient.prompt(prompt).stream() // 流式調用.content();
}
1.6 設置System
可以發現,當我們詢問AI你是誰的時候,它回答自己是DeepSeek-R1,這是大模型底層的設定。如果我們希望AI按照新的設定工作,就需要給它設置System背景信息。
在SpringAI中,設置System信息非常方便,不需要在每次發送時封裝到Message,而是創建ChatClient時指定即可。
我們修改配置類中的代碼,給ChatClient
設定默認的System信息:
@Bean
public ChatClient chatClient(OllamaChatModel model) {return ChatClient.builder(model) // 創建ChatClient工廠實例.defaultSystem("您是一個經驗豐富的軟件開發工程師,請以友好、樂于助人和愉快的方式解答學生的各種問題。").defaultAdvisors(new SimpleLoggerAdvisor()).build(); // 構建ChatClient實例}
1.7 日志功能
默認情況下,應用于AI的交互時不記錄日志的,我們無法得知SpringAI組織的提示詞到底長什么樣,有沒有問題。這樣不方便我們調試。
1.7.1 Advisor
SpringAI基于AOP機制實現與大模型對話過程的增強、攔截、修改等功能。所有的增強通知都需要實現Advisor接口。
Spring提供了一些Advisor的默認實現,來實現一些基本的增強功能:
-
SimpleLoggerAdvisor:日志記錄的Advisor
-
MessageChatMemoryAdvisor:會話記憶的Advisor
-
QuestionAnswerAdvisor:實現RAG的Advisor
1.7.2?添加日志Advisor
首先,我們需要修改配置文件,給ChatClient
添加日志Advisor:
@Bean
public ChatClient chatClient(OllamaChatModel model) {return ChatClient.builder(model) // 創建ChatClient工廠實例.defaultSystem("你是一個熱心、可愛的智能助手,你的名字叫小團團,請以小團團的身份和語氣回答問題。").defaultAdvisors(new SimpleLoggerAdvisor()) // 添加默認的Advisor,記錄日志.build(); // 構建ChatClient實例}
1.7.2 修改日志級別
接下來,我們在application.yml
中添加日志配置,更新日志級別:
logging:level:org.springframework.ai: debug # AI對話的日志級別com.lgh.ai: debug # 本項目的日志級別
1.8 會話記憶
SpringAI自帶了會話記憶功能,可以幫我們把歷史會話保存下來,下一次請求AI時會自動拼接。
1.8.1 ChatMemory
話記憶功能同樣是基于AOP實現,Spring提供了一個MessageChatMemoryAdvisor
的通知,我們可以像之前添加日志通知一樣添加到ChatClient
即可。不過,要注意的是,MessageChatMemoryAdvisor
需要指定一個ChatMemory
實例,也就是會話歷史保存的方式。
ChatMemory
接口聲明如下(此接口SpringAI自帶):
public interface ChatMemory {// TODO: consider a non-blocking interface for streaming usagesdefault void add(String conversationId, Message message) {this.add(conversationId, List.of(message));}// 添加會話信息到指定conversationId的會話歷史中void add(String conversationId, List<Message> messages);// 根據conversationId查詢歷史會話List<Message> get(String conversationId, int lastN);// 清除指定conversationId的會話歷史void clear(String conversationId);}
所有的會話記憶都是與conversationId
有關聯的,也就是會話Id,將來不同會話id的記憶自然是分開管理的。
目前,在SpringAI中有兩個ChatMemory的實現:
-
InMemoryChatMemory
:會話歷史保存在內存中 -
CassandraChatMemory
:會話保存在Cassandra數據庫中(需要引入額外依賴,并且綁定了向量數據庫,不夠靈活)
基于內存的ChatMemory(SpringAI自帶):
public class InMemoryChatMemory implements ChatMemory {Map<String, List<Message>> conversationHistory = new ConcurrentHashMap();public InMemoryChatMemory() {}public void add(String conversationId, List<Message> messages) {this.conversationHistory.putIfAbsent(conversationId, new ArrayList());((List)this.conversationHistory.get(conversationId)).addAll(messages);}public List<Message> get(String conversationId, int lastN) {List<Message> all = (List)this.conversationHistory.get(conversationId);return all != null ? all.stream().skip((long)Math.max(0, all.size() - lastN)).toList() : List.of();}public void clear(String conversationId) {this.conversationHistory.remove(conversationId);}
}
基于Redis的ChatMemory實現:
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.lgh.web.manager.springai.model.Msg;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import java.util.List;/*** Redis ChatMemory實現類* @Author GuihaoLv*/
@RequiredArgsConstructor
@Component
public class RedisChatMemory implements ChatMemory {private final StringRedisTemplate redisTemplate;private final ObjectMapper objectMapper;private final static String PREFIX = "chat:";@Overridepublic void add(String conversationId, List<Message> messages) {if (messages == null || messages.isEmpty()) {return;}List<String> list = messages.stream().map(Msg::new).map(msg -> {try {return objectMapper.writeValueAsString(msg);} catch (JsonProcessingException e) {throw new RuntimeException(e);}}).toList();redisTemplate.opsForList().leftPushAll(PREFIX + conversationId, list);}@Overridepublic List<Message> get(String conversationId, int lastN) {List<String> list = redisTemplate.opsForList().range(PREFIX +conversationId, 0, lastN);if (list == null || list.isEmpty()) {return List.of();}return list.stream().map(s -> {try {return objectMapper.readValue(s, Msg.class);} catch (JsonProcessingException e) {throw new RuntimeException(e);}}).map(Msg::toMessage).toList();}@Overridepublic void clear(String conversationId) {redisTemplate.delete(PREFIX + conversationId);}
}
Msg消息類封裝:
/*** 消息類* @Author GuihaoLv*/
@NoArgsConstructor
@AllArgsConstructor
@Data
public class Msg {MessageType messageType; // 消息類型(枚舉)String text; // 消息文本內容Map<String, Object> metadata; // 消息元數據(附加信息)List<AssistantMessage.ToolCall> toolCalls;// 工具調用列表(僅助手消息可能有)public Msg(Message message) {this.messageType = message.getMessageType();this.text = message.getText();this.metadata = message.getMetadata();// 僅當原始消息是助手消息時,才復制toolCallsif(message instanceof AssistantMessage am) {this.toolCalls = am.getToolCalls();}}public Message toMessage() {return switch (messageType) {case SYSTEM -> new SystemMessage(text);case USER -> new UserMessage(text, List.of(), metadata);case ASSISTANT -> new AssistantMessage(text, metadata, toolCalls, List.of());default -> throw new IllegalArgumentException("Unsupported message type: " + messageType);};}
}
1.8.2 添加會話記憶Advisor
注冊ChatMemory
@Bean
public ChatMemory chatMemory() {return new InMemoryChatMemory();
}
@Autowiredprivate StringRedisTemplate stringRedisTemplate;@Beanpublic ObjectMapper objectMapper() {ObjectMapper objectMapper = new ObjectMapper();objectMapper.configure(com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);// 添加Java 8日期時間模塊支持objectMapper.registerModule(new JavaTimeModule());return objectMapper;}@Beanpublic ChatMemory chatMemory() {return new RedisChatMemory(stringRedisTemplate, objectMapper());}
添加MessageChatMemoryAdvisor
到ChatClient
:
@Bean
public ChatClient chatClient(OllamaChatModel model, ChatMemory chatMemory) {return ChatClient.builder(model) // 創建ChatClient工廠實例.defaultSystem("你的名字叫小黑。請以友好、樂于助人和愉快的方式解答學生的各種問題。").defaultAdvisors(new SimpleLoggerAdvisor()) // 添加默認的Advisor,記錄日志.defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory)).build(); // 構建ChatClient實例
}
現在聊天會話已經有記憶功能了。
1.9 會話歷史
會話歷史與會話記憶是兩個不同的事情:
會話記憶:是指讓大模型記住每一輪對話的內容,不至于前一句剛問完,下一句就忘了。
會話歷史:是指要記錄總共有多少不同的對話
在ChatMemory中,會記錄一個會話中的所有消息,記錄方式是以conversationId
為key,以List<Message>
為value,根據這些歷史消息,大模型就能繼續回答問題,這就是所謂的會話記憶。
而會話歷史,就是每一個會話的conversationId
,將來根據conversationId
再去查詢List<Message>
。
1.9.1 會話記憶管理
由于會話記憶是以conversationId
來管理的,也就是會話id(以后簡稱為chatId)。將來要查詢會話歷史,其實就是查詢歷史中有哪些chatId。因此,為了實現查詢會話歷史記錄,我們必須記錄所有的chatId,我們需要定義一個管理會話歷史的標準接口。
/*** 會話記錄操作相關接口* @Author GuihaoLv*/
public interface ChatHistoryRepository {/*** 保存會話記錄* @param type 業務類型,如:chat、service、pdf* @param chatId 會話ID*/void save(String type, String chatId);/*** 獲取會話ID列表* @param type 業務類型,如:chat、service、pdf* @return 會話ID列表*/List<String> getChatIds(String type);
}
基于內存的會話歷史管理:
/*** 基于內存實現的會話管理* @Author GuihaoLv*/
@Slf4j
//@Component
@RequiredArgsConstructor
public class InMemoryChatHistoryRepository implements ChatHistoryRepository {private Map<String, List<String>> chatHistory;private final ObjectMapper objectMapper;private final ChatMemory chatMemory;@Overridepublic void save(String type, String chatId) {/*if (!chatHistory.containsKey(type)) {chatHistory.put(type, new ArrayList<>());}List<String> chatIds = chatHistory.get(type);*/List<String> chatIds = chatHistory.computeIfAbsent(type, k -> new ArrayList<>());if (chatIds.contains(chatId)) {return;}chatIds.add(chatId);}@Overridepublic List<String> getChatIds(String type) {/*List<String> chatIds = chatHistory.get(type);return chatIds == null ? List.of() : chatIds;*/return chatHistory.getOrDefault(type, List.of());}@PostConstructprivate void init() {// 1.初始化會話歷史記錄this.chatHistory = new HashMap<>();// 2.讀取本地會話歷史和會話記憶FileSystemResource historyResource = new FileSystemResource("chat-history.json");FileSystemResource memoryResource = new FileSystemResource("chat-memory.json");if (!historyResource.exists()) {return;}try {// 會話歷史Map<String, List<String>> chatIds = this.objectMapper.readValue(historyResource.getInputStream(), new TypeReference<>() {});if (chatIds != null) {this.chatHistory = chatIds;}// 會話記憶Map<String, List<Msg>> memory = this.objectMapper.readValue(memoryResource.getInputStream(), new TypeReference<>() {});if (memory != null) {memory.forEach(this::convertMsgToMessage);}} catch (IOException ex) {throw new RuntimeException(ex);}}private void convertMsgToMessage(String chatId, List<Msg> messages) {this.chatMemory.add(chatId, messages.stream().map(Msg::toMessage).toList());}@PreDestroyprivate void persistent() {String history = toJsonString(this.chatHistory);String memory = getMemoryJsonString();FileSystemResource historyResource = new FileSystemResource("chat-history.json");FileSystemResource memoryResource = new FileSystemResource("chat-memory.json");try (PrintWriter historyWriter = new PrintWriter(historyResource.getOutputStream(), true, StandardCharsets.UTF_8);PrintWriter memoryWriter = new PrintWriter(memoryResource.getOutputStream(), true, StandardCharsets.UTF_8)) {historyWriter.write(history);memoryWriter.write(memory);} catch (IOException ex) {log.error("IOException occurred while saving vector store file.", ex);throw new RuntimeException(ex);} catch (SecurityException ex) {log.error("SecurityException occurred while saving vector store file.", ex);throw new RuntimeException(ex);} catch (NullPointerException ex) {log.error("NullPointerException occurred while saving vector store file.", ex);throw new RuntimeException(ex);}}private String getMemoryJsonString() {Class<InMemoryChatMemory> clazz = InMemoryChatMemory.class;try {Field field = clazz.getDeclaredField("conversationHistory");field.setAccessible(true);Map<String, List<Message>> memory = (Map<String, List<Message>>) field.get(chatMemory);Map<String, List<Msg>> memoryToSave = new HashMap<>();memory.forEach((chatId, messages) -> memoryToSave.put(chatId, messages.stream().map(Msg::new).toList()));return toJsonString(memoryToSave);} catch (NoSuchFieldException | IllegalAccessException e) {throw new RuntimeException(e);}}private String toJsonString(Object object) {ObjectWriter objectWriter = this.objectMapper.writerWithDefaultPrettyPrinter();try {return objectWriter.writeValueAsString(object);} catch (JsonProcessingException e) {throw new RuntimeException("Error serializing documentMap to JSON.", e);}}}
基于Redis實現會話歷史管理:
/*** Redis ChatHistory 實現類* @Author GuihaoLv*/
@RequiredArgsConstructor
@Component
public class RedisChatHistory implements ChatHistoryRepository{private final StringRedisTemplate redisTemplate;private final static String CHAT_HISTORY_KEY_PREFIX = "chat:history:";@Overridepublic void save(String type, String chatId) {redisTemplate.opsForSet().add(CHAT_HISTORY_KEY_PREFIX + type, chatId);}@Overridepublic List<String> getChatIds(String type) {Set<String> chatIds =redisTemplate.opsForSet().members(CHAT_HISTORY_KEY_PREFIX + type);if(chatIds == null || chatIds.isEmpty()) {return Collections.emptyList();}return chatIds.stream().sorted(String::compareTo).toList();}
}
1.9.2 保存會話id
接下來,修改ChatController中的chat方法,做到3點:
-
添加一個請求參數:chatId,每次前端請求AI時都需要傳遞chatId
-
每次處理請求時,將chatId存儲到ChatRepository
-
每次發請求到AI大模型時,都傳遞自定義的chatId
@CrossOrigin("*")
@RequiredArgsConstructor
@RestController
@RequestMapping("/ai")
public class ChatController {private final ChatClient chatClient;private final ChatMemory chatMemory;private final ChatHistoryRepository chatHistoryRepository;@RequestMapping(value = "/chat", produces = "text/html;charset=UTF-8")public Flux<String> chat(@RequestParam(defaultValue = "講個笑話") String prompt, String chatId) {chatHistoryRepository.addChatId(chatId);return chatClient.prompt(prompt).advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)).stream().content();}
}
這里傳遞chatId給Advisor的方式是通過AdvisorContext,也就是以key-value形式存入上下文:
chatClient.advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId))
其中的CHAT_MEMORY_CONVERSATION_ID_KEY
是AbstractChatMemoryAdvisor中定義的常量key,將來MessageChatMemoryAdvisor
執行的過程中就可以拿到這個chatId了。
1.9.3 查詢會話歷史
我們定義一個新的Controller,專門實現會話歷史的查詢。包含兩個接口:
-
根據業務類型查詢會話歷史列表(我們將來有3個不同業務,需要分別記錄歷史。大家的業務可能是按userId記錄,根據UserId查詢)
-
根據chatId查詢指定會話的歷史消息
其中,查詢會話歷史消息,也就是Message集合。但是由于Message并不符合頁面的需要,我們需要自己定義一個VO.
/*** 消息查詢結果類* @Author GuihaoLv*/
@NoArgsConstructor
@Data
public class MessageVO {private String role;private String content;public MessageVO(Message message) {switch (message.getMessageType()) {case USER:role = "user";break;case ASSISTANT:role = "assistant";break;default:role = "";break;}this.content = message.getText();}
}
/*** AI會話歷史記錄* @author GuihaoLv*/
@RestController
@RequestMapping("/web/aiHistory")
@Tag(name = "AI會話歷史記錄")
@Slf4j
public class ChatHistoryController {@Autowiredprivate RedisChatHistory chatHistoryRepository;@Autowiredprivate RedisChatMemory chatMemory;/*** 獲取會話ID列表* @param type* @return* */@GetMapping("/{type}")@Operation(summary = "獲取會話ID列表")public List<String> getChatIds(@PathVariable("type") String type) {return chatHistoryRepository.getChatIds(type);}/*** 獲取會話記錄* @param type* @param chatId* @return*/@GetMapping("/{type}/{chatId}")@Operation(summary = "獲取會話記錄")public List<MessageVO> getChatHistory(@PathVariable("type") String type, @PathVariable("chatId") String chatId) {List<Message> messages = chatMemory.get(chatId, Integer.MAX_VALUE);if(messages == null) {return List.of();}return messages.stream().map(MessageVO::new).toList();}}
會話記憶整體邏輯設計:
2 FunctionCalling?
2.1 FunctionCalling介紹
AI擅長的是非結構化數據的分析,如果需求中包含嚴格的邏輯校驗或需要讀寫數據庫等業務邏輯,我們可以賦予大模型執行業務規則的邏輯。我們可以把數據庫操作等業務邏輯都定義成Function,或者也可以叫Tool,也就是工具。然后,我們可以在提示詞中,告訴大模型,什么情況下需要調用什么工具,將來用戶在與大模型交互的時候,大模型就可以在適當的時候調用工具了。
流程解讀:
-
提前把這些操作定義為Function(SpringAI中叫Tool),
-
然后將Function的名稱、作用、需要的參數等信息都封裝為Prompt提示詞與用戶的提問一起發送給大模型
-
大模型在與用戶交互的過程中,根據用戶交流的內容判斷是否需要調用Function
-
如果需要則返回Function名稱、參數等信息
-
Java解析結果,判斷要執行哪個函數,代碼執行Function,把結果再次封裝到Prompt中發送給AI
-
AI繼續與用戶交互,直到完成任務
SpringAI提供了FunctionCalling的功能,由于解析大模型響應,找到函數名稱、參數,調用函數等這些動作都是固定的,所以SpringAI再次利用AOP的能力,幫我們把中間調用函數的部分自動完成了。
我們要做的事情就簡化了:
-
編寫基礎提示詞(不包括Tool的定義)
-
編寫Tool(Function)
-
配置Advisor(SpringAI利用AOP幫我們拼接Tool定義到提示詞,完成Tool調用動作)
2.2 FunctionCalling實戰
實現一個大模型自動總結并保存當前會話知識點
2.2.1 業務封裝
@Mapper
public interface AINoteMapper {/*** 插入一條記錄* @param aINote* @return*/@Insert("insert into tb_ai_note (user_id, chat_id, title, content) values (#{userId}, #{chatId}, #{title}, #{content})")Boolean insert(AINote aINote);/*** 刪除一條記錄* @param aiNoteId* @return*/@Delete("delete from tb_ai_note where id = #{aiNoteId}")Boolean deleteById(Long aiNoteId);/*** 查詢所有記錄* @return*/@Select("SELECT * FROM tb_ai_note")List<AINote> selectList();/*** 根據ID查詢記錄* @param aiNoteId* @return*/@Select("SELECT * FROM tb_ai_note where id = #{aiNoteId}")AINote selectById(Long aiNoteId);/*** 添加AI詞生文記錄* @param generateText* @return*/@Insert("insert into tb_generate_text (user_id, prompt_words, generated_text,translated_text) values (#{userId}, #{promptWords}, #{generatedText},#{translatedText})")Boolean addGT(GenerateText generateText);/*** 獲取AI詞生文記錄* @param userId* @return*/@Select("SELECT * FROM tb_generate_text where user_id = #{userId} order by create_time desc")List<GenerateText> getGTList(Long userId);/*** 刪除AI詞生文記錄* @param generateTextId* @return*/@Delete("delete from tb_generate_text where id = #{generateTextId}")Boolean deleteGT(Long generateTextId);
}
/**
* AI筆記表
* @Author GuihaoLv
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class AINote extends BaseEntity implements Serializable {private Long userId;//用戶IDprivate String chatId;//會話IDprivate String title;//標題private String content;//內容
}
2.2.2 Function定義
/*** AI筆記工具類* @Author GuihaoLv*/
@Component
public class AINoteTools {@Autowiredprivate AINoteMapper aiNoteMapper;/*** 將會話中的知識點保存為AI筆記* @param chatId 會話ID(關聯筆記所屬會話)* @param title 筆記標題(總結知識點核心)* @param content 筆記內容(詳細知識點)* @return 保存結果(true成功/false失敗)*/@Tool(description = "將會話中的知識點創建為AI筆記,需傳入會話ID、標題和內容,自動關聯當前用戶")public Boolean createAINote(@ToolParam(required = false, description = "會話唯一標識,用于關聯筆記和對應會話") String chatId,@ToolParam(required = true, description = "筆記標題,簡潔概括知識點內容(不超過20字)") String title,@ToolParam(required = true, description = "筆記詳細內容,記錄會話中的知識點詳情") String content) {// 構建筆記對象(自動填充當前用戶ID)AINote aiNote = new AINote();aiNote.setUserId(UserUtil.getUserId());aiNote.setChatId(chatId);aiNote.setTitle(title);aiNote.setContent(content);// 保存到數據庫return aiNoteMapper.insert(aiNote);}
}
這里的@ToolParam
注解是SpringAI提供的用來解釋Function
參數的注解。其中的信息都會通過提示詞的方式發送給AI模型。
2.2.3 System提示詞
public static final String CHAT_ROLE ="""你是一個可以幫助用戶記錄會話筆記的助手。當用戶發出以下指令時,必須調用createAINote工具:- "把剛才的內容記成筆記"- "記錄這段知識點"- "保存當前對話內容"- 其他類似要求保存會話內容的表述調用工具時必須包含3個參數:1. chatId:當前會話的ID(從會話上下文獲取)2. title:從會話內容中提煉的標題(不超過20字)3. content:需要記錄的會話知識點詳情(完整提取相關內容)調用成功后,回復用戶"已為你保存筆記:[標題]";調用失敗則提示"筆記保存失敗,請重試"。""";
2.2.4 在ChatClient中配置tool
@Beanpublic ChatClient chatCommonClient(OpenAiChatModel model, ChatMemory chatMemory,VectorStore vectorStore, AINoteTools aiNoteTools) {return ChatClient.builder(model).defaultOptions(ChatOptions.builder().model("qwen-omni-turbo").build()).defaultSystem(AIChatConstant.CHAT_ROLE).defaultAdvisors(new SimpleLoggerAdvisor(),new MessageChatMemoryAdvisor(chatMemory),new QuestionAnswerAdvisor(vectorStore,SearchRequest.builder().similarityThreshold(0.6).topK(2).build())).defaultTools(List.of(aiNoteTools)).build();}
目前SpringAI的OpenAI客戶端與阿里云百煉存在兼容性問題,所以FunctionCalling功能無法使用stream模式,為了兼容百煉云平臺,我們需做調整
3 兼容百煉云平臺
截止SpringAI的1.0.0-M6版本為止,SpringAI的OpenAiModel和阿里云百煉的部分接口存在兼容性問題,包括但不限于以下兩個問題:
-
FunctionCalling的stream模式,阿里云百煉返回的tool-arguments是不完整的,需要拼接,而OpenAI則是完整的,無需拼接。
-
音頻識別中的數據格式,阿里云百煉的qwen-omni模型要求的參數格式為data:;base64,${media-data},而OpenAI是直接{media-data}
由于SpringAI的OpenAI模塊是遵循OpenAI規范的,所以即便版本升級也不會去兼容阿里云,除非SpringAI單獨為阿里云開發starter,所以目前解決方案有兩個:
-
等待阿里云官方推出的spring-alibaba-ai升級到最新版本
-
自己重寫OpenAiModel的實現邏輯。
接下來,我們就用重寫OpenAiModel的方式,來解決上述兩個問題。
首先,我們自己寫一個遵循阿里巴巴百煉平臺接口規范的ChatModel
,其中大部分代碼來自SpringAI的OpenAiChatModel
,只需要重寫接口協議不匹配的地方即可,重寫部分會以黃色高亮顯示。
新建一個AlibabaOpenAiChatModel
類:
package com.itheima.ai.model;import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.*;
import org.springframework.ai.chat.model.*;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;public class AlibabaOpenAiChatModel extends AbstractToolCallSupport implements ChatModel {private static final Logger logger = LoggerFactory.getLogger(AlibabaOpenAiChatModel.class);private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();/*** The default options used for the chat completion requests.*/private final OpenAiChatOptions defaultOptions;/*** The retry template used to retry the OpenAI API calls.*/private final RetryTemplate retryTemplate;/*** Low-level access to the OpenAI API.*/private final OpenAiApi openAiApi;/*** Observation registry used for instrumentation.*/private final ObservationRegistry observationRegistry;private final ToolCallingManager toolCallingManager;/*** Conventions to use for generating observations.*/private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;/*** Creates an instance of the AlibabaOpenAiChatModel.* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI* Chat API.* @throws IllegalArgumentException if openAiApi is null* @deprecated Use AlibabaOpenAiChatModel.Builder.*/@Deprecatedpublic AlibabaOpenAiChatModel(OpenAiApi openAiApi) {this(openAiApi, OpenAiChatOptions.builder().model(OpenAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build());}/*** Initializes an instance of the AlibabaOpenAiChatModel.* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI* Chat API.* @param options The OpenAiChatOptions to configure the chat model.* @deprecated Use AlibabaOpenAiChatModel.Builder.*/@Deprecatedpublic AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) {this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);}/*** Initializes a new instance of the AlibabaOpenAiChatModel.* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI* Chat API.* @param options The OpenAiChatOptions to configure the chat model.* @param functionCallbackResolver The function callback resolver.* @param retryTemplate The retry template.* @deprecated Use AlibabaOpenAiChatModel.Builder.*/@Deprecatedpublic AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,@Nullable FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {this(openAiApi, options, functionCallbackResolver, List.of(), retryTemplate);}/*** Initializes a new instance of the AlibabaOpenAiChatModel.* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI* Chat API.* @param options The OpenAiChatOptions to configure the chat model.* @param functionCallbackResolver The function callback resolver.* @param toolFunctionCallbacks The tool function callbacks.* @param retryTemplate The retry template.* @deprecated Use AlibabaOpenAiChatModel.Builder.*/@Deprecatedpublic AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,@Nullable FunctionCallbackResolver functionCallbackResolver,@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate) {this(openAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate,ObservationRegistry.NOOP);}/*** Initializes a new instance of the AlibabaOpenAiChatModel.* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI* Chat API.* @param options The OpenAiChatOptions to configure the chat model.* @param functionCallbackResolver The function callback resolver.* @param toolFunctionCallbacks The tool function callbacks.* @param retryTemplate The retry template.* @param observationRegistry The ObservationRegistry used for instrumentation.* @deprecated Use AlibabaOpenAiChatModel.Builder or AlibabaOpenAiChatModel(OpenAiApi,* OpenAiChatOptions, ToolCallingManager, RetryTemplate, ObservationRegistry).*/@Deprecatedpublic AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,@Nullable FunctionCallbackResolver functionCallbackResolver,@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate,ObservationRegistry observationRegistry) {this(openAiApi, options,LegacyToolCallingManager.builder().functionCallbackResolver(functionCallbackResolver).functionCallbacks(toolFunctionCallbacks).build(),retryTemplate, observationRegistry);logger.warn("This constructor is deprecated and will be removed in the next milestone. "+ "Please use the AlibabaOpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");}public AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager,RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {// We do not pass the 'defaultOptions' to the AbstractToolSupport,// because it modifies them. We are using ToolCallingManager instead,// so we just pass empty options here.super(null, OpenAiChatOptions.builder().build(), List.of());Assert.notNull(openAiApi, "openAiApi cannot be null");Assert.notNull(defaultOptions, "defaultOptions cannot be null");Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");Assert.notNull(retryTemplate, "retryTemplate cannot be null");Assert.notNull(observationRegistry, "observationRegistry cannot be null");this.openAiApi = openAiApi;this.defaultOptions = defaultOptions;this.toolCallingManager = toolCallingManager;this.retryTemplate = retryTemplate;this.observationRegistry = observationRegistry;}@Overridepublic ChatResponse call(Prompt prompt) {// Before moving any further, build the final request Prompt,// merging runtime and default options.Prompt requestPrompt = buildRequestPrompt(prompt);return this.internalCall(requestPrompt, null);}public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false);ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OpenAiApiConstants.PROVIDER_NAME).requestOptions(prompt.getOptions()).build();ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,this.observationRegistry).observe(() -> {ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = this.retryTemplate.execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt)));var chatCompletion = completionEntity.getBody();if (chatCompletion == null) {logger.warn("No chat completion returned for prompt: {}", prompt);return new ChatResponse(List.of());}List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();if (choices == null) {logger.warn("No choices returned for prompt: {}", prompt);return new ChatResponse(List.of());}List<Generation> generations = choices.stream().map(choice -> {// @formatter:offMap<String, Object> metadata = Map.of("id", chatCompletion.id() != null ? chatCompletion.id() : "","role", choice.message().role() != null ? choice.message().role().name() : "","index", choice.index(),"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "","refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");// @formatter:onreturn buildGeneration(choice, metadata, request);}).toList();RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);// Current usageOpenAiApi.Usage usage = completionEntity.getBody().usage();Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);ChatResponse chatResponse = new ChatResponse(generations,from(completionEntity.getBody(), rateLimit, accumulatedUsage));observationContext.setResponse(chatResponse);return chatResponse;});if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null&& response.hasToolCalls()) {var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);if (toolExecutionResult.returnDirect()) {// Return tool execution result directly to the client.return ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build();}else {// Send the tool execution result back to the model.return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),response);}}return response;}@Overridepublic Flux<ChatResponse> stream(Prompt prompt) {// Before moving any further, build the final request Prompt,// merging runtime and default options.Prompt requestPrompt = buildRequestPrompt(prompt);return internalStream(requestPrompt, null);}public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {return Flux.deferContextual(contextView -> {OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true);if (request.outputModalities() != null) {if (request.outputModalities().stream().anyMatch(m -> m.equals("audio"))) {logger.warn("Audio output is not supported for streaming requests. Removing audio output.");throw new IllegalArgumentException("Audio output is not supported for streaming requests.");}}if (request.audioParameters() != null) {logger.warn("Audio parameters are not supported for streaming requests. Removing audio parameters.");throw new IllegalArgumentException("Audio parameters are not supported for streaming requests.");}Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,getAdditionalHttpHeaders(prompt));// For chunked responses, only the first chunk contains the choice role.// The rest of the chunks with same ID share the same role.ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();final ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OpenAiApiConstants.PROVIDER_NAME).requestOptions(prompt.getOptions()).build();Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,this.observationRegistry);observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse// the function call handling logic.Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion).switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {try {@SuppressWarnings("null")String id = chatCompletion2.id();List<Generation> generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:offif (choice.message().role() != null) {roleMap.putIfAbsent(id, choice.message().role().name());}Map<String, Object> metadata = Map.of("id", chatCompletion2.id(),"role", roleMap.getOrDefault(id, ""),"index", choice.index(),"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "","refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");return buildGeneration(choice, metadata, request);}).toList();// @formatter:onOpenAiApi.Usage usage = chatCompletion2.usage();Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage,previousChatResponse);return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage));}catch (Exception e) {logger.error("Error processing chat completion", e);return new ChatResponse(List.of());}// When in stream mode and enabled to include the usage, the OpenAI// Chat completion response would have the usage set only in its// final response. Hence, the following overlapping buffer is// created to store both the current and the subsequent response// to accumulate the usage from the subsequent response.})).buffer(2, 1).map(bufferList -> {ChatResponse firstResponse = bufferList.get(0);if (request.streamOptions() != null && request.streamOptions().includeUsage()) {if (bufferList.size() == 2) {ChatResponse secondResponse = bufferList.get(1);if (secondResponse != null && secondResponse.getMetadata() != null) {// This is the usage from the final Chat response for a// given Chat request.Usage usage = secondResponse.getMetadata().getUsage();if (!UsageUtils.isEmpty(usage)) {// Store the usage from the final response to the// penultimate response for accumulation.return new ChatResponse(firstResponse.getResults(),from(firstResponse.getMetadata(), usage));}}}}return firstResponse;});// @formatter:offFlux<ChatResponse> flux = chatResponse.flatMap(response -> {if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) {var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);if (toolExecutionResult.returnDirect()) {// Return tool execution result directly to the client.return Flux.just(ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build());} else {// Send the tool execution result back to the model.return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),response);}}else {return Flux.just(response);}}).doOnError(observation::error).doFinally(s -> observation.stop()).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));// @formatter:onreturn new MessageAggregator().aggregate(flux, observationContext::setResponse);});}private MultiValueMap<String, String> getAdditionalHttpHeaders(Prompt prompt) {Map<String, String> headers = new HashMap<>(this.defaultOptions.getHttpHeaders());if (prompt.getOptions() != null && prompt.getOptions() instanceof OpenAiChatOptions chatOptions) {headers.putAll(chatOptions.getHttpHeaders());}return CollectionUtils.toMultiValueMap(headers.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> List.of(e.getValue()))));}private Generation buildGeneration(OpenAiApi.ChatCompletion.Choice choice, Map<String, Object> metadata, OpenAiApi.ChatCompletionRequest request) {List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of(): choice.message().toolCalls().stream().map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function",toolCall.function().name(), toolCall.function().arguments())).reduce((tc1, tc2) -> new AssistantMessage.ToolCall(tc1.id(), "function", tc1.name(), tc1.arguments() + tc2.arguments())).stream().toList();String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");var generationMetadataBuilder = ChatGenerationMetadata.builder().finishReason(finishReason);List<Media> media = new ArrayList<>();String textContent = choice.message().content();var audioOutput = choice.message().audioOutput();if (audioOutput != null) {String mimeType = String.format("audio/%s", request.audioParameters().format().name().toLowerCase());byte[] audioData = Base64.getDecoder().decode(audioOutput.data());Resource resource = new ByteArrayResource(audioData);Media.builder().mimeType(MimeTypeUtils.parseMimeType(mimeType)).data(resource).id(audioOutput.id()).build();media.add(Media.builder().mimeType(MimeTypeUtils.parseMimeType(mimeType)).data(resource).id(audioOutput.id()).build());if (!StringUtils.hasText(textContent)) {textContent = audioOutput.transcript();}generationMetadataBuilder.metadata("audioId", audioOutput.id());generationMetadataBuilder.metadata("audioExpiresAt", audioOutput.expiresAt());}var assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media);return new Generation(assistantMessage, generationMetadataBuilder.build());}private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit, Usage usage) {Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");var builder = ChatResponseMetadata.builder().id(result.id() != null ? result.id() : "").usage(usage).model(result.model() != null ? result.model() : "").keyValue("created", result.created() != null ? result.created() : 0L).keyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : "");if (rateLimit != null) {builder.rateLimit(rateLimit);}return builder.build();}private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usage usage) {Assert.notNull(chatResponseMetadata, "OpenAI ChatResponseMetadata must not be null");var builder = ChatResponseMetadata.builder().id(chatResponseMetadata.getId() != null ? chatResponseMetadata.getId() : "").usage(usage).model(chatResponseMetadata.getModel() != null ? chatResponseMetadata.getModel() : "");if (chatResponseMetadata.getRateLimit() != null) {builder.rateLimit(chatResponseMetadata.getRateLimit());}return builder.build();}/*** Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.* @param chunk the ChatCompletionChunk to convert* @return the ChatCompletion*/private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionChunk chunk) {List<OpenAiApi.ChatCompletion.Choice> choices = chunk.choices().stream().map(chunkChoice -> new OpenAiApi.ChatCompletion.Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(),chunkChoice.logprobs())).toList();return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(),chunk.systemFingerprint(), "chat.completion", chunk.usage());}private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) {return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);}Prompt buildRequestPrompt(Prompt prompt) {// Process runtime optionsOpenAiChatOptions runtimeOptions = null;if (prompt.getOptions() != null) {if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,OpenAiChatOptions.class);}else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,OpenAiChatOptions.class);}else {runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,OpenAiChatOptions.class);}}// Define request options by merging runtime options and default optionsOpenAiChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,OpenAiChatOptions.class);// Merge @JsonIgnore-annotated options explicitly since they are ignored by// Jackson, used by ModelOptionsUtils.if (runtimeOptions != null) {requestOptions.setHttpHeaders(mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));requestOptions.setInternalToolExecutionEnabled(ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(),this.defaultOptions.isInternalToolExecutionEnabled()));requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),this.defaultOptions.getToolNames()));requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),this.defaultOptions.getToolCallbacks()));requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),this.defaultOptions.getToolContext()));}else {requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled());requestOptions.setToolNames(this.defaultOptions.getToolNames());requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());requestOptions.setToolContext(this.defaultOptions.getToolContext());}ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());return new Prompt(prompt.getInstructions(), requestOptions);}private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHeaders,Map<String, String> defaultHttpHeaders) {var mergedHttpHeaders = new HashMap<>(defaultHttpHeaders);mergedHttpHeaders.putAll(runtimeHttpHeaders);return mergedHttpHeaders;}/*** Accessible for testing.*/OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {Object content = message.getText();if (message instanceof UserMessage userMessage) {if (!CollectionUtils.isEmpty(userMessage.getMedia())) {List<OpenAiApi.ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(List.of(new OpenAiApi.ChatCompletionMessage.MediaContent(message.getText())));contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());content = contentList;}}return List.of(new OpenAiApi.ChatCompletionMessage(content,OpenAiApi.ChatCompletionMessage.Role.valueOf(message.getMessageType().name())));}else if (message.getMessageType() == MessageType.ASSISTANT) {var assistantMessage = (AssistantMessage) message;List<OpenAiApi.ChatCompletionMessage.ToolCall> toolCalls = null;if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {var function = new OpenAiApi.ChatCompletionMessage.ChatCompletionFunction(toolCall.name(), toolCall.arguments());return new OpenAiApi.ChatCompletionMessage.ToolCall(toolCall.id(), toolCall.type(), function);}).toList();}OpenAiApi.ChatCompletionMessage.AudioOutput audioOutput = null;if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) {Assert.isTrue(assistantMessage.getMedia().size() == 1,"Only one media content is supported for assistant messages");audioOutput = new OpenAiApi.ChatCompletionMessage.AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null);}return List.of(new OpenAiApi.ChatCompletionMessage(assistantMessage.getText(),OpenAiApi.ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));}else if (message.getMessageType() == MessageType.TOOL) {ToolResponseMessage toolMessage = (ToolResponseMessage) message;toolMessage.getResponses().forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));return toolMessage.getResponses().stream().map(tr -> new OpenAiApi.ChatCompletionMessage(tr.responseData(), OpenAiApi.ChatCompletionMessage.Role.TOOL, tr.name(),tr.id(), null, null, null)).toList();}else {throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());}}).flatMap(List::stream).toList();OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);OpenAiChatOptions requestOptions = (OpenAiChatOptions) prompt.getOptions();request = ModelOptionsUtils.merge(requestOptions, request, OpenAiApi.ChatCompletionRequest.class);// Add the tool definitions to the request's tools parameter.List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);if (!CollectionUtils.isEmpty(toolDefinitions)) {request = ModelOptionsUtils.merge(OpenAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request,OpenAiApi.ChatCompletionRequest.class);}// Remove `streamOptions` from the request if it is not a streaming requestif (request.streamOptions() != null && !stream) {logger.warn("Removing streamOptions from the request as it is not a streaming request!");request = request.streamOptions(null);}return request;}private OpenAiApi.ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {var mimeType = media.getMimeType();if (MimeTypeUtils.parseMimeType("audio/mp3").equals(mimeType) || MimeTypeUtils.parseMimeType("audio/mpeg").equals(mimeType)) {return new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio(fromAudioData(media.getData()), OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio.Format.MP3));}if (MimeTypeUtils.parseMimeType("audio/wav").equals(mimeType)) {return new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio(fromAudioData(media.getData()), OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio.Format.WAV));}else {return new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())));}}private String fromAudioData(Object audioData) {if (audioData instanceof byte[] bytes) {return String.format("data:;base64,%s", Base64.getEncoder().encodeToString(bytes));}throw new IllegalArgumentException("Unsupported audio data type: " + audioData.getClass().getSimpleName());}private String fromMediaData(MimeType mimeType, Object mediaContentData) {if (mediaContentData instanceof byte[] bytes) {// Assume the bytes are an image. So, convert the bytes to a base64 encoded// following the prefix pattern.return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));}else if (mediaContentData instanceof String text) {// Assume the text is a URLs or a base64 encoded image prefixed by the user.return text;}else {throw new IllegalArgumentException("Unsupported media data type: " + mediaContentData.getClass().getSimpleName());}}private List<OpenAiApi.FunctionTool> getFunctionTools(List<ToolDefinition> toolDefinitions) {return toolDefinitions.stream().map(toolDefinition -> {var function = new OpenAiApi.FunctionTool.Function(toolDefinition.description(), toolDefinition.name(),toolDefinition.inputSchema());return new OpenAiApi.FunctionTool(function);}).toList();}@Overridepublic ChatOptions getDefaultOptions() {return OpenAiChatOptions.fromOptions(this.defaultOptions);}@Overridepublic String toString() {return "AlibabaOpenAiChatModel [defaultOptions=" + this.defaultOptions + "]";}/*** Use the provided convention for reporting observation data* @param observationConvention The provided convention*/public void setObservationConvention(ChatModelObservationConvention observationConvention) {Assert.notNull(observationConvention, "observationConvention cannot be null");this.observationConvention = observationConvention;}public static AlibabaOpenAiChatModel.Builder builder() {return new AlibabaOpenAiChatModel.Builder();}public static final class Builder {private OpenAiApi openAiApi;private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder().model(OpenAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build();private ToolCallingManager toolCallingManager;private FunctionCallbackResolver functionCallbackResolver;private List<FunctionCallback> toolFunctionCallbacks;private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;private Builder() {}public AlibabaOpenAiChatModel.Builder openAiApi(OpenAiApi openAiApi) {this.openAiApi = openAiApi;return this;}public AlibabaOpenAiChatModel.Builder defaultOptions(OpenAiChatOptions defaultOptions) {this.defaultOptions = defaultOptions;return this;}public AlibabaOpenAiChatModel.Builder toolCallingManager(ToolCallingManager toolCallingManager) {this.toolCallingManager = toolCallingManager;return this;}@Deprecatedpublic AlibabaOpenAiChatModel.Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {this.functionCallbackResolver = functionCallbackResolver;return this;}@Deprecatedpublic AlibabaOpenAiChatModel.Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {this.toolFunctionCallbacks = toolFunctionCallbacks;return this;}public AlibabaOpenAiChatModel.Builder retryTemplate(RetryTemplate retryTemplate) {this.retryTemplate = retryTemplate;return this;}public AlibabaOpenAiChatModel.Builder observationRegistry(ObservationRegistry observationRegistry) {this.observationRegistry = observationRegistry;return this;}public AlibabaOpenAiChatModel build() {if (toolCallingManager != null) {Assert.isNull(functionCallbackResolver,"functionCallbackResolver cannot be set when toolCallingManager is set");Assert.isNull(toolFunctionCallbacks,"toolFunctionCallbacks cannot be set when toolCallingManager is set");return new AlibabaOpenAiChatModel(openAiApi, defaultOptions, toolCallingManager, retryTemplate,observationRegistry);}if (functionCallbackResolver != null) {Assert.isNull(toolCallingManager,"toolCallingManager cannot be set when functionCallbackResolver is set");List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks: List.of();return new AlibabaOpenAiChatModel(openAiApi, defaultOptions, functionCallbackResolver, toolCallbacks,retryTemplate, observationRegistry);}return new AlibabaOpenAiChatModel(openAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,observationRegistry);}}}
接下來,我們要把AliababaOpenAiChatModel
配置到Spring容器。
修改CommonConfiguration
,添加配置:
@Bean
public AlibabaOpenAiChatModel alibabaOpenAiChatModel(OpenAiConnectionProperties commonProperties, OpenAiChatProperties chatProperties, ObjectProvider<RestClient.Builder> restClientBuilderProvider, ObjectProvider<WebClient.Builder> webClientBuilderProvider, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, ObjectProvider<ObservationRegistry> observationRegistry, ObjectProvider<ChatModelObservationConvention> observationConvention) {String baseUrl = StringUtils.hasText(chatProperties.getBaseUrl()) ? chatProperties.getBaseUrl() : commonProperties.getBaseUrl();String apiKey = StringUtils.hasText(chatProperties.getApiKey()) ? chatProperties.getApiKey() : commonProperties.getApiKey();String projectId = StringUtils.hasText(chatProperties.getProjectId()) ? chatProperties.getProjectId() : commonProperties.getProjectId();String organizationId = StringUtils.hasText(chatProperties.getOrganizationId()) ? chatProperties.getOrganizationId() : commonProperties.getOrganizationId();Map<String, List<String>> connectionHeaders = new HashMap<>();if (StringUtils.hasText(projectId)) {connectionHeaders.put("OpenAI-Project", List.of(projectId));}if (StringUtils.hasText(organizationId)) {connectionHeaders.put("OpenAI-Organization", List.of(organizationId));}RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder);WebClient.Builder webClientBuilder = webClientBuilderProvider.getIfAvailable(WebClient::builder);OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(baseUrl).apiKey(new SimpleApiKey(apiKey)).headers(CollectionUtils.toMultiValueMap(connectionHeaders)).completionsPath(chatProperties.getCompletionsPath()).embeddingsPath("/v1/embeddings").restClientBuilder(restClientBuilder).webClientBuilder(webClientBuilder).responseErrorHandler(responseErrorHandler).build();AlibabaOpenAiChatModel chatModel = AlibabaOpenAiChatModel.builder().openAiApi(openAiApi).defaultOptions(chatProperties.getOptions()).toolCallingManager(toolCallingManager).retryTemplate(retryTemplate).observationRegistry((ObservationRegistry)observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)).build();Objects.requireNonNull(chatModel);observationConvention.ifAvailable(chatModel::setObservationConvention);return chatModel;
}
最后,讓之前的ChatClient
都使用自定義的AlibabaOpenAiChatModel
.
修改CommonConfiguration
中的ChatClient配置:
@Bean
public ChatClient chatClient(AlibabaOpenAiChatModel model, ChatMemory chatMemory) {return ChatClient.builder(model) // 創建ChatClient工廠實例.defaultOptions(ChatOptions.builder().model("qwen-omni-turbo").build()).defaultSystem("。請以友好、樂于助人和愉快的方式解答用戶的各種問題。").defaultAdvisors(new SimpleLoggerAdvisor()) // 添加默認的Advisor,記錄日志.defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory)).build(); // 構建ChatClient實例}@Bean
public ChatClient serviceChatClient(AlibabaOpenAiChatModel model,ChatMemory chatMemory,CourseTools courseTools) {return ChatClient.builder(model).defaultSystem(CUSTOMER_SERVICE_SYSTEM).defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORYnew SimpleLoggerAdvisor()).defaultTools(courseTools).build();
}
4.RAG
由于訓練大模型非常耗時,再加上訓練語料本身比較滯后,所以大模型存在知識限制問題:
-
知識數據比較落后,往往是幾個月之前的
-
不包含太過專業領域或者企業私有的數據
為了解決這些問題,我們就需要用到RAG了。
4.1 RAG原理
實現思路是給大模型外掛一個知識庫,可以是專業領域知識,也可以是企業私有的數據。因為通常知識庫數據量都是非常大的,而大模型的上下文是有大小限制的,早期的GPT上下文不能超過2000token,現在也不到200k token,因此知識庫不能直接寫在提示詞中。所以,我們需要想辦法從龐大的知識庫中找到與用戶問題相關的一小部分,組裝成提示詞,發送給大模型就可以了。
4.1.1 向量模型
向量是空間中有方向和長度的量,空間可以是二維,也可以是多維。向量既然是在空間中,兩個向量之間就一定能計算距離。我們以二維向量為例,向量之間的距離有兩種計算方法:
通常,兩個向量之間歐式距離越近,我們認為兩個向量的相似度越高。(余弦距離相反,越大相似度越高)所以,如果我們能把文本轉為向量,就可以通過向量距離來判斷文本的相似度了。現在,有不少的專門的向量模型,就可以實現將文本向量化。一個好的向量模型,就是要盡可能讓文本含義相似的向量,在空間中距離更近:
接下來,我們就準備一個向量模型,用于將文本向量化。
阿里云百煉平臺就提供了這樣的模型:
這里我們選擇通用文本向量-v3
,這個模型兼容OpenAI,所以我們依然采用OpenAI的配置。
修改配置文件,添加向量模型:
server:ai:openai:base-url: ${spring.ai.openai.base-url}api-key: ${spring.ai.openai.api-key}chat:options:model: qwen-max-latestembedding:options:model: text-embedding-v3dimensions: 1024vectorstore:redis:initialize-schema: trueindex: 0prefix: "doc:" # 向量庫key前綴
4.1.2 向量數據庫
向量數據庫的主要作用有兩個:
-
存儲向量數據
-
基于相似度檢索數據
剛好符合我們的需求。
SpringAI支持很多向量數據庫,并且都進行了封裝,可以用統一的API去訪問:
-
Azure Vector Search - The Azure vector store.
-
Apache Cassandra - The Apache Cassandra vector store.
-
Chroma Vector Store - The Chroma vector store.
-
Elasticsearch Vector Store - The Elasticsearch vector store.
-
GemFire Vector Store - The GemFire vector store.
-
MariaDB Vector Store - The MariaDB vector store.
-
Milvus Vector Store - The Milvus vector store.
-
MongoDB Atlas Vector Store - The MongoDB Atlas vector store.
-
Neo4j Vector Store - The Neo4j vector store.
-
OpenSearch Vector Store - The OpenSearch vector store.
-
Oracle Vector Store - The Oracle Database vector store.
-
PgVector Store - The PostgreSQL/PGVector vector store.
-
Pinecone Vector Store - PineCone vector store.
-
Qdrant Vector Store - Qdrant vector store.
-
Redis Vector Store - The Redis vector store.
-
SAP Hana Vector Store - The SAP HANA vector store.
-
Typesense Vector Store - The Typesense vector store.
-
Weaviate Vector Store - The Weaviate vector store.
-
SimpleVectorStore - A simple implementation of persistent vector storage, good for educational purposes.
這些庫都實現了統一的接口:VectorStore
,因此操作方式一模一樣,大家學會任意一個,其它就都不是問題。
不過,除了最后一個庫以外,其它所有向量數據庫都是需要安裝部署的。每個企業用的向量庫都不一樣。
4.2 VectorStore
VectorStore接口:
public interface VectorStore extends DocumentWriter {default String getName() {return this.getClass().getSimpleName();}// 保存文檔到向量庫void add(List<Document> documents);// 根據文檔id刪除文檔void delete(List<String> idList);void delete(Filter.Expression filterExpression);default void delete(String filterExpression) { ... };// 根據條件檢索文檔List<Document> similaritySearch(String query);// 根據條件檢索文檔List<Document> similaritySearch(SearchRequest request);default <T> Optional<T> getNativeClient() {return Optional.empty();}
}
VectorStore
操作向量化的基本單位是Document
,我們在使用時需要將自己的知識庫分割轉換為一個個的Document
,然后寫入VectorStore
.
基于內存或Redis-Stack實現向量數據庫:
@Beanpublic VectorStore vectorStore(OpenAiEmbeddingModel embeddingModel) {return SimpleVectorStore.builder(embeddingModel).build();}/*** 創建RedisStack向量數據庫** @param embeddingModel 嵌入模型* @param properties redis-stack的配置信息* @return vectorStore 向量數據庫*/@Beanpublic VectorStore vectorStore(EmbeddingModel embeddingModel,RedisVectorStoreProperties properties,RedisConnectionDetails redisConnectionDetails) {JedisPooled jedisPooled = new JedisPooled(redisConnectionDetails.getStandalone().getHost(),redisConnectionDetails.getStandalone().getPort(), redisConnectionDetails.getUsername(),redisConnectionDetails.getPassword());return RedisVectorStore.builder(jedisPooled, embeddingModel).indexName(properties.getIndex()).prefix(properties.getPrefix()).initializeSchema(properties.isInitializeSchema()).build();}
文件讀取和轉換:
知識庫太大,是需要拆分成文檔片段,然后再做向量化的。而且SpringAI中向量庫接收的是Document類型的文檔,也就是說,我們處理文
文檔讀取、拆分、轉換的動作并不需要我們親自完成。在SpringAI中提供了各種文檔讀取的工具,可以參考官網:https://docs.spring.io/spring-ai/reference/api/etl-pipeline.html#_pdf_paragraph
比如PDF文檔讀取和拆分,SpringAI提供了兩種默認的拆分原則:
-
PagePdfDocumentReader
:按頁拆分,推薦使用 -
ParagraphPdfDocumentReader
:按pdf的目錄拆分,不推薦,因為很多PDF不規范,沒有章節標簽
當然,大家也可以自己實現PDF的讀取和拆分功能。
這里我們選擇使用PagePdfDocumentReader
。
首先,我們需要在pom.xml中引入依賴:
<dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-pdf-document-reader</artifactId>
</dependency>
然后就可以利用工具把PDF文件讀取并處理成Document了。
我們寫一個單元測試(別忘了配置API_KEY):
@Test
public void testVectorStore(){Resource resource = new FileSystemResource("中二知識筆記.pdf");// 1.創建PDF的讀取器PagePdfDocumentReader reader = new PagePdfDocumentReader(resource, // 文件源PdfDocumentReaderConfig.builder().withPageExtractedTextFormatter(ExtractedTextFormatter.defaults()).withPagesPerDocument(1) // 每1頁PDF作為一個Document.build());// 2.讀取PDF文檔,拆分為DocumentList<Document> documents = reader.read();// 3.寫入向量庫vectorStore.add(documents);// 4.搜索SearchRequest request = SearchRequest.builder().query("論語中教育的目的是什么").topK(1).similarityThreshold(0.6).filterExpression("file_name == '中二知識筆記.pdf'").build();List<Document> docs = vectorStore.similaritySearch(request);if (docs == null) {System.out.println("沒有搜索到任何內容");return;}for (Document doc : docs) {System.out.println(doc.getId());System.out.println(doc.getScore());System.out.println(doc.getText());}
}
4.3 RAG原理總結
OK,現在我們有了這些工具:
-
PDFReader:讀取文檔并拆分為片段
-
向量大模型:將文本片段向量化
-
向量數據庫:存儲向量,檢索向量
讓我們梳理一下要解決的問題和解決思路:
-
要解決大模型的知識限制問題,需要外掛知識庫
-
受到大模型上下文限制,知識庫不能簡單的直接拼接在提示詞中
-
我們需要從龐大的知識庫中找到與用戶問題相關的一小部分,再組裝成提示詞
-
這些可以利用文檔讀取器、向量大模型、向量數據庫來解決。
所以RAG要做的事情就是將知識庫分割,然后利用向量模型做向量化,存入向量數據庫,然后查詢的時候去檢索:
第一階段(存儲知識庫):
-
將知識庫內容切片,分為一個個片段
-
將每個片段利用向量模型向量化
-
將所有向量化后的片段寫入向量數據庫
第二階段(檢索知識庫):
-
每當用戶詢問AI時,將用戶問題向量化
-
拿著問題向量去向量數據庫檢索最相關的片段
第三階段(對話大模型):
-
將檢索到的片段、用戶的問題一起拼接為提示詞
-
發送提示詞給大模型,得到響應
4.4 AI文獻閱讀助手實例
基于RAG實現一個AI文獻閱讀助手
整體架構:
4.4.1 PDF文件管理
文件管理接口:
public interface FileRepository {/*** 保存文件,還要記錄chatId與文件的映射關系* @param chatId 會話id* @param resource 文件* @return 上傳成功,返回true; 否則返回false*/boolean save(String chatId, Resource resource);/*** 根據chatId獲取文件* @param chatId 會話id* @return 找到的文件*/Resource getFile(String chatId);}
@Slf4j
@Component
public class LocalPdfFileRepository implements FileRepository {@Autowiredprivate CommonFileServiceImpl commonFileService;@Autowiredprivate PdfFileMappingMapper fileMappingMapper;@Autowiredprivate FileUtil fileUtil;@Autowiredprivate VectorStore vectorStore;/*** 保存文件到MinIO并記錄映射關系到MySQL*/@Overridepublic boolean save(String chatId, Resource resource) {try {// 轉換Resource為MultipartFileMultipartFile file = convertResourceToMultipartFile(resource);if (file == null) {log.error("文件轉換失敗,chatId:{}", chatId);return false;}// 上傳到MinIOString fileUrl = commonFileService.upload(file);// 保存新記錄到數據庫FileMapping mapping = FileMapping.builder().chatId(chatId).fileName(file.getOriginalFilename()).filePath(fileUrl).contentType(file.getContentType()).build();int rows = fileMappingMapper.insert(mapping);return rows > 0;} catch (Exception e) {log.error("保存文件映射失敗,chatId:{}", chatId, e);return false;}}/*** 從MinIO獲取文件*/@Overridepublic Resource getFile(String chatId) {try {// 查詢數據庫獲取文件信息FileMapping mapping = fileMappingMapper.selectByChatId(chatId);if (mapping == null) {log.warn("文件映射不存在,chatId:{}", chatId);return null;}// 從MinIO下載文件String fileName=fileUtil.extractFileNameFromUrl(mapping.getFilePath());byte[] fileBytes = commonFileService.download(fileName);if (fileBytes == null || fileBytes.length == 0) {log.error("文件內容為空,filePath:{}", mapping.getFilePath());return null;}// 轉換為Resource返回return new ByteArrayResource(fileBytes) {@Overridepublic String getFilename() {return mapping.getFileName();}@Overridepublic long contentLength() {return fileBytes.length;}};} catch (Exception e) {log.error("獲取文件失敗,chatId:{}", chatId, e);return null;}}/*** 轉換Resource為MultipartFile(解決私有類和類型問題)*/private MultipartFile convertResourceToMultipartFile(Resource resource) throws IOException {// 獲取文件名String filename = Optional.ofNullable(resource.getFilename()).orElse("temp-" + UUID.randomUUID() + ".pdf");// 獲取文件類型(解決Resource無getContentType()的問題)String contentType = null;if (resource.exists()) {// 嘗試通過文件路徑探測類型try {contentType = Files.probeContentType(resource.getFile().toPath());} catch (IOException e) {log.warn("通過文件路徑獲取類型失敗,chatId:{}", e);}// 兜底:使用默認PDF類型if (contentType == null) {contentType = MediaType.APPLICATION_PDF_VALUE;}} else {contentType = MediaType.APPLICATION_OCTET_STREAM_VALUE;}// 讀取文件內容為字節數組byte[] content = FileCopyUtils.copyToByteArray(resource.getInputStream());// 自定義MultipartFile實現(避免使用私有內部類)String finalContentType = contentType;return new MultipartFile() {@Overridepublic String getName() {return "file"; // 參數名,可自定義}@Overridepublic String getOriginalFilename() {return filename;}@Overridepublic String getContentType() {return finalContentType;}@Overridepublic boolean isEmpty() {return content.length == 0;}@Overridepublic long getSize() {return content.length;}@Overridepublic byte[] getBytes() throws IOException {return content;}@Overridepublic InputStream getInputStream() throws IOException {return new ByteArrayInputStream(content);}@Overridepublic void transferTo(File dest) throws IOException, IllegalStateException {FileCopyUtils.copy(content, dest);}};}/*** 初始化:加載向量存儲*/@PostConstructprivate void init() {try {File vectorFile = new File("chat-pdf.json");if (vectorFile.exists() && vectorStore instanceof SimpleVectorStore) {((SimpleVectorStore) vectorStore).load(vectorFile);log.info("向量存儲已加載");}} catch (Exception e) {log.error("初始化向量存儲失敗", e);}}/*** 銷毀前:保存向量存儲*/@PreDestroyprivate void persistent() {try {if (vectorStore instanceof SimpleVectorStore) {((SimpleVectorStore) vectorStore).save(new File("chat-pdf.json"));log.info("向量存儲已保存");}} catch (Exception e) {log.error("保存向量存儲失敗", e);}}
}
4.4.2 文獻閱讀助手
/*** 文件閱讀助手* @param* @param* @return*/@PostMapping(value = "/chat", produces = "text/html;charset=utf-8")@Operation(summary = "文件閱讀助手")public Flux<String> chat(@RequestBody PDFDto pdfDto) {String prompt = pdfDto.getPrompt();String chatId = pdfDto.getChatId();// 1.找到會話文件Resource file = fileRepository.getFile(chatId);if (!file.exists()) {// 文件不存在,不回答throw new RuntimeException("會話文件不存在!");}// 2.保存會話idchatHistoryRepository.save("pdf", chatId);// 3.請求模型return pdfChatClient.prompt().user(prompt).advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)).advisors(a -> a.param(FILTER_EXPRESSION, "file_name == '" + file.getFilename() + "'")).stream().content();}/*** 閱讀助手文件上傳*/@PostMapping("/upload/{chatId}")@Operation(summary = "閱讀助手文件上傳")public Result uploadPdf(@PathVariable String chatId, @RequestParam("file") MultipartFile file) {try {// 1. 校驗文件是否為PDF格式if (!Objects.equals(file.getContentType(), "application/pdf")) {return Result.fail("只能上傳PDF文件!");}// 2.保存文件boolean success = fileRepository.save(chatId, file.getResource());if (!success) {return Result.fail("保存文件失敗!");}// 3.寫入向量庫this.writeToVectorStore(file.getResource());return Result.success();} catch (Exception e) {log.error("Failed to upload PDF.", e);return Result.fail("上傳文件失敗!");}}/*** 閱讀助手文件下載*/@GetMapping("/file/{chatId}")@Operation(summary = "閱讀助手文件下載")public ResponseEntity<Resource> download(@PathVariable("chatId") String chatId) throws IOException {// 1.讀取文件Resource resource = fileRepository.getFile(chatId);if (!resource.exists()) {return ResponseEntity.notFound().build();}// 2.文件名編碼,寫入響應頭String filename = URLEncoder.encode(Objects.requireNonNull(resource.getFilename()), StandardCharsets.UTF_8);// 3.返回文件return ResponseEntity.ok().contentType(MediaType.APPLICATION_OCTET_STREAM).header("Content-Disposition", "attachment; filename=\"" + filename + "\"").body(resource);}// private void writeToVectorStore(Resource resource) {
// // 1.創建PDF的讀取器
// PagePdfDocumentReader reader = new PagePdfDocumentReader(
// resource, // 文件源
// PdfDocumentReaderConfig.builder()
// .withPageExtractedTextFormatter(ExtractedTextFormatter.defaults())
// .withPagesPerDocument(1) // 每1頁PDF作為一個Document
// .build()
// );
// // 2.讀取PDF文檔,拆分為Document
// List<Document> documents = reader.read();
// // 3.寫入向量庫
// vectorStore.add(documents);
// }private void writeToVectorStore(Resource resource) {try {// 使用 Tika 解析 PDF 內容ContentHandler handler = new BodyContentHandler(-1); // 不限制內容長度Metadata metadata = new Metadata();ParseContext context = new ParseContext();PDFParser parser = new PDFParser();// 解析 PDF 并提取文本parser.parse(resource.getInputStream(), handler, metadata, context);String content = handler.toString();// 關鍵修復:將 Metadata 轉換為 MapMap<String, Object> metadataMap = new HashMap<>();for (String name : metadata.names()) {metadataMap.put(name, metadata.get(name));}// 補充文件名到元數據(可選)metadataMap.put("file_name", resource.getFilename());// 創建 Document 并寫入向量庫Document document = new Document(resource.getFilename(), // 文檔 IDcontent, // 提取的文本內容metadataMap // 轉換后的元數據 Map);vectorStore.add(List.of(document));} catch (Exception e) {log.error("Failed to parse PDF with Tika", e);throw new RuntimeException("解析PDF失敗", e);}}
ChatClient配置:
/*** AI文獻閱讀助手* @param model* @param chatMemory* @param vectorStore* @return*/@Beanpublic ChatClient pdfChatClient(OpenAiChatModel model, ChatMemory chatMemory, VectorStore vectorStore) {return ChatClient.builder(model).defaultSystem("請根據上下文回答問題,遇到上下文沒有的問題,不要隨意編造。").defaultAdvisors(new SimpleLoggerAdvisor(),new MessageChatMemoryAdvisor(chatMemory),new QuestionAnswerAdvisor(vectorStore,SearchRequest.builder().similarityThreshold(0.6).topK(2).build())).build();}
5.多模態
多模態是指不同類型的數據輸入,如文本、圖像、聲音、視頻等。目前為止,我們與大模型交互都是基于普通文本輸入,這跟我們選擇的大模型有關。deepseek、qwen-max等模型都是純文本模型,在ollama和百煉平臺,我們也能找到很多多模態模型。
以ollama為例,在搜索時點擊vison,就能找到支持圖像識別的模型:
在阿里云百煉平臺也一樣:
阿里云的qwen-omni模型是支持文本、圖像、音頻、視頻輸入的全模態模型,還能支持語音合成功能,非常強大。
注意:
在SpringAI的當前版本(1.0.0-m6)中,qwen-omni與SpringAI中的OpenAI模塊的兼容性有問題,目前僅支持文本和圖片兩種模態。音頻會有數據格式錯誤問題,視頻完全不支持。
目前的解決方案有兩種:
-
一是使用spring-ai-alibaba來替代。
-
二是重寫OpenAIModel的實現
多模態Agent實例:
/*** 智能體對話多模態助手* @param model* @param chatMemory* @param vectorStore* @return*/@Beanpublic ChatClient chatCommonClient(AlibabaOpenAiChatModel model, ChatMemory chatMemory,VectorStore vectorStore, AINoteTools aiNoteTools) {return ChatClient.builder(model).defaultOptions(ChatOptions.builder().model("qwen-omni-turbo").build()).defaultSystem(AIChatConstant.CHAT_ROLE).defaultAdvisors(new SimpleLoggerAdvisor(),new MessageChatMemoryAdvisor(chatMemory),new QuestionAnswerAdvisor(vectorStore,SearchRequest.builder().similarityThreshold(0.6).topK(2).build())).defaultTools(List.of(aiNoteTools)).build();}
/*** 智能體對話* @param prompt* @param chatId* @param files* @return*/@PostMapping(value = "/commonChat", produces = "text/html;charset=utf-8")@Operation(summary = "智能體對話")public Flux<String> chat(@RequestParam("prompt") String prompt, @RequestParam(required = false) String chatId,@RequestParam(value = "files", required = false) List<MultipartFile> files) {// 1.保存會話,idchatHistoryRepository.save("chat", chatId);// 2.請求模型if (files == null || files.isEmpty()) {// 沒有附件,純文本聊天return textChat(prompt, chatId);} else {// 有附件,多模態聊天return multiModalChat(prompt, chatId, files);}}private Flux<String> multiModalChat(String prompt, String chatId, List<MultipartFile> files) {// 1.解析多媒體List<Media> medias = files.stream().map(file -> new Media(MimeType.valueOf(Objects.requireNonNull(file.getContentType())),file.getResource())).toList();// 2.請求模型return chatCommonClient.prompt().user(p -> p.text(prompt).media(medias.toArray(Media[]::new))).advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)).stream().content();}private Flux<String> textChat(String prompt, String chatId) {return chatCommonClient.prompt().user(prompt).advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)).stream().content();}