formatting
parent
2cf0bebdb7
commit
8794e519cf
@ -1,105 +1,105 @@
|
||||
use std::{collections::HashMap, sync::mpsc::Sender};
|
||||
|
||||
use crate::config::{ParamType, AI};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use once_cell::sync::Lazy;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
static CLIENT: Lazy<Client> = Lazy::new(Client::new);
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct Params {
|
||||
pub prompt: String,
|
||||
pub temperature: f32,
|
||||
pub max_tokens: i32,
|
||||
}
|
||||
|
||||
impl Default for Params {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
prompt: Default::default(),
|
||||
temperature: 1.0,
|
||||
max_tokens: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate(ai: AI, params: Params) -> Result<String> {
|
||||
let mut json_params: HashMap<String, Value> = HashMap::new();
|
||||
|
||||
// yeah this is cursed but i could not think of a better way to do it
|
||||
for (key, val) in ai.params {
|
||||
match val {
|
||||
ParamType::String(string) => {
|
||||
json_params.insert(key, Value::from(string));
|
||||
}
|
||||
ParamType::Prompt(addition) => {
|
||||
json_params.insert(
|
||||
key,
|
||||
Value::from(format!("{}{}", params.prompt, addition.unwrap_or_default())),
|
||||
);
|
||||
}
|
||||
ParamType::Temperature => {
|
||||
json_params.insert(key, Value::from(params.temperature));
|
||||
}
|
||||
ParamType::MaxTokens => {
|
||||
json_params.insert(key, Value::from(params.max_tokens));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut res: &serde_json::Value = &CLIENT
|
||||
.post(ai.url)
|
||||
.json(&json_params)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
for retriever in ai.retriever {
|
||||
// it works but at what cost
|
||||
if let Ok(retriever) = retriever.clone().into_rust::<usize>() {
|
||||
res = res.get(retriever).context("Failed to execute retriever")?;
|
||||
} else {
|
||||
let retriever = retriever.into_rust::<String>()?;
|
||||
res = res.get(retriever).context("Failed to execute retriever")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res
|
||||
.as_str()
|
||||
.context("Output is not a valid string")?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum ResponseType {
|
||||
Error,
|
||||
Success,
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
pub response_type: ResponseType,
|
||||
pub output: String,
|
||||
}
|
||||
|
||||
pub fn generate_with_mpsc(ai: AI, params: Params, tx: Sender<Response>) {
|
||||
tokio::spawn(async move {
|
||||
let output = match generate(ai, params).await {
|
||||
Ok(output) => Response {
|
||||
response_type: ResponseType::Success,
|
||||
output,
|
||||
},
|
||||
Err(error) => Response {
|
||||
response_type: ResponseType::Error,
|
||||
output: error.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
tx.send(output)
|
||||
.expect("Failed to send output, this should never happen!")
|
||||
});
|
||||
}
|
||||
use std::{collections::HashMap, sync::mpsc::Sender};
|
||||
|
||||
use crate::config::{ParamType, AI};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use once_cell::sync::Lazy;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
static CLIENT: Lazy<Client> = Lazy::new(Client::new);
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct Params {
|
||||
pub prompt: String,
|
||||
pub temperature: f32,
|
||||
pub max_tokens: i32,
|
||||
}
|
||||
|
||||
impl Default for Params {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
prompt: Default::default(),
|
||||
temperature: 1.0,
|
||||
max_tokens: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate(ai: AI, params: Params) -> Result<String> {
|
||||
let mut json_params: HashMap<String, Value> = HashMap::new();
|
||||
|
||||
// yeah this is cursed but i could not think of a better way to do it
|
||||
for (key, val) in ai.params {
|
||||
match val {
|
||||
ParamType::String(string) => {
|
||||
json_params.insert(key, Value::from(string));
|
||||
}
|
||||
ParamType::Prompt(addition) => {
|
||||
json_params.insert(
|
||||
key,
|
||||
Value::from(format!("{}{}", params.prompt, addition.unwrap_or_default())),
|
||||
);
|
||||
}
|
||||
ParamType::Temperature => {
|
||||
json_params.insert(key, Value::from(params.temperature));
|
||||
}
|
||||
ParamType::MaxTokens => {
|
||||
json_params.insert(key, Value::from(params.max_tokens));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut res: &serde_json::Value = &CLIENT
|
||||
.post(ai.url)
|
||||
.json(&json_params)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
for retriever in ai.retriever {
|
||||
// it works but at what cost
|
||||
if let Ok(retriever) = retriever.clone().into_rust::<usize>() {
|
||||
res = res.get(retriever).context("Failed to execute retriever")?;
|
||||
} else {
|
||||
let retriever = retriever.into_rust::<String>()?;
|
||||
res = res.get(retriever).context("Failed to execute retriever")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res
|
||||
.as_str()
|
||||
.context("Output is not a valid string")?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum ResponseType {
|
||||
Error,
|
||||
Success,
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
pub response_type: ResponseType,
|
||||
pub output: String,
|
||||
}
|
||||
|
||||
pub fn generate_with_mpsc(ai: AI, params: Params, tx: Sender<Response>) {
|
||||
tokio::spawn(async move {
|
||||
let output = match generate(ai, params).await {
|
||||
Ok(output) => Response {
|
||||
response_type: ResponseType::Success,
|
||||
output,
|
||||
},
|
||||
Err(error) => Response {
|
||||
response_type: ResponseType::Error,
|
||||
output: error.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
tx.send(output)
|
||||
.expect("Failed to send output, this should never happen!")
|
||||
});
|
||||
}
|
||||
|
@ -1,39 +1,38 @@
|
||||
use std::{fs, collections::HashMap};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use ron::Value;
|
||||
use serde::Deserialize;
|
||||
|
||||
pub static CONFIG: Lazy<Config> = Lazy::new(|| {
|
||||
let config = fs::read_to_string("./config.ron").expect("Failed to read config.ron");
|
||||
|
||||
ron::from_str(&config).expect("Failed to parse config")
|
||||
});
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Config {
|
||||
pub app: App,
|
||||
pub ai: Vec<AI>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct App {
|
||||
pub history_dir: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, PartialEq, Debug)]
|
||||
pub struct AI {
|
||||
pub name: String,
|
||||
pub url: String,
|
||||
pub params: HashMap<String, ParamType>,
|
||||
pub retriever: Vec<Value>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, PartialEq, Debug)]
|
||||
pub enum ParamType {
|
||||
String(String),
|
||||
Prompt(Option<String>),
|
||||
Temperature,
|
||||
MaxTokens
|
||||
}
|
||||
|
||||
use std::{collections::HashMap, fs};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use ron::Value;
|
||||
use serde::Deserialize;
|
||||
|
||||
pub static CONFIG: Lazy<Config> = Lazy::new(|| {
|
||||
let config = fs::read_to_string("./config.ron").expect("Failed to read config.ron");
|
||||
|
||||
ron::from_str(&config).expect("Failed to parse config")
|
||||
});
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Config {
|
||||
pub app: App,
|
||||
pub ai: Vec<AI>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct App {
|
||||
pub history_dir: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, PartialEq, Debug)]
|
||||
pub struct AI {
|
||||
pub name: String,
|
||||
pub url: String,
|
||||
pub params: HashMap<String, ParamType>,
|
||||
pub retriever: Vec<Value>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, PartialEq, Debug)]
|
||||
pub enum ParamType {
|
||||
String(String),
|
||||
Prompt(Option<String>),
|
||||
Temperature,
|
||||
MaxTokens,
|
||||
}
|
||||
|
Reference in new issue