providers.rs

 1use sea_orm::sea_query::OnConflict;
 2use sea_orm::QueryOrder;
 3
 4use super::*;
 5
 6impl LlmDatabase {
 7    pub async fn initialize_providers(&self) -> Result<()> {
 8        self.transaction(|tx| async move {
 9            let providers_and_models = vec![
10                ("anthropic", "claude-3-5-sonnet"),
11                ("anthropic", "claude-3-opus"),
12                ("anthropic", "claude-3-sonnet"),
13                ("anthropic", "claude-3-haiku"),
14            ];
15
16            for (provider_name, model_name) in providers_and_models {
17                let insert_provider = provider::Entity::insert(provider::ActiveModel {
18                    name: ActiveValue::set(provider_name.to_owned()),
19                    ..Default::default()
20                })
21                .on_conflict(
22                    OnConflict::columns([provider::Column::Name])
23                        .update_column(provider::Column::Name)
24                        .to_owned(),
25                );
26
27                let provider = if tx.support_returning() {
28                    insert_provider.exec_with_returning(&*tx).await?
29                } else {
30                    insert_provider.exec_without_returning(&*tx).await?;
31                    provider::Entity::find()
32                        .filter(provider::Column::Name.eq(provider_name))
33                        .one(&*tx)
34                        .await?
35                        .ok_or_else(|| anyhow!("failed to insert provider"))?
36                };
37
38                model::Entity::insert(model::ActiveModel {
39                    provider_id: ActiveValue::set(provider.id),
40                    name: ActiveValue::set(model_name.to_owned()),
41                    ..Default::default()
42                })
43                .on_conflict(
44                    OnConflict::columns([model::Column::ProviderId, model::Column::Name])
45                        .update_column(model::Column::Name)
46                        .to_owned(),
47                )
48                .exec_without_returning(&*tx)
49                .await?;
50            }
51
52            Ok(())
53        })
54        .await
55    }
56
57    /// Returns the list of LLM providers.
58    pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
59        self.transaction(|tx| async move {
60            Ok(provider::Entity::find()
61                .order_by_asc(provider::Column::Name)
62                .all(&*tx)
63                .await?)
64        })
65        .await
66    }
67}