tests.rs

  1mod provider_tests;
  2
  3use gpui::BackgroundExecutor;
  4use parking_lot::Mutex;
  5use rand::prelude::*;
  6use sea_orm::ConnectionTrait;
  7use sqlx::migrate::MigrateDatabase;
  8use std::time::Duration;
  9
 10use crate::migrations::run_database_migrations;
 11
 12use super::*;
 13
 14pub struct TestLlmDb {
 15    pub db: Option<LlmDatabase>,
 16    pub connection: Option<sqlx::AnyConnection>,
 17}
 18
 19impl TestLlmDb {
 20    pub fn postgres(background: BackgroundExecutor) -> Self {
 21        static LOCK: Mutex<()> = Mutex::new(());
 22
 23        let _guard = LOCK.lock();
 24        let mut rng = StdRng::from_entropy();
 25        let url = format!(
 26            "postgres://postgres@localhost/zed-llm-test-{}",
 27            rng.r#gen::<u128>()
 28        );
 29        let runtime = tokio::runtime::Builder::new_current_thread()
 30            .enable_io()
 31            .enable_time()
 32            .build()
 33            .unwrap();
 34
 35        let mut db = runtime.block_on(async {
 36            sqlx::Postgres::create_database(&url)
 37                .await
 38                .expect("failed to create test db");
 39            let mut options = ConnectOptions::new(url);
 40            options
 41                .max_connections(5)
 42                .idle_timeout(Duration::from_secs(0));
 43            let db = LlmDatabase::new(options, Executor::Deterministic(background))
 44                .await
 45                .unwrap();
 46            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
 47            run_database_migrations(db.options(), migrations_path)
 48                .await
 49                .unwrap();
 50            db
 51        });
 52
 53        db.runtime = Some(runtime);
 54
 55        Self {
 56            db: Some(db),
 57            connection: None,
 58        }
 59    }
 60
 61    pub fn db(&mut self) -> &mut LlmDatabase {
 62        self.db.as_mut().unwrap()
 63    }
 64}
 65
 66#[macro_export]
 67macro_rules! test_llm_db {
 68    ($test_name:ident, $postgres_test_name:ident) => {
 69        #[gpui::test]
 70        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
 71            if !cfg!(target_os = "macos") {
 72                return;
 73            }
 74
 75            let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
 76            $test_name(test_db.db()).await;
 77        }
 78    };
 79}
 80
 81impl Drop for TestLlmDb {
 82    fn drop(&mut self) {
 83        let db = self.db.take().unwrap();
 84        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
 85            db.runtime.as_ref().unwrap().block_on(async {
 86                use util::ResultExt;
 87                let query = "
 88                        SELECT pg_terminate_backend(pg_stat_activity.pid)
 89                        FROM pg_stat_activity
 90                        WHERE
 91                            pg_stat_activity.datname = current_database() AND
 92                            pid <> pg_backend_pid();
 93                    ";
 94                db.pool
 95                    .execute(sea_orm::Statement::from_string(
 96                        db.pool.get_database_backend(),
 97                        query,
 98                    ))
 99                    .await
100                    .log_err();
101                sqlx::Postgres::drop_database(db.options.get_url())
102                    .await
103                    .log_err();
104            })
105        }
106    }
107}