db.rs

 1use std::future::Future;
 2use std::sync::Arc;
 3
 4use anyhow::Context;
 5pub use sea_orm::ConnectOptions;
 6use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait};
 7
 8use crate::Result;
 9use crate::db::TransactionHandle;
10use crate::executor::Executor;
11
12/// The database for the LLM service.
13pub struct LlmDatabase {
14    options: ConnectOptions,
15    pool: DatabaseConnection,
16    #[allow(unused)]
17    executor: Executor,
18    #[cfg(test)]
19    runtime: Option<tokio::runtime::Runtime>,
20}
21
22impl LlmDatabase {
23    /// Connects to the database with the given options
24    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
25        sqlx::any::install_default_drivers();
26        Ok(Self {
27            options: options.clone(),
28            pool: sea_orm::Database::connect(options).await?,
29            executor,
30            #[cfg(test)]
31            runtime: None,
32        })
33    }
34
35    pub fn options(&self) -> &ConnectOptions {
36        &self.options
37    }
38
39    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
40    where
41        F: Send + Fn(TransactionHandle) -> Fut,
42        Fut: Send + Future<Output = Result<T>>,
43    {
44        let body = async {
45            let (tx, result) = self.with_transaction(&f).await?;
46            match result {
47                Ok(result) => match tx.commit().await.map_err(Into::into) {
48                    Ok(()) => Ok(result),
49                    Err(error) => Err(error),
50                },
51                Err(error) => {
52                    tx.rollback().await?;
53                    Err(error)
54                }
55            }
56        };
57
58        self.run(body).await
59    }
60
61    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
62    where
63        F: Send + Fn(TransactionHandle) -> Fut,
64        Fut: Send + Future<Output = Result<T>>,
65    {
66        let tx = self
67            .pool
68            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
69            .await?;
70
71        let mut tx = Arc::new(Some(tx));
72        let result = f(TransactionHandle(tx.clone())).await;
73        let tx = Arc::get_mut(&mut tx)
74            .and_then(|tx| tx.take())
75            .context("couldn't complete transaction because it's still in use")?;
76
77        Ok((tx, result))
78    }
79
80    async fn run<F, T>(&self, future: F) -> Result<T>
81    where
82        F: Future<Output = Result<T>>,
83    {
84        #[cfg(test)]
85        {
86            if let Executor::Deterministic(executor) = &self.executor {
87                executor.simulate_random_delay().await;
88            }
89
90            self.runtime.as_ref().unwrap().block_on(future)
91        }
92
93        #[cfg(not(test))]
94        {
95            future.await
96        }
97    }
98}