collab: Setup database for LLM service (#15882)

Marshall Bowers created

This PR puts the initial infrastructure for the LLM service's database
in place.

The LLM service will be using a separate Postgres database, with its own
set of migrations.

Currently we only connect to the database in development, as we don't
yet have the database setup for the staging/production environments.

Release Notes:

- N/A

Change summary

Dockerfile                                                                  |   2 
crates/collab/.env.toml                                                     |   2 
crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql          |  16 
crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql |  16 
crates/collab/src/db.rs                                                     |  53 
crates/collab/src/db/ids.rs                                                 |   1 
crates/collab/src/db/tests.rs                                               |   6 
crates/collab/src/lib.rs                                                    |   7 
crates/collab/src/llm.rs                                                    |  25 
crates/collab/src/llm/db.rs                                                 | 118 
crates/collab/src/llm/db/ids.rs                                             |   7 
crates/collab/src/llm/db/queries.rs                                         |   3 
crates/collab/src/llm/db/queries/providers.rs                               |  67 
crates/collab/src/llm/db/tables.rs                                          |   2 
crates/collab/src/llm/db/tables/model.rs                                    |  31 
crates/collab/src/llm/db/tables/provider.rs                                 |  26 
crates/collab/src/llm/db/tests.rs                                           | 147 
crates/collab/src/llm/db/tests/provider_tests.rs                            |  30 
crates/collab/src/main.rs                                                   |  53 
crates/collab/src/migrations.rs                                             |  49 
crates/collab/src/tests/test_server.rs                                      |   3 
docker-compose.sql                                                          |   1 
script/bootstrap                                                            |  18 
script/reset_db                                                             |   1 
script/sqlx                                                                 |  17 
25 files changed, 627 insertions(+), 74 deletions(-)

Detailed changes

Dockerfile 🔗

@@ -27,5 +27,7 @@ RUN apt-get update; \
 WORKDIR app
 COPY --from=builder /app/collab /app/collab
 COPY --from=builder /app/crates/collab/migrations /app/migrations
+COPY --from=builder /app/crates/collab/migrations_llm /app/migrations_llm
 ENV MIGRATIONS_PATH=/app/migrations
+ENV LLM_DATABASE_MIGRATIONS_PATH=/app/migrations_llm
 ENTRYPOINT ["/app/collab"]

crates/collab/.env.toml 🔗

@@ -15,6 +15,8 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
 BLOB_STORE_REGION = "the-region"
 ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
 SEED_PATH = "crates/collab/seed.default.json"
+LLM_DATABASE_URL = "postgres://postgres@localhost/zed_llm"
+LLM_DATABASE_MAX_CONNECTIONS = 5
 LLM_API_SECRET = "llm-secret"
 
 # CLICKHOUSE_URL = ""

crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql 🔗

@@ -0,0 +1,16 @@
+create table providers (
+    id integer primary key autoincrement,
+    name text not null
+);
+
+create unique index uix_providers_on_name on providers (name);
+
+create table models (
+    id integer primary key autoincrement,
+    provider_id integer not null references providers (id) on delete cascade,
+    name text not null
+);
+
+create unique index uix_models_on_provider_id_name on models (provider_id, name);
+create index ix_models_on_provider_id on models (provider_id);
+create index ix_models_on_name on models (name);

crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql 🔗

@@ -0,0 +1,16 @@
+create table if not exists providers (
+    id serial primary key,
+    name text not null
+);
+
+create unique index uix_providers_on_name on providers (name);
+
+create table if not exists models (
+    id serial primary key,
+    provider_id integer not null references providers (id) on delete cascade,
+    name text not null
+);
+
+create unique index uix_models_on_provider_id_name on models (provider_id, name);
+create index ix_models_on_provider_id on models (provider_id);
+create index ix_models_on_name on models (name);

crates/collab/src/db.rs 🔗

@@ -23,17 +23,12 @@ use sea_orm::{
 };
 use semantic_version::SemanticVersion;
 use serde::{Deserialize, Serialize};
-use sqlx::{
-    migrate::{Migrate, Migration, MigrationSource},
-    Connection,
-};
 use std::ops::RangeInclusive;
 use std::{
     fmt::Write as _,
     future::Future,
     marker::PhantomData,
     ops::{Deref, DerefMut},
-    path::Path,
     rc::Rc,
     sync::Arc,
     time::Duration,
@@ -90,54 +85,16 @@ impl Database {
         })
     }
 
