1mod ids;
2mod queries;
3mod seed;
4mod tables;
5
6#[cfg(test)]
7mod tests;
8
9use collections::HashMap;
10pub use ids::*;
11use rpc::LanguageModelProvider;
12pub use seed::*;
13pub use tables::*;
14
15#[cfg(test)]
16pub use tests::TestLlmDb;
17use usage_measure::UsageMeasure;
18
19use std::future::Future;
20use std::sync::Arc;
21
22use anyhow::anyhow;
23pub use queries::usages::{ActiveUserCount, TokenUsage};
24pub use sea_orm::ConnectOptions;
25use sea_orm::prelude::*;
26use sea_orm::{
27 ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
28};
29
30use crate::Result;
31use crate::db::TransactionHandle;
32use crate::executor::Executor;
33
34/// The database for the LLM service.
35pub struct LlmDatabase {
36 options: ConnectOptions,
37 pool: DatabaseConnection,
38 #[allow(unused)]
39 executor: Executor,
40 provider_ids: HashMap<LanguageModelProvider, ProviderId>,
41 models: HashMap<(LanguageModelProvider, String), model::Model>,
42 usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
43 #[cfg(test)]
44 runtime: Option<tokio::runtime::Runtime>,
45}
46
47impl LlmDatabase {
48 /// Connects to the database with the given options
49 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
50 sqlx::any::install_default_drivers();
51 Ok(Self {
52 options: options.clone(),
53 pool: sea_orm::Database::connect(options).await?,
54 executor,
55 provider_ids: HashMap::default(),
56 models: HashMap::default(),
57 usage_measure_ids: HashMap::default(),
58 #[cfg(test)]
59 runtime: None,
60 })
61 }
62
63 pub async fn initialize(&mut self) -> Result<()> {
64 self.initialize_providers().await?;
65 self.initialize_models().await?;
66 self.initialize_usage_measures().await?;
67 Ok(())
68 }
69
70 /// Returns the list of all known models, with their [`LanguageModelProvider`].
71 pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
72 self.models
73 .iter()
74 .map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
75 .collect::<Vec<_>>()
76 }
77
78 /// Returns the names of the known models for the given [`LanguageModelProvider`].
79 pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
80 self.models
81 .keys()
82 .filter_map(|(model_provider, model_name)| {
83 if model_provider == &provider {
84 Some(model_name)
85 } else {
86 None
87 }
88 })
89 .cloned()
90 .collect::<Vec<_>>()
91 }
92
93 pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
94 Ok(self
95 .models
96 .get(&(provider, name.to_string()))
97 .ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
98 }
99
100 pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
101 Ok(self
102 .models
103 .values()
104 .find(|model| model.id == id)
105 .ok_or_else(|| anyhow!("no model for ID {id:?}"))?)
106 }
107
108 pub fn options(&self) -> &ConnectOptions {
109 &self.options
110 }
111
112 pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
113 where
114 F: Send + Fn(TransactionHandle) -> Fut,
115 Fut: Send + Future<Output = Result<T>>,
116 {
117 let body = async {
118 let (tx, result) = self.with_transaction(&f).await?;
119 match result {
120 Ok(result) => match tx.commit().await.map_err(Into::into) {
121 Ok(()) => Ok(result),
122 Err(error) => Err(error),
123 },
124 Err(error) => {
125 tx.rollback().await?;
126 Err(error)
127 }
128 }
129 };
130
131 self.run(body).await
132 }
133
134 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
135 where
136 F: Send + Fn(TransactionHandle) -> Fut,
137 Fut: Send + Future<Output = Result<T>>,
138 {
139 let tx = self
140 .pool
141 .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
142 .await?;
143
144 let mut tx = Arc::new(Some(tx));
145 let result = f(TransactionHandle(tx.clone())).await;
146 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
147 return Err(anyhow!(
148 "couldn't complete transaction because it's still in use"
149 ))?;
150 };
151
152 Ok((tx, result))
153 }
154
155 async fn run<F, T>(&self, future: F) -> Result<T>
156 where
157 F: Future<Output = Result<T>>,
158 {
159 #[cfg(test)]
160 {
161 if let Executor::Deterministic(executor) = &self.executor {
162 executor.simulate_random_delay().await;
163 }
164
165 self.runtime.as_ref().unwrap().block_on(future)
166 }
167
168 #[cfg(not(test))]
169 {
170 future.await
171 }
172 }
173}