1mod provider_tests;
2mod subscription_usage_tests;
3
4use gpui::BackgroundExecutor;
5use parking_lot::Mutex;
6use rand::prelude::*;
7use sea_orm::ConnectionTrait;
8use sqlx::migrate::MigrateDatabase;
9use std::time::Duration;
10
11use crate::migrations::run_database_migrations;
12
13use super::*;
14
15pub struct TestLlmDb {
16 pub db: Option<LlmDatabase>,
17 pub connection: Option<sqlx::AnyConnection>,
18}
19
20impl TestLlmDb {
21 pub fn postgres(background: BackgroundExecutor) -> Self {
22 static LOCK: Mutex<()> = Mutex::new(());
23
24 let _guard = LOCK.lock();
25 let mut rng = StdRng::from_entropy();
26 let url = format!(
27 "postgres://postgres@localhost/zed-llm-test-{}",
28 rng.r#gen::<u128>()
29 );
30 let runtime = tokio::runtime::Builder::new_current_thread()
31 .enable_io()
32 .enable_time()
33 .build()
34 .unwrap();
35
36 let mut db = runtime.block_on(async {
37 sqlx::Postgres::create_database(&url)
38 .await
39 .expect("failed to create test db");
40 let mut options = ConnectOptions::new(url);
41 options
42 .max_connections(5)
43 .idle_timeout(Duration::from_secs(0));
44 let db = LlmDatabase::new(options, Executor::Deterministic(background))
45 .await
46 .unwrap();
47 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
48 run_database_migrations(db.options(), migrations_path)
49 .await
50 .unwrap();
51 db
52 });
53
54 db.runtime = Some(runtime);
55
56 Self {
57 db: Some(db),
58 connection: None,
59 }
60 }
61
62 pub fn db(&mut self) -> &mut LlmDatabase {
63 self.db.as_mut().unwrap()
64 }
65}
66
67#[macro_export]
68macro_rules! test_llm_db {
69 ($test_name:ident, $postgres_test_name:ident) => {
70 #[gpui::test]
71 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
72 if !cfg!(target_os = "macos") {
73 return;
74 }
75
76 let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
77 $test_name(test_db.db()).await;
78 }
79 };
80}
81
82impl Drop for TestLlmDb {
83 fn drop(&mut self) {
84 let db = self.db.take().unwrap();
85 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
86 db.runtime.as_ref().unwrap().block_on(async {
87 use util::ResultExt;
88 let query = "
89 SELECT pg_terminate_backend(pg_stat_activity.pid)
90 FROM pg_stat_activity
91 WHERE
92 pg_stat_activity.datname = current_database() AND
93 pid <> pg_backend_pid();
94 ";
95 db.pool
96 .execute(sea_orm::Statement::from_string(
97 db.pool.get_database_backend(),
98 query,
99 ))
100 .await
101 .log_err();
102 sqlx::Postgres::drop_database(db.options.get_url())
103 .await
104 .log_err();
105 })
106 }
107 }
108}