You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
2.6 KiB
104 lines
2.6 KiB
use std::collections::HashMap;
|
|
|
|
use anyhow::{Context, Result};
|
|
use once_cell::sync::Lazy;
|
|
use reqwest::Client;
|
|
use ron::Value as RonValue;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
|
|
static CLIENT: Lazy<Client> = Lazy::new(Client::new);
|
|
|
|
#[derive(Deserialize, Clone, PartialEq, Debug)]
|
|
pub struct AI {
|
|
pub name: String,
|
|
pub url: String,
|
|
pub params: HashMap<String, ParamType>,
|
|
pub retriever: Vec<RonValue>,
|
|
}
|
|
|
|
#[derive(Deserialize, Clone, PartialEq, Debug)]
|
|
pub enum ParamType {
|
|
String(String),
|
|
Prompt(Option<String>),
|
|
Temperature,
|
|
MaxTokens,
|
|
}
|
|
|
|
#[derive(Clone, Serialize, Deserialize)]
|
|
pub struct Params {
|
|
pub prompt: String,
|
|
pub temperature: f32,
|
|
pub max_tokens: i32,
|
|
}
|
|
|
|
impl Default for Params {
|
|
fn default() -> Self {
|
|
Self {
|
|
prompt: Default::default(),
|
|
temperature: 1.0,
|
|
max_tokens: 1000,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(PartialEq)]
|
|
pub enum ResponseType {
|
|
Error,
|
|
Success,
|
|
}
|
|
|
|
pub struct Response {
|
|
pub response_type: ResponseType,
|
|
pub output: String,
|
|
}
|
|
|
|
pub async fn generate(ai: AI, params: Params) -> Result<String> {
|
|
let mut json_params: HashMap<String, Value> = HashMap::new();
|
|
|
|
// yeah this is cursed but i could not think of a better way to do it
|
|
for (key, val) in ai.params {
|
|
match val {
|
|
ParamType::String(string) => {
|
|
json_params.insert(key, Value::from(string));
|
|
}
|
|
ParamType::Prompt(addition) => {
|
|
json_params.insert(
|
|
key,
|
|
Value::from(format!("{}{}", params.prompt, addition.unwrap_or_default())),
|
|
);
|
|
}
|
|
ParamType::Temperature => {
|
|
json_params.insert(key, Value::from(params.temperature));
|
|
}
|
|
ParamType::MaxTokens => {
|
|
json_params.insert(key, Value::from(params.max_tokens));
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut res: &serde_json::Value = &CLIENT
|
|
.post(ai.url)
|
|
.json(&json_params)
|
|
.send()
|
|
.await?
|
|
.json()
|
|
.await?;
|
|
|
|
for retriever in ai.retriever {
|
|
// it works but at what cost
|
|
if let Ok(retriever) = retriever.clone().into_rust::<usize>() {
|
|
res = res.get(retriever).context("Failed to execute retriever")?;
|
|
} else {
|
|
let retriever = retriever.into_rust::<String>()?;
|
|
res = res.get(retriever).context("Failed to execute retriever")?;
|
|
}
|
|
}
|
|
|
|
Ok(res
|
|
.as_str()
|
|
.context("Output is not a valid string")?
|
|
.to_string())
|
|
}
|
|
|