You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
3.0 KiB

import { configType } from "../config.ts";
import { Channel } from "../database.ts";
class HuggingFaceAI {
name: string;
description: string;
prefix: string;
memory: string[];
tokens: string[];
tokenNum: number;
memoryLen: number;
parameters: {
max_new_tokens: number;
temperature: number;
repetition_penalty: number;
top_k: number;
return_full_text: boolean;
};
model: string;
constructor(name: string, description: string, config: configType, db: Channel) {
this.name = name;
this.description = description;
this.prefix = `The following is a chat with ${this.name}, ${this.description}.\n`;
this.tokens = config.huggingface.tokens;
this.tokenNum = 0;
this.memoryLen = config.huggingface.memoryLen;
this.parameters = {
"max_new_tokens": 50,
"temperature": 0.8,
"repetition_penalty": 1.8,
"top_k": 40,
"return_full_text": false,
};
this.model = config.huggingface.model;
this.memory = db.history ? db.history.map((m) => `${m.name}: "${m.content}"`) : [];
}
//TODO: type our response object
async #query(prompt: string): Promise<Record<string, string>[]> {
const res = await fetch(`https://api-inference.huggingface.co/models/${this.model}`, {
body: JSON.stringify({
inputs: prompt,
parameters: this.parameters,
}),
headers: {
Authorization: `Bearer ${this.tokens[this.tokenNum]}`,
},
method: "POST",
});
this.tokenNum++;
if (this.tokenNum > this.tokens.length - 1) {
this.tokenNum = 0;
}
const jsonRes = await res.json();
if (!jsonRes[0].generated_text) {
console.log(jsonRes);
console.warn("Retrying generation with new token...");
return await this.#query(prompt); // bound to go well
}
return jsonRes;
}
reset() {
this.prefix = `The following is a chat with ${this.name}, ${this.description}.\n`;
this.memory = [];
}
changeShit(opts: {
name?: string;
description?: string;
}) {
this.name = opts.name ?? this.name;
this.description = opts.description ?? this.description;
this.reset();
}
async complete(username: string, message: string) {
console.log(`${username}: "${message}"`);
const ctx = this.memory.slice(this.memoryLen * -2);
const prompt = `${this.prefix}\n${ctx.join("\n")}\n${username}: "${message}"\n${this.name}: "`;
const res = await this.#query(prompt);
const botMsg = res[0].generated_text.split(/"[.?!]?\n/gm)[0];
this.memory.push(`${username}: "${message}"`);
this.memory.push(`${this.name}: "${botMsg}"`);
console.log(`${this.name}: "${botMsg}"`);
return botMsg;
}
}
export default HuggingFaceAI;