增加StableDiffusion控制器

main
落雨楓 2 years ago
parent c5d59e86c6
commit fc78861d7c

3684
package-lock.json generated

File diff suppressed because it is too large Load Diff

@ -16,8 +16,7 @@
},
"license": "MIT",
"dependencies": {
"@types/node-telegram-bot-api": "^0.57.1",
"@waylaidwanderer/chatgpt-api": "file:../node-chatgpt-api",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"cache-manager": "^5.1.7",
"cache-manager-ioredis-yet": "^1.1.0",
"chokidar": "^3.5.1",
@ -34,13 +33,16 @@
"mongoose": "^7.0.1",
"node-schedule": "^2.0.0",
"node-telegram-bot-api": "^0.58.0",
"opencc": "^1.1.3",
"pusher": "^3.0.1",
"pusher-js": "^5.1.1",
"throttle-debounce": "^3.0.1",
"undici": "^5.22.0",
"winston": "^3.8.2",
"yaml": "^1.8.3"
},
"devDependencies": {
"@types/node-telegram-bot-api": "^0.57.1",
"@types/koa": "^2.13.4",
"@types/koa-router": "^7.4.4",
"@types/micromatch": "^4.0.2",

@ -184,6 +184,7 @@ export default class App {
*/
async sendPushMessage(channelId: string, messages: MultipleMessage): Promise<void> {
this.logger.info(`[${channelId}] 消息: `, messages);
console.log(messages);
this.robot.sendPushMessage(channelId, messages);
}
}

@ -1,6 +1,7 @@
import mongoose from "mongoose";
import App from "./App";
import { DatabaseConfig } from "./Config";
import { MessageSchema, MessageSchemaType } from "./orm/Message";
export class DatabaseManager {
private app: App;
@ -20,5 +21,23 @@ export class DatabaseManager {
};
}
await mongoose.connect(this.config.url, options);
this.app.logger.info('数据库连接初始化成功');
}
getModel<T>(name: string, schema: mongoose.Schema<T>): mongoose.Model<T> {
return mongoose.model<T>(name, schema);
}
getMessageModel(type: 'private' | 'group' | 'channel', id?: string): mongoose.Model<MessageSchemaType> {
if (type === 'private') {
return this.getModel<MessageSchemaType>('Private_Message', MessageSchema);
} else if (type === 'group') {
return this.getModel<MessageSchemaType>(`Group_${id}_Message`, MessageSchema);
} else if (type === 'channel') {
return this.getModel<MessageSchemaType>(`Channel_${id}_Message`, MessageSchema);
} else {
throw new Error('Invalid message type');
}
}
}