+    pub fn options(&self) -> &ConnectOptions {
+        &self.options
+    }
+
     #[cfg(test)]
     pub fn reset(&self) {
         self.rooms.clear();
         self.projects.clear();
     }
 
-    /// Runs the database migrations.
-    pub async fn migrate(
-        &self,
-        migrations_path: &Path,
-        ignore_checksum_mismatch: bool,
-    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
-        let migrations = MigrationSource::resolve(migrations_path)
-            .await
-            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
-
-        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
-
-        connection.ensure_migrations_table().await?;
-        let applied_migrations: HashMap<_, _> = connection
-            .list_applied_migrations()
-            .await?
-            .into_iter()
-            .map(|m| (m.version, m))
-            .collect();
-
-        let mut new_migrations = Vec::new();
-        for migration in migrations {
-            match applied_migrations.get(&migration.version) {
-                Some(applied_migration) => {
-                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
-                    {
-                        Err(anyhow!(
-                            "checksum mismatch for applied migration {}",
-                            migration.description
-                        ))?;
-                    }
-                }
-                None => {
-                    let elapsed = connection.apply(&migration).await?;
-                    new_migrations.push((migration, elapsed));
-                }
-            }
-        }
-
-        Ok(new_migrations)
-    }
-
     /// Transaction runs things in a transaction. If you want to call other methods
     /// and pass the transaction around you need to reborrow the transaction at each
     /// call site with: `&*tx`.
@@ -453,7 +410,7 @@ fn is_serialization_error(error: &Error) -> bool {
 }
 
 /// A handle to a [`DatabaseTransaction`].
-pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
+pub struct TransactionHandle(pub(crate) Arc<Option<DatabaseTransaction>>);
 
 impl Deref for TransactionHandle {
     type Target = DatabaseTransaction;

crates/collab/src/db/ids.rs 🔗

@@ -3,6 +3,7 @@ use rpc::proto;
 use sea_orm::{entity::prelude::*, DbErr};
 use serde::{Deserialize, Serialize};
 
+#[macro_export]
 macro_rules! id_type {
     ($name:ident) => {
         #[derive(

crates/collab/src/db/tests.rs 🔗

@@ -11,6 +11,8 @@ mod feature_flag_tests;
 mod message_tests;
 mod processed_stripe_event_tests;
 
+use crate::migrations::run_database_migrations;
+
 use super::*;
 use gpui::BackgroundExecutor;
 use parking_lot::Mutex;
@@ -91,7 +93,9 @@ impl TestDb {
                 .await
                 .unwrap();
             let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
-            db.migrate(Path::new(migrations_path), false).await.unwrap();
+            run_database_migrations(db.options(), migrations_path, false)
+                .await
+                .unwrap();
             db.initialize_notification_kinds().await.unwrap();
             db
         });

crates/collab/src/lib.rs 🔗

@@ -4,6 +4,7 @@ pub mod db;
 pub mod env;
 pub mod executor;
 pub mod llm;
+pub mod migrations;
 mod rate_limiter;
 pub mod rpc;
 pub mod seed;
@@ -150,6 +151,9 @@ pub struct Config {
     pub live_kit_server: Option<String>,
     pub live_kit_key: Option<String>,
     pub live_kit_secret: Option<String>,
+    pub llm_database_url: Option<String>,
+    pub llm_database_max_connections: Option<u32>,
+    pub llm_database_migrations_path: Option<PathBuf>,
     pub llm_api_secret: Option<String>,
     pub rust_log: Option<String>,
     pub log_json: Option<bool>,
@@ -197,6 +201,9 @@ impl Config {
             live_kit_server: None,
             live_kit_key: None,
             live_kit_secret: None,
+            llm_database_url: None,
+            llm_database_max_connections: None,
+            llm_database_migrations_path: None,
             llm_api_secret: None,
             rust_log: None,
             log_json: None,

crates/collab/src/llm.rs 🔗

@@ -1,10 +1,12 @@
 mod authorization;
+pub mod db;
 mod token;
 
 use crate::api::CloudflareIpCountryHeader;
 use crate::llm::authorization::authorize_access_to_language_model;
+use crate::llm::db::LlmDatabase;
 use crate::{executor::Executor, Config, Error, Result};
-use anyhow::Context as _;
+use anyhow::{anyhow, Context as _};
 use axum::TypedHeader;
 use axum::{
     body::Body,
@@ -24,11 +26,31 @@ pub use token::*;
 pub struct LlmState {
     pub config: Config,
     pub executor: Executor,
+    pub db: Option<Arc<LlmDatabase>>,
     pub http_client: IsahcHttpClient,
 }
 
 impl LlmState {
     pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
+        // TODO: This is temporary until we have the LLM database stood up.
+        let db = if config.is_development() {
+            let database_url = config
+                .llm_database_url
+                .as_ref()
+                .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
+            let max_connections = config
+                .llm_database_max_connections
+                .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
+
+            let mut db_options = db::ConnectOptions::new(database_url);
+            db_options.max_connections(max_connections);
+            let db = LlmDatabase::new(db_options, executor.clone()).await?;
+
+            Some(Arc::new(db))
+        } else {
+            None
+        };
+
         let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
         let http_client = IsahcHttpClient::builder()
             .default_header("User-Agent", user_agent)
@@ -38,6 +60,7 @@ impl LlmState {
         let this = Self {
             config,
             executor,
+            db,
             http_client,
         };
 

crates/collab/src/llm/db.rs 🔗

@@ -0,0 +1,118 @@
+mod ids;
+mod queries;
+mod tables;
+
+#[cfg(test)]
+mod tests;
+
+pub use ids::*;
+pub use tables::*;
+
+#[cfg(test)]
+pub use tests::TestLlmDb;
+
+use std::future::Future;
+use std::sync::Arc;
+
+use anyhow::anyhow;
+use sea_orm::prelude::*;
+pub use sea_orm::ConnectOptions;
+use sea_orm::{
+    ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
+};
+
+use crate::db::TransactionHandle;
+use crate::executor::Executor;
+use crate::Result;
+
+/// The database for the LLM service.
+pub struct LlmDatabase {
+    options: ConnectOptions,
+    pool: DatabaseConnection,
+    #[allow(unused)]
+    executor: Executor,
+    #[cfg(test)]
+    runtime: Option<tokio::runtime::Runtime>,
+}
+
+impl LlmDatabase {
+    /// Connects to the database with the given options
+    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
+        sqlx::any::install_default_drivers();
+        Ok(Self {
+            options: options.clone(),
+            pool: sea_orm::Database::connect(options).await?,
+            executor,
+            #[cfg(test)]
+            runtime: None,
+        })
+    }
+
+    pub fn options(&self) -> &ConnectOptions {
+        &self.options
+    }
+
+    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
+    where
+        F: Send + Fn(TransactionHandle) -> Fut,
+        Fut: Send + Future<Output = Result<T>>,
+    {
+        let body = async {
+            let (tx, result) = self.with_transaction(&f).await?;
+            match result {
+                Ok(result) => match tx.commit().await.map_err(Into::into) {
+                    Ok(()) => return Ok(result),
+                    Err(error) => {
+                        return Err(error);
+                    }
+                },
+                Err(error) => {
+                    tx.rollback().await?;
+                    return Err(error);
+                }
+            }
+        };
+
+        self.run(body).await
+    }
+
+    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
+    where
+        F: Send + Fn(TransactionHandle) -> Fut,
+        Fut: Send + Future<Output = Result<T>>,
+    {
+        let tx = self
+            .pool
+            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
+            .await?;
+
+        let mut tx = Arc::new(Some(tx));
+        let result = f(TransactionHandle(tx.clone())).await;
+        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
+            return Err(anyhow!(
+                "couldn't complete transaction because it's still in use"
+            ))?;
+        };
+
+        Ok((tx, result))
+    }
+
+    async fn run<F, T>(&self, future: F) -> Result<T>
+    where
+        F: Future<Output = Result<T>>,
+    {
+        #[cfg(test)]
+        {
+            if let Executor::Deterministic(executor) = &self.executor {
+                executor.simulate_random_delay().await;
+            }
+
+            self.runtime.as_ref().unwrap().block_on(future)
+        }
+
+        #[cfg(not(test))]
+        {
+            future.await
+        }
+    }
+}

crates/collab/src/llm/db/ids.rs 🔗

@@ -0,0 +1,7 @@
+use sea_orm::{entity::prelude::*, DbErr};
+use serde::{Deserialize, Serialize};
+
+use crate::id_type;
+
+id_type!(ProviderId);
+id_type!(ModelId);

crates/collab/src/llm/db/queries/providers.rs 🔗

@@ -0,0 +1,67 @@
+use sea_orm::sea_query::OnConflict;
+use sea_orm::QueryOrder;
+
+use super::*;
+
+impl LlmDatabase {
+    pub async fn initialize_providers(&self) -> Result<()> {
+        self.transaction(|tx| async move {
+            let providers_and_models = vec![
+                ("anthropic", "claude-3-5-sonnet"),
+                ("anthropic", "claude-3-opus"),
+                ("anthropic", "claude-3-sonnet"),
+                ("anthropic", "claude-3-haiku"),
+            ];
+
+            for (provider_name, model_name) in providers_and_models {
+                let insert_provider = provider::Entity::insert(provider::ActiveModel {
+                    name: ActiveValue::set(provider_name.to_owned()),
+                    ..Default::default()
+                })
+                .on_conflict(
+                    OnConflict::columns([provider::Column::Name])
+                        .update_column(provider::Column::Name)
+                        .to_owned(),
+                );
+
+                let provider = if tx.support_returning() {
+                    insert_provider.exec_with_returning(&*tx).await?
+                } else {
+                    insert_provider.exec_without_returning(&*tx).await?;
+                    provider::Entity::find()
+                        .filter(provider::Column::Name.eq(provider_name))
+                        .one(&*tx)
+                        .await?
+                        .ok_or_else(|| anyhow!("failed to insert provider"))?
+                };
+
+                model::Entity::insert(model::ActiveModel {
+                    provider_id: ActiveValue::set(provider.id),
+                    name: ActiveValue::set(model_name.to_owned()),
+                    ..Default::default()
+                })
+                .on_conflict(
+                    OnConflict::columns([model::Column::ProviderId, model::Column::Name])
+                        .update_column(model::Column::Name)
+                        .to_owned(),
+                )
+                .exec_without_returning(&*tx)
+                .await?;
+            }
+
+            Ok(())
+        })
+        .await
+    }
+
+    /// Returns the list of LLM providers.
+    pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
+        self.transaction(|tx| async move {
+            Ok(provider::Entity::find()
+                .order_by_asc(provider::Column::Name)
+                .all(&*tx)
+                .await?)
+        })
+        .await
+    }
+}

crates/collab/src/llm/db/tables/model.rs 🔗

@@ -0,0 +1,31 @@
+use sea_orm::entity::prelude::*;
+
+use crate::llm::db::{ModelId, ProviderId};
+
+/// An LLM model.
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "models")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: ModelId,
+    pub provider_id: ProviderId,
+    pub name: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::provider::Entity",
+        from = "Column::ProviderId",
+        to = "super::provider::Column::Id"
+    )]
+    Provider,
+}
+
+impl Related<super::provider::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Provider.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tables/provider.rs 🔗

