You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.0 KiB
93 lines
3.0 KiB
import { configType } from "../config.ts";
|
|
import { Channel } from "../database.ts";
|
|
|
|
class HuggingFaceAI {
|
|
name: string;
|
|
description: string;
|
|
prefix: string;
|
|
memory: string[];
|
|
tokens: string[];
|
|
tokenNum: number;
|
|
memoryLen: number;
|
|
parameters: {
|
|
max_new_tokens: number;
|
|
temperature: number;
|
|
repetition_penalty: number;
|
|
top_k: number;
|
|
return_full_text: boolean;
|
|
};
|
|
model: string;
|
|
|
|
constructor(name: string, description: string, config: configType, db: Channel) {
|
|
this.name = name;
|
|
this.description = description;
|
|
this.prefix = `The following is a chat with ${this.name}, ${this.description}.\n`;
|
|
this.tokens = config.huggingface.tokens;
|
|
this.tokenNum = 0;
|
|
this.memoryLen = config.huggingface.memoryLen;
|
|
this.parameters = {
|
|
"max_new_tokens": 50,
|
|
"temperature": 0.8,
|
|
"repetition_penalty": 1.8,
|
|
"top_k": 40,
|
|
"return_full_text": false,
|
|
};
|
|
this.model = config.huggingface.model;
|
|
this.memory = db.history ? db.history.map((m) => `${m.name}: "${m.content}"`) : [];
|
|
}
|
|
|
|
//TODO: type our response object
|
|
async #query(prompt: string): Promise<Record<string, string>[]> {
|
|
const res = await fetch(`https://api-inference.huggingface.co/models/${this.model}`, {
|
|
body: JSON.stringify({
|
|
inputs: prompt,
|
|
parameters: this.parameters,
|
|
}),
|
|
headers: {
|
|
Authorization: `Bearer ${this.tokens[this.tokenNum]}`,
|
|
},
|
|
method: "POST",
|
|
});
|
|
this.tokenNum++;
|
|
if (this.tokenNum > this.tokens.length - 1) {
|
|
this.tokenNum = 0;
|
|
}
|
|
|
|
const jsonRes = await res.json();
|
|
if (!jsonRes[0].generated_text) {
|
|
console.log(jsonRes);
|
|
console.warn("Retrying generation with new token...");
|
|
return await this.#query(prompt); // bound to go well
|
|
}
|
|
return jsonRes;
|
|
}
|
|
|
|
reset() {
|
|
this.prefix = `The following is a chat with ${this.name}, ${this.description}.\n`;
|
|
this.memory = [];
|
|
}
|
|
|
|
changeShit(opts: {
|
|
name?: string;
|
|
description?: string;
|
|
}) {
|
|
this.name = opts.name ?? this.name;
|
|
this.description = opts.description ?? this.description;
|
|
this.reset();
|
|
}
|
|
|
|
async complete(username: string, message: string) {
|
|
console.log(`${username}: "${message}"`);
|
|
const ctx = this.memory.slice(this.memoryLen * -2);
|
|
const prompt = `${this.prefix}\n${ctx.join("\n")}\n${username}: "${message}"\n${this.name}: "`;
|
|
const res = await this.#query(prompt);
|
|
const botMsg = res[0].generated_text.split(/"[.?!]?\n/gm)[0];
|
|
this.memory.push(`${username}: "${message}"`);
|
|
this.memory.push(`${this.name}: "${botMsg}"`);
|
|
console.log(`${this.name}: "${botMsg}"`);
|
|
return botMsg;
|
|
}
|
|
}
|
|
|
|
export default HuggingFaceAI;
|