Detailed changes
@@ -27,5 +27,7 @@ RUN apt-get update; \
WORKDIR app
COPY --from=builder /app/collab /app/collab
COPY --from=builder /app/crates/collab/migrations /app/migrations
+COPY --from=builder /app/crates/collab/migrations_llm /app/migrations_llm
ENV MIGRATIONS_PATH=/app/migrations
+ENV LLM_DATABASE_MIGRATIONS_PATH=/app/migrations_llm
ENTRYPOINT ["/app/collab"]
@@ -15,6 +15,8 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
BLOB_STORE_REGION = "the-region"
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
SEED_PATH = "crates/collab/seed.default.json"
+LLM_DATABASE_URL = "postgres://postgres@localhost/zed_llm"
+LLM_DATABASE_MAX_CONNECTIONS = 5
LLM_API_SECRET = "llm-secret"
# CLICKHOUSE_URL = ""
@@ -0,0 +1,16 @@
+create table providers (
+ id integer primary key autoincrement,
+ name text not null
+);
+
+create unique index uix_providers_on_name on providers (name);
+
+create table models (
+ id integer primary key autoincrement,
+ provider_id integer not null references providers (id) on delete cascade,
+ name text not null
+);
+
+create unique index uix_models_on_provider_id_name on models (provider_id, name);
+create index ix_models_on_provider_id on models (provider_id);
+create index ix_models_on_name on models (name);
@@ -0,0 +1,16 @@
+create table if not exists providers (
+ id serial primary key,
+ name text not null
+);
+
+create unique index uix_providers_on_name on providers (name);
+
+create table if not exists models (
+ id serial primary key,
+ provider_id integer not null references providers (id) on delete cascade,
+ name text not null
+);
+
+create unique index uix_models_on_provider_id_name on models (provider_id, name);
+create index ix_models_on_provider_id on models (provider_id);
+create index ix_models_on_name on models (name);
@@ -23,17 +23,12 @@ use sea_orm::{
};
use semantic_version::SemanticVersion;
use serde::{Deserialize, Serialize};
-use sqlx::{
- migrate::{Migrate, Migration, MigrationSource},
- Connection,
-};
use std::ops::RangeInclusive;
use std::{
fmt::Write as _,
future::Future,
marker::PhantomData,
ops::{Deref, DerefMut},
- path::Path,
rc::Rc,
sync::Arc,
time::Duration,
@@ -90,54 +85,16 @@ impl Database {
})
}
+ pub fn options(&self) -> &ConnectOptions {
+ &self.options
+ }
+
#[cfg(test)]
pub fn reset(&self) {
self.rooms.clear();
self.projects.clear();
}
- /// Runs the database migrations.
- pub async fn migrate(
- &self,
- migrations_path: &Path,
- ignore_checksum_mismatch: bool,
- ) -> anyhow::Result<Vec<(Migration, Duration)>> {
- let migrations = MigrationSource::resolve(migrations_path)
- .await
- .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
-
- let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
-
- connection.ensure_migrations_table().await?;
- let applied_migrations: HashMap<_, _> = connection
- .list_applied_migrations()
- .await?
- .into_iter()
- .map(|m| (m.version, m))
- .collect();
-
- let mut new_migrations = Vec::new();
- for migration in migrations {
- match applied_migrations.get(&migration.version) {
- Some(applied_migration) => {
- if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
- {
- Err(anyhow!(
- "checksum mismatch for applied migration {}",
- migration.description
- ))?;
- }
- }
- None => {
- let elapsed = connection.apply(&migration).await?;
- new_migrations.push((migration, elapsed));
- }
- }
- }
-
- Ok(new_migrations)
- }
-
/// Transaction runs things in a transaction. If you want to call other methods
/// and pass the transaction around you need to reborrow the transaction at each
/// call site with: `&*tx`.
@@ -453,7 +410,7 @@ fn is_serialization_error(error: &Error) -> bool {
}
/// A handle to a [`DatabaseTransaction`].
-pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
+pub struct TransactionHandle(pub(crate) Arc<Option<DatabaseTransaction>>);
impl Deref for TransactionHandle {
type Target = DatabaseTransaction;
@@ -3,6 +3,7 @@ use rpc::proto;
use sea_orm::{entity::prelude::*, DbErr};
use serde::{Deserialize, Serialize};
+#[macro_export]
macro_rules! id_type {
($name:ident) => {
#[derive(
@@ -11,6 +11,8 @@ mod feature_flag_tests;
mod message_tests;
mod processed_stripe_event_tests;
+use crate::migrations::run_database_migrations;
+
use super::*;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
@@ -91,7 +93,9 @@ impl TestDb {
.await
.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
- db.migrate(Path::new(migrations_path), false).await.unwrap();
+ run_database_migrations(db.options(), migrations_path, false)
+ .await
+ .unwrap();
db.initialize_notification_kinds().await.unwrap();
db
});
@@ -4,6 +4,7 @@ pub mod db;
pub mod env;
pub mod executor;
pub mod llm;
+pub mod migrations;
mod rate_limiter;
pub mod rpc;
pub mod seed;
@@ -150,6 +151,9 @@ pub struct Config {
pub live_kit_server: Option<String>,
pub live_kit_key: Option<String>,
pub live_kit_secret: Option<String>,
+ pub llm_database_url: Option<String>,
+ pub llm_database_max_connections: Option<u32>,
+ pub llm_database_migrations_path: Option<PathBuf>,
pub llm_api_secret: Option<String>,
pub rust_log: Option<String>,
pub log_json: Option<bool>,
@@ -197,6 +201,9 @@ impl Config {
live_kit_server: None,
live_kit_key: None,
live_kit_secret: None,
+ llm_database_url: None,
+ llm_database_max_connections: None,
+ llm_database_migrations_path: None,
llm_api_secret: None,
rust_log: None,
log_json: None,
@@ -1,10 +1,12 @@
mod authorization;
+pub mod db;
mod token;
use crate::api::CloudflareIpCountryHeader;
use crate::llm::authorization::authorize_access_to_language_model;
+use crate::llm::db::LlmDatabase;
use crate::{executor::Executor, Config, Error, Result};
-use anyhow::Context as _;
+use anyhow::{anyhow, Context as _};
use axum::TypedHeader;
use axum::{
body::Body,
@@ -24,11 +26,31 @@ pub use token::*;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
+ pub db: Option<Arc<LlmDatabase>>,
pub http_client: IsahcHttpClient,
}
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
+ // TODO: This is temporary until we have the LLM database stood up.
+ let db = if config.is_development() {
+ let database_url = config
+ .llm_database_url
+ .as_ref()
+ .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
+ let max_connections = config
+ .llm_database_max_connections
+ .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
+
+ let mut db_options = db::ConnectOptions::new(database_url);
+ db_options.max_connections(max_connections);
+ let db = LlmDatabase::new(db_options, executor.clone()).await?;
+
+ Some(Arc::new(db))
+ } else {
+ None
+ };
+
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client = IsahcHttpClient::builder()
.default_header("User-Agent", user_agent)
@@ -38,6 +60,7 @@ impl LlmState {
let this = Self {
config,
executor,
+ db,
http_client,
};
@@ -0,0 +1,118 @@
+mod ids;
+mod queries;
+mod tables;
+
+#[cfg(test)]
+mod tests;
+
+pub use ids::*;
+pub use tables::*;
+
+#[cfg(test)]
+pub use tests::TestLlmDb;
+
+use std::future::Future;
+use std::sync::Arc;
+
+use anyhow::anyhow;
+use sea_orm::prelude::*;
+pub use sea_orm::ConnectOptions;
+use sea_orm::{
+ ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
+};
+
+use crate::db::TransactionHandle;
+use crate::executor::Executor;
+use crate::Result;
+
+/// The database for the LLM service.
+pub struct LlmDatabase {
+ options: ConnectOptions,
+ pool: DatabaseConnection,
+ #[allow(unused)]
+ executor: Executor,
+ #[cfg(test)]
+ runtime: Option<tokio::runtime::Runtime>,
+}
+
+impl LlmDatabase {
+ /// Connects to the database with the given options
+ pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
+ sqlx::any::install_default_drivers();
+ Ok(Self {
+ options: options.clone(),
+ pool: sea_orm::Database::connect(options).await?,
+ executor,
+ #[cfg(test)]
+ runtime: None,
+ })
+ }
+
+ pub fn options(&self) -> &ConnectOptions {
+ &self.options
+ }
+
+ pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
+ where
+ F: Send + Fn(TransactionHandle) -> Fut,
+ Fut: Send + Future<Output = Result<T>>,
+ {
+ let body = async {
+ let (tx, result) = self.with_transaction(&f).await?;
+ match result {
+ Ok(result) => match tx.commit().await.map_err(Into::into) {
+ Ok(()) => return Ok(result),
+ Err(error) => {
+ return Err(error);
+ }
+ },
+ Err(error) => {
+ tx.rollback().await?;
+ return Err(error);
+ }
+ }
+ };
+
+ self.run(body).await
+ }
+
+ async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
+ where
+ F: Send + Fn(TransactionHandle) -> Fut,
+ Fut: Send + Future<Output = Result<T>>,
+ {
+ let tx = self
+ .pool
+ .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
+ .await?;
+
+ let mut tx = Arc::new(Some(tx));
+ let result = f(TransactionHandle(tx.clone())).await;
+ let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
+ return Err(anyhow!(
+ "couldn't complete transaction because it's still in use"
+ ))?;
+ };
+
+ Ok((tx, result))
+ }
+
+ async fn run<F, T>(&self, future: F) -> Result<T>
+ where
+ F: Future<Output = Result<T>>,
+ {
+ #[cfg(test)]
+ {
+ if let Executor::Deterministic(executor) = &self.executor {
+ executor.simulate_random_delay().await;
+ }
+
+ self.runtime.as_ref().unwrap().block_on(future)
+ }
+
+ #[cfg(not(test))]
+ {
+ future.await
+ }
+ }
+}
@@ -0,0 +1,7 @@
+use sea_orm::{entity::prelude::*, DbErr};
+use serde::{Deserialize, Serialize};
+
+use crate::id_type;
+
+id_type!(ProviderId);
+id_type!(ModelId);
@@ -0,0 +1,3 @@
+use super::*;
+
+pub mod providers;
@@ -0,0 +1,67 @@
+use sea_orm::sea_query::OnConflict;
+use sea_orm::QueryOrder;
+
+use super::*;
+
+impl LlmDatabase {
+ pub async fn initialize_providers(&self) -> Result<()> {
+ self.transaction(|tx| async move {
+ let providers_and_models = vec![
+ ("anthropic", "claude-3-5-sonnet"),
+ ("anthropic", "claude-3-opus"),
+ ("anthropic", "claude-3-sonnet"),
+ ("anthropic", "claude-3-haiku"),
+ ];
+
+ for (provider_name, model_name) in providers_and_models {
+ let insert_provider = provider::Entity::insert(provider::ActiveModel {
+ name: ActiveValue::set(provider_name.to_owned()),
+ ..Default::default()
+ })
+ .on_conflict(
+ OnConflict::columns([provider::Column::Name])
+ .update_column(provider::Column::Name)
+ .to_owned(),
+ );
+
+ let provider = if tx.support_returning() {
+ insert_provider.exec_with_returning(&*tx).await?
+ } else {
+ insert_provider.exec_without_returning(&*tx).await?;
+ provider::Entity::find()
+ .filter(provider::Column::Name.eq(provider_name))
+ .one(&*tx)
+ .await?
+ .ok_or_else(|| anyhow!("failed to insert provider"))?
+ };
+
+ model::Entity::insert(model::ActiveModel {
+ provider_id: ActiveValue::set(provider.id),
+ name: ActiveValue::set(model_name.to_owned()),
+ ..Default::default()
+ })
+ .on_conflict(
+ OnConflict::columns([model::Column::ProviderId, model::Column::Name])
+ .update_column(model::Column::Name)
+ .to_owned(),
+ )
+ .exec_without_returning(&*tx)
+ .await?;
+ }
+
+ Ok(())
+ })
+ .await
+ }
+
+ /// Returns the list of LLM providers.
+ pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
+ self.transaction(|tx| async move {
+ Ok(provider::Entity::find()
+ .order_by_asc(provider::Column::Name)
+ .all(&*tx)
+ .await?)
+ })
+ .await
+ }
+}
@@ -0,0 +1,2 @@
+pub mod model;
+pub mod provider;
@@ -0,0 +1,31 @@
+use sea_orm::entity::prelude::*;
+
+use crate::llm::db::{ModelId, ProviderId};
+
+/// An LLM model.
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "models")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: ModelId,
+ pub provider_id: ProviderId,
+ pub name: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::provider::Entity",
+ from = "Column::ProviderId",
+ to = "super::provider::Column::Id"
+ )]
+ Provider,
+}
+
+impl Related<super::provider::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Provider.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -0,0 +1,26 @@
+use sea_orm::entity::prelude::*;
+
+use crate::llm::db::ProviderId;
+
+/// An LLM provider.
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "providers")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: ProviderId,
+ pub name: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(has_many = "super::model::Entity")]
+ Models,
+}
+
+impl Related<super::model::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Models.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -0,0 +1,147 @@
+mod provider_tests;
+
+use gpui::BackgroundExecutor;
+use parking_lot::Mutex;
+use rand::prelude::*;
+use sea_orm::ConnectionTrait;
+use sqlx::migrate::MigrateDatabase;
+use std::sync::Arc;
+use std::time::Duration;
+
+use crate::migrations::run_database_migrations;
+
+use super::*;
+
+pub struct TestLlmDb {
+ pub db: Option<Arc<LlmDatabase>>,
+ pub connection: Option<sqlx::AnyConnection>,
+}
+
+impl TestLlmDb {
+ pub fn sqlite(background: BackgroundExecutor) -> Self {
+ let url = "sqlite::memory:";
+ let runtime = tokio::runtime::Builder::new_current_thread()
+ .enable_io()
+ .enable_time()
+ .build()
+ .unwrap();
+
+ let mut db = runtime.block_on(async {
+ let mut options = ConnectOptions::new(url);
+ options.max_connections(5);
+ let db = LlmDatabase::new(options, Executor::Deterministic(background))
+ .await
+ .unwrap();
+ let sql = include_str!(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/migrations_llm.sqlite/20240806182921_test_schema.sql"
+ ));
+ db.pool
+ .execute(sea_orm::Statement::from_string(
+ db.pool.get_database_backend(),
+ sql,
+ ))
+ .await
+ .unwrap();
+ db
+ });
+
+ db.runtime = Some(runtime);
+
+ Self {
+ db: Some(Arc::new(db)),
+ connection: None,
+ }
+ }
+
+ pub fn postgres(background: BackgroundExecutor) -> Self {
+ static LOCK: Mutex<()> = Mutex::new(());
+
+ let _guard = LOCK.lock();
+ let mut rng = StdRng::from_entropy();
+ let url = format!(
+ "postgres://postgres@localhost/zed-llm-test-{}",
+ rng.gen::<u128>()
+ );
+ let runtime = tokio::runtime::Builder::new_current_thread()
+ .enable_io()
+ .enable_time()
+ .build()
+ .unwrap();
+
+ let mut db = runtime.block_on(async {
+ sqlx::Postgres::create_database(&url)
+ .await
+ .expect("failed to create test db");
+ let mut options = ConnectOptions::new(url);
+ options
+ .max_connections(5)
+ .idle_timeout(Duration::from_secs(0));
+ let db = LlmDatabase::new(options, Executor::Deterministic(background))
+ .await
+ .unwrap();
+ let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
+ run_database_migrations(db.options(), migrations_path, false)
+ .await
+ .unwrap();
+ db
+ });
+
+ db.runtime = Some(runtime);
+
+ Self {
+ db: Some(Arc::new(db)),
+ connection: None,
+ }
+ }
+
+ pub fn db(&self) -> &Arc<LlmDatabase> {
+ self.db.as_ref().unwrap()
+ }
+}
+
+#[macro_export]
+macro_rules! test_both_llm_dbs {
+ ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
+ #[cfg(target_os = "macos")]
+ #[gpui::test]
+ async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
+ let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
+ $test_name(test_db.db()).await;
+ }
+
+ #[gpui::test]
+ async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
+ let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
+ $test_name(test_db.db()).await;
+ }
+ };
+}
+
+impl Drop for TestLlmDb {
+ fn drop(&mut self) {
+ let db = self.db.take().unwrap();
+ if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
+ db.runtime.as_ref().unwrap().block_on(async {
+ use util::ResultExt;
+ let query = "
+ SELECT pg_terminate_backend(pg_stat_activity.pid)
+ FROM pg_stat_activity
+ WHERE
+ pg_stat_activity.datname = current_database() AND
+ pid <> pg_backend_pid();
+ ";
+ db.pool
+ .execute(sea_orm::Statement::from_string(
+ db.pool.get_database_backend(),
+ query,
+ ))
+ .await
+ .log_err();
+ sqlx::Postgres::drop_database(db.options.get_url())
+ .await
+ .log_err();
+ })
+ }
+ }
+}
@@ -0,0 +1,30 @@
+use std::sync::Arc;
+
+use pretty_assertions::assert_eq;
+
+use crate::llm::db::LlmDatabase;
+use crate::test_both_llm_dbs;
+
+test_both_llm_dbs!(
+ test_initialize_providers,
+ test_initialize_providers_postgres,
+ test_initialize_providers_sqlite
+);
+
+async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
+ let initial_providers = db.list_providers().await.unwrap();
+ assert_eq!(initial_providers, vec![]);
+
+ db.initialize_providers().await.unwrap();
+
+ // Do it twice, to make sure the operation is idempotent.
+ db.initialize_providers().await.unwrap();
+
+ let providers = db.list_providers().await.unwrap();
+
+ let provider_names = providers
+ .into_iter()
+ .map(|provider| provider.name)
+ .collect::<Vec<_>>();
+ assert_eq!(provider_names, vec!["anthropic".to_string()]);
+}
@@ -5,6 +5,8 @@ use axum::{
routing::get,
Extension, Router,
};
+use collab::llm::db::LlmDatabase;
+use collab::migrations::run_database_migrations;
use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
use collab::{
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
@@ -45,7 +47,7 @@ async fn main() -> Result<()> {
}
Some("migrate") => {
let config = envy::from_env::<Config>().expect("error loading config");
- run_migrations(&config).await?;
+ setup_app_database(&config).await?;
}
Some("seed") => {
let config = envy::from_env::<Config>().expect("error loading config");
@@ -81,6 +83,8 @@ async fn main() -> Result<()> {
let mut on_shutdown = None;
if mode.is_llm() {
+ setup_llm_database(&config).await?;
+
let state = LlmState::new(config.clone(), Executor::Production).await?;
app = app
@@ -89,7 +93,7 @@ async fn main() -> Result<()> {
}
if mode.is_collab() || mode.is_api() {
- run_migrations(&config).await?;
+ setup_app_database(&config).await?;
let state = AppState::new(config, Executor::Production).await?;
@@ -203,7 +207,7 @@ async fn main() -> Result<()> {
Ok(())
}
-async fn run_migrations(config: &Config) -> Result<()> {
+async fn setup_app_database(config: &Config) -> Result<()> {
let db_options = db::ConnectOptions::new(config.database_url.clone());
let mut db = Database::new(db_options, Executor::Production).await?;
@@ -216,7 +220,7 @@ async fn run_migrations(config: &Config) -> Result<()> {
Path::new(default_migrations)
});
- let migrations = db.migrate(&migrations_path, false).await?;
+ let migrations = run_database_migrations(db.options(), migrations_path, false).await?;
for (migration, duration) in migrations {
log::info!(
"Migrated {} {} {:?}",
@@ -232,7 +236,46 @@ async fn run_migrations(config: &Config) -> Result<()> {
collab::seed::seed(&config, &db, false).await?;
}
- return Ok(());
+ Ok(())
+}
+
+async fn setup_llm_database(config: &Config) -> Result<()> {
+ // TODO: This is temporary until we have the LLM database stood up.
+ if !config.is_development() {
+ return Ok(());
+ }
+
+ let database_url = config
+ .llm_database_url
+ .as_ref()
+ .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
+
+ let db_options = db::ConnectOptions::new(database_url.clone());
+ let db = LlmDatabase::new(db_options, Executor::Production).await?;
+
+ let migrations_path = config
+ .llm_database_migrations_path
+ .as_deref()
+ .unwrap_or_else(|| {
+ #[cfg(feature = "sqlite")]
+ let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
+ #[cfg(not(feature = "sqlite"))]
+ let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
+
+ Path::new(default_migrations)
+ });
+
+ let migrations = run_database_migrations(db.options(), migrations_path, false).await?;
+ for (migration, duration) in migrations {
+ log::info!(
+ "Migrated {} {} {:?}",
+ migration.version,
+ migration.description,
+ duration
+ );
+ }
+
+ Ok(())
}
async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
@@ -0,0 +1,49 @@
+use std::path::Path;
+use std::time::Duration;
+
+use anyhow::{anyhow, Result};
+use collections::HashMap;
+use sea_orm::ConnectOptions;
+use sqlx::migrate::{Migrate, Migration, MigrationSource};
+use sqlx::Connection;
+
+/// Runs the database migrations for the specified database.
+pub async fn run_database_migrations(
+ database_options: &ConnectOptions,
+ migrations_path: impl AsRef<Path>,
+ ignore_checksum_mismatch: bool,
+) -> Result<Vec<(Migration, Duration)>> {
+ let migrations = MigrationSource::resolve(migrations_path.as_ref())
+ .await
+ .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
+
+ let mut connection = sqlx::AnyConnection::connect(database_options.get_url()).await?;
+
+ connection.ensure_migrations_table().await?;
+ let applied_migrations: HashMap<_, _> = connection
+ .list_applied_migrations()
+ .await?
+ .into_iter()
+ .map(|migration| (migration.version, migration))
+ .collect();
+
+ let mut new_migrations = Vec::new();
+ for migration in migrations {
+ match applied_migrations.get(&migration.version) {
+ Some(applied_migration) => {
+ if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch {
+ Err(anyhow!(
+ "checksum mismatch for applied migration {}",
+ migration.description
+ ))?;
+ }
+ }
+ None => {
+ let elapsed = connection.apply(&migration).await?;
+ new_migrations.push((migration, elapsed));
+ }
+ }
+ }
+
+ Ok(new_migrations)
+}
@@ -651,6 +651,9 @@ impl TestServer {
live_kit_server: None,
live_kit_key: None,
live_kit_secret: None,
+ llm_database_url: None,
+ llm_database_max_connections: None,
+ llm_database_migrations_path: None,
llm_api_secret: None,
rust_log: None,
log_json: None,
@@ -1 +1,2 @@
create database zed;
+create database zed_llm;
@@ -1,5 +1,7 @@
#!/usr/bin/env bash
+set -e
+
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
echo "Linux dependencies..."
script/linux
@@ -8,5 +10,17 @@ else
which foreman > /dev/null || brew install foreman
fi
-echo "creating database..."
-script/sqlx database create
+# Install sqlx-cli if needed
+if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
+ echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
+ cargo install sqlx-cli --version 0.5.7
+fi
+
+cd crates/collab
+
+# Export contents of .env.toml
+eval "$(cargo run --bin dotenv)"
+
+echo "creating databases..."
+sqlx database create --database-url "$DATABASE_URL"
+sqlx database create --database-url "$LLM_DATABASE_URL"
@@ -1,2 +1,3 @@
psql -c "DROP DATABASE zed (FORCE);"
+psql -c "DROP DATABASE zed_llm (FORCE);"
script/bootstrap
@@ -1,17 +0,0 @@
-#!/bin/bash
-
-set -e
-
-# Install sqlx-cli if needed
-if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
- echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
- cargo install sqlx-cli --version 0.5.7
-fi
-
-cd crates/collab
-
-# Export contents of .env.toml
-eval "$(cargo run --bin dotenv)"
-
-# Run sqlx command
-sqlx $@