1mod provider_tests;
2mod usage_tests;
3
4use gpui::BackgroundExecutor;
5use parking_lot::Mutex;
6use rand::prelude::*;
7use sea_orm::ConnectionTrait;
8use sqlx::migrate::MigrateDatabase;
9use std::sync::Arc;
10use std::time::Duration;
11
12use crate::migrations::run_database_migrations;
13
14use super::*;
15
16pub struct TestLlmDb {
17 pub db: Option<Arc<LlmDatabase>>,
18 pub connection: Option<sqlx::AnyConnection>,
19}
20
21impl TestLlmDb {
22 pub fn sqlite(background: BackgroundExecutor) -> Self {
23 let url = "sqlite::memory:";
24 let runtime = tokio::runtime::Builder::new_current_thread()
25 .enable_io()
26 .enable_time()
27 .build()
28 .unwrap();
29
30 let mut db = runtime.block_on(async {
31 let mut options = ConnectOptions::new(url);
32 options.max_connections(5);
33 let db = LlmDatabase::new(options, Executor::Deterministic(background))
34 .await
35 .unwrap();
36 let sql = include_str!(concat!(
37 env!("CARGO_MANIFEST_DIR"),
38 "/migrations_llm.sqlite/20240806182921_test_schema.sql"
39 ));
40 db.pool
41 .execute(sea_orm::Statement::from_string(
42 db.pool.get_database_backend(),
43 sql,
44 ))
45 .await
46 .unwrap();
47 db
48 });
49
50 db.runtime = Some(runtime);
51
52 Self {
53 db: Some(Arc::new(db)),
54 connection: None,
55 }
56 }
57
58 pub fn postgres(background: BackgroundExecutor) -> Self {
59 static LOCK: Mutex<()> = Mutex::new(());
60
61 let _guard = LOCK.lock();
62 let mut rng = StdRng::from_entropy();
63 let url = format!(
64 "postgres://postgres@localhost/zed-llm-test-{}",
65 rng.gen::<u128>()
66 );
67 let runtime = tokio::runtime::Builder::new_current_thread()
68 .enable_io()
69 .enable_time()
70 .build()
71 .unwrap();
72
73 let mut db = runtime.block_on(async {
74 sqlx::Postgres::create_database(&url)
75 .await
76 .expect("failed to create test db");
77 let mut options = ConnectOptions::new(url);
78 options
79 .max_connections(5)
80 .idle_timeout(Duration::from_secs(0));
81 let db = LlmDatabase::new(options, Executor::Deterministic(background))
82 .await
83 .unwrap();
84 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
85 run_database_migrations(db.options(), migrations_path)
86 .await
87 .unwrap();
88 db
89 });
90
91 db.runtime = Some(runtime);
92
93 Self {
94 db: Some(Arc::new(db)),
95 connection: None,
96 }
97 }
98
99 pub fn db(&self) -> &Arc<LlmDatabase> {
100 self.db.as_ref().unwrap()
101 }
102}
103
104#[macro_export]
105macro_rules! test_both_llm_dbs {
106 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
107 #[cfg(target_os = "macos")]
108 #[gpui::test]
109 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
110 let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
111 $test_name(test_db.db()).await;
112 }
113
114 #[gpui::test]
115 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
116 let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
117 $test_name(test_db.db()).await;
118 }
119 };
120}
121
122impl Drop for TestLlmDb {
123 fn drop(&mut self) {
124 let db = self.db.take().unwrap();
125 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
126 db.runtime.as_ref().unwrap().block_on(async {
127 use util::ResultExt;
128 let query = "
129 SELECT pg_terminate_backend(pg_stat_activity.pid)
130 FROM pg_stat_activity
131 WHERE
132 pg_stat_activity.datname = current_database() AND
133 pid <> pg_backend_pid();
134 ";
135 db.pool
136 .execute(sea_orm::Statement::from_string(
137 db.pool.get_database_backend(),
138 query,
139 ))
140 .await
141 .log_err();
142 sqlx::Postgres::drop_database(db.options.get_url())
143 .await
144 .log_err();
145 })
146 }
147 }
148}