diff --git a/kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/service/impl/ImContentServiceImpl.java b/kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/service/impl/ImContentServiceImpl.java index 979e160f..cdb155b8 100644 --- a/kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/service/impl/ImContentServiceImpl.java +++ b/kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/service/impl/ImContentServiceImpl.java @@ -18,7 +18,6 @@ import com.cloud.kicc.system.mapper.ImContentMapper; import com.cloud.kicc.system.service.FileService; import com.cloud.kicc.system.service.IImContentService; import com.cloud.kicc.system.util.AiUtil; -import com.knuddels.jtokkit.api.ModelType; import com.pig4cloud.plugin.oss.OssProperties; import com.pig4cloud.plugin.oss.service.OssTemplate; import com.theokanning.openai.audio.CreateSpeechRequest; @@ -29,7 +28,6 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.completion.chat.ChatMessageRole; import com.theokanning.openai.service.OpenAiService; -import com.theokanning.openai.utils.TikTokensUtil; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import okhttp3.ResponseBody; @@ -94,7 +92,7 @@ public class ImContentServiceImpl extends ServiceImpl historyMessages = imContents.stream().map(item -> { @@ -103,8 +101,6 @@ public class ImContentServiceImpl extends ServiceImpl messages) { - ModelType modelType = ModelType.fromName(modelName) - .orElseThrow(() -> new CheckedException(String.format("找不到指定的:%s模型请检查配置!", modelName)));; - int sumTokens = TikTokensUtil.tokens(modelName, messages); - // 从前向后遍历消息,直到总 token 数在限制之内 - while (!messages.isEmpty() && sumTokens > modelType.getMaxContextLength()) { - // 移除列表中的第一个消息 - messages.remove(0); - // 重新计算总 token 数 - sumTokens = TikTokensUtil.tokens(modelName, messages); - } - } - } diff --git a/kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/util/AiUtil.java b/kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/util/AiUtil.java index bdc9ea90..115b7f4a 100644 --- a/kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/util/AiUtil.java +++ b/kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/util/AiUtil.java @@ -1,14 +1,17 @@ package com.cloud.kicc.system.util; +import com.cloud.kicc.common.core.exception.CheckedException; import com.cloud.kicc.common.core.util.SpringContextHolderUtil; import com.cloud.kicc.common.security.util.SecurityUtils; import com.cloud.kicc.system.config.OpenAiConfigProperties; +import com.knuddels.jtokkit.api.ModelType; import com.theokanning.openai.client.OpenAiApi; import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.completion.chat.ChatMessageRole; import com.theokanning.openai.service.OpenAiService; +import com.theokanning.openai.utils.TikTokensUtil; import lombok.experimental.UtilityClass; import okhttp3.OkHttpClient; import retrofit2.Retrofit; @@ -56,6 +59,7 @@ public class AiUtil { ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), openAiConfigProperties.getSystemRule()); messages.add(systemMessage); messages.addAll(historyMessages); + truncateToTokenLimit(openAiConfigProperties.getModel(), historyMessages); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(openAiConfigProperties.getModel()) .messages(messages) @@ -69,4 +73,27 @@ public class AiUtil { return getOpenAiService().createChatCompletion(request); } + + /** + * ... + * 截断支持的上下文消息列表到令牌限制 + * @param messages 上下文消息列表 + * @param modelName 指定模型 + */ + public void truncateToTokenLimit(String modelName, List messages) { + ModelType modelType = ModelType.fromName(modelName) + .orElseThrow(() -> new CheckedException(String.format("找不到指定的:%s模型请检查配置!", modelName)));; + int sumTokens = TikTokensUtil.tokens(modelName, messages); + // 确保至少有一条消息(系统消息)保持不变 + if (!messages.isEmpty()) { + // 从前向后遍历消息(从第二条开始),直到总 token 数在限制之内 + while (messages.size() > 1 && sumTokens > modelType.getMaxContextLength()) { + // 移除列表中的第二条及以后的消息 + messages.remove(1); + // 重新计算总 token 数 + sumTokens = TikTokensUtil.tokens(modelName, messages); + } + } + } + }