You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

106 lines
2.9 KiB

use std::{collections::HashMap, sync::mpsc::Sender};
use crate::config::{ParamType, AI};
use anyhow::{Context, Result};
use once_cell::sync::Lazy;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
static CLIENT: Lazy<Client> = Lazy::new(Client::new);
#[derive(Clone, Serialize, Deserialize)]
pub struct Params {
pub prompt: String,
pub temperature: f32,
pub max_tokens: i32,
}
impl Default for Params {
fn default() -> Self {
Self {
prompt: Default::default(),
temperature: 1.0,
max_tokens: 1000,
}
}
}
pub async fn generate(ai: AI, params: Params) -> Result<String> {
let mut json_params: HashMap<String, Value> = HashMap::new();
// yeah this is cursed but i could not think of a better way to do it
for (key, val) in ai.params {
match val {
ParamType::String(string) => {
json_params.insert(key, Value::from(string));
}
ParamType::Prompt(addition) => {
json_params.insert(
key,
Value::from(format!("{}{}", params.prompt, addition.unwrap_or_default())),
);
}
ParamType::Temperature => {
json_params.insert(key, Value::from(params.temperature));
}
ParamType::MaxTokens => {
json_params.insert(key, Value::from(params.max_tokens));
}
}
}
let mut res: &serde_json::Value = &CLIENT
.post(ai.url)
.json(&json_params)
.send()
.await?
.json()
.await?;
for retriever in ai.retriever {
// it works but at what cost
if let Ok(retriever) = retriever.clone().into_rust::<usize>() {
res = res.get(retriever).context("Failed to execute retriever")?;
} else {
let retriever = retriever.into_rust::<String>()?;
res = res.get(retriever).context("Failed to execute retriever")?;
}
}
Ok(res
.as_str()
.context("Output is not a valid string")?
.to_string())
}
#[derive(PartialEq)]
pub enum ResponseType {
Error,
Success,
}
pub struct Response {
pub response_type: ResponseType,
pub output: String,
}
pub fn generate_with_mpsc(ai: AI, params: Params, tx: Sender<Response>) {
tokio::spawn(async move {
let output = match generate(ai, params).await {
Ok(output) => Response {
response_type: ResponseType::Success,
output,
},
Err(error) => Response {
response_type: ResponseType::Error,
output: error.to_string(),
},
};
tx.send(output)
.expect("Failed to send output, this should never happen!")
});
}