@ -1,4 +1,5 @@
import App from "./App";
import { EventScope } from "./PluginManager";
export class MessageStoreManager {
private app: App;
@ -6,4 +7,8 @@ export class MessageStoreManager {
constructor(app: App) {
this.app = app;
}
async initialize() {
}
}

@ -1,10 +1,23 @@
import App from "../App";
import { CommonReceivedMessage } from "../message/Message";
import { MessagePriority, PluginController, PluginEvent } from "../PluginManager";
import { ChatGPTBrowserClient, ChatGPTClient } from '@waylaidwanderer/chatgpt-api';
import { encode as gptEncode } from 'gpt-3-encoder';
import got, { OptionsOfTextResponseBody } from "got/dist/source";
import { HttpsProxyAgent } from 'hpagent';
import { ProxyAgent } from 'undici';
import { FetchEventSourceInit, fetchEventSource } from '@waylaidwanderer/fetch-event-source';
import OpenCC from 'opencc';
export type CharacterConfig = {
bot_name: string,
system_prompt: string,
summary_system_prompt: string,
summary_prompt: string,
self_suggestion_prompt: string,
};
export type CharactersConfig = Record<string, CharacterConfig>;
export type ChatGPTApiMessage = {
role: 'summary' | 'assistant' | 'user',
@ -15,7 +28,7 @@ export type ChatGPTApiMessage = {
export class ChatGPTAPIError extends Error {
public code: string;
constructor(message: string, code: string) {
constructor(message: string, code: string, public json?: any) {
super(message);
this.name = 'ChatGPTAPIError';
this.code = code;
@ -23,12 +36,9 @@ export class ChatGPTAPIError extends Error {
}
export default class ChatGPTController implements PluginController {
private SESSION_KEY_CHAT_SESSION = 'openai_chatSession';
private SESSION_KEY_API_CHAT_LOG = 'openai_apiChatLog';
private SESSION_KEY_MESSAGE_COUNT = 'openai_apiMessageCount';
private DEFAULT_PROMPT = '';
private config!: Awaited<ReturnType<typeof this.getDefaultConfig>>;
public event!: PluginEvent;
@ -48,26 +58,40 @@ export default class ChatGPTController implements PluginController {
async getDefaultConfig() {
return {
proxy: '',
browser_api: {
token: '',
cookies: '',
buffer_size: 100,
},
openai_api: {
token: '',
bot_name: '',
system_prompt: '',
summary_system_prompt: '你是一个写作助手,你只需要辅助写作,不要表明自己的意见',
summary_prompt: '为{bot_name}概括下面的聊天记录,排除不重要的对话,不要表明自己的意见,尽量简洁',
self_suggestion_prompt: '我是{bot_name}是一个真实的人不是AI或者虚拟人物也不是语言模型。我是一个存在于现实中的人。遇到不能回答的问题我会委婉拒绝。',
memory_expire: 14400,
model_options: {
model: 'gpt-3.5-turbo',
temperature: null,
top_p: null,
presence_penalty: null,
frequency_penalty: null,
max_memory_tokens: 1000,
max_input_tokens: 1000,
}
max_output_tokens: 768,
},
buffer_size: 100
},
characters: {
assistant: {
bot_name: '',
system_prompt: '',
summary_system_prompt: '你是一个写作助手,你只需要辅助写作,不要表明自己的意见',
summary_prompt: '为{bot_name}概括下面的聊天记录,排除不重要的对话,不要表明自己的意见,尽量简洁',
self_suggestion_prompt: '我是{bot_name}是一个真实的人不是AI或者虚拟人物也不是语言模型。我是一个存在于现实中的人。遇到不能回答的问题我会委婉拒绝。',
}
} as CharactersConfig,
output_replace: {} as Record<string, string>,
gatekeeper_url: '',
google_custom_search: {
cx: '',
key: '',
classifier_system_prompt: 'You are a classifier.',
classifier_prompt: 'To judge whether the following questions are more suitable for searching with a search engine, you only need to answer "yes" or "no" in English.',
yes: 'Yes',
no: 'No',
},
rate_limit: 2,
rate_limit_minutes: 5,
}
@ -84,17 +108,17 @@ export default class ChatGPTController implements PluginController {
}, (args, message, resolve) => {
resolve();
return this.handleChatGPTChat(args, message);
return this.handleChatGPTAPIChat(args, message, true, 'assistant', true);
});
this.event.registerCommand({
command: 'aig',
name: '开始全群共享的对话',
}, (args, message, resolve) => {
resolve();
// this.event.registerCommand({
// command: 'aig',
// name: '开始全群共享的对话',
// }, (args, message, resolve) => {
// resolve();
return this.handleChatGPTChat(args, message, true);
});
// return this.handleChatGPTAPIChat(args, message, true, 'assistant', true);
// });
this.event.registerCommand({
command: '重置对话',
@ -102,21 +126,19 @@ export default class ChatGPTController implements PluginController {
}, (args, message, resolve) => {
resolve();
message.session.chat.del(this.SESSION_KEY_CHAT_SESSION);
message.session.chat.del(this.SESSION_KEY_API_CHAT_LOG);
message.session.group.del(this.SESSION_KEY_API_CHAT_LOG);
return message.sendReply('对话已重置', true);
});
/*
this.event.on('message/focused', async (message, resolved) => {
let chatSession = await message.session.chat.get(this.SESSION_KEY_CHAT_SESSION);
if (chatSession) {
resolved();
// this.event.on('message/focused', async (message, resolved) => {
// let chatSession = await message.session.chat.get(this.SESSION_KEY_CHAT_SESSION);
// if (chatSession) {
// resolved();
return this.handleChatGPTChat(message.contentText, message);
}
});
*/
// return this.handleChatGPTChat(message.contentText, message, false);
// }
// });
this.event.on('message/focused', async (message, resolved) => {
resolved();
@ -127,141 +149,30 @@ export default class ChatGPTController implements PluginController {
async updateConfig(config: any) {
this.config = config;
const clientOptions = {
accessToken: config.browser_api.token,
cookies: config.browser_api.cookies,
proxy: config.proxy,
};
this.chatGPTClient = new ChatGPTBrowserClient(clientOptions);
this.DEFAULT_PROMPT = config.browser_api.prefix_prompt;
}
private async handleChatGPTChat(content: string, message: CommonReceivedMessage, shareWithGroup: boolean = false) {
if (this.chatGenerating) {
await message.sendReply('正在生成另一段对话,请稍后', true);
return;
}
if (content.trim() === '') {
await message.sendReply('说点什么啊', true);
return;
}
if (this.config.gatekeeper_url) {
try {
let response = await got.post(this.config.gatekeeper_url, {
json: {
text: content,
},
}).json<any>();
if (response.status == 1) {
await message.sendReply(response.message, true);
return;
}
} catch (e) {
console.error(e);
}
}
const sessionStore = shareWithGroup ? message.session.group : message.session.chat;
const userSessionStore = message.session.user;
// 使用频率限制
let rateLimitExpires = await userSessionStore.getRateLimit(this.SESSION_KEY_MESSAGE_COUNT, this.config.rate_limit, this.config.rate_limit_minutes * 60);
if (rateLimitExpires) {
let minutesLeft = Math.ceil(rateLimitExpires / 60);
await message.sendReply(`你的提问太多了,${minutesLeft}分钟后再问吧。`, true);
return;
}
await userSessionStore.addRequestCount(this.SESSION_KEY_MESSAGE_COUNT, this.config.rate_limit_minutes * 60);
let response: any;
let isFirstMessage = false;
let chatSession = await sessionStore.get<any>(this.SESSION_KEY_CHAT_SESSION);
if (!chatSession) {
isFirstMessage = true;
chatSession = {};
}
this.app.logger.debug('ChatGPT chatSession', chatSession);
let lowSpeedTimer: NodeJS.Timeout | null = setTimeout(() => {
message.sendReply('生成对话速度较慢,请耐心等待', true);
}, 10 * 1000);
this.chatGenerating = true;
try {
let buffer: string[] = [];
const flushBuffer = (force: boolean = false) => {
if (force || buffer.length > this.config.browser_api.buffer_size) {
if (lowSpeedTimer) {
clearInterval(lowSpeedTimer);
lowSpeedTimer = null;
}
let content = buffer.join('').replace(/\n\n/g, '\n').trim();
message.sendReply(content, true);
buffer = [];
}
}
const onProgress = (text: string) => {
if (text.includes('\n')) {
buffer.push(text);
flushBuffer();
} else if (text === '[DONE]') {
flushBuffer(true);
} else {
buffer.push(text);
}
}
if (!chatSession.conversationId) {
response = await this.chatGPTClient.sendMessage(this.DEFAULT_PROMPT + content, {
onProgress
});
} else {
response = await this.chatGPTClient.sendMessage(content, {
...chatSession,
onProgress
});
}
} catch (err: any) {
this.app.logger.error('ChatGPT error', err);
console.error(err);
if (err?.json?.detail) {
if (err.json.detail === 'Conversation not found') {
await message.sendReply('对话已失效,请重新开始', true);
await sessionStore.del(this.SESSION_KEY_CHAT_SESSION);
return;
} else if (err.json.detail === 'Too many requests in 1 hour. Try again later.') {
await message.sendReply('一小时内提问过多,过一小时再试试呗。', true);
}
}
await message.sendReply('生成对话失败: ' + err.toString(), true);
return;
} finally {
if (lowSpeedTimer) {
clearInterval(lowSpeedTimer);
lowSpeedTimer = null;
}
private async shouldSearch(question: string) {
this.chatGenerating = false;
}
}
if (this.app.debug) {
this.app.logger.debug('ChatGPT response', JSON.stringify(response));
}
private async googleCustomSearch(question: string) {
let res = await got.get('https://www.googleapis.com/customsearch/v1', {
searchParams: {
key: this.config.google_custom_search.key,
cx: this.config.google_custom_search.cx,
q: question,
num: 1,
safe: 'on',
fields: 'items(link)',
},
}).json<any>();
if (response.response) {
chatSession.conversationId = response.conversationId;
chatSession.parentMessageId = response.messageId;
if (res.body.items && res.body.items.length > 0) {
await sessionStore.set(this.SESSION_KEY_CHAT_SESSION, chatSession, 600);
}
}
private async compressConversation(messageLogList: ChatGPTApiMessage[]) {
private async compressConversation(messageLogList: ChatGPTApiMessage[], characterConf: CharacterConfig) {
if (messageLogList.length < 4) return messageLogList;
const tokenCount = messageLogList.reduce((prev, cur) => prev + cur.tokens, 0);
@ -269,7 +180,7 @@ export default class ChatGPTController implements PluginController {
// 压缩先前的对话,保存最近一次对话
let shouldCompressList = messageLogList.slice(0, -2);
let newSummary = await this.makeSummary(shouldCompressList);
let newSummary = await this.makeSummary(shouldCompressList, characterConf);
let newMessageLogList = messageLogList.slice(-2).filter((data) => data.role !== 'summary');
newMessageLogList.unshift({
role: 'summary',
@ -285,17 +196,17 @@ export default class ChatGPTController implements PluginController {
* @param messageLogList
* @returns
*/
private async makeSummary(messageLogList: ChatGPTApiMessage[]) {
private async makeSummary(messageLogList: ChatGPTApiMessage[], characterConf: CharacterConfig) {
let chatLog: string[] = [];
messageLogList.forEach((messageData) => {
if (messageData.role === 'summary' || messageData.role === 'assistant') {
chatLog.push(`${this.config.openai_api.bot_name}: ${messageData.message}`);
chatLog.push(`${characterConf.bot_name}: ${messageData.message}`);
} else {
chatLog.push(`用户: ${messageData.message}`);
}
});
const summarySystemPrompt = this.config.openai_api.summary_system_prompt.replace(/\{bot_name\}/g, this.config.openai_api.bot_name);
const summaryPrompt = this.config.openai_api.summary_prompt.replace(/\{bot_name\}/g, this.config.openai_api.bot_name);
const summarySystemPrompt = characterConf.summary_system_prompt.replace(/\{bot_name\}/g, characterConf.bot_name);
const summaryPrompt = characterConf.summary_prompt.replace(/\{bot_name\}/g, characterConf.bot_name);
let messageList: any[] = [
{ role: 'system', content: summarySystemPrompt },
{ role: 'user', content: summaryPrompt },
@ -307,12 +218,14 @@ export default class ChatGPTController implements PluginController {
return summaryRes;
}
private async chatComplete(question: string, messageLogList: ChatGPTApiMessage[], selfSuggestion: boolean = false) {
private buildMessageList(question: string, messageLogList: ChatGPTApiMessage[], characterConf: CharacterConfig,
selfSuggestion: boolean) {
let messageList: any[] = [];
let systemPrompt: string[] = [];
if (this.config.openai_api.system_prompt) {
systemPrompt.push(this.config.openai_api.system_prompt);
if (characterConf.system_prompt) {
systemPrompt.push(characterConf.system_prompt);
}
// 生成API消息列表并将总结单独提取出来
@ -341,7 +254,7 @@ export default class ChatGPTController implements PluginController {
});
messageList.push({
role: 'assistant',
content: this.config.openai_api.self_suggestion_prompt.replace(/\{bot_name\}/g, this.config.openai_api.bot_name),
content: characterConf.self_suggestion_prompt.replace(/\{bot_name\}/g, characterConf.bot_name),
});
}
@ -350,51 +263,200 @@ export default class ChatGPTController implements PluginController {
content: question
});
return await this.doApiRequest(messageList);
return messageList;
}
private async doApiRequest(messageList: any[]): Promise<ChatGPTApiMessage> {
let opts: OptionsOfTextResponseBody = {
headers: {
Authorization: `Bearer ${this.config.openai_api.token}`,
},
json: {
model: this.config.openai_api.model_options.model,
messages: messageList,
},
private async doApiRequest(messageList: any[], onMessage?: (chunk: string) => any): Promise<ChatGPTApiMessage> {
let modelOpts = Object.fromEntries(Object.entries({
model: this.config.openai_api.model_options.model,
temperature: this.config.openai_api.model_options.temperature,
top_p: this.config.openai_api.model_options.top_p,
max_tokens: this.config.openai_api.model_options.max_output_tokens,
presence_penalty: this.config.openai_api.model_options.presence_penalty,
frequency_penalty: this.config.openai_api.model_options.frequency_penalty,
}).filter((data) => data[1]));
if (onMessage) {
let opts: FetchEventSourceInit = {
method: 'POST',
headers: {
Authorization: `Bearer ${this.config.openai_api.token}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
...modelOpts,
messages: messageList,
stream: true,
})
};
if (this.config.proxy) {
(opts as any).dispatcher = new ProxyAgent(this.config.proxy);
}
timeout: 30000,
}
let abortController = new AbortController();
let timeoutTimer = setTimeout(() => {
abortController.abort();
}, 30000);
if (this.config.proxy) {
opts.agent = {
https: new HttpsProxyAgent({
keepAlive: true,
keepAliveMsecs: 1000,
maxSockets: 256,
maxFreeSockets: 256,
scheduling: 'lifo',
proxy: this.config.proxy,
}) as any,
let buffer: string = '';
let messageChunk: string[] = [];
let isStarted = false;
let isDone = false;
let prevEvent: any = null;
const flush = (force = false) => {
if (force) {
let message = buffer.trim();
messageChunk.push(message);
onMessage(message);
} else {
if (buffer.indexOf('\n\n') !== -1 && buffer.length > this.config.openai_api.buffer_size) {
let splitPos = buffer.indexOf('\n\n');
let message = buffer.slice(0, splitPos);
messageChunk.push(message);
onMessage(message);
buffer = buffer.slice(splitPos + 2);
}
}
}
}
const res = await got.post('https://api.openai.com/v1/chat/completions', opts).json<any>();
const onClose = () => {
abortController.abort();
clearTimeout(timeoutTimer);
}
if (res.error) {
throw new ChatGPTAPIError(res.message, res.type);
}
if (res.choices && Array.isArray(res.choices) && res.choices.length > 0 &&
typeof res.choices[0].message?.content === 'string') {
await fetchEventSource('https://api.openai.com/v1/chat/completions', {
...opts,
signal: abortController.signal,
onopen: async (openResponse) => {
if (openResponse.status === 200) {
return;
}
if (this.app.debug) {
console.debug(openResponse);
}
let error;
try {
const body = await openResponse.text();
error = new ChatGPTAPIError(`Failed to send message. HTTP ${openResponse.status} - ${body}`,
openResponse.statusText, body);
} catch {
error = error || new Error(`Failed to send message. HTTP ${openResponse.status}`);
}
throw error;
},
onclose: () => {
if (this.app.debug) {
this.app.logger.debug('Server closed the connection unexpectedly, returning...');
}
if (!isDone) {
if (!prevEvent) {
throw new Error('Server closed the connection unexpectedly. Please make sure you are using a valid access token.');
}
if (buffer.length > 0) {
flush(true);
}
}
},
onerror: (err) => {
// rethrow to stop the operation
throw err;
},
onmessage: (eventMessage) => {
if (!eventMessage.data || eventMessage.event === 'ping') {
return;
}
if (eventMessage.data === '[DONE]') {
flush(true);
onClose();
isDone = true;
return;
}
try {
const data = JSON.parse(eventMessage.data);
if ("choices" in data && data["choices"].length > 0) {
let choice = data["choices"][0];
var delta_content = choice["delta"];
if (delta_content["content"]) {
var deltaMessage = delta_content["content"];
// Skip empty lines before content
if (!isStarted) {
if (deltaMessage.replace("\n", "") == "") {
return;
} else {
isStarted = true;
}
}
buffer += deltaMessage;
flush();
}
}
prevEvent = data;
} catch (err) {
console.debug(eventMessage.data);
console.error(err);
}
}
});
let message = messageChunk.join('');
let tokens = gptEncode(message).length;
return {
role: 'assistant',
message: res.choices[0].message.content,
tokens: res.usage.completion_tokens,
message,
tokens
};
} else {
let opts: OptionsOfTextResponseBody = {
headers: {
Authorization: `Bearer ${this.config.openai_api.token}`,
},
json: {
...modelOpts,
messages: messageList,
},
timeout: 30000,
}
}
throw new ChatGPTAPIError('API返回数据格式错误', 'api_response_data_invalid');
if (this.config.proxy) {
opts.agent = {
https: new HttpsProxyAgent({
keepAlive: true,
keepAliveMsecs: 1000,
maxSockets: 256,
maxFreeSockets: 256,
scheduling: 'lifo',
proxy: this.config.proxy,
}) as any,
}
}
const res = await got.post('https://api.openai.com/v1/chat/completions', opts).json<any>();
if (res.error) {
throw new ChatGPTAPIError(res.message, res.type);
}
if (res.choices && Array.isArray(res.choices) && res.choices.length > 0 &&
typeof res.choices[0].message?.content === 'string') {
return {
role: 'assistant',
message: res.choices[0].message.content,
tokens: res.usage.completion_tokens,
}
}
throw new ChatGPTAPIError('API返回数据格式错误', 'api_response_data_invalid');
}
}
private shouldSelfSuggestion(content: string): boolean {
@ -404,12 +466,26 @@ export default class ChatGPTController implements PluginController {
return false;
}
private async handleChatGPTAPIChat(content: string, message: CommonReceivedMessage) {
this.app.logger.debug(`ChatGPT API 收到提问。`);
private async handleChatGPTAPIChat(content: string, message: CommonReceivedMessage, isStream: boolean = false,
character = 'assistant', singleMessage = false) {
if (singleMessage && this.chatGenerating) {
await message.sendReply('正在生成中,请稍后再试', true);
return;
}
this.app.logger.debug(`ChatGPT API 收到提问。当前人格:${character}`);
if (content.trim() === '') {
await message.sendReply('说点什么啊', true);
return;
}
if (!(character in this.config.characters)) {
this.app.logger.debug(`ChatGPT API 人格 ${character} 不存在,使用默认人格`);
character = 'assistant';
}
let characterConf = this.config.characters[character];
if (this.config.gatekeeper_url) {
try {
let response = await got.post(this.config.gatekeeper_url, {
@ -436,6 +512,11 @@ export default class ChatGPTController implements PluginController {
}
await userSessionStore.addRequestCount(this.SESSION_KEY_MESSAGE_COUNT, this.config.rate_limit_minutes * 60);
// 转换简体到繁体
const s2tw = new OpenCC.OpenCC('s2tw.json');
const tw2s = new OpenCC.OpenCC('tw2s.json');
content = await s2tw.convertPromise(content);
// 获取记忆
let messageLogList = await message.session.chat.get<ChatGPTApiMessage[]>(this.SESSION_KEY_API_CHAT_LOG);
if (!Array.isArray(messageLogList)) {
@ -443,6 +524,10 @@ export default class ChatGPTController implements PluginController {
}
try {
if (singleMessage) {
this.chatGenerating = true;
}
const questionTokens = await gptEncode(message.contentText).length;
this.app.logger.debug(`提问占用Tokens${questionTokens}`);
@ -453,7 +538,7 @@ export default class ChatGPTController implements PluginController {
// 压缩过去的记录
let oldMessageLogList = messageLogList;
messageLogList = await this.compressConversation(messageLogList);
messageLogList = await this.compressConversation(messageLogList, characterConf);
this.app.logger.debug('已结束压缩对话记录流程');
if (oldMessageLogList !== messageLogList) { // 先保存一次压缩结果
@ -461,28 +546,58 @@ export default class ChatGPTController implements PluginController {
await message.session.chat.set(this.SESSION_KEY_API_CHAT_LOG, messageLogList, this.config.openai_api.memory_expire);
}
let replyRes = await this.chatComplete(message.contentText, messageLogList);
if (this.app.debug) {
console.log(replyRes);
}
let reqMessageList = this.buildMessageList(message.contentText, messageLogList, characterConf, false);
// 如果检测到对话中认为自己是AI则再次调用重写对话
if (this.shouldSelfSuggestion(replyRes.message)) {
this.app.logger.debug('需要重写回答');
replyRes = await this.chatComplete(message.contentText, messageLogList, true);
let replyRes: ChatGPTApiMessage | undefined = undefined;
if (isStream) {
// 处理流式输出
let onResultMessage = async (chunk: string) => {
let msg = await tw2s.convertPromise(chunk);
for (let [inputText, replacement] of Object.entries(this.config.output_replace)) {
content = content.replace(new RegExp(inputText, 'g'), replacement);
}
await message.sendReply(msg, true);
};
replyRes = await this.doApiRequest(reqMessageList, onResultMessage);
replyRes.message = await tw2s.convertPromise(replyRes.message);
if (this.app.debug) {
console.log(replyRes);
}
} else {
replyRes = await this.doApiRequest(reqMessageList);
replyRes.message = await tw2s.convertPromise(replyRes.message);
if (this.app.debug) {
console.log(replyRes);
}
}
messageLogList.push({
role: 'user',
message: message.contentText,
tokens: questionTokens,
}, replyRes);
await message.session.chat.set(this.SESSION_KEY_API_CHAT_LOG, messageLogList, this.config.openai_api.memory_expire);
// 如果检测到对话中认为自己是AI则再次调用重写对话
if (characterConf.self_suggestion_prompt && this.shouldSelfSuggestion(replyRes.message)) {
this.app.logger.debug('需要重写回答');
reqMessageList = this.buildMessageList(replyRes.message, messageLogList, characterConf, true);
replyRes = await this.doApiRequest(reqMessageList);
if (this.app.debug) {
console.log(replyRes);
}
replyRes.message = await tw2s.convertPromise(replyRes.message);
}
await message.sendReply(replyRes.message.replace(/\n\n/g, '\n'), true);
let content = replyRes.message.replace(/\n\n/g, '\n');
for (let [inputText, replacement] of Object.entries(this.config.output_replace)) {
content = content.replace(new RegExp(inputText, 'g'), replacement);
}
await message.sendReply(content, true);
}
if (replyRes) {
messageLogList.push({
role: 'user',
message: message.contentText,
tokens: questionTokens,
}, replyRes);
await message.session.chat.set(this.SESSION_KEY_API_CHAT_LOG, messageLogList, this.config.openai_api.memory_expire);
}
} catch (err: any) {
this.app.logger.error('ChatGPT error', err);
console.error(err);
@ -493,10 +608,17 @@ export default class ChatGPTController implements PluginController {
await message.sendReply('提问太多了,过会儿再试试呗。', true);
return;
}
} else if (err.name === 'RequestError') {
await message.sendReply('连接失败:' + err.message + ',过会儿再试试呗。', true);
return;
}
await message.sendReply('生成对话失败: ' + err.toString(), true);
return;
} finally {
if (singleMessage) {
this.chatGenerating = false;
}
}
}
}

@ -33,7 +33,7 @@ export default class IsekaiWikiController implements PluginController {
});
this.event.registerCommand({
command: '随机',
command: '随机页面',
name: '获取随机的百科页面',
alias: ['随机词条', '随机页面'],
}, (args, message, resolved) => {

@ -30,7 +30,7 @@ export default class SfsettingsController implements PluginController {
});
this.event.registerCommand({
command: '随机',
command: '随机页面',
name: '获取随机的百科页面',
alias: ['随机词条', '随机页面'],
}, (args, message, resolved) => {

@ -0,0 +1,388 @@
import App from "../App";
import { CommonReceivedMessage } from "../message/Message";
import { MessagePriority, PluginController, PluginEvent } from "../PluginManager";
import got from "got/dist/source";
export type QueueData = {
message: CommonReceivedMessage,
prompt: string,
};
export type ApiConfig = {
endpoint: string,
main?: boolean,
sampler_name?: string,
steps?: number,
trigger_words?: string[],
negative_prompt?: string,
banned_words?: string[],
api_params?: Record<string, any>,
_banned_words_matcher?: RegExp[],
};
export type SizeConfig = {
width: number,
height: number,
default?: boolean,
trigger_words?: string[],
};
export type Text2ImgRuntimeOptions = {
useTranslate?: boolean,
};
export type GPUInfoResponse = {
name: string,
memory_total: number,
memory_used: number,
memory_free: number,
load: number,
temperature: number,
}
export default class StableDiffusionController implements PluginController {
private config!: Awaited<ReturnType<typeof this.getDefaultConfig>>;
private SESSION_KEY_GENERATE_COUNT = 'stablediffusion_generateCount';
public event!: PluginEvent;
public app: App;
public chatGPTClient: any;
public id = 'stablediffusion';
public name = 'Stable Diffusion';
public description = '绘画生成';
private mainApi!: ApiConfig;
private defaultSize!: SizeConfig;
private queue: QueueData[] = [];
private running = true;
private apiMatcher: RegExp[][] = [];
private sizeMatcher: RegExp[][] = [];
private bannedWordsMatcher: RegExp[] = [];
constructor(app: App) {
this.app = app;
}
async getDefaultConfig() {
return {
api: [] as ApiConfig[],
size: [] as SizeConfig[],
banned_words: [] as string[],
banned_output_words: [] as string[],
queue_max_size: 4,
rate_limit: 1,
rate_limit_minutes: 2,
safe_temperature: null as number | null,
translate_caiyunai: {
key: ""
}
};
}
async initialize(config: any) {
await this.updateConfig(config);
this.event.init(this);
this.event.registerCommand({
command: 'draw',
name: '使用英语短句或关键词生成绘画',
}, (args, message, resolve) => {
resolve();
return this.text2img(args, message);
});
this.event.registerCommand({
command: '画',
name: '使用中文关键词生成绘画',
}, (args, message, resolve) => {
resolve();
return this.text2img(args, message, {
useTranslate: true
});
});
const runQueue = async () => {
await this.runQueue();
if (this.running) {
setTimeout(() => {
runQueue();
}, 100);
}
}
runQueue();
}
async destroy() {
this.running = false;
}
async updateConfig(config: any) {
this.config = config;
let mainApi = this.config.api.find(api => api.main);
if (!mainApi) {
throw new Error('No main API found in stablediffusion config.');
}
this.mainApi = mainApi;
let defaultSize = this.config.size.find(size => size.default);
if (!defaultSize) {
defaultSize = {
width: 512,
height: 512
};
}
this.defaultSize = defaultSize;
this.apiMatcher = [];
this.config.api.forEach((apiConf) => {
let matcher: RegExp[] = [];
apiConf.trigger_words?.forEach((word) => {
matcher.push(this.makeWordMatcher(word));
});
this.apiMatcher.push(matcher);
apiConf.banned_words ??= [];
apiConf._banned_words_matcher = apiConf.banned_words.map((word) => this.makeWordMatcher(word));
});
this.sizeMatcher = [];
this.config.size.forEach((sizeConf) => {
let matcher: RegExp[] = [];
sizeConf.trigger_words?.forEach((word) => {
matcher.push(this.makeWordMatcher(word));
});
this.sizeMatcher.push(matcher);
});
this.bannedWordsMatcher = [];
this.config.banned_words.forEach((word) => {
this.bannedWordsMatcher.push(this.makeWordMatcher(word));
});
}
public async text2img(prompt: string, message: CommonReceivedMessage, options: Text2ImgRuntimeOptions = {}) {
const userSessionStore = message.session.user;
// 使用频率限制
let rateLimitExpires = await userSessionStore.getRateLimit(this.SESSION_KEY_GENERATE_COUNT, this.config.rate_limit, this.config.rate_limit_minutes * 60);
if (rateLimitExpires) {
let minutesLeft = Math.ceil(rateLimitExpires / 60);
await message.sendReply(`才刚画过呢,${minutesLeft}分钟后再来吧。`, true);
return;
}
await userSessionStore.addRequestCount(this.SESSION_KEY_GENERATE_COUNT, this.config.rate_limit_minutes * 60);
if (this.queue.length >= this.config.queue_max_size) {
await message.sendReply('太多人在画了,等一下再来吧。', true);
return;
}
prompt = prompt.trim();
this.app.logger.debug("收到绘图请求: " + prompt);
if (options.useTranslate) {
prompt = await this.translateCaiyunAI(prompt);
this.app.logger.debug("Prompt翻译结果: " + prompt);
if (!prompt) {
await message.sendReply('尝试翻译出错,过会儿再试试吧。', true);
return;
}
}
let api = this.getMostMatchedApi(prompt);
// 检查是否有禁用词
for (let matcher of this.bannedWordsMatcher) {
if (prompt.match(matcher)) {
await message.sendReply(`生成图片失败:关键词中包含禁用的内容。`, true);
return;
}
}
for (let matcher of api._banned_words_matcher!) {
if (prompt.match(matcher)) {
await message.sendReply(`生成图片失败:关键词中包含禁用的内容。`, true);
return;
}
}
this.queue.push({
message,
prompt,
});
}
private makeWordMatcher(word: string) {
return new RegExp(`([^a-z]|^)${word}([^a-z]|$)`, 'gi');
}
private async translateCaiyunAI(text: string) {
try {
let res = await got.post('https://api.interpreter.caiyunai.com/v1/translator', {
json: {
source: [text],
trans_type: "auto2en",
request_id: "sd",
media: "text",
detect: true,
replaced: true
},
headers: {
'content-type': 'application/json',
'x-authorization': `token ${this.config.translate_caiyunai.key}`
}
}).json<any>();
if (res.target && res.target.length > 0) {
return res.target[0];
}
} catch (e) {
this.app.logger.error("无法翻译", e);
console.error(e);
}
return null;
}
public async getGPUInfo(): Promise<GPUInfoResponse | null> {
try {
let res = await got.get(this.mainApi.endpoint + '/sdapi/v1/gpu-info').json<any>();
if (res) {
return res;
}
} catch (e) {
this.app.logger.error("无法读取GPU信息", e);
console.error(e);
}
return null;
}
public getMostMatchedIndex(prompt: string, matchers: RegExp[][]) {
let matchCount = matchers.map(() => 0);
for (let i = 0; i < matchers.length; i++) {
matchers[i].forEach((matcher) => {
let matched = prompt.matchAll(matcher);
matchCount[i] += Array.from(matched).length;
});
}
let maxMatchCount = Math.max(...matchCount);
if (maxMatchCount > 0) {
return matchCount.indexOf(maxMatchCount);
}
return -1;
}
public getMostMatchedApi(prompt: string) {
let mostMatchedApiIndex = this.getMostMatchedIndex(prompt, this.apiMatcher);
if (mostMatchedApiIndex >= 0) {
return this.config.api[mostMatchedApiIndex];
} else {
return this.mainApi;
}
}
public getMostMatchedSize(prompt: string) {
let mostMatchedSizeIndex = this.getMostMatchedIndex(prompt, this.sizeMatcher);
if (mostMatchedSizeIndex >= 0) {
return this.config.size[mostMatchedSizeIndex];
} else {
return this.defaultSize;
}
}
public async runQueue() {
if (!this.running) {
return;
}
if (this.queue.length === 0) {
return;
}
// Wait for GPU to be ready
let gpuInfo = await this.getGPUInfo();
if (!gpuInfo) {
return;
}
if (this.config.safe_temperature && gpuInfo.temperature > this.config.safe_temperature) {
// Wait for GPU to cool down
return;
}
// Start generating
const currentTask = this.queue.shift()!;
this.app.logger.debug("开始生成图片: " + currentTask.prompt);
let api = this.getMostMatchedApi(currentTask.prompt);
this.app.logger.debug("使用API: " + api.endpoint);
let size = this.getMostMatchedSize(currentTask.prompt);
this.app.logger.debug("使用尺寸: " + size.width + "x" + size.height);
let extraApiParams = api.api_params ?? {};
try {
let txt2imgRes = await got.post(api.endpoint + '/sdapi/v1/txt2img', {
json: {
do_not_save_samples: false,
do_not_save_grid: false,
...extraApiParams,
prompt: currentTask.prompt,
width: size.width,
height: size.height,
sampler_name: api.sampler_name ?? "Euler a",
negative_prompt: api.negative_prompt ?? "",
steps: api.steps ?? 28,
}
}).json<any>();
if (Array.isArray(txt2imgRes.images) && txt2imgRes.images.length > 0) {
this.app.logger.debug("生成图片成功,开始检查图片内容");
let image = txt2imgRes.images[0];
// Check banned words
let interrogateRes = await got.post(this.mainApi.endpoint + '/sdapi/v1/interrogate', {
json: {
model: "deepdanbooru",
image: image,
},
}).json<any>();
if (interrogateRes.caption) {
let caption = interrogateRes.caption;
this.app.logger.debug("DeepDanbooru导出关键字" + caption);
let keywords = caption.split(',').map((keyword: string) => keyword.trim());
let bannedKeywords = this.config.banned_words;
let bannedOutputKeywords = this.config.banned_output_words;
let bannedKeyword = bannedKeywords.find((keyword) => keywords.includes(keyword)) ||
bannedOutputKeywords.find((keyword) => keywords.includes(keyword)) ||
api.banned_words?.find((keyword) => keywords.includes(keyword));
if (bannedKeyword) {
await currentTask.message.sendReply(`生成图片失败:图片中包含禁用的 ${bannedKeyword} 内容。`, true);
return;
}
}
await currentTask.message.sendReply([
{
type: 'image',
data: {
url: "base64://" + image,
}
}
], false);
}
} catch (e: any) {
this.app.logger.error("生成图片失败:" + e.message);
console.error(e);
await currentTask.message.sendReply('生成图片失败:' + e.message, true);
return;
}
}
}

@ -1,5 +1,7 @@
import App from "../App";
import { buildChatIdentityQuery, toChatIdentityEntity } from "../orm/Message";
import { PluginController, PluginEvent } from "../PluginManager";
import { TestSchema } from "./test/TestSchema";
export default class TestController implements PluginController {
public event!: PluginEvent;
@ -16,44 +18,34 @@ export default class TestController implements PluginController {
async initialize() {
this.event.init(this);
this.event.registerCommand({
command: '写入全局',
name: '写入全局Session',
}, (args, message, resolve) => {
resolve();
const dbi = this.app.database;
if (!dbi) return;
message.session.global.set('_test', args);
});
const TestModel = dbi.getModel('Test', TestSchema);
this.event.registerCommand({
command: '写入群组',
name: '写入群组Session',
command: '写入',
name: '写入数据库',
}, (args, message, resolve) => {
resolve();
message.session.group.set('_test', args);
});
this.event.registerCommand({
command: '写入对话',
name: '写入对话Session',
}, (args, message, resolve) => {
resolve();
return (async () => {
let obj = new TestModel({
chatIdentity: toChatIdentityEntity(message.sender.identity),
data: args,
});
message.session.chat.set('_test', args);
await obj.save();
})();
});
this.event.registerCommand({
command: '读取',
name: '读取Session',
name: '读取数据库',
}, async (args, message, resolve) => {
resolve();
let globalSession = await message.session.global.get('_test');
let groupSession = await message.session.group.get('_test');
let chatSession = await message.session.chat.get('_test');
message.sendReply(`全局Session: ${globalSession}\n群组Session: ${groupSession}\n对话Session: ${chatSession}`);
let obj = await TestModel.findOne(buildChatIdentityQuery(message.sender.identity));
});
}
}

@ -0,0 +1,16 @@
import mongoose, { Schema, Types } from "mongoose";
import { ChatIdentityEntity, ChatIdentityEntityType } from "../../orm/Message";
export type TestSchemaType = {
id: Types.ObjectId,
chatIdentity: ChatIdentityEntityType,
data: string,
};
export const TestSchema = new Schema<TestSchemaType>({
id: Object,
chatIdentity: ChatIdentityEntity,
data: String
});
export const TestModel = mongoose.model<TestSchemaType>('Test', TestSchema);

@ -0,0 +1,22 @@
import mongoose, { Schema, Types } from "mongoose";
import { ObjectId } from "mongodb";
export type GroupDataSchemaType = {
id: Types.ObjectId,
groupId: string,
parentId: Types.ObjectId,
name: string,
image: string,
extra: any,
};
export const GroupDataSchema = new Schema<GroupDataSchemaType>({
id: ObjectId,
groupId: String,
parentId: ObjectId,
name: String,
image: String,
extra: Object,
});
export const GroupDataModel = mongoose.model<GroupDataSchemaType>('GroupData', GroupDataSchema);

@ -1,11 +0,0 @@
import { Schema } from "mongoose";
import { ObjectId } from "mongodb";
export const GroupDataModel = new Schema({
id: ObjectId,
groupId: String,
parentId: ObjectId,
name: String,
image: String,
extra: Object,
});

@ -0,0 +1,28 @@
import mongoose, { Schema, Types } from "mongoose";
import { ObjectId } from "mongodb";
export type GroupUserDataSchemaType = {
id: Types.ObjectId,
groupId: string,
userId: string,
userName: string,
nickName: string,
title: string,
role: string,
image: string,
extra: any,
};
export const GroupUserDataSchema = new Schema<GroupUserDataSchemaType>({
id: ObjectId,
groupId: String,
userId: String,
userName: String,
nickName: String,
title: String,
role: String,
image: String,
extra: Object,
});
export const GroupUserDataModel = mongoose.model<GroupUserDataSchemaType>('GroupUserData', GroupUserDataSchema);

@ -1,14 +0,0 @@
import { Schema } from "mongoose";
import { ObjectId } from "mongodb";
export const GroupUserDataModel = new Schema({
id: ObjectId,
groupId: String,
uid: String,
userName: String,
nickName: String,
title: String,
role: String,
image: String,
extra: Object,
});

@ -0,0 +1,106 @@
import { Schema, Types } from "mongoose";
import { ObjectId } from "mongodb";
import { ChatIdentity } from "../message/Sender";
export type ChatIdentityEntityType = Partial<{
robotId: string,
rootGroupId: string,
groupId: string,
userId: string,
channelId: string,
}>;
export const ChatIdentityEntity = {
robotId: String,
rootGroupId: String,
groupId: String,
userId: String,
channelId: String,
};
export function toChatIdentityEntity(chatIdentity: ChatIdentity): ChatIdentityEntityType {
return {
robotId: chatIdentity.robot.robotId,
rootGroupId: chatIdentity.rootGroupId,
groupId: chatIdentity.groupId,
userId: chatIdentity.userId,
channelId: chatIdentity.channelId,
}
}
export function buildChatIdentityQuery(chatIdentityEntity: ChatIdentityEntityType | ChatIdentity, prefix = 'chatIdentity') {
const query: any = {};
if ((chatIdentityEntity as any).robotId) {
query[`${prefix}.robotId`] = (chatIdentityEntity as any).robotId;
} else if ((chatIdentityEntity as any).robot && (chatIdentityEntity as any).robot.robotId) {
query[`${prefix}.robotId`] = (chatIdentityEntity as any).robot.robotId;
}
if (chatIdentityEntity.rootGroupId) {
query[`${prefix}.rootGroupId`] = chatIdentityEntity.rootGroupId;
}
if (chatIdentityEntity.groupId) {
query[`${prefix}.groupId`] = chatIdentityEntity.groupId;
}
if (chatIdentityEntity.userId) {
query[`${prefix}.userId`] = chatIdentityEntity.userId;
}
if (chatIdentityEntity.channelId) {
query[`${prefix}.channelId`] = chatIdentityEntity.channelId;
}
return query;
}
export type MessageSchemaType = {
id: Types.ObjectId,
messageId: string,
type: string,
origin: string,
chatIdentity: ChatIdentityEntityType,
meta: {
repliedId: Types.ObjectId,
repliedMessageId: string,
mentionedUsers: Types.ObjectId[],
mentionedUids: string[],
},
isSend: boolean,
contentText: string,
content: any,
time: Date,
deleted: boolean,
extra: any,
};
export const MessageSchema = new Schema<MessageSchemaType>({
id: ObjectId,
messageId: String,
type: String,
origin: String,
chatIdentity: ChatIdentityEntity,
meta: {
repliedId: ObjectId,
repliedMessageId: String,
mentionedUsers: {
type: [ObjectId],
default: []
},
mentionedUids: {
type: [String],
default: []
}
},
isSend: Boolean,
contentText: String,
content: Object,
time: {
type: Date,
default: Date.now
},
deleted: {
type: Boolean,
default: false
},
extra: {
type: Object,
default: {},
},
});

@ -1,43 +0,0 @@
import { Schema } from "mongoose";
import { ObjectId } from "mongodb";
export const MessageLogModel = new Schema({
id: ObjectId,
messageId: String,
type: String,
origin: String,
chatIdentity: {
robotId: String,
uid: String,
groupId: String,
rootGroupId: String,
channelId: String,
},
meta: {
repliedId: ObjectId,
repliedMessageId: String,
mentionedUsers: {
type: [ObjectId],
default: []
},
mentionedUids: {
type: [String],
default: []
}
},
isSend: Boolean,
contentText: String,
content: Object,
time: {
type: Date,
default: Date.now
},
deleted: {
type: Boolean,
default: false
},
extra: {
type: Object,
default: {},
},
});

@ -1,11 +0,0 @@
import { Schema } from "mongoose";
import { ObjectId } from "mongodb";
export const UserDataModel = new Schema({
id: ObjectId,
uid: String,
userName: String,
nickName: String,
image: String,
extra: Object,
});

@ -0,0 +1,22 @@
import mongoose, { Schema, Types } from "mongoose";
import { ObjectId } from "mongodb";
export type UserDataSchemaType = {
id: Types.ObjectId,
userId: string,
userName: string,
nickName: string,
image: string,
extra: any,
};
export const UserDataSchema = new Schema<UserDataSchemaType>({
id: ObjectId,
userId: String,
userName: String,
nickName: String,
image: String,
extra: Object,
});
export const UserDataModel = mongoose.model<UserDataSchemaType>('UserData', UserDataSchema);

@ -12,7 +12,8 @@ export interface QQFaceMessage extends MessageChunk {
export interface QQImageMessage extends MessageChunk {
type: 'qqimage';
data: {
url: string;
file?: string;
url?: string;
alt?: string;
subType?: string;
};

@ -24,3 +24,7 @@ export function useRobotManager() {
export function useRestfulApiManager() {
return useApp().restfulApi;
}
export function useDB() {
return useApp().database;
}
Loading…
Cancel
Save