1mod billing_tests;
2mod provider_tests;
3mod usage_tests;
4
5use gpui::BackgroundExecutor;
6use parking_lot::Mutex;
7use rand::prelude::*;
8use sea_orm::ConnectionTrait;
9use sqlx::migrate::MigrateDatabase;
10use std::time::Duration;
11
12use crate::migrations::run_database_migrations;
13
14use super::*;
15
16pub struct TestLlmDb {
17 pub db: Option<LlmDatabase>,
18 pub connection: Option<sqlx::AnyConnection>,
19}
20
21impl TestLlmDb {
22 pub fn postgres(background: BackgroundExecutor) -> Self {
23 static LOCK: Mutex<()> = Mutex::new(());
24
25 let _guard = LOCK.lock();
26 let mut rng = StdRng::from_entropy();
27 let url = format!(
28 "postgres://postgres@localhost/zed-llm-test-{}",
29 rng.r#gen::<u128>()
30 );
31 let runtime = tokio::runtime::Builder::new_current_thread()
32 .enable_io()
33 .enable_time()
34 .build()
35 .unwrap();
36
37 let mut db = runtime.block_on(async {
38 sqlx::Postgres::create_database(&url)
39 .await
40 .expect("failed to create test db");
41 let mut options = ConnectOptions::new(url);
42 options
43 .max_connections(5)
44 .idle_timeout(Duration::from_secs(0));
45 let db = LlmDatabase::new(options, Executor::Deterministic(background))
46 .await
47 .unwrap();
48 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
49 run_database_migrations(db.options(), migrations_path)
50 .await
51 .unwrap();
52 db
53 });
54
55 db.runtime = Some(runtime);
56
57 Self {
58 db: Some(db),
59 connection: None,
60 }
61 }
62
63 pub fn db(&mut self) -> &mut LlmDatabase {
64 self.db.as_mut().unwrap()
65 }
66}
67
68#[macro_export]
69macro_rules! test_llm_db {
70 ($test_name:ident, $postgres_test_name:ident) => {
71 #[gpui::test]
72 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
73 if !cfg!(target_os = "macos") {
74 return;
75 }
76
77 let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
78 $test_name(test_db.db()).await;
79 }
80 };
81}
82
83impl Drop for TestLlmDb {
84 fn drop(&mut self) {
85 let db = self.db.take().unwrap();
86 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
87 db.runtime.as_ref().unwrap().block_on(async {
88 use util::ResultExt;
89 let query = "
90 SELECT pg_terminate_backend(pg_stat_activity.pid)
91 FROM pg_stat_activity
92 WHERE
93 pg_stat_activity.datname = current_database() AND
94 pid <> pg_backend_pid();
95 ";
96 db.pool
97 .execute(sea_orm::Statement::from_string(
98 db.pool.get_database_backend(),
99 query,
100 ))
101 .await
102 .log_err();
103 sqlx::Postgres::drop_database(db.options.get_url())
104 .await
105 .log_err();
106 })
107 }
108 }
109}