tests.rs

  1mod billing_tests;
  2mod provider_tests;
  3mod usage_tests;
  4
  5use gpui::BackgroundExecutor;
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use sea_orm::ConnectionTrait;
  9use sqlx::migrate::MigrateDatabase;
 10use std::time::Duration;
 11
 12use crate::migrations::run_database_migrations;
 13
 14use super::*;
 15
 16pub struct TestLlmDb {
 17    pub db: Option<LlmDatabase>,
 18    pub connection: Option<sqlx::AnyConnection>,
 19}
 20
 21impl TestLlmDb {
 22    pub fn postgres(background: BackgroundExecutor) -> Self {
 23        static LOCK: Mutex<()> = Mutex::new(());
 24
 25        let _guard = LOCK.lock();
 26        let mut rng = StdRng::from_entropy();
 27        let url = format!(
 28            "postgres://postgres@localhost/zed-llm-test-{}",
 29            rng.r#gen::<u128>()
 30        );
 31        let runtime = tokio::runtime::Builder::new_current_thread()
 32            .enable_io()
 33            .enable_time()
 34            .build()
 35            .unwrap();
 36
 37        let mut db = runtime.block_on(async {
 38            sqlx::Postgres::create_database(&url)
 39                .await
 40                .expect("failed to create test db");
 41            let mut options = ConnectOptions::new(url);
 42            options
 43                .max_connections(5)
 44                .idle_timeout(Duration::from_secs(0));
 45            let db = LlmDatabase::new(options, Executor::Deterministic(background))
 46                .await
 47                .unwrap();
 48            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
 49            run_database_migrations(db.options(), migrations_path)
 50                .await
 51                .unwrap();
 52            db
 53        });
 54
 55        db.runtime = Some(runtime);
 56
 57        Self {
 58            db: Some(db),
 59            connection: None,
 60        }
 61    }
 62
 63    pub fn db(&mut self) -> &mut LlmDatabase {
 64        self.db.as_mut().unwrap()
 65    }
 66}
 67
 68#[macro_export]
 69macro_rules! test_llm_db {
 70    ($test_name:ident, $postgres_test_name:ident) => {
 71        #[gpui::test]
 72        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
 73            if !cfg!(target_os = "macos") {
 74                return;
 75            }
 76
 77            let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
 78            $test_name(test_db.db()).await;
 79        }
 80    };
 81}
 82
 83impl Drop for TestLlmDb {
 84    fn drop(&mut self) {
 85        let db = self.db.take().unwrap();
 86        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
 87            db.runtime.as_ref().unwrap().block_on(async {
 88                use util::ResultExt;
 89                let query = "
 90                        SELECT pg_terminate_backend(pg_stat_activity.pid)
 91                        FROM pg_stat_activity
 92                        WHERE
 93                            pg_stat_activity.datname = current_database() AND
 94                            pid <> pg_backend_pid();
 95                    ";
 96                db.pool
 97                    .execute(sea_orm::Statement::from_string(
 98                        db.pool.get_database_backend(),
 99                        query,
100                    ))
101                    .await
102                    .log_err();
103                sqlx::Postgres::drop_database(db.options.get_url())
104                    .await
105                    .log_err();
106            })
107        }
108    }
109}