1mod ids;
2mod queries;
3mod seed;
4mod tables;
5
6#[cfg(test)]
7mod tests;
8
9use collections::HashMap;
10pub use ids::*;
11pub use seed::*;
12pub use tables::*;
13use zed_llm_client::LanguageModelProvider;
14
15#[cfg(test)]
16pub use tests::TestLlmDb;
17use usage_measure::UsageMeasure;
18
19use std::future::Future;
20use std::sync::Arc;
21
22use anyhow::anyhow;
23pub use sea_orm::ConnectOptions;
24use sea_orm::prelude::*;
25use sea_orm::{
26 ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
27};
28
29use crate::Result;
30use crate::db::TransactionHandle;
31use crate::executor::Executor;
32
33/// The database for the LLM service.
34pub struct LlmDatabase {
35 options: ConnectOptions,
36 pool: DatabaseConnection,
37 #[allow(unused)]
38 executor: Executor,
39 provider_ids: HashMap<LanguageModelProvider, ProviderId>,
40 models: HashMap<(LanguageModelProvider, String), model::Model>,
41 usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
42 #[cfg(test)]
43 runtime: Option<tokio::runtime::Runtime>,
44}
45
46impl LlmDatabase {
47 /// Connects to the database with the given options
48 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
49 sqlx::any::install_default_drivers();
50 Ok(Self {
51 options: options.clone(),
52 pool: sea_orm::Database::connect(options).await?,
53 executor,
54 provider_ids: HashMap::default(),
55 models: HashMap::default(),
56 usage_measure_ids: HashMap::default(),
57 #[cfg(test)]
58 runtime: None,
59 })
60 }
61
62 pub async fn initialize(&mut self) -> Result<()> {
63 self.initialize_providers().await?;
64 self.initialize_models().await?;
65 self.initialize_usage_measures().await?;
66 Ok(())
67 }
68
69 /// Returns the list of all known models, with their [`LanguageModelProvider`].
70 pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
71 self.models
72 .iter()
73 .map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
74 .collect::<Vec<_>>()
75 }
76
77 /// Returns the names of the known models for the given [`LanguageModelProvider`].
78 pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
79 self.models
80 .keys()
81 .filter_map(|(model_provider, model_name)| {
82 if model_provider == &provider {
83 Some(model_name)
84 } else {
85 None
86 }
87 })
88 .cloned()
89 .collect::<Vec<_>>()
90 }
91
92 pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
93 Ok(self
94 .models
95 .get(&(provider, name.to_string()))
96 .ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
97 }
98
99 pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
100 Ok(self
101 .models
102 .values()
103 .find(|model| model.id == id)
104 .ok_or_else(|| anyhow!("no model for ID {id:?}"))?)
105 }
106
107 pub fn options(&self) -> &ConnectOptions {
108 &self.options
109 }
110
111 pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
112 where
113 F: Send + Fn(TransactionHandle) -> Fut,
114 Fut: Send + Future<Output = Result<T>>,
115 {
116 let body = async {
117 let (tx, result) = self.with_transaction(&f).await?;
118 match result {
119 Ok(result) => match tx.commit().await.map_err(Into::into) {
120 Ok(()) => Ok(result),
121 Err(error) => Err(error),
122 },
123 Err(error) => {
124 tx.rollback().await?;
125 Err(error)
126 }
127 }
128 };
129
130 self.run(body).await
131 }
132
133 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
134 where
135 F: Send + Fn(TransactionHandle) -> Fut,
136 Fut: Send + Future<Output = Result<T>>,
137 {
138 let tx = self
139 .pool
140 .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
141 .await?;
142
143 let mut tx = Arc::new(Some(tx));
144 let result = f(TransactionHandle(tx.clone())).await;
145 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
146 return Err(anyhow!(
147 "couldn't complete transaction because it's still in use"
148 ))?;
149 };
150
151 Ok((tx, result))
152 }
153
154 async fn run<F, T>(&self, future: F) -> Result<T>
155 where
156 F: Future<Output = Result<T>>,
157 {
158 #[cfg(test)]
159 {
160 if let Executor::Deterministic(executor) = &self.executor {
161 executor.simulate_random_delay().await;
162 }
163
164 self.runtime.as_ref().unwrap().block_on(future)
165 }
166
167 #[cfg(not(test))]
168 {
169 future.await
170 }
171 }
172}