db.rs

  1mod ids;
  2mod queries;
  3mod seed;
  4mod tables;
  5
  6#[cfg(test)]
  7mod tests;
  8
  9use collections::HashMap;
 10pub use ids::*;
 11use rpc::LanguageModelProvider;
 12pub use seed::*;
 13pub use tables::*;
 14
 15#[cfg(test)]
 16pub use tests::TestLlmDb;
 17use usage_measure::UsageMeasure;
 18
 19use std::future::Future;
 20use std::sync::Arc;
 21
 22use anyhow::anyhow;
 23pub use queries::usages::{ActiveUserCount, TokenUsage};
 24pub use sea_orm::ConnectOptions;
 25use sea_orm::prelude::*;
 26use sea_orm::{
 27    ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
 28};
 29
 30use crate::Result;
 31use crate::db::TransactionHandle;
 32use crate::executor::Executor;
 33
 34/// The database for the LLM service.
 35pub struct LlmDatabase {
 36    options: ConnectOptions,
 37    pool: DatabaseConnection,
 38    #[allow(unused)]
 39    executor: Executor,
 40    provider_ids: HashMap<LanguageModelProvider, ProviderId>,
 41    models: HashMap<(LanguageModelProvider, String), model::Model>,
 42    usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
 43    #[cfg(test)]
 44    runtime: Option<tokio::runtime::Runtime>,
 45}
 46
 47impl LlmDatabase {
 48    /// Connects to the database with the given options
 49    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 50        sqlx::any::install_default_drivers();
 51        Ok(Self {
 52            options: options.clone(),
 53            pool: sea_orm::Database::connect(options).await?,
 54            executor,
 55            provider_ids: HashMap::default(),
 56            models: HashMap::default(),
 57            usage_measure_ids: HashMap::default(),
 58            #[cfg(test)]
 59            runtime: None,
 60        })
 61    }
 62
 63    pub async fn initialize(&mut self) -> Result<()> {
 64        self.initialize_providers().await?;
 65        self.initialize_models().await?;
 66        self.initialize_usage_measures().await?;
 67        Ok(())
 68    }
 69
 70    /// Returns the list of all known models, with their [`LanguageModelProvider`].
 71    pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
 72        self.models
 73            .iter()
 74            .map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
 75            .collect::<Vec<_>>()
 76    }
 77
 78    /// Returns the names of the known models for the given [`LanguageModelProvider`].
 79    pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
 80        self.models
 81            .keys()
 82            .filter_map(|(model_provider, model_name)| {
 83                if model_provider == &provider {
 84                    Some(model_name)
 85                } else {
 86                    None
 87                }
 88            })
 89            .cloned()
 90            .collect::<Vec<_>>()
 91    }
 92
 93    pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
 94        Ok(self
 95            .models
 96            .get(&(provider, name.to_string()))
 97            .ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
 98    }
 99
100    pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
101        Ok(self
102            .models
103            .values()
104            .find(|model| model.id == id)
105            .ok_or_else(|| anyhow!("no model for ID {id:?}"))?)
106    }
107
108    pub fn options(&self) -> &ConnectOptions {
109        &self.options
110    }
111
112    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
113    where
114        F: Send + Fn(TransactionHandle) -> Fut,
115        Fut: Send + Future<Output = Result<T>>,
116    {
117        let body = async {
118            let (tx, result) = self.with_transaction(&f).await?;
119            match result {
120                Ok(result) => match tx.commit().await.map_err(Into::into) {
121                    Ok(()) => Ok(result),
122                    Err(error) => Err(error),
123                },
124                Err(error) => {
125                    tx.rollback().await?;
126                    Err(error)
127                }
128            }
129        };
130
131        self.run(body).await
132    }
133
134    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
135    where
136        F: Send + Fn(TransactionHandle) -> Fut,
137        Fut: Send + Future<Output = Result<T>>,
138    {
139        let tx = self
140            .pool
141            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
142            .await?;
143
144        let mut tx = Arc::new(Some(tx));
145        let result = f(TransactionHandle(tx.clone())).await;
146        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
147            return Err(anyhow!(
148                "couldn't complete transaction because it's still in use"
149            ))?;
150        };
151
152        Ok((tx, result))
153    }
154
155    async fn run<F, T>(&self, future: F) -> Result<T>
156    where
157        F: Future<Output = Result<T>>,
158    {
159        #[cfg(test)]
160        {
161            if let Executor::Deterministic(executor) = &self.executor {
162                executor.simulate_random_delay().await;
163            }
164
165            self.runtime.as_ref().unwrap().block_on(future)
166        }
167
168        #[cfg(not(test))]
169        {
170            future.await
171        }
172    }
173}