1use sea_orm::sea_query::OnConflict;
2use sea_orm::QueryOrder;
3
4use super::*;
5
6impl LlmDatabase {
7 pub async fn initialize_providers(&self) -> Result<()> {
8 self.transaction(|tx| async move {
9 let providers_and_models = vec![
10 ("anthropic", "claude-3-5-sonnet"),
11 ("anthropic", "claude-3-opus"),
12 ("anthropic", "claude-3-sonnet"),
13 ("anthropic", "claude-3-haiku"),
14 ];
15
16 for (provider_name, model_name) in providers_and_models {
17 let insert_provider = provider::Entity::insert(provider::ActiveModel {
18 name: ActiveValue::set(provider_name.to_owned()),
19 ..Default::default()
20 })
21 .on_conflict(
22 OnConflict::columns([provider::Column::Name])
23 .update_column(provider::Column::Name)
24 .to_owned(),
25 );
26
27 let provider = if tx.support_returning() {
28 insert_provider.exec_with_returning(&*tx).await?
29 } else {
30 insert_provider.exec_without_returning(&*tx).await?;
31 provider::Entity::find()
32 .filter(provider::Column::Name.eq(provider_name))
33 .one(&*tx)
34 .await?
35 .ok_or_else(|| anyhow!("failed to insert provider"))?
36 };
37
38 model::Entity::insert(model::ActiveModel {
39 provider_id: ActiveValue::set(provider.id),
40 name: ActiveValue::set(model_name.to_owned()),
41 ..Default::default()
42 })
43 .on_conflict(
44 OnConflict::columns([model::Column::ProviderId, model::Column::Name])
45 .update_column(model::Column::Name)
46 .to_owned(),
47 )
48 .exec_without_returning(&*tx)
49 .await?;
50 }
51
52 Ok(())
53 })
54 .await
55 }
56
57 /// Returns the list of LLM providers.
58 pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
59 self.transaction(|tx| async move {
60 Ok(provider::Entity::find()
61 .order_by_asc(provider::Column::Name)
62 .all(&*tx)
63 .await?)
64 })
65 .await
66 }
67}