Browse Source

fix: 模型修复token限制

master
wangxiang 2 years ago
parent
commit
e9429a033a
No known key found for this signature in database
GPG Key ID: 1BA7946AB6B232E4
  1. 25
      kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/service/impl/ImContentServiceImpl.java
  2. 27
      kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/util/AiUtil.java

25
kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/service/impl/ImContentServiceImpl.java

@ -18,7 +18,6 @@ import com.cloud.kicc.system.mapper.ImContentMapper; @@ -18,7 +18,6 @@ import com.cloud.kicc.system.mapper.ImContentMapper;
import com.cloud.kicc.system.service.FileService;
import com.cloud.kicc.system.service.IImContentService;
import com.cloud.kicc.system.util.AiUtil;
import com.knuddels.jtokkit.api.ModelType;
import com.pig4cloud.plugin.oss.OssProperties;
import com.pig4cloud.plugin.oss.service.OssTemplate;
import com.theokanning.openai.audio.CreateSpeechRequest;
@ -29,7 +28,6 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult; @@ -29,7 +28,6 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.service.OpenAiService;
import com.theokanning.openai.utils.TikTokensUtil;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import okhttp3.ResponseBody;
@ -94,7 +92,7 @@ public class ImContentServiceImpl extends ServiceImpl<ImContentMapper, ImContent @@ -94,7 +92,7 @@ public class ImContentServiceImpl extends ServiceImpl<ImContentMapper, ImContent
i.eq(ImContent::getSendUserId, imContent.getReceiveUserId())
.eq(ImContent::getReceiveUserId, imContent.getSendUserId())
).last(String.format("LIMIT %s", openAiConfigProperties.getContextMessageMaxCount()))
.orderByAsc(ImContent::getSendTime));
.orderByDesc(ImContent::getSendTime));
// 历史上下文消息
List<ChatMessage> historyMessages = imContents.stream().map(item -> {
@ -103,8 +101,6 @@ public class ImContentServiceImpl extends ServiceImpl<ImContentMapper, ImContent @@ -103,8 +101,6 @@ public class ImContentServiceImpl extends ServiceImpl<ImContentMapper, ImContent
chatMessage.setContent(item.getContent());
return chatMessage;
}).collect(Collectors.toList());
// 截断支持的上下文消息列表到令牌限制
truncateToTokenLimit(openAiConfigProperties.getModel(), historyMessages);
// Completion消息
ChatMessage completionMessage = new ChatMessage(ChatMessageRole.USER.value(), imContent.getContent());
@ -265,23 +261,4 @@ public class ImContentServiceImpl extends ServiceImpl<ImContentMapper, ImContent @@ -265,23 +261,4 @@ public class ImContentServiceImpl extends ServiceImpl<ImContentMapper, ImContent
return String.format("请过 %s小时%s分钟%s秒 之后在试!",hours,minutes,seconds);
}
/**
* 截断支持的上下文消息列表到令牌限制
* @param messages 上下文消息列表
* @param modelName 指定模型
*/
private void truncateToTokenLimit(String modelName, List<ChatMessage> messages) {
ModelType modelType = ModelType.fromName(modelName)
.orElseThrow(() -> new CheckedException(String.format("找不到指定的:%s模型请检查配置!", modelName)));;
int sumTokens = TikTokensUtil.tokens(modelName, messages);
// 从前向后遍历消息,直到总 token 数在限制之内
while (!messages.isEmpty() && sumTokens > modelType.getMaxContextLength()) {
// 移除列表中的第一个消息
messages.remove(0);
// 重新计算总 token 数
sumTokens = TikTokensUtil.tokens(modelName, messages);
}
}
}

27
kicc-platform/kicc-platform-biz/kicc-system-biz/src/main/java/com/cloud/kicc/system/util/AiUtil.java

@ -1,14 +1,17 @@ @@ -1,14 +1,17 @@
package com.cloud.kicc.system.util;
import com.cloud.kicc.common.core.exception.CheckedException;
import com.cloud.kicc.common.core.util.SpringContextHolderUtil;
import com.cloud.kicc.common.security.util.SecurityUtils;
import com.cloud.kicc.system.config.OpenAiConfigProperties;
import com.knuddels.jtokkit.api.ModelType;
import com.theokanning.openai.client.OpenAiApi;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.service.OpenAiService;
import com.theokanning.openai.utils.TikTokensUtil;
import lombok.experimental.UtilityClass;
import okhttp3.OkHttpClient;
import retrofit2.Retrofit;
@ -56,6 +59,7 @@ public class AiUtil { @@ -56,6 +59,7 @@ public class AiUtil {
ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), openAiConfigProperties.getSystemRule());
messages.add(systemMessage);
messages.addAll(historyMessages);
truncateToTokenLimit(openAiConfigProperties.getModel(), historyMessages);
ChatCompletionRequest request = ChatCompletionRequest.builder()
.model(openAiConfigProperties.getModel())
.messages(messages)
@ -69,4 +73,27 @@ public class AiUtil { @@ -69,4 +73,27 @@ public class AiUtil {
return getOpenAiService().createChatCompletion(request);
}
/**
* <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">...</a>
* 截断支持的上下文消息列表到令牌限制
* @param messages 上下文消息列表
* @param modelName 指定模型
*/
public void truncateToTokenLimit(String modelName, List<ChatMessage> messages) {
ModelType modelType = ModelType.fromName(modelName)
.orElseThrow(() -> new CheckedException(String.format("找不到指定的:%s模型请检查配置!", modelName)));;
int sumTokens = TikTokensUtil.tokens(modelName, messages);
// 确保至少有一条消息(系统消息)保持不变
if (!messages.isEmpty()) {
// 从前向后遍历消息(从第二条开始),直到总 token 数在限制之内
while (messages.size() > 1 && sumTokens > modelType.getMaxContextLength()) {
// 移除列表中的第二条及以后的消息
messages.remove(1);
// 重新计算总 token 数
sumTokens = TikTokensUtil.tokens(modelName, messages);
}
}
}
}

Loading…
Cancel
Save