providers.rs

  1use super::*;
  2use sea_orm::{QueryOrder, sea_query::OnConflict};
  3use std::str::FromStr;
  4use strum::IntoEnumIterator as _;
  5
  6pub struct ModelParams {
  7    pub provider: LanguageModelProvider,
  8    pub name: String,
  9    pub max_requests_per_minute: i64,
 10    pub max_tokens_per_minute: i64,
 11    pub max_tokens_per_day: i64,
 12    pub price_per_million_input_tokens: i32,
 13    pub price_per_million_output_tokens: i32,
 14}
 15
 16impl LlmDatabase {
 17    pub async fn initialize_providers(&mut self) -> Result<()> {
 18        self.provider_ids = self
 19            .transaction(|tx| async move {
 20                let existing_providers = provider::Entity::find().all(&*tx).await?;
 21
 22                let mut new_providers = LanguageModelProvider::iter()
 23                    .filter(|provider| {
 24                        !existing_providers
 25                            .iter()
 26                            .any(|p| p.name == provider.to_string())
 27                    })
 28                    .map(|provider| provider::ActiveModel {
 29                        name: ActiveValue::set(provider.to_string()),
 30                        ..Default::default()
 31                    })
 32                    .peekable();
 33
 34                if new_providers.peek().is_some() {
 35                    provider::Entity::insert_many(new_providers)
 36                        .exec(&*tx)
 37                        .await?;
 38                }
 39
 40                let all_providers: HashMap<_, _> = provider::Entity::find()
 41                    .all(&*tx)
 42                    .await?
 43                    .iter()
 44                    .filter_map(|provider| {
 45                        LanguageModelProvider::from_str(&provider.name)
 46                            .ok()
 47                            .map(|p| (p, provider.id))
 48                    })
 49                    .collect();
 50
 51                Ok(all_providers)
 52            })
 53            .await?;
 54        Ok(())
 55    }
 56
 57    pub async fn initialize_models(&mut self) -> Result<()> {
 58        let all_provider_ids = &self.provider_ids;
 59        self.models = self
 60            .transaction(|tx| async move {
 61                let all_models: HashMap<_, _> = model::Entity::find()
 62                    .all(&*tx)
 63                    .await?
 64                    .into_iter()
 65                    .filter_map(|model| {
 66                        let provider = all_provider_ids.iter().find_map(|(provider, id)| {
 67                            if *id == model.provider_id {
 68                                Some(provider)
 69                            } else {
 70                                None
 71                            }
 72                        })?;
 73                        Some(((*provider, model.name.clone()), model))
 74                    })
 75                    .collect();
 76                Ok(all_models)
 77            })
 78            .await?;
 79        Ok(())
 80    }
 81
 82    pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> {
 83        let all_provider_ids = &self.provider_ids;
 84        self.transaction(|tx| async move {
 85            model::Entity::insert_many(models.iter().map(|model_params| {
 86                let provider_id = all_provider_ids[&model_params.provider];
 87                model::ActiveModel {
 88                    provider_id: ActiveValue::set(provider_id),
 89                    name: ActiveValue::set(model_params.name.clone()),
 90                    max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute),
 91                    max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute),
 92                    max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day),
 93                    price_per_million_input_tokens: ActiveValue::set(
 94                        model_params.price_per_million_input_tokens,
 95                    ),
 96                    price_per_million_output_tokens: ActiveValue::set(
 97                        model_params.price_per_million_output_tokens,
 98                    ),
 99                    ..Default::default()
100                }
101            }))
102            .on_conflict(
103                OnConflict::columns([model::Column::ProviderId, model::Column::Name])
104                    .update_columns([
105                        model::Column::MaxRequestsPerMinute,
106                        model::Column::MaxTokensPerMinute,
107                        model::Column::MaxTokensPerDay,
108                        model::Column::PricePerMillionInputTokens,
109                        model::Column::PricePerMillionOutputTokens,
110                    ])
111                    .to_owned(),
112            )
113            .exec_without_returning(&*tx)
114            .await?;
115            Ok(())
116        })
117        .await?;
118        self.initialize_models().await
119    }
120
121    /// Returns the list of LLM providers.
122    pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
123        self.transaction(|tx| async move {
124            Ok(provider::Entity::find()
125                .order_by_asc(provider::Column::Name)
126                .all(&*tx)
127                .await?
128                .into_iter()
129                .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
130                .collect())
131        })
132        .await
133    }
134}