1mod ids;
2mod queries;
3mod tables;
4
5#[cfg(test)]
6mod tests;
7
8pub use ids::*;
9pub use tables::*;
10
11#[cfg(test)]
12pub use tests::TestLlmDb;
13
14use std::future::Future;
15use std::sync::Arc;
16
17use anyhow::anyhow;
18use sea_orm::prelude::*;
19pub use sea_orm::ConnectOptions;
20use sea_orm::{
21 ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
22};
23
24use crate::db::TransactionHandle;
25use crate::executor::Executor;
26use crate::Result;
27
28/// The database for the LLM service.
29pub struct LlmDatabase {
30 options: ConnectOptions,
31 pool: DatabaseConnection,
32 #[allow(unused)]
33 executor: Executor,
34 #[cfg(test)]
35 runtime: Option<tokio::runtime::Runtime>,
36}
37
38impl LlmDatabase {
39 /// Connects to the database with the given options
40 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
41 sqlx::any::install_default_drivers();
42 Ok(Self {
43 options: options.clone(),
44 pool: sea_orm::Database::connect(options).await?,
45 executor,
46 #[cfg(test)]
47 runtime: None,
48 })
49 }
50
51 pub fn options(&self) -> &ConnectOptions {
52 &self.options
53 }
54
55 pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
56 where
57 F: Send + Fn(TransactionHandle) -> Fut,
58 Fut: Send + Future<Output = Result<T>>,
59 {
60 let body = async {
61 let (tx, result) = self.with_transaction(&f).await?;
62 match result {
63 Ok(result) => match tx.commit().await.map_err(Into::into) {
64 Ok(()) => return Ok(result),
65 Err(error) => {
66 return Err(error);
67 }
68 },
69 Err(error) => {
70 tx.rollback().await?;
71 return Err(error);
72 }
73 }
74 };
75
76 self.run(body).await
77 }
78
79 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
80 where
81 F: Send + Fn(TransactionHandle) -> Fut,
82 Fut: Send + Future<Output = Result<T>>,
83 {
84 let tx = self
85 .pool
86 .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
87 .await?;
88
89 let mut tx = Arc::new(Some(tx));
90 let result = f(TransactionHandle(tx.clone())).await;
91 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
92 return Err(anyhow!(
93 "couldn't complete transaction because it's still in use"
94 ))?;
95 };
96
97 Ok((tx, result))
98 }
99
100 async fn run<F, T>(&self, future: F) -> Result<T>
101 where
102 F: Future<Output = Result<T>>,
103 {
104 #[cfg(test)]
105 {
106 if let Executor::Deterministic(executor) = &self.executor {
107 executor.simulate_random_delay().await;
108 }
109
110 self.runtime.as_ref().unwrap().block_on(future)
111 }
112
113 #[cfg(not(test))]
114 {
115 future.await
116 }
117 }
118}