@@ -0,0 +1,26 @@
+use sea_orm::entity::prelude::*;
+
+use crate::llm::db::ProviderId;
+
+/// An LLM provider.
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "providers")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: ProviderId,
+    pub name: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(has_many = "super::model::Entity")]
+    Models,
+}
+
+impl Related<super::model::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Models.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tests.rs 🔗

@@ -0,0 +1,147 @@
+mod provider_tests;
+
+use gpui::BackgroundExecutor;
+use parking_lot::Mutex;
+use rand::prelude::*;
+use sea_orm::ConnectionTrait;
+use sqlx::migrate::MigrateDatabase;
+use std::sync::Arc;
+use std::time::Duration;
+
+use crate::migrations::run_database_migrations;
+
+use super::*;
+
+pub struct TestLlmDb {
+    pub db: Option<Arc<LlmDatabase>>,
+    pub connection: Option<sqlx::AnyConnection>,
+}
+
+impl TestLlmDb {
+    pub fn sqlite(background: BackgroundExecutor) -> Self {
+        let url = "sqlite::memory:";
+        let runtime = tokio::runtime::Builder::new_current_thread()
+            .enable_io()
+            .enable_time()
+            .build()
+            .unwrap();
+
+        let mut db = runtime.block_on(async {
+            let mut options = ConnectOptions::new(url);
+            options.max_connections(5);
+            let db = LlmDatabase::new(options, Executor::Deterministic(background))
+                .await
+                .unwrap();
+            let sql = include_str!(concat!(
+                env!("CARGO_MANIFEST_DIR"),
+                "/migrations_llm.sqlite/20240806182921_test_schema.sql"
+            ));
+            db.pool
+                .execute(sea_orm::Statement::from_string(
+                    db.pool.get_database_backend(),
+                    sql,
+                ))
+                .await
+                .unwrap();
+            db
+        });
+
+        db.runtime = Some(runtime);
+
+        Self {
+            db: Some(Arc::new(db)),
+            connection: None,
+        }
+    }
+
+    pub fn postgres(background: BackgroundExecutor) -> Self {
+        static LOCK: Mutex<()> = Mutex::new(());
+
+        let _guard = LOCK.lock();
+        let mut rng = StdRng::from_entropy();
+        let url = format!(
+            "postgres://postgres@localhost/zed-llm-test-{}",
+            rng.gen::<u128>()
+        );
+        let runtime = tokio::runtime::Builder::new_current_thread()
+            .enable_io()
+            .enable_time()
+            .build()
+            .unwrap();
+
+        let mut db = runtime.block_on(async {
+            sqlx::Postgres::create_database(&url)
+                .await
+                .expect("failed to create test db");
+            let mut options = ConnectOptions::new(url);
+            options
+                .max_connections(5)
+                .idle_timeout(Duration::from_secs(0));
+            let db = LlmDatabase::new(options, Executor::Deterministic(background))
+                .await
+                .unwrap();
+            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
+            run_database_migrations(db.options(), migrations_path, false)
+                .await
+                .unwrap();
+            db
+        });
+
+        db.runtime = Some(runtime);
+
+        Self {
+            db: Some(Arc::new(db)),
+            connection: None,
+        }
+    }
+
+    pub fn db(&self) -> &Arc<LlmDatabase> {
+        self.db.as_ref().unwrap()
+    }
+}
+
+#[macro_export]
+macro_rules! test_both_llm_dbs {
+    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
+        #[cfg(target_os = "macos")]
+        #[gpui::test]
+        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
+            let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
+            $test_name(test_db.db()).await;
+        }
+
+        #[gpui::test]
+        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
+            let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
+            $test_name(test_db.db()).await;
+        }
+    };
+}
+
+impl Drop for TestLlmDb {
+    fn drop(&mut self) {
+        let db = self.db.take().unwrap();
+        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
+            db.runtime.as_ref().unwrap().block_on(async {
+                use util::ResultExt;
+                let query = "
+                        SELECT pg_terminate_backend(pg_stat_activity.pid)
+                        FROM pg_stat_activity
+                        WHERE
+                            pg_stat_activity.datname = current_database() AND
+                            pid <> pg_backend_pid();
+                    ";
+                db.pool
+                    .execute(sea_orm::Statement::from_string(
+                        db.pool.get_database_backend(),
+                        query,
+                    ))
+                    .await
+                    .log_err();
+                sqlx::Postgres::drop_database(db.options.get_url())
+                    .await
+                    .log_err();
+            })
+        }
+    }
+}

crates/collab/src/llm/db/tests/provider_tests.rs 🔗

@@ -0,0 +1,30 @@
+use std::sync::Arc;
+
+use pretty_assertions::assert_eq;
+
+use crate::llm::db::LlmDatabase;
+use crate::test_both_llm_dbs;
+
+test_both_llm_dbs!(
+    test_initialize_providers,
+    test_initialize_providers_postgres,
+    test_initialize_providers_sqlite
+);
+
+async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
+    let initial_providers = db.list_providers().await.unwrap();
+    assert_eq!(initial_providers, vec![]);
+
+    db.initialize_providers().await.unwrap();
+
+    // Do it twice, to make sure the operation is idempotent.
+    db.initialize_providers().await.unwrap();
+
+    let providers = db.list_providers().await.unwrap();
+
+    let provider_names = providers
+        .into_iter()
+        .map(|provider| provider.name)
+        .collect::<Vec<_>>();
+    assert_eq!(provider_names, vec!["anthropic".to_string()]);
+}

crates/collab/src/main.rs 🔗

@@ -5,6 +5,8 @@ use axum::{
     routing::get,
     Extension, Router,
 };
+use collab::llm::db::LlmDatabase;
+use collab::migrations::run_database_migrations;
 use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
 use collab::{
     api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
@@ -45,7 +47,7 @@ async fn main() -> Result<()> {
         }
         Some("migrate") => {
             let config = envy::from_env::<Config>().expect("error loading config");
-            run_migrations(&config).await?;
+            setup_app_database(&config).await?;
         }
         Some("seed") => {
             let config = envy::from_env::<Config>().expect("error loading config");
@@ -81,6 +83,8 @@ async fn main() -> Result<()> {
             let mut on_shutdown = None;
 
             if mode.is_llm() {
+                setup_llm_database(&config).await?;
+
                 let state = LlmState::new(config.clone(), Executor::Production).await?;
 
                 app = app
@@ -89,7 +93,7 @@ async fn main() -> Result<()> {
             }
 
             if mode.is_collab() || mode.is_api() {
-                run_migrations(&config).await?;
+                setup_app_database(&config).await?;
 
                 let state = AppState::new(config, Executor::Production).await?;
 
@@ -203,7 +207,7 @@ async fn main() -> Result<()> {
     Ok(())
 }
 
-async fn run_migrations(config: &Config) -> Result<()> {
+async fn setup_app_database(config: &Config) -> Result<()> {
     let db_options = db::ConnectOptions::new(config.database_url.clone());
     let mut db = Database::new(db_options, Executor::Production).await?;
 
@@ -216,7 +220,7 @@ async fn run_migrations(config: &Config) -> Result<()> {
         Path::new(default_migrations)
     });
 
-    let migrations = db.migrate(&migrations_path, false).await?;
+    let migrations = run_database_migrations(db.options(), migrations_path, false).await?;
     for (migration, duration) in migrations {
         log::info!(
             "Migrated {} {} {:?}",
@@ -232,7 +236,46 @@ async fn run_migrations(config: &Config) -> Result<()> {
         collab::seed::seed(&config, &db, false).await?;
     }
 
-    return Ok(());
+    Ok(())
+}
+
+async fn setup_llm_database(config: &Config) -> Result<()> {
+    // TODO: This is temporary until we have the LLM database stood up.
+    if !config.is_development() {
+        return Ok(());
+    }
+
+    let database_url = config
+        .llm_database_url
+        .as_ref()
+        .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
+
+    let db_options = db::ConnectOptions::new(database_url.clone());
+    let db = LlmDatabase::new(db_options, Executor::Production).await?;
+
+    let migrations_path = config
+        .llm_database_migrations_path
+        .as_deref()
+        .unwrap_or_else(|| {
+            #[cfg(feature = "sqlite")]
+            let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
+            #[cfg(not(feature = "sqlite"))]
+            let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
+
+            Path::new(default_migrations)
+        });
+
+    let migrations = run_database_migrations(db.options(), migrations_path, false).await?;
+    for (migration, duration) in migrations {
+        log::info!(
+            "Migrated {} {} {:?}",
+            migration.version,
+            migration.description,
+            duration
+        );
+    }
+
+    Ok(())
 }
 
 async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {

crates/collab/src/migrations.rs 🔗

@@ -0,0 +1,49 @@
+use std::path::Path;
+use std::time::Duration;
+
+use anyhow::{anyhow, Result};
+use collections::HashMap;
+use sea_orm::ConnectOptions;
+use sqlx::migrate::{Migrate, Migration, MigrationSource};
+use sqlx::Connection;
+
+/// Runs the database migrations for the specified database.
+pub async fn run_database_migrations(
+    database_options: &ConnectOptions,
+    migrations_path: impl AsRef<Path>,
+    ignore_checksum_mismatch: bool,
+) -> Result<Vec<(Migration, Duration)>> {
+    let migrations = MigrationSource::resolve(migrations_path.as_ref())
+        .await
+        .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
+
+    let mut connection = sqlx::AnyConnection::connect(database_options.get_url()).await?;
+
+    connection.ensure_migrations_table().await?;
+    let applied_migrations: HashMap<_, _> = connection
+        .list_applied_migrations()
+        .await?
+        .into_iter()
+        .map(|migration| (migration.version, migration))
+        .collect();
+
+    let mut new_migrations = Vec::new();
+    for migration in migrations {
+        match applied_migrations.get(&migration.version) {
+            Some(applied_migration) => {
+                if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch {
+                    Err(anyhow!(
+                        "checksum mismatch for applied migration {}",
+                        migration.description
+                    ))?;
+                }
+            }
+            None => {
+                let elapsed = connection.apply(&migration).await?;
+                new_migrations.push((migration, elapsed));
+            }
+        }
+    }
+
+    Ok(new_migrations)
+}

crates/collab/src/tests/test_server.rs 🔗

@@ -651,6 +651,9 @@ impl TestServer {
                 live_kit_server: None,
                 live_kit_key: None,
                 live_kit_secret: None,
+                llm_database_url: None,
+                llm_database_max_connections: None,
+                llm_database_migrations_path: None,
                 llm_api_secret: None,
                 rust_log: None,
                 log_json: None,

script/bootstrap 🔗

@@ -1,5 +1,7 @@
 #!/usr/bin/env bash
 
+set -e
+
 if [[ "$OSTYPE" == "linux-gnu"* ]]; then
   echo "Linux dependencies..."
   script/linux
@@ -8,5 +10,17 @@ else
   which foreman > /dev/null || brew install foreman
 fi
 
-echo "creating database..."
-script/sqlx database create
+# Install sqlx-cli if needed
+if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
+    echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
+    cargo install sqlx-cli --version 0.5.7
+fi
+
+cd crates/collab
+
+# Export contents of .env.toml
+eval "$(cargo run --bin dotenv)"
+
+echo "creating databases..."
+sqlx database create --database-url "$DATABASE_URL"
+sqlx database create --database-url "$LLM_DATABASE_URL"

script/reset_db 🔗

@@ -1,2 +1,3 @@
 psql -c "DROP DATABASE zed (FORCE);"
+psql -c "DROP DATABASE zed_llm (FORCE);"
 script/bootstrap

script/sqlx 🔗

@@ -1,17 +0,0 @@
-#!/bin/bash
-
-set -e
-
-# Install sqlx-cli if needed
-if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
-    echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
-    cargo install sqlx-cli --version 0.5.7
-fi
-
-cd crates/collab
-
-# Export contents of .env.toml
-eval "$(cargo run --bin dotenv)"
-
-# Run sqlx command
-sqlx $@