formatting

master
Tymon 1 year ago
parent 2cf0bebdb7
commit 8794e519cf

@ -1,105 +1,105 @@
use std::{collections::HashMap, sync::mpsc::Sender};
use crate::config::{ParamType, AI};
use anyhow::{Context, Result};
use once_cell::sync::Lazy;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
static CLIENT: Lazy<Client> = Lazy::new(Client::new);
#[derive(Clone, Serialize, Deserialize)]
pub struct Params {
pub prompt: String,
pub temperature: f32,
pub max_tokens: i32,
}
impl Default for Params {
fn default() -> Self {
Self {
prompt: Default::default(),
temperature: 1.0,
max_tokens: 1000,
}
}
}
pub async fn generate(ai: AI, params: Params) -> Result<String> {
let mut json_params: HashMap<String, Value> = HashMap::new();
// yeah this is cursed but i could not think of a better way to do it
for (key, val) in ai.params {
match val {
ParamType::String(string) => {
json_params.insert(key, Value::from(string));
}
ParamType::Prompt(addition) => {
json_params.insert(
key,
Value::from(format!("{}{}", params.prompt, addition.unwrap_or_default())),
);
}
ParamType::Temperature => {
json_params.insert(key, Value::from(params.temperature));
}
ParamType::MaxTokens => {
json_params.insert(key, Value::from(params.max_tokens));
}
}
}
let mut res: &serde_json::Value = &CLIENT
.post(ai.url)
.json(&json_params)
.send()
.await?
.json()
.await?;
for retriever in ai.retriever {
// it works but at what cost
if let Ok(retriever) = retriever.clone().into_rust::<usize>() {
res = res.get(retriever).context("Failed to execute retriever")?;
} else {
let retriever = retriever.into_rust::<String>()?;
res = res.get(retriever).context("Failed to execute retriever")?;
}
}
Ok(res
.as_str()
.context("Output is not a valid string")?
.to_string())
}
#[derive(PartialEq)]
pub enum ResponseType {
Error,
Success,
}
pub struct Response {
pub response_type: ResponseType,
pub output: String,
}
pub fn generate_with_mpsc(ai: AI, params: Params, tx: Sender<Response>) {
tokio::spawn(async move {
let output = match generate(ai, params).await {
Ok(output) => Response {
response_type: ResponseType::Success,
output,
},
Err(error) => Response {
response_type: ResponseType::Error,
output: error.to_string(),
},
};
tx.send(output)
.expect("Failed to send output, this should never happen!")
});
}
use std::{collections::HashMap, sync::mpsc::Sender};
use crate::config::{ParamType, AI};
use anyhow::{Context, Result};
use once_cell::sync::Lazy;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
static CLIENT: Lazy<Client> = Lazy::new(Client::new);
#[derive(Clone, Serialize, Deserialize)]
pub struct Params {
pub prompt: String,
pub temperature: f32,
pub max_tokens: i32,
}
impl Default for Params {
fn default() -> Self {
Self {
prompt: Default::default(),
temperature: 1.0,
max_tokens: 1000,
}
}
}
pub async fn generate(ai: AI, params: Params) -> Result<String> {
let mut json_params: HashMap<String, Value> = HashMap::new();
// yeah this is cursed but i could not think of a better way to do it
for (key, val) in ai.params {
match val {
ParamType::String(string) => {
json_params.insert(key, Value::from(string));
}
ParamType::Prompt(addition) => {
json_params.insert(
key,
Value::from(format!("{}{}", params.prompt, addition.unwrap_or_default())),
);
}
ParamType::Temperature => {
json_params.insert(key, Value::from(params.temperature));
}
ParamType::MaxTokens => {
json_params.insert(key, Value::from(params.max_tokens));
}
}
}
let mut res: &serde_json::Value = &CLIENT
.post(ai.url)
.json(&json_params)
.send()
.await?
.json()
.await?;
for retriever in ai.retriever {
// it works but at what cost
if let Ok(retriever) = retriever.clone().into_rust::<usize>() {
res = res.get(retriever).context("Failed to execute retriever")?;
} else {
let retriever = retriever.into_rust::<String>()?;
res = res.get(retriever).context("Failed to execute retriever")?;
}
}
Ok(res
.as_str()
.context("Output is not a valid string")?
.to_string())
}
#[derive(PartialEq)]
pub enum ResponseType {
Error,
Success,
}
pub struct Response {
pub response_type: ResponseType,
pub output: String,
}
pub fn generate_with_mpsc(ai: AI, params: Params, tx: Sender<Response>) {
tokio::spawn(async move {
let output = match generate(ai, params).await {
Ok(output) => Response {
response_type: ResponseType::Success,
output,
},
Err(error) => Response {
response_type: ResponseType::Error,
output: error.to_string(),
},
};
tx.send(output)
.expect("Failed to send output, this should never happen!")
});
}

@ -1,39 +1,38 @@
use std::{fs, collections::HashMap};
use once_cell::sync::Lazy;
use ron::Value;
use serde::Deserialize;
pub static CONFIG: Lazy<Config> = Lazy::new(|| {
let config = fs::read_to_string("./config.ron").expect("Failed to read config.ron");
ron::from_str(&config).expect("Failed to parse config")
});
#[derive(Deserialize, Debug)]
pub struct Config {
pub app: App,
pub ai: Vec<AI>,
}
#[derive(Deserialize, Debug)]
pub struct App {
pub history_dir: Option<String>,
}
#[derive(Deserialize, Clone, PartialEq, Debug)]
pub struct AI {
pub name: String,
pub url: String,
pub params: HashMap<String, ParamType>,
pub retriever: Vec<Value>,
}
#[derive(Deserialize, Clone, PartialEq, Debug)]
pub enum ParamType {
String(String),
Prompt(Option<String>),
Temperature,
MaxTokens
}
use std::{collections::HashMap, fs};
use once_cell::sync::Lazy;
use ron::Value;
use serde::Deserialize;
pub static CONFIG: Lazy<Config> = Lazy::new(|| {
let config = fs::read_to_string("./config.ron").expect("Failed to read config.ron");
ron::from_str(&config).expect("Failed to parse config")
});
#[derive(Deserialize, Debug)]
pub struct Config {
pub app: App,
pub ai: Vec<AI>,
}
#[derive(Deserialize, Debug)]
pub struct App {
pub history_dir: Option<String>,
}
#[derive(Deserialize, Clone, PartialEq, Debug)]
pub struct AI {
pub name: String,
pub url: String,
pub params: HashMap<String, ParamType>,
pub retriever: Vec<Value>,
}
#[derive(Deserialize, Clone, PartialEq, Debug)]
pub enum ParamType {
String(String),
Prompt(Option<String>),
Temperature,
MaxTokens,
}

@ -1,4 +1,4 @@
use anyhow::{Result, Context};
use anyhow::{Context, Result};
use chrono::{DateTime, Local};
use serde::{Deserialize, Serialize};
use std::fs;
@ -49,7 +49,7 @@ pub fn read_all_history() -> Result<Vec<HistoryEntry>> {
}
}
entries.sort_by(|a, b|b.date.cmp(&a.date));
entries.sort_by(|a, b| b.date.cmp(&a.date));
Ok(entries)
}