1use super::*;
2use sea_orm::{QueryOrder, sea_query::OnConflict};
3use std::str::FromStr;
4use strum::IntoEnumIterator as _;
5
6pub struct ModelParams {
7 pub provider: LanguageModelProvider,
8 pub name: String,
9 pub max_requests_per_minute: i64,
10 pub max_tokens_per_minute: i64,
11 pub max_tokens_per_day: i64,
12 pub price_per_million_input_tokens: i32,
13 pub price_per_million_output_tokens: i32,
14}
15
16impl LlmDatabase {
17 pub async fn initialize_providers(&mut self) -> Result<()> {
18 self.provider_ids = self
19 .transaction(|tx| async move {
20 let existing_providers = provider::Entity::find().all(&*tx).await?;
21
22 let mut new_providers = LanguageModelProvider::iter()
23 .filter(|provider| {
24 !existing_providers
25 .iter()
26 .any(|p| p.name == provider.to_string())
27 })
28 .map(|provider| provider::ActiveModel {
29 name: ActiveValue::set(provider.to_string()),
30 ..Default::default()
31 })
32 .peekable();
33
34 if new_providers.peek().is_some() {
35 provider::Entity::insert_many(new_providers)
36 .exec(&*tx)
37 .await?;
38 }
39
40 let all_providers: HashMap<_, _> = provider::Entity::find()
41 .all(&*tx)
42 .await?
43 .iter()
44 .filter_map(|provider| {
45 LanguageModelProvider::from_str(&provider.name)
46 .ok()
47 .map(|p| (p, provider.id))
48 })
49 .collect();
50
51 Ok(all_providers)
52 })
53 .await?;
54 Ok(())
55 }
56
57 pub async fn initialize_models(&mut self) -> Result<()> {
58 let all_provider_ids = &self.provider_ids;
59 self.models = self
60 .transaction(|tx| async move {
61 let all_models: HashMap<_, _> = model::Entity::find()
62 .all(&*tx)
63 .await?
64 .into_iter()
65 .filter_map(|model| {
66 let provider = all_provider_ids.iter().find_map(|(provider, id)| {
67 if *id == model.provider_id {
68 Some(provider)
69 } else {
70 None
71 }
72 })?;
73 Some(((*provider, model.name.clone()), model))
74 })
75 .collect();
76 Ok(all_models)
77 })
78 .await?;
79 Ok(())
80 }
81
82 pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> {
83 let all_provider_ids = &self.provider_ids;
84 self.transaction(|tx| async move {
85 model::Entity::insert_many(models.iter().map(|model_params| {
86 let provider_id = all_provider_ids[&model_params.provider];
87 model::ActiveModel {
88 provider_id: ActiveValue::set(provider_id),
89 name: ActiveValue::set(model_params.name.clone()),
90 max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute),
91 max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute),
92 max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day),
93 price_per_million_input_tokens: ActiveValue::set(
94 model_params.price_per_million_input_tokens,
95 ),
96 price_per_million_output_tokens: ActiveValue::set(
97 model_params.price_per_million_output_tokens,
98 ),
99 ..Default::default()
100 }
101 }))
102 .on_conflict(
103 OnConflict::columns([model::Column::ProviderId, model::Column::Name])
104 .update_columns([
105 model::Column::MaxRequestsPerMinute,
106 model::Column::MaxTokensPerMinute,
107 model::Column::MaxTokensPerDay,
108 model::Column::PricePerMillionInputTokens,
109 model::Column::PricePerMillionOutputTokens,
110 ])
111 .to_owned(),
112 )
113 .exec_without_returning(&*tx)
114 .await?;
115 Ok(())
116 })
117 .await?;
118 self.initialize_models().await
119 }
120
121 /// Returns the list of LLM providers.
122 pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
123 self.transaction(|tx| async move {
124 Ok(provider::Entity::find()
125 .order_by_asc(provider::Column::Name)
126 .all(&*tx)
127 .await?
128 .into_iter()
129 .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
130 .collect())
131 })
132 .await
133 }
134}