1mod provider_tests;
2
3use gpui::BackgroundExecutor;
4use parking_lot::Mutex;
5use rand::prelude::*;
6use sea_orm::ConnectionTrait;
7use sqlx::migrate::MigrateDatabase;
8use std::time::Duration;
9
10use crate::migrations::run_database_migrations;
11
12use super::*;
13
14pub struct TestLlmDb {
15 pub db: Option<LlmDatabase>,
16 pub connection: Option<sqlx::AnyConnection>,
17}
18
19impl TestLlmDb {
20 pub fn postgres(background: BackgroundExecutor) -> Self {
21 static LOCK: Mutex<()> = Mutex::new(());
22
23 let _guard = LOCK.lock();
24 let mut rng = StdRng::from_entropy();
25 let url = format!(
26 "postgres://postgres@localhost/zed-llm-test-{}",
27 rng.r#gen::<u128>()
28 );
29 let runtime = tokio::runtime::Builder::new_current_thread()
30 .enable_io()
31 .enable_time()
32 .build()
33 .unwrap();
34
35 let mut db = runtime.block_on(async {
36 sqlx::Postgres::create_database(&url)
37 .await
38 .expect("failed to create test db");
39 let mut options = ConnectOptions::new(url);
40 options
41 .max_connections(5)
42 .idle_timeout(Duration::from_secs(0));
43 let db = LlmDatabase::new(options, Executor::Deterministic(background))
44 .await
45 .unwrap();
46 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
47 run_database_migrations(db.options(), migrations_path)
48 .await
49 .unwrap();
50 db
51 });
52
53 db.runtime = Some(runtime);
54
55 Self {
56 db: Some(db),
57 connection: None,
58 }
59 }
60
61 pub fn db(&mut self) -> &mut LlmDatabase {
62 self.db.as_mut().unwrap()
63 }
64}
65
66#[macro_export]
67macro_rules! test_llm_db {
68 ($test_name:ident, $postgres_test_name:ident) => {
69 #[gpui::test]
70 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
71 if !cfg!(target_os = "macos") {
72 return;
73 }
74
75 let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
76 $test_name(test_db.db()).await;
77 }
78 };
79}
80
81impl Drop for TestLlmDb {
82 fn drop(&mut self) {
83 let db = self.db.take().unwrap();
84 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
85 db.runtime.as_ref().unwrap().block_on(async {
86 use util::ResultExt;
87 let query = "
88 SELECT pg_terminate_backend(pg_stat_activity.pid)
89 FROM pg_stat_activity
90 WHERE
91 pg_stat_activity.datname = current_database() AND
92 pid <> pg_backend_pid();
93 ";
94 db.pool
95 .execute(sea_orm::Statement::from_string(
96 db.pool.get_database_backend(),
97 query,
98 ))
99 .await
100 .log_err();
101 sqlx::Postgres::drop_database(db.options.get_url())
102 .await
103 .log_err();
104 })
105 }
106 }
107}