db.rs

  1mod ids;
  2mod queries;
  3mod seed;
  4mod tables;
  5
  6#[cfg(test)]
  7mod tests;
  8
  9use cloud_llm_client::LanguageModelProvider;
 10use collections::HashMap;
 11pub use ids::*;
 12pub use seed::*;
 13pub use tables::*;
 14
 15#[cfg(test)]
 16pub use tests::TestLlmDb;
 17use usage_measure::UsageMeasure;
 18
 19use std::future::Future;
 20use std::sync::Arc;
 21
 22use anyhow::Context;
 23pub use sea_orm::ConnectOptions;
 24use sea_orm::prelude::*;
 25use sea_orm::{
 26    ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
 27};
 28
 29use crate::Result;
 30use crate::db::TransactionHandle;
 31use crate::executor::Executor;
 32
 33/// The database for the LLM service.
 34pub struct LlmDatabase {
 35    options: ConnectOptions,
 36    pool: DatabaseConnection,
 37    #[allow(unused)]
 38    executor: Executor,
 39    provider_ids: HashMap<LanguageModelProvider, ProviderId>,
 40    models: HashMap<(LanguageModelProvider, String), model::Model>,
 41    usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
 42    #[cfg(test)]
 43    runtime: Option<tokio::runtime::Runtime>,
 44}
 45
 46impl LlmDatabase {
 47    /// Connects to the database with the given options
 48    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 49        sqlx::any::install_default_drivers();
 50        Ok(Self {
 51            options: options.clone(),
 52            pool: sea_orm::Database::connect(options).await?,
 53            executor,
 54            provider_ids: HashMap::default(),
 55            models: HashMap::default(),
 56            usage_measure_ids: HashMap::default(),
 57            #[cfg(test)]
 58            runtime: None,
 59        })
 60    }
 61
 62    pub async fn initialize(&mut self) -> Result<()> {
 63        self.initialize_providers().await?;
 64        self.initialize_models().await?;
 65        self.initialize_usage_measures().await?;
 66        Ok(())
 67    }
 68
 69    /// Returns the list of all known models, with their [`LanguageModelProvider`].
 70    pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
 71        self.models
 72            .iter()
 73            .map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
 74            .collect::<Vec<_>>()
 75    }
 76
 77    /// Returns the names of the known models for the given [`LanguageModelProvider`].
 78    pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
 79        self.models
 80            .keys()
 81            .filter_map(|(model_provider, model_name)| {
 82                if model_provider == &provider {
 83                    Some(model_name)
 84                } else {
 85                    None
 86                }
 87            })
 88            .cloned()
 89            .collect::<Vec<_>>()
 90    }
 91
 92    pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
 93        Ok(self
 94            .models
 95            .get(&(provider, name.to_string()))
 96            .with_context(|| format!("unknown model {provider:?}:{name}"))?)
 97    }
 98
 99    pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
100        Ok(self
101            .models
102            .values()
103            .find(|model| model.id == id)
104            .with_context(|| format!("no model for ID {id:?}"))?)
105    }
106
107    pub fn options(&self) -> &ConnectOptions {
108        &self.options
109    }
110
111    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
112    where
113        F: Send + Fn(TransactionHandle) -> Fut,
114        Fut: Send + Future<Output = Result<T>>,
115    {
116        let body = async {
117            let (tx, result) = self.with_transaction(&f).await?;
118            match result {
119                Ok(result) => match tx.commit().await.map_err(Into::into) {
120                    Ok(()) => Ok(result),
121                    Err(error) => Err(error),
122                },
123                Err(error) => {
124                    tx.rollback().await?;
125                    Err(error)
126                }
127            }
128        };
129
130        self.run(body).await
131    }
132
133    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
134    where
135        F: Send + Fn(TransactionHandle) -> Fut,
136        Fut: Send + Future<Output = Result<T>>,
137    {
138        let tx = self
139            .pool
140            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
141            .await?;
142
143        let mut tx = Arc::new(Some(tx));
144        let result = f(TransactionHandle(tx.clone())).await;
145        let tx = Arc::get_mut(&mut tx)
146            .and_then(|tx| tx.take())
147            .context("couldn't complete transaction because it's still in use")?;
148
149        Ok((tx, result))
150    }
151
152    async fn run<F, T>(&self, future: F) -> Result<T>
153    where
154        F: Future<Output = Result<T>>,
155    {
156        #[cfg(test)]
157        {
158            if let Executor::Deterministic(executor) = &self.executor {
159                executor.simulate_random_delay().await;
160            }
161
162            self.runtime.as_ref().unwrap().block_on(future)
163        }
164
165        #[cfg(not(test))]
166        {
167            future.await
168        }
169    }
170}