Cargo.lock 🔗
@@ -3270,7 +3270,6 @@ dependencies = [
"chrono",
"client",
"clock",
- "cloud_llm_client",
"collab_ui",
"collections",
"command_palette_hooks",
Marshall Bowers created
This PR removes a bunch of unused database code related to billing, as
we no longer need it.
Release Notes:
- N/A
Cargo.lock | 1
crates/collab/Cargo.toml | 1
crates/collab/src/db.rs | 5
crates/collab/src/db/ids.rs | 3
crates/collab/src/db/queries.rs | 4
crates/collab/src/db/queries/billing_customers.rs | 100 ----
crates/collab/src/db/queries/billing_preferences.rs | 17
crates/collab/src/db/queries/billing_subscriptions.rs | 158 ------
crates/collab/src/db/queries/processed_stripe_events.rs | 69 ---
crates/collab/src/db/tables.rs | 4
crates/collab/src/db/tables/billing_customer.rs | 41 -
crates/collab/src/db/tables/billing_preference.rs | 32 -
crates/collab/src/db/tables/billing_subscription.rs | 161 -------
crates/collab/src/db/tables/processed_stripe_event.rs | 16
crates/collab/src/db/tables/user.rs | 8
crates/collab/src/db/tests.rs | 1
crates/collab/src/db/tests/processed_stripe_event_tests.rs | 38 -
crates/collab/src/lib.rs | 17
crates/collab/src/llm/db.rs | 74 ---
crates/collab/src/llm/db/ids.rs | 11
crates/collab/src/llm/db/queries.rs | 5
crates/collab/src/llm/db/queries/providers.rs | 134 -----
crates/collab/src/llm/db/queries/subscription_usages.rs | 38 -
crates/collab/src/llm/db/queries/usages.rs | 44 -
crates/collab/src/llm/db/seed.rs | 45 -
crates/collab/src/llm/db/tables.rs | 6
crates/collab/src/llm/db/tables/model.rs | 48 --
crates/collab/src/llm/db/tables/provider.rs | 25 -
crates/collab/src/llm/db/tables/subscription_usage.rs | 22
crates/collab/src/llm/db/tables/subscription_usage_meter.rs | 55 --
crates/collab/src/llm/db/tables/usage.rs | 52 --
crates/collab/src/llm/db/tables/usage_measure.rs | 36 -
crates/collab/src/llm/db/tests.rs | 107 ----
crates/collab/src/llm/db/tests/provider_tests.rs | 31 -
crates/collab/src/main.rs | 10
crates/collab/src/tests/test_server.rs | 1
36 files changed, 1 insertion(+), 1,419 deletions(-)
@@ -3270,7 +3270,6 @@ dependencies = [
"chrono",
"client",
"clock",
- "cloud_llm_client",
"collab_ui",
"collections",
"command_palette_hooks",
@@ -29,7 +29,6 @@ axum-extra = { version = "0.4", features = ["erased-json"] }
base64.workspace = true
chrono.workspace = true
clock.workspace = true
-cloud_llm_client.workspace = true
collections.workspace = true
dashmap.workspace = true
envy = "0.4.2"
@@ -41,12 +41,7 @@ use worktree_settings_file::LocalSettingsKind;
pub use tests::TestDb;
pub use ids::*;
-pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
-pub use queries::billing_subscriptions::{
- CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
-};
pub use queries::contributors::ContributorSelector;
-pub use queries::processed_stripe_events::CreateProcessedStripeEventParams;
pub use sea_orm::ConnectOptions;
pub use tables::user::Model as User;
pub use tables::*;
@@ -70,9 +70,6 @@ macro_rules! id_type {
}
id_type!(AccessTokenId);
-id_type!(BillingCustomerId);
-id_type!(BillingSubscriptionId);
-id_type!(BillingPreferencesId);
id_type!(BufferId);
id_type!(ChannelBufferCollaboratorId);
id_type!(ChannelChatParticipantId);
@@ -1,9 +1,6 @@
use super::*;
pub mod access_tokens;
-pub mod billing_customers;
-pub mod billing_preferences;
-pub mod billing_subscriptions;
pub mod buffers;
pub mod channels;
pub mod contacts;
@@ -12,7 +9,6 @@ pub mod embeddings;
pub mod extensions;
pub mod messages;
pub mod notifications;
-pub mod processed_stripe_events;
pub mod projects;
pub mod rooms;
pub mod servers;
@@ -1,100 +0,0 @@
-use super::*;
-
-#[derive(Debug)]
-pub struct CreateBillingCustomerParams {
- pub user_id: UserId,
- pub stripe_customer_id: String,
-}
-
-#[derive(Debug, Default)]
-pub struct UpdateBillingCustomerParams {
- pub user_id: ActiveValue<UserId>,
- pub stripe_customer_id: ActiveValue<String>,
- pub has_overdue_invoices: ActiveValue<bool>,
- pub trial_started_at: ActiveValue<Option<DateTime>>,
-}
-
-impl Database {
- /// Creates a new billing customer.
- pub async fn create_billing_customer(
- &self,
- params: &CreateBillingCustomerParams,
- ) -> Result<billing_customer::Model> {
- self.transaction(|tx| async move {
- let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
- user_id: ActiveValue::set(params.user_id),
- stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
- ..Default::default()
- })
- .exec_with_returning(&*tx)
- .await?;
-
- Ok(customer)
- })
- .await
- }
-
- /// Updates the specified billing customer.
- pub async fn update_billing_customer(
- &self,
- id: BillingCustomerId,
- params: &UpdateBillingCustomerParams,
- ) -> Result<()> {
- self.transaction(|tx| async move {
- billing_customer::Entity::update(billing_customer::ActiveModel {
- id: ActiveValue::set(id),
- user_id: params.user_id.clone(),
- stripe_customer_id: params.stripe_customer_id.clone(),
- has_overdue_invoices: params.has_overdue_invoices.clone(),
- trial_started_at: params.trial_started_at.clone(),
- created_at: ActiveValue::not_set(),
- })
- .exec(&*tx)
- .await?;
-
- Ok(())
- })
- .await
- }
-
- pub async fn get_billing_customer_by_id(
- &self,
- id: BillingCustomerId,
- ) -> Result<Option<billing_customer::Model>> {
- self.transaction(|tx| async move {
- Ok(billing_customer::Entity::find()
- .filter(billing_customer::Column::Id.eq(id))
- .one(&*tx)
- .await?)
- })
- .await
- }
-
- /// Returns the billing customer for the user with the specified ID.
- pub async fn get_billing_customer_by_user_id(
- &self,
- user_id: UserId,
- ) -> Result<Option<billing_customer::Model>> {
- self.transaction(|tx| async move {
- Ok(billing_customer::Entity::find()
- .filter(billing_customer::Column::UserId.eq(user_id))
- .one(&*tx)
- .await?)
- })
- .await
- }
-
- /// Returns the billing customer for the user with the specified Stripe customer ID.
- pub async fn get_billing_customer_by_stripe_customer_id(
- &self,
- stripe_customer_id: &str,
- ) -> Result<Option<billing_customer::Model>> {
- self.transaction(|tx| async move {
- Ok(billing_customer::Entity::find()
- .filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
- .one(&*tx)
- .await?)
- })
- .await
- }
-}
@@ -1,17 +0,0 @@
-use super::*;
-
-impl Database {
- /// Returns the billing preferences for the given user, if they exist.
- pub async fn get_billing_preferences(
- &self,
- user_id: UserId,
- ) -> Result<Option<billing_preference::Model>> {
- self.transaction(|tx| async move {
- Ok(billing_preference::Entity::find()
- .filter(billing_preference::Column::UserId.eq(user_id))
- .one(&*tx)
- .await?)
- })
- .await
- }
-}
@@ -1,158 +0,0 @@
-use anyhow::Context as _;
-
-use crate::db::billing_subscription::{
- StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
-};
-
-use super::*;
-
-#[derive(Debug)]
-pub struct CreateBillingSubscriptionParams {
- pub billing_customer_id: BillingCustomerId,
- pub kind: Option<SubscriptionKind>,
- pub stripe_subscription_id: String,
- pub stripe_subscription_status: StripeSubscriptionStatus,
- pub stripe_cancellation_reason: Option<StripeCancellationReason>,
- pub stripe_current_period_start: Option<i64>,
- pub stripe_current_period_end: Option<i64>,
-}
-
-#[derive(Debug, Default)]
-pub struct UpdateBillingSubscriptionParams {
- pub billing_customer_id: ActiveValue<BillingCustomerId>,
- pub kind: ActiveValue<Option<SubscriptionKind>>,
- pub stripe_subscription_id: ActiveValue<String>,
- pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
- pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
- pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
- pub stripe_current_period_start: ActiveValue<Option<i64>>,
- pub stripe_current_period_end: ActiveValue<Option<i64>>,
-}
-
-impl Database {
- /// Creates a new billing subscription.
- pub async fn create_billing_subscription(
- &self,
- params: &CreateBillingSubscriptionParams,
- ) -> Result<billing_subscription::Model> {
- self.transaction(|tx| async move {
- let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
- billing_customer_id: ActiveValue::set(params.billing_customer_id),
- kind: ActiveValue::set(params.kind),
- stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
- stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
- stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
- stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
- stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
- ..Default::default()
- })
- .exec(&*tx)
- .await?
- .last_insert_id;
-
- Ok(billing_subscription::Entity::find_by_id(id)
- .one(&*tx)
- .await?
- .context("failed to retrieve inserted billing subscription")?)
- })
- .await
- }
-
- /// Updates the specified billing subscription.
- pub async fn update_billing_subscription(
- &self,
- id: BillingSubscriptionId,
- params: &UpdateBillingSubscriptionParams,
- ) -> Result<()> {
- self.transaction(|tx| async move {
- billing_subscription::Entity::update(billing_subscription::ActiveModel {
- id: ActiveValue::set(id),
- billing_customer_id: params.billing_customer_id.clone(),
- kind: params.kind.clone(),
- stripe_subscription_id: params.stripe_subscription_id.clone(),
- stripe_subscription_status: params.stripe_subscription_status.clone(),
- stripe_cancel_at: params.stripe_cancel_at.clone(),
- stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
- stripe_current_period_start: params.stripe_current_period_start.clone(),
- stripe_current_period_end: params.stripe_current_period_end.clone(),
- created_at: ActiveValue::not_set(),
- })
- .exec(&*tx)
- .await?;
-
- Ok(())
- })
- .await
- }
-
- /// Returns the billing subscription with the specified Stripe subscription ID.
- pub async fn get_billing_subscription_by_stripe_subscription_id(
- &self,
- stripe_subscription_id: &str,
- ) -> Result<Option<billing_subscription::Model>> {
- self.transaction(|tx| async move {
- Ok(billing_subscription::Entity::find()
- .filter(
- billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
- )
- .one(&*tx)
- .await?)
- })
- .await
- }
-
- pub async fn get_active_billing_subscription(
- &self,
- user_id: UserId,
- ) -> Result<Option<billing_subscription::Model>> {
- self.transaction(|tx| async move {
- Ok(billing_subscription::Entity::find()
- .inner_join(billing_customer::Entity)
- .filter(billing_customer::Column::UserId.eq(user_id))
- .filter(
- Condition::all()
- .add(
- Condition::any()
- .add(
- billing_subscription::Column::StripeSubscriptionStatus
- .eq(StripeSubscriptionStatus::Active),
- )
- .add(
- billing_subscription::Column::StripeSubscriptionStatus
- .eq(StripeSubscriptionStatus::Trialing),
- ),
- )
- .add(billing_subscription::Column::Kind.is_not_null()),
- )
- .one(&*tx)
- .await?)
- })
- .await
- }
-
- /// Returns whether the user has an active billing subscription.
- pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
- Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
- }
-
- /// Returns the count of the active billing subscriptions for the user with the specified ID.
- pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
- self.transaction(|tx| async move {
- let count = billing_subscription::Entity::find()
- .inner_join(billing_customer::Entity)
- .filter(
- billing_customer::Column::UserId.eq(user_id).and(
- billing_subscription::Column::StripeSubscriptionStatus
- .eq(StripeSubscriptionStatus::Active)
- .or(billing_subscription::Column::StripeSubscriptionStatus
- .eq(StripeSubscriptionStatus::Trialing)),
- ),
- )
- .count(&*tx)
- .await?;
-
- Ok(count as usize)
- })
- .await
- }
-}
@@ -1,69 +0,0 @@
-use super::*;
-
-#[derive(Debug)]
-pub struct CreateProcessedStripeEventParams {
- pub stripe_event_id: String,
- pub stripe_event_type: String,
- pub stripe_event_created_timestamp: i64,
-}
-
-impl Database {
- /// Creates a new processed Stripe event.
- pub async fn create_processed_stripe_event(
- &self,
- params: &CreateProcessedStripeEventParams,
- ) -> Result<()> {
- self.transaction(|tx| async move {
- processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
- stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
- stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
- stripe_event_created_timestamp: ActiveValue::set(
- params.stripe_event_created_timestamp,
- ),
- ..Default::default()
- })
- .exec_without_returning(&*tx)
- .await?;
-
- Ok(())
- })
- .await
- }
-
- /// Returns the processed Stripe event with the specified event ID.
- pub async fn get_processed_stripe_event_by_event_id(
- &self,
- event_id: &str,
- ) -> Result<Option<processed_stripe_event::Model>> {
- self.transaction(|tx| async move {
- Ok(processed_stripe_event::Entity::find_by_id(event_id)
- .one(&*tx)
- .await?)
- })
- .await
- }
-
- /// Returns the processed Stripe events with the specified event IDs.
- pub async fn get_processed_stripe_events_by_event_ids(
- &self,
- event_ids: &[&str],
- ) -> Result<Vec<processed_stripe_event::Model>> {
- self.transaction(|tx| async move {
- Ok(processed_stripe_event::Entity::find()
- .filter(
- processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),
- )
- .all(&*tx)
- .await?)
- })
- .await
- }
-
- /// Returns whether the Stripe event with the specified ID has already been processed.
- pub async fn already_processed_stripe_event(&self, event_id: &str) -> Result<bool> {
- Ok(self
- .get_processed_stripe_event_by_event_id(event_id)
- .await?
- .is_some())
- }
-}
@@ -1,7 +1,4 @@
pub mod access_token;
-pub mod billing_customer;
-pub mod billing_preference;
-pub mod billing_subscription;
pub mod buffer;
pub mod buffer_operation;
pub mod buffer_snapshot;
@@ -23,7 +20,6 @@ pub mod notification;
pub mod notification_kind;
pub mod observed_buffer_edits;
pub mod observed_channel_messages;
-pub mod processed_stripe_event;
pub mod project;
pub mod project_collaborator;
pub mod project_repository;
@@ -1,41 +0,0 @@
-use crate::db::{BillingCustomerId, UserId};
-use sea_orm::entity::prelude::*;
-
-/// A billing customer.
-#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
-#[sea_orm(table_name = "billing_customers")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: BillingCustomerId,
- pub user_id: UserId,
- pub stripe_customer_id: String,
- pub has_overdue_invoices: bool,
- pub trial_started_at: Option<DateTime>,
- pub created_at: DateTime,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(
- belongs_to = "super::user::Entity",
- from = "Column::UserId",
- to = "super::user::Column::Id"
- )]
- User,
- #[sea_orm(has_many = "super::billing_subscription::Entity")]
- BillingSubscription,
-}
-
-impl Related<super::user::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::User.def()
- }
-}
-
-impl Related<super::billing_subscription::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::BillingSubscription.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,32 +0,0 @@
-use crate::db::{BillingPreferencesId, UserId};
-use sea_orm::entity::prelude::*;
-
-#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
-#[sea_orm(table_name = "billing_preferences")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: BillingPreferencesId,
- pub created_at: DateTime,
- pub user_id: UserId,
- pub max_monthly_llm_usage_spending_in_cents: i32,
- pub model_request_overages_enabled: bool,
- pub model_request_overages_spend_limit_in_cents: i32,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(
- belongs_to = "super::user::Entity",
- from = "Column::UserId",
- to = "super::user::Column::Id"
- )]
- User,
-}
-
-impl Related<super::user::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::User.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,161 +0,0 @@
-use crate::db::{BillingCustomerId, BillingSubscriptionId};
-use chrono::{Datelike as _, NaiveDate, Utc};
-use sea_orm::entity::prelude::*;
-use serde::Serialize;
-
-/// A billing subscription.
-#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
-#[sea_orm(table_name = "billing_subscriptions")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: BillingSubscriptionId,
- pub billing_customer_id: BillingCustomerId,
- pub kind: Option<SubscriptionKind>,
- pub stripe_subscription_id: String,
- pub stripe_subscription_status: StripeSubscriptionStatus,
- pub stripe_cancel_at: Option<DateTime>,
- pub stripe_cancellation_reason: Option<StripeCancellationReason>,
- pub stripe_current_period_start: Option<i64>,
- pub stripe_current_period_end: Option<i64>,
- pub created_at: DateTime,
-}
-
-impl Model {
- pub fn current_period_start_at(&self) -> Option<DateTimeUtc> {
- let period_start = self.stripe_current_period_start?;
- chrono::DateTime::from_timestamp(period_start, 0)
- }
-
- pub fn current_period_end_at(&self) -> Option<DateTimeUtc> {
- let period_end = self.stripe_current_period_end?;
- chrono::DateTime::from_timestamp(period_end, 0)
- }
-
- pub fn current_period(
- subscription: Option<Self>,
- is_staff: bool,
- ) -> Option<(DateTimeUtc, DateTimeUtc)> {
- if is_staff {
- let now = Utc::now();
- let year = now.year();
- let month = now.month();
-
- let first_day_of_this_month =
- NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?;
-
- let next_month = if month == 12 { 1 } else { month + 1 };
- let next_month_year = if month == 12 { year + 1 } else { year };
- let first_day_of_next_month =
- NaiveDate::from_ymd_opt(next_month_year, next_month, 1)?.and_hms_opt(23, 59, 59)?;
-
- let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1);
-
- Some((
- first_day_of_this_month.and_utc(),
- last_day_of_this_month.and_utc(),
- ))
- } else {
- let subscription = subscription?;
- let period_start_at = subscription.current_period_start_at()?;
- let period_end_at = subscription.current_period_end_at()?;
-
- Some((period_start_at, period_end_at))
- }
- }
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(
- belongs_to = "super::billing_customer::Entity",
- from = "Column::BillingCustomerId",
- to = "super::billing_customer::Column::Id"
- )]
- BillingCustomer,
-}
-
-impl Related<super::billing_customer::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::BillingCustomer.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
-
-#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
-#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
-#[serde(rename_all = "snake_case")]
-pub enum SubscriptionKind {
- #[sea_orm(string_value = "zed_pro")]
- ZedPro,
- #[sea_orm(string_value = "zed_pro_trial")]
- ZedProTrial,
- #[sea_orm(string_value = "zed_free")]
- ZedFree,
-}
-
-impl From<SubscriptionKind> for cloud_llm_client::Plan {
- fn from(value: SubscriptionKind) -> Self {
- match value {
- SubscriptionKind::ZedPro => Self::ZedPro,
- SubscriptionKind::ZedProTrial => Self::ZedProTrial,
- SubscriptionKind::ZedFree => Self::ZedFree,
- }
- }
-}
-
-/// The status of a Stripe subscription.
-///
-/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-status)
-#[derive(
- Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
-)]
-#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
-#[serde(rename_all = "snake_case")]
-pub enum StripeSubscriptionStatus {
- #[default]
- #[sea_orm(string_value = "incomplete")]
- Incomplete,
- #[sea_orm(string_value = "incomplete_expired")]
- IncompleteExpired,
- #[sea_orm(string_value = "trialing")]
- Trialing,
- #[sea_orm(string_value = "active")]
- Active,
- #[sea_orm(string_value = "past_due")]
- PastDue,
- #[sea_orm(string_value = "canceled")]
- Canceled,
- #[sea_orm(string_value = "unpaid")]
- Unpaid,
- #[sea_orm(string_value = "paused")]
- Paused,
-}
-
-impl StripeSubscriptionStatus {
- pub fn is_cancelable(&self) -> bool {
- match self {
- Self::Trialing | Self::Active | Self::PastDue => true,
- Self::Incomplete
- | Self::IncompleteExpired
- | Self::Canceled
- | Self::Unpaid
- | Self::Paused => false,
- }
- }
-}
-
-/// The cancellation reason for a Stripe subscription.
-///
-/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-cancellation_details-reason)
-#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
-#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
-#[serde(rename_all = "snake_case")]
-pub enum StripeCancellationReason {
- #[sea_orm(string_value = "cancellation_requested")]
- CancellationRequested,
- #[sea_orm(string_value = "payment_disputed")]
- PaymentDisputed,
- #[sea_orm(string_value = "payment_failed")]
- PaymentFailed,
-}
@@ -1,16 +0,0 @@
-use sea_orm::entity::prelude::*;
-
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "processed_stripe_events")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub stripe_event_id: String,
- pub stripe_event_type: String,
- pub stripe_event_created_timestamp: i64,
- pub processed_at: DateTime,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -29,8 +29,6 @@ pub struct Model {
pub enum Relation {
#[sea_orm(has_many = "super::access_token::Entity")]
AccessToken,
- #[sea_orm(has_one = "super::billing_customer::Entity")]
- BillingCustomer,
#[sea_orm(has_one = "super::room_participant::Entity")]
RoomParticipant,
#[sea_orm(has_many = "super::project::Entity")]
@@ -68,12 +66,6 @@ impl Related<super::access_token::Entity> for Entity {
}
}
-impl Related<super::billing_customer::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::BillingCustomer.def()
- }
-}
-
impl Related<super::room_participant::Entity> for Entity {
fn to() -> RelationDef {
Relation::RoomParticipant.def()
@@ -8,7 +8,6 @@ mod embedding_tests;
mod extension_tests;
mod feature_flag_tests;
mod message_tests;
-mod processed_stripe_event_tests;
mod user_tests;
use crate::migrations::run_database_migrations;
@@ -1,38 +0,0 @@
-use std::sync::Arc;
-
-use crate::test_both_dbs;
-
-use super::{CreateProcessedStripeEventParams, Database};
-
-test_both_dbs!(
- test_already_processed_stripe_event,
- test_already_processed_stripe_event_postgres,
- test_already_processed_stripe_event_sqlite
-);
-
-async fn test_already_processed_stripe_event(db: &Arc<Database>) {
- let unprocessed_event_id = "evt_1PiJOuRxOf7d5PNaw2zzWiyO".to_string();
- let processed_event_id = "evt_1PiIfMRxOf7d5PNakHrAUe8P".to_string();
-
- db.create_processed_stripe_event(&CreateProcessedStripeEventParams {
- stripe_event_id: processed_event_id.clone(),
- stripe_event_type: "customer.created".into(),
- stripe_event_created_timestamp: 1722355968,
- })
- .await
- .unwrap();
-
- assert!(
- db.already_processed_stripe_event(&processed_event_id)
- .await
- .unwrap(),
- "Expected {processed_event_id} to already be processed"
- );
-
- assert!(
- !db.already_processed_stripe_event(&unprocessed_event_id)
- .await
- .unwrap(),
- "Expected {unprocessed_event_id} to be unprocessed"
- );
-}
@@ -20,7 +20,6 @@ use axum::{
};
use db::{ChannelId, Database};
use executor::Executor;
-use llm::db::LlmDatabase;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
@@ -242,7 +241,6 @@ impl ServiceMode {
pub struct AppState {
pub db: Arc<Database>,
- pub llm_db: Option<Arc<LlmDatabase>>,
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub executor: Executor,
@@ -257,20 +255,6 @@ impl AppState {
let mut db = Database::new(db_options).await?;
db.initialize_notification_kinds().await?;
- let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
- .llm_database_url
- .clone()
- .zip(config.llm_database_max_connections)
- {
- let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
- llm_db_options.max_connections(llm_database_max_connections);
- let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
- llm_db.initialize().await?;
- Some(Arc::new(llm_db))
- } else {
- None
- };
-
let livekit_client = if let Some(((server, key), secret)) = config
.livekit_server
.as_ref()
@@ -289,7 +273,6 @@ impl AppState {
let db = Arc::new(db);
let this = Self {
db: db.clone(),
- llm_db,
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
executor,
@@ -1,30 +1,9 @@
-mod ids;
-mod queries;
-mod seed;
-mod tables;
-
-#[cfg(test)]
-mod tests;
-
-use cloud_llm_client::LanguageModelProvider;
-use collections::HashMap;
-pub use ids::*;
-pub use seed::*;
-pub use tables::*;
-
-#[cfg(test)]
-pub use tests::TestLlmDb;
-use usage_measure::UsageMeasure;
-
use std::future::Future;
use std::sync::Arc;
use anyhow::Context;
pub use sea_orm::ConnectOptions;
-use sea_orm::prelude::*;
-use sea_orm::{
- ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
-};
+use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait};
use crate::Result;
use crate::db::TransactionHandle;
@@ -36,9 +15,6 @@ pub struct LlmDatabase {
pool: DatabaseConnection,
#[allow(unused)]
executor: Executor,
- provider_ids: HashMap<LanguageModelProvider, ProviderId>,
- models: HashMap<(LanguageModelProvider, String), model::Model>,
- usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
#[cfg(test)]
runtime: Option<tokio::runtime::Runtime>,
}
@@ -51,59 +27,11 @@ impl LlmDatabase {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
executor,
- provider_ids: HashMap::default(),
- models: HashMap::default(),
- usage_measure_ids: HashMap::default(),
#[cfg(test)]
runtime: None,
})
}
- pub async fn initialize(&mut self) -> Result<()> {
- self.initialize_providers().await?;
- self.initialize_models().await?;
- self.initialize_usage_measures().await?;
- Ok(())
- }
-
- /// Returns the list of all known models, with their [`LanguageModelProvider`].
- pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
- self.models
- .iter()
- .map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
- .collect::<Vec<_>>()
- }
-
- /// Returns the names of the known models for the given [`LanguageModelProvider`].
- pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
- self.models
- .keys()
- .filter_map(|(model_provider, model_name)| {
- if model_provider == &provider {
- Some(model_name)
- } else {
- None
- }
- })
- .cloned()
- .collect::<Vec<_>>()
- }
-
- pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
- Ok(self
- .models
- .get(&(provider, name.to_string()))
- .with_context(|| format!("unknown model {provider:?}:{name}"))?)
- }
-
- pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
- Ok(self
- .models
- .values()
- .find(|model| model.id == id)
- .with_context(|| format!("no model for ID {id:?}"))?)
- }
-
pub fn options(&self) -> &ConnectOptions {
&self.options
}
@@ -1,11 +0,0 @@
-use sea_orm::{DbErr, entity::prelude::*};
-use serde::{Deserialize, Serialize};
-
-use crate::id_type;
-
-id_type!(BillingEventId);
-id_type!(ModelId);
-id_type!(ProviderId);
-id_type!(RevokedAccessTokenId);
-id_type!(UsageId);
-id_type!(UsageMeasureId);
@@ -1,5 +0,0 @@
-use super::*;
-
-pub mod providers;
-pub mod subscription_usages;
-pub mod usages;
@@ -1,134 +0,0 @@
-use super::*;
-use sea_orm::{QueryOrder, sea_query::OnConflict};
-use std::str::FromStr;
-use strum::IntoEnumIterator as _;
-
-pub struct ModelParams {
- pub provider: LanguageModelProvider,
- pub name: String,
- pub max_requests_per_minute: i64,
- pub max_tokens_per_minute: i64,
- pub max_tokens_per_day: i64,
- pub price_per_million_input_tokens: i32,
- pub price_per_million_output_tokens: i32,
-}
-
-impl LlmDatabase {
- pub async fn initialize_providers(&mut self) -> Result<()> {
- self.provider_ids = self
- .transaction(|tx| async move {
- let existing_providers = provider::Entity::find().all(&*tx).await?;
-
- let mut new_providers = LanguageModelProvider::iter()
- .filter(|provider| {
- !existing_providers
- .iter()
- .any(|p| p.name == provider.to_string())
- })
- .map(|provider| provider::ActiveModel {
- name: ActiveValue::set(provider.to_string()),
- ..Default::default()
- })
- .peekable();
-
- if new_providers.peek().is_some() {
- provider::Entity::insert_many(new_providers)
- .exec(&*tx)
- .await?;
- }
-
- let all_providers: HashMap<_, _> = provider::Entity::find()
- .all(&*tx)
- .await?
- .iter()
- .filter_map(|provider| {
- LanguageModelProvider::from_str(&provider.name)
- .ok()
- .map(|p| (p, provider.id))
- })
- .collect();
-
- Ok(all_providers)
- })
- .await?;
- Ok(())
- }
-
- pub async fn initialize_models(&mut self) -> Result<()> {
- let all_provider_ids = &self.provider_ids;
- self.models = self
- .transaction(|tx| async move {
- let all_models: HashMap<_, _> = model::Entity::find()
- .all(&*tx)
- .await?
- .into_iter()
- .filter_map(|model| {
- let provider = all_provider_ids.iter().find_map(|(provider, id)| {
- if *id == model.provider_id {
- Some(provider)
- } else {
- None
- }
- })?;
- Some(((*provider, model.name.clone()), model))
- })
- .collect();
- Ok(all_models)
- })
- .await?;
- Ok(())
- }
-
- pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> {
- let all_provider_ids = &self.provider_ids;
- self.transaction(|tx| async move {
- model::Entity::insert_many(models.iter().map(|model_params| {
- let provider_id = all_provider_ids[&model_params.provider];
- model::ActiveModel {
- provider_id: ActiveValue::set(provider_id),
- name: ActiveValue::set(model_params.name.clone()),
- max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute),
- max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute),
- max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day),
- price_per_million_input_tokens: ActiveValue::set(
- model_params.price_per_million_input_tokens,
- ),
- price_per_million_output_tokens: ActiveValue::set(
- model_params.price_per_million_output_tokens,
- ),
- ..Default::default()
- }
- }))
- .on_conflict(
- OnConflict::columns([model::Column::ProviderId, model::Column::Name])
- .update_columns([
- model::Column::MaxRequestsPerMinute,
- model::Column::MaxTokensPerMinute,
- model::Column::MaxTokensPerDay,
- model::Column::PricePerMillionInputTokens,
- model::Column::PricePerMillionOutputTokens,
- ])
- .to_owned(),
- )
- .exec_without_returning(&*tx)
- .await?;
- Ok(())
- })
- .await?;
- self.initialize_models().await
- }
-
- /// Returns the list of LLM providers.
- pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
- self.transaction(|tx| async move {
- Ok(provider::Entity::find()
- .order_by_asc(provider::Column::Name)
- .all(&*tx)
- .await?
- .into_iter()
- .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
- .collect())
- })
- .await
- }
-}
@@ -1,38 +0,0 @@
-use crate::db::UserId;
-
-use super::*;
-
-impl LlmDatabase {
- pub async fn get_subscription_usage_for_period(
- &self,
- user_id: UserId,
- period_start_at: DateTimeUtc,
- period_end_at: DateTimeUtc,
- ) -> Result<Option<subscription_usage::Model>> {
- self.transaction(|tx| async move {
- self.get_subscription_usage_for_period_in_tx(
- user_id,
- period_start_at,
- period_end_at,
- &tx,
- )
- .await
- })
- .await
- }
-
- async fn get_subscription_usage_for_period_in_tx(
- &self,
- user_id: UserId,
- period_start_at: DateTimeUtc,
- period_end_at: DateTimeUtc,
- tx: &DatabaseTransaction,
- ) -> Result<Option<subscription_usage::Model>> {
- Ok(subscription_usage::Entity::find()
- .filter(subscription_usage::Column::UserId.eq(user_id))
- .filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
- .filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
- .one(tx)
- .await?)
- }
-}
@@ -1,44 +0,0 @@
-use std::str::FromStr;
-use strum::IntoEnumIterator as _;
-
-use super::*;
-
-impl LlmDatabase {
- pub async fn initialize_usage_measures(&mut self) -> Result<()> {
- let all_measures = self
- .transaction(|tx| async move {
- let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
-
- let new_measures = UsageMeasure::iter()
- .filter(|measure| {
- !existing_measures
- .iter()
- .any(|m| m.name == measure.to_string())
- })
- .map(|measure| usage_measure::ActiveModel {
- name: ActiveValue::set(measure.to_string()),
- ..Default::default()
- })
- .collect::<Vec<_>>();
-
- if !new_measures.is_empty() {
- usage_measure::Entity::insert_many(new_measures)
- .exec(&*tx)
- .await?;
- }
-
- Ok(usage_measure::Entity::find().all(&*tx).await?)
- })
- .await?;
-
- self.usage_measure_ids = all_measures
- .into_iter()
- .filter_map(|measure| {
- UsageMeasure::from_str(&measure.name)
- .ok()
- .map(|um| (um, measure.id))
- })
- .collect();
- Ok(())
- }
-}
@@ -1,45 +0,0 @@
-use super::*;
-use crate::{Config, Result};
-use queries::providers::ModelParams;
-
-pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
- db.insert_models(&[
- ModelParams {
- provider: LanguageModelProvider::Anthropic,
- name: "claude-3-5-sonnet".into(),
- max_requests_per_minute: 5,
- max_tokens_per_minute: 20_000,
- max_tokens_per_day: 300_000,
- price_per_million_input_tokens: 300, // $3.00/MTok
- price_per_million_output_tokens: 1500, // $15.00/MTok
- },
- ModelParams {
- provider: LanguageModelProvider::Anthropic,
- name: "claude-3-opus".into(),
- max_requests_per_minute: 5,
- max_tokens_per_minute: 10_000,
- max_tokens_per_day: 300_000,
- price_per_million_input_tokens: 1500, // $15.00/MTok
- price_per_million_output_tokens: 7500, // $75.00/MTok
- },
- ModelParams {
- provider: LanguageModelProvider::Anthropic,
- name: "claude-3-sonnet".into(),
- max_requests_per_minute: 5,
- max_tokens_per_minute: 20_000,
- max_tokens_per_day: 300_000,
- price_per_million_input_tokens: 1500, // $15.00/MTok
- price_per_million_output_tokens: 7500, // $75.00/MTok
- },
- ModelParams {
- provider: LanguageModelProvider::Anthropic,
- name: "claude-3-haiku".into(),
- max_requests_per_minute: 5,
- max_tokens_per_minute: 25_000,
- max_tokens_per_day: 300_000,
- price_per_million_input_tokens: 25, // $0.25/MTok
- price_per_million_output_tokens: 125, // $1.25/MTok
- },
- ])
- .await
-}
@@ -1,6 +0,0 @@
-pub mod model;
-pub mod provider;
-pub mod subscription_usage;
-pub mod subscription_usage_meter;
-pub mod usage;
-pub mod usage_measure;
@@ -1,48 +0,0 @@
-use sea_orm::entity::prelude::*;
-
-use crate::llm::db::{ModelId, ProviderId};
-
-/// An LLM model.
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "models")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: ModelId,
- pub provider_id: ProviderId,
- pub name: String,
- pub max_requests_per_minute: i64,
- pub max_tokens_per_minute: i64,
- pub max_input_tokens_per_minute: i64,
- pub max_output_tokens_per_minute: i64,
- pub max_tokens_per_day: i64,
- pub price_per_million_input_tokens: i32,
- pub price_per_million_cache_creation_input_tokens: i32,
- pub price_per_million_cache_read_input_tokens: i32,
- pub price_per_million_output_tokens: i32,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(
- belongs_to = "super::provider::Entity",
- from = "Column::ProviderId",
- to = "super::provider::Column::Id"
- )]
- Provider,
- #[sea_orm(has_many = "super::usage::Entity")]
- Usages,
-}
-
-impl Related<super::provider::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::Provider.def()
- }
-}
-
-impl Related<super::usage::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::Usages.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,25 +0,0 @@
-use crate::llm::db::ProviderId;
-use sea_orm::entity::prelude::*;
-
-/// An LLM provider.
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "providers")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: ProviderId,
- pub name: String,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(has_many = "super::model::Entity")]
- Models,
-}
-
-impl Related<super::model::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::Models.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,22 +0,0 @@
-use crate::db::UserId;
-use crate::db::billing_subscription::SubscriptionKind;
-use sea_orm::entity::prelude::*;
-use time::PrimitiveDateTime;
-
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "subscription_usages_v2")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: Uuid,
- pub user_id: UserId,
- pub period_start_at: PrimitiveDateTime,
- pub period_end_at: PrimitiveDateTime,
- pub plan: SubscriptionKind,
- pub model_requests: i32,
- pub edit_predictions: i32,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,55 +0,0 @@
-use sea_orm::entity::prelude::*;
-use serde::Serialize;
-
-use crate::llm::db::ModelId;
-
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "subscription_usage_meters_v2")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: Uuid,
- pub subscription_usage_id: Uuid,
- pub model_id: ModelId,
- pub mode: CompletionMode,
- pub requests: i32,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(
- belongs_to = "super::subscription_usage::Entity",
- from = "Column::SubscriptionUsageId",
- to = "super::subscription_usage::Column::Id"
- )]
- SubscriptionUsage,
- #[sea_orm(
- belongs_to = "super::model::Entity",
- from = "Column::ModelId",
- to = "super::model::Column::Id"
- )]
- Model,
-}
-
-impl Related<super::subscription_usage::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::SubscriptionUsage.def()
- }
-}
-
-impl Related<super::model::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::Model.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
-
-#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
-#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
-#[serde(rename_all = "snake_case")]
-pub enum CompletionMode {
- #[sea_orm(string_value = "normal")]
- Normal,
- #[sea_orm(string_value = "max")]
- Max,
-}
@@ -1,52 +0,0 @@
-use crate::{
- db::UserId,
- llm::db::{ModelId, UsageId, UsageMeasureId},
-};
-use sea_orm::entity::prelude::*;
-
-/// An LLM usage record.
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "usages")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: UsageId,
- /// The ID of the Zed user.
- ///
- /// Corresponds to the `users` table in the primary collab database.
- pub user_id: UserId,
- pub model_id: ModelId,
- pub measure_id: UsageMeasureId,
- pub timestamp: DateTime,
- pub buckets: Vec<i64>,
- pub is_staff: bool,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(
- belongs_to = "super::model::Entity",
- from = "Column::ModelId",
- to = "super::model::Column::Id"
- )]
- Model,
- #[sea_orm(
- belongs_to = "super::usage_measure::Entity",
- from = "Column::MeasureId",
- to = "super::usage_measure::Column::Id"
- )]
- UsageMeasure,
-}
-
-impl Related<super::model::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::Model.def()
- }
-}
-
-impl Related<super::usage_measure::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::UsageMeasure.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,36 +0,0 @@
-use crate::llm::db::UsageMeasureId;
-use sea_orm::entity::prelude::*;
-
-#[derive(
- Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
-)]
-#[strum(serialize_all = "snake_case")]
-pub enum UsageMeasure {
- RequestsPerMinute,
- TokensPerMinute,
- InputTokensPerMinute,
- OutputTokensPerMinute,
- TokensPerDay,
-}
-
-#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
-#[sea_orm(table_name = "usage_measures")]
-pub struct Model {
- #[sea_orm(primary_key)]
- pub id: UsageMeasureId,
- pub name: String,
-}
-
-#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {
- #[sea_orm(has_many = "super::usage::Entity")]
- Usages,
-}
-
-impl Related<super::usage::Entity> for Entity {
- fn to() -> RelationDef {
- Relation::Usages.def()
- }
-}
-
-impl ActiveModelBehavior for ActiveModel {}
@@ -1,107 +0,0 @@
-mod provider_tests;
-
-use gpui::BackgroundExecutor;
-use parking_lot::Mutex;
-use rand::prelude::*;
-use sea_orm::ConnectionTrait;
-use sqlx::migrate::MigrateDatabase;
-use std::time::Duration;
-
-use crate::migrations::run_database_migrations;
-
-use super::*;
-
-pub struct TestLlmDb {
- pub db: Option<LlmDatabase>,
- pub connection: Option<sqlx::AnyConnection>,
-}
-
-impl TestLlmDb {
- pub fn postgres(background: BackgroundExecutor) -> Self {
- static LOCK: Mutex<()> = Mutex::new(());
-
- let _guard = LOCK.lock();
- let mut rng = StdRng::from_entropy();
- let url = format!(
- "postgres://postgres@localhost/zed-llm-test-{}",
- rng.r#gen::<u128>()
- );
- let runtime = tokio::runtime::Builder::new_current_thread()
- .enable_io()
- .enable_time()
- .build()
- .unwrap();
-
- let mut db = runtime.block_on(async {
- sqlx::Postgres::create_database(&url)
- .await
- .expect("failed to create test db");
- let mut options = ConnectOptions::new(url);
- options
- .max_connections(5)
- .idle_timeout(Duration::from_secs(0));
- let db = LlmDatabase::new(options, Executor::Deterministic(background))
- .await
- .unwrap();
- let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
- run_database_migrations(db.options(), migrations_path)
- .await
- .unwrap();
- db
- });
-
- db.runtime = Some(runtime);
-
- Self {
- db: Some(db),
- connection: None,
- }
- }
-
- pub fn db(&mut self) -> &mut LlmDatabase {
- self.db.as_mut().unwrap()
- }
-}
-
-#[macro_export]
-macro_rules! test_llm_db {
- ($test_name:ident, $postgres_test_name:ident) => {
- #[gpui::test]
- async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
- if !cfg!(target_os = "macos") {
- return;
- }
-
- let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
- $test_name(test_db.db()).await;
- }
- };
-}
-
-impl Drop for TestLlmDb {
- fn drop(&mut self) {
- let db = self.db.take().unwrap();
- if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
- db.runtime.as_ref().unwrap().block_on(async {
- use util::ResultExt;
- let query = "
- SELECT pg_terminate_backend(pg_stat_activity.pid)
- FROM pg_stat_activity
- WHERE
- pg_stat_activity.datname = current_database() AND
- pid <> pg_backend_pid();
- ";
- db.pool
- .execute(sea_orm::Statement::from_string(
- db.pool.get_database_backend(),
- query,
- ))
- .await
- .log_err();
- sqlx::Postgres::drop_database(db.options.get_url())
- .await
- .log_err();
- })
- }
- }
-}
@@ -1,31 +0,0 @@
-use cloud_llm_client::LanguageModelProvider;
-use pretty_assertions::assert_eq;
-
-use crate::llm::db::LlmDatabase;
-use crate::test_llm_db;
-
-test_llm_db!(
- test_initialize_providers,
- test_initialize_providers_postgres
-);
-
-async fn test_initialize_providers(db: &mut LlmDatabase) {
- let initial_providers = db.list_providers().await.unwrap();
- assert_eq!(initial_providers, vec![]);
-
- db.initialize_providers().await.unwrap();
-
- // Do it twice, to make sure the operation is idempotent.
- db.initialize_providers().await.unwrap();
-
- let providers = db.list_providers().await.unwrap();
-
- assert_eq!(
- providers,
- &[
- LanguageModelProvider::Anthropic,
- LanguageModelProvider::Google,
- LanguageModelProvider::OpenAi,
- ]
- )
-}
@@ -62,13 +62,6 @@ async fn main() -> Result<()> {
db.initialize_notification_kinds().await?;
collab::seed::seed(&config, &db, false).await?;
-
- if let Some(llm_database_url) = config.llm_database_url.clone() {
- let db_options = db::ConnectOptions::new(llm_database_url);
- let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
- db.initialize().await?;
- collab::llm::db::seed_database(&config, &mut db, true).await?;
- }
}
Some("serve") => {
let mode = match args.next().as_deref() {
@@ -263,9 +256,6 @@ async fn setup_llm_database(config: &Config) -> Result<()> {
.llm_database_migrations_path
.as_deref()
.unwrap_or_else(|| {
- #[cfg(feature = "sqlite")]
- let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
- #[cfg(not(feature = "sqlite"))]
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
Path::new(default_migrations)
@@ -565,7 +565,6 @@ impl TestServer {
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
- llm_db: None,
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
blob_store_client: None,
executor,