providers.rs

  1use super::*;
  2use sea_orm::QueryOrder;
  3use std::str::FromStr;
  4use strum::IntoEnumIterator as _;
  5
  6pub struct ModelParams {
  7    pub provider: LanguageModelProvider,
  8    pub name: String,
  9    pub max_requests_per_minute: i64,
 10    pub max_tokens_per_minute: i64,
 11    pub max_tokens_per_day: i64,
 12    pub price_per_million_input_tokens: i32,
 13    pub price_per_million_output_tokens: i32,
 14}
 15
 16impl LlmDatabase {
 17    pub async fn initialize_providers(&mut self) -> Result<()> {
 18        self.provider_ids = self
 19            .transaction(|tx| async move {
 20                let existing_providers = provider::Entity::find().all(&*tx).await?;
 21
 22                let mut new_providers = LanguageModelProvider::iter()
 23                    .filter(|provider| {
 24                        !existing_providers
 25                            .iter()
 26                            .any(|p| p.name == provider.to_string())
 27                    })
 28                    .map(|provider| provider::ActiveModel {
 29                        name: ActiveValue::set(provider.to_string()),
 30                        ..Default::default()
 31                    })
 32                    .peekable();
 33
 34                if new_providers.peek().is_some() {
 35                    provider::Entity::insert_many(new_providers)
 36                        .exec(&*tx)
 37                        .await?;
 38                }
 39
 40                let all_providers: HashMap<_, _> = provider::Entity::find()
 41                    .all(&*tx)
 42                    .await?
 43                    .iter()
 44                    .filter_map(|provider| {
 45                        LanguageModelProvider::from_str(&provider.name)
 46                            .ok()
 47                            .map(|p| (p, provider.id))
 48                    })
 49                    .collect();
 50
 51                Ok(all_providers)
 52            })
 53            .await?;
 54        Ok(())
 55    }
 56
 57    pub async fn initialize_models(&mut self) -> Result<()> {
 58        let all_provider_ids = &self.provider_ids;
 59        self.models = self
 60            .transaction(|tx| async move {
 61                let all_models: HashMap<_, _> = model::Entity::find()
 62                    .all(&*tx)
 63                    .await?
 64                    .into_iter()
 65                    .filter_map(|model| {
 66                        let provider = all_provider_ids.iter().find_map(|(provider, id)| {
 67                            if *id == model.provider_id {
 68                                Some(provider)
 69                            } else {
 70                                None
 71                            }
 72                        })?;
 73                        Some(((*provider, model.name.clone()), model))
 74                    })
 75                    .collect();
 76                Ok(all_models)
 77            })
 78            .await?;
 79        Ok(())
 80    }
 81
 82    pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> {
 83        let all_provider_ids = &self.provider_ids;
 84        self.transaction(|tx| async move {
 85            model::Entity::insert_many(models.into_iter().map(|model_params| {
 86                let provider_id = all_provider_ids[&model_params.provider];
 87                model::ActiveModel {
 88                    provider_id: ActiveValue::set(provider_id),
 89                    name: ActiveValue::set(model_params.name.clone()),
 90                    max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute),
 91                    max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute),
 92                    max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day),
 93                    price_per_million_input_tokens: ActiveValue::set(
 94                        model_params.price_per_million_input_tokens,
 95                    ),
 96                    price_per_million_output_tokens: ActiveValue::set(
 97                        model_params.price_per_million_output_tokens,
 98                    ),
 99                    ..Default::default()
100                }
101            }))
102            .exec_without_returning(&*tx)
103            .await?;
104            Ok(())
105        })
106        .await?;
107        self.initialize_models().await
108    }
109
110    /// Returns the list of LLM providers.
111    pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
112        self.transaction(|tx| async move {
113            Ok(provider::Entity::find()
114                .order_by_asc(provider::Column::Name)
115                .all(&*tx)
116                .await?
117                .into_iter()
118                .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
119                .collect())
120        })
121        .await
122    }
123}