tests.rs

  1mod provider_tests;
  2mod subscription_usage_tests;
  3
  4use gpui::BackgroundExecutor;
  5use parking_lot::Mutex;
  6use rand::prelude::*;
  7use sea_orm::ConnectionTrait;
  8use sqlx::migrate::MigrateDatabase;
  9use std::time::Duration;
 10
 11use crate::migrations::run_database_migrations;
 12
 13use super::*;
 14
 15pub struct TestLlmDb {
 16    pub db: Option<LlmDatabase>,
 17    pub connection: Option<sqlx::AnyConnection>,
 18}
 19
 20impl TestLlmDb {
 21    pub fn postgres(background: BackgroundExecutor) -> Self {
 22        static LOCK: Mutex<()> = Mutex::new(());
 23
 24        let _guard = LOCK.lock();
 25        let mut rng = StdRng::from_entropy();
 26        let url = format!(
 27            "postgres://postgres@localhost/zed-llm-test-{}",
 28            rng.r#gen::<u128>()
 29        );
 30        let runtime = tokio::runtime::Builder::new_current_thread()
 31            .enable_io()
 32            .enable_time()
 33            .build()
 34            .unwrap();
 35
 36        let mut db = runtime.block_on(async {
 37            sqlx::Postgres::create_database(&url)
 38                .await
 39                .expect("failed to create test db");
 40            let mut options = ConnectOptions::new(url);
 41            options
 42                .max_connections(5)
 43                .idle_timeout(Duration::from_secs(0));
 44            let db = LlmDatabase::new(options, Executor::Deterministic(background))
 45                .await
 46                .unwrap();
 47            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
 48            run_database_migrations(db.options(), migrations_path)
 49                .await
 50                .unwrap();
 51            db
 52        });
 53
 54        db.runtime = Some(runtime);
 55
 56        Self {
 57            db: Some(db),
 58            connection: None,
 59        }
 60    }
 61
 62    pub fn db(&mut self) -> &mut LlmDatabase {
 63        self.db.as_mut().unwrap()
 64    }
 65}
 66
 67#[macro_export]
 68macro_rules! test_llm_db {
 69    ($test_name:ident, $postgres_test_name:ident) => {
 70        #[gpui::test]
 71        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
 72            if !cfg!(target_os = "macos") {
 73                return;
 74            }
 75
 76            let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
 77            $test_name(test_db.db()).await;
 78        }
 79    };
 80}
 81
 82impl Drop for TestLlmDb {
 83    fn drop(&mut self) {
 84        let db = self.db.take().unwrap();
 85        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
 86            db.runtime.as_ref().unwrap().block_on(async {
 87                use util::ResultExt;
 88                let query = "
 89                        SELECT pg_terminate_backend(pg_stat_activity.pid)
 90                        FROM pg_stat_activity
 91                        WHERE
 92                            pg_stat_activity.datname = current_database() AND
 93                            pid <> pg_backend_pid();
 94                    ";
 95                db.pool
 96                    .execute(sea_orm::Statement::from_string(
 97                        db.pool.get_database_backend(),
 98                        query,
 99                    ))
100                    .await
101                    .log_err();
102                sqlx::Postgres::drop_database(db.options.get_url())
103                    .await
104                    .log_err();
105            })
106        }
107    }
108}