1use std::future::Future;
2use std::sync::Arc;
3
4use anyhow::Context;
5pub use sea_orm::ConnectOptions;
6use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait};
7
8use crate::Result;
9use crate::db::TransactionHandle;
10use crate::executor::Executor;
11
12/// The database for the LLM service.
13pub struct LlmDatabase {
14 options: ConnectOptions,
15 pool: DatabaseConnection,
16 #[allow(unused)]
17 executor: Executor,
18 #[cfg(test)]
19 runtime: Option<tokio::runtime::Runtime>,
20}
21
22impl LlmDatabase {
23 /// Connects to the database with the given options
24 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
25 sqlx::any::install_default_drivers();
26 Ok(Self {
27 options: options.clone(),
28 pool: sea_orm::Database::connect(options).await?,
29 executor,
30 #[cfg(test)]
31 runtime: None,
32 })
33 }
34
35 pub fn options(&self) -> &ConnectOptions {
36 &self.options
37 }
38
39 pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
40 where
41 F: Send + Fn(TransactionHandle) -> Fut,
42 Fut: Send + Future<Output = Result<T>>,
43 {
44 let body = async {
45 let (tx, result) = self.with_transaction(&f).await?;
46 match result {
47 Ok(result) => match tx.commit().await.map_err(Into::into) {
48 Ok(()) => Ok(result),
49 Err(error) => Err(error),
50 },
51 Err(error) => {
52 tx.rollback().await?;
53 Err(error)
54 }
55 }
56 };
57
58 self.run(body).await
59 }
60
61 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
62 where
63 F: Send + Fn(TransactionHandle) -> Fut,
64 Fut: Send + Future<Output = Result<T>>,
65 {
66 let tx = self
67 .pool
68 .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
69 .await?;
70
71 let mut tx = Arc::new(Some(tx));
72 let result = f(TransactionHandle(tx.clone())).await;
73 let tx = Arc::get_mut(&mut tx)
74 .and_then(|tx| tx.take())
75 .context("couldn't complete transaction because it's still in use")?;
76
77 Ok((tx, result))
78 }
79
80 async fn run<F, T>(&self, future: F) -> Result<T>
81 where
82 F: Future<Output = Result<T>>,
83 {
84 #[cfg(test)]
85 {
86 if let Executor::Deterministic(executor) = &self.executor {
87 executor.simulate_random_delay().await;
88 }
89
90 self.runtime.as_ref().unwrap().block_on(future)
91 }
92
93 #[cfg(not(test))]
94 {
95 future.await
96 }
97 }
98}