db.rs

  1mod ids;
  2mod queries;
  3mod tables;
  4
  5#[cfg(test)]
  6mod tests;
  7
  8pub use ids::*;
  9pub use tables::*;
 10
 11#[cfg(test)]
 12pub use tests::TestLlmDb;
 13
 14use std::future::Future;
 15use std::sync::Arc;
 16
 17use anyhow::anyhow;
 18use sea_orm::prelude::*;
 19pub use sea_orm::ConnectOptions;
 20use sea_orm::{
 21    ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
 22};
 23
 24use crate::db::TransactionHandle;
 25use crate::executor::Executor;
 26use crate::Result;
 27
 28/// The database for the LLM service.
 29pub struct LlmDatabase {
 30    options: ConnectOptions,
 31    pool: DatabaseConnection,
 32    #[allow(unused)]
 33    executor: Executor,
 34    #[cfg(test)]
 35    runtime: Option<tokio::runtime::Runtime>,
 36}
 37
 38impl LlmDatabase {
 39    /// Connects to the database with the given options
 40    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 41        sqlx::any::install_default_drivers();
 42        Ok(Self {
 43            options: options.clone(),
 44            pool: sea_orm::Database::connect(options).await?,
 45            executor,
 46            #[cfg(test)]
 47            runtime: None,
 48        })
 49    }
 50
 51    pub fn options(&self) -> &ConnectOptions {
 52        &self.options
 53    }
 54
 55    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
 56    where
 57        F: Send + Fn(TransactionHandle) -> Fut,
 58        Fut: Send + Future<Output = Result<T>>,
 59    {
 60        let body = async {
 61            let (tx, result) = self.with_transaction(&f).await?;
 62            match result {
 63                Ok(result) => match tx.commit().await.map_err(Into::into) {
 64                    Ok(()) => return Ok(result),
 65                    Err(error) => {
 66                        return Err(error);
 67                    }
 68                },
 69                Err(error) => {
 70                    tx.rollback().await?;
 71                    return Err(error);
 72                }
 73            }
 74        };
 75
 76        self.run(body).await
 77    }
 78
 79    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
 80    where
 81        F: Send + Fn(TransactionHandle) -> Fut,
 82        Fut: Send + Future<Output = Result<T>>,
 83    {
 84        let tx = self
 85            .pool
 86            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
 87            .await?;
 88
 89        let mut tx = Arc::new(Some(tx));
 90        let result = f(TransactionHandle(tx.clone())).await;
 91        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
 92            return Err(anyhow!(
 93                "couldn't complete transaction because it's still in use"
 94            ))?;
 95        };
 96
 97        Ok((tx, result))
 98    }
 99
100    async fn run<F, T>(&self, future: F) -> Result<T>
101    where
102        F: Future<Output = Result<T>>,
103    {
104        #[cfg(test)]
105        {
106            if let Executor::Deterministic(executor) = &self.executor {
107                executor.simulate_random_delay().await;
108            }
109
110            self.runtime.as_ref().unwrap().block_on(future)
111        }
112
113        #[cfg(not(test))]
114        {
115            future.await
116        }
117    }
118}