tests.rs

  1mod provider_tests;
  2mod usage_tests;
  3
  4use gpui::BackgroundExecutor;
  5use parking_lot::Mutex;
  6use rand::prelude::*;
  7use sea_orm::ConnectionTrait;
  8use sqlx::migrate::MigrateDatabase;
  9use std::sync::Arc;
 10use std::time::Duration;
 11
 12use crate::migrations::run_database_migrations;
 13
 14use super::*;
 15
 16pub struct TestLlmDb {
 17    pub db: Option<Arc<LlmDatabase>>,
 18    pub connection: Option<sqlx::AnyConnection>,
 19}
 20
 21impl TestLlmDb {
 22    pub fn sqlite(background: BackgroundExecutor) -> Self {
 23        let url = "sqlite::memory:";
 24        let runtime = tokio::runtime::Builder::new_current_thread()
 25            .enable_io()
 26            .enable_time()
 27            .build()
 28            .unwrap();
 29
 30        let mut db = runtime.block_on(async {
 31            let mut options = ConnectOptions::new(url);
 32            options.max_connections(5);
 33            let db = LlmDatabase::new(options, Executor::Deterministic(background))
 34                .await
 35                .unwrap();
 36            let sql = include_str!(concat!(
 37                env!("CARGO_MANIFEST_DIR"),
 38                "/migrations_llm.sqlite/20240806182921_test_schema.sql"
 39            ));
 40            db.pool
 41                .execute(sea_orm::Statement::from_string(
 42                    db.pool.get_database_backend(),
 43                    sql,
 44                ))
 45                .await
 46                .unwrap();
 47            db
 48        });
 49
 50        db.runtime = Some(runtime);
 51
 52        Self {
 53            db: Some(Arc::new(db)),
 54            connection: None,
 55        }
 56    }
 57
 58    pub fn postgres(background: BackgroundExecutor) -> Self {
 59        static LOCK: Mutex<()> = Mutex::new(());
 60
 61        let _guard = LOCK.lock();
 62        let mut rng = StdRng::from_entropy();
 63        let url = format!(
 64            "postgres://postgres@localhost/zed-llm-test-{}",
 65            rng.gen::<u128>()
 66        );
 67        let runtime = tokio::runtime::Builder::new_current_thread()
 68            .enable_io()
 69            .enable_time()
 70            .build()
 71            .unwrap();
 72
 73        let mut db = runtime.block_on(async {
 74            sqlx::Postgres::create_database(&url)
 75                .await
 76                .expect("failed to create test db");
 77            let mut options = ConnectOptions::new(url);
 78            options
 79                .max_connections(5)
 80                .idle_timeout(Duration::from_secs(0));
 81            let db = LlmDatabase::new(options, Executor::Deterministic(background))
 82                .await
 83                .unwrap();
 84            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
 85            run_database_migrations(db.options(), migrations_path)
 86                .await
 87                .unwrap();
 88            db
 89        });
 90
 91        db.runtime = Some(runtime);
 92
 93        Self {
 94            db: Some(Arc::new(db)),
 95            connection: None,
 96        }
 97    }
 98
 99    pub fn db(&self) -> &Arc<LlmDatabase> {
100        self.db.as_ref().unwrap()
101    }
102}
103
104#[macro_export]
105macro_rules! test_both_llm_dbs {
106    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
107        #[cfg(target_os = "macos")]
108        #[gpui::test]
109        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
110            let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
111            $test_name(test_db.db()).await;
112        }
113
114        #[gpui::test]
115        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
116            let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
117            $test_name(test_db.db()).await;
118        }
119    };
120}
121
122impl Drop for TestLlmDb {
123    fn drop(&mut self) {
124        let db = self.db.take().unwrap();
125        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
126            db.runtime.as_ref().unwrap().block_on(async {
127                use util::ResultExt;
128                let query = "
129                        SELECT pg_terminate_backend(pg_stat_activity.pid)
130                        FROM pg_stat_activity
131                        WHERE
132                            pg_stat_activity.datname = current_database() AND
133                            pid <> pg_backend_pid();
134                    ";
135                db.pool
136                    .execute(sea_orm::Statement::from_string(
137                        db.pool.get_database_backend(),
138                        query,
139                    ))
140                    .await
141                    .log_err();
142                sqlx::Postgres::drop_database(db.options.get_url())
143                    .await
144                    .log_err();
145            })
146        }
147    }
148}