collab: Add billing preferences for maximum LLM monthly spend (#18948)

Marshall Bowers and Richard created

This PR adds a new `billing_preferences` table.

Right now there is a single preference: the maximum monthly spend for
LLM usage.

Release Notes:

- N/A

---------

Co-authored-by: Richard <richard@zed.dev>

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql      |  9 
crates/collab/migrations/20241009190639_add_billing_preferences.sql |  8 
crates/collab/src/api/billing.rs                                    | 84 
crates/collab/src/db.rs                                             |  3 
crates/collab/src/db/ids.rs                                         |  1 
crates/collab/src/db/queries.rs                                     |  1 
crates/collab/src/db/queries/billing_preferences.rs                 | 75 
crates/collab/src/db/tables.rs                                      |  1 
crates/collab/src/db/tables/billing_preference.rs                   | 30 
crates/collab/src/llm.rs                                            |  6 
10 files changed, 216 insertions(+), 2 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -422,6 +422,15 @@ CREATE TABLE dev_server_projects (
     paths TEXT NOT NULL
 );
 
+CREATE TABLE IF NOT EXISTS billing_preferences (
+    id INTEGER PRIMARY KEY AUTOINCREMENT,
+    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+    user_id INTEGER NOT NULL REFERENCES users(id),
+    max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL
+);
+
+CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);
+
 CREATE TABLE IF NOT EXISTS billing_customers (
     id INTEGER PRIMARY KEY AUTOINCREMENT,
     created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,

crates/collab/migrations/20241009190639_add_billing_preferences.sql 🔗

@@ -0,0 +1,8 @@
+create table if not exists billing_preferences (
+    id serial primary key,
+    created_at timestamp without time zone not null default now(),
+    user_id integer not null references users(id) on delete cascade,
+    max_monthly_llm_usage_spending_in_cents integer not null
+);
+
+create unique index "uix_billing_preferences_on_user_id" on billing_preferences (user_id);

crates/collab/src/api/billing.rs 🔗

@@ -26,15 +26,19 @@ use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
 use crate::db::{
     billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
     CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
-    UpdateBillingSubscriptionParams,
+    UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
 };
 use crate::llm::db::LlmDatabase;
-use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
+use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
 use crate::rpc::ResultExt as _;
 use crate::{AppState, Error, Result};
 
 pub fn router() -> Router {
     Router::new()
+        .route(
+            "/billing/preferences",
+            get(get_billing_preferences).put(update_billing_preferences),
+        )
         .route(
             "/billing/subscriptions",
             get(list_billing_subscriptions).post(create_billing_subscription),
@@ -45,6 +49,82 @@ pub fn router() -> Router {
         )
 }
 
+#[derive(Debug, Deserialize)]
+struct GetBillingPreferencesParams {
+    github_user_id: i32,
+}
+
+#[derive(Debug, Serialize)]
+struct BillingPreferencesResponse {
+    max_monthly_llm_usage_spending_in_cents: i32,
+}
+
+async fn get_billing_preferences(
+    Extension(app): Extension<Arc<AppState>>,
+    Query(params): Query<GetBillingPreferencesParams>,
+) -> Result<Json<BillingPreferencesResponse>> {
+    let user = app
+        .db
+        .get_user_by_github_user_id(params.github_user_id)
+        .await?
+        .ok_or_else(|| anyhow!("user not found"))?;
+
+    let preferences = app.db.get_billing_preferences(user.id).await?;
+
+    Ok(Json(BillingPreferencesResponse {
+        max_monthly_llm_usage_spending_in_cents: preferences
+            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
+                preferences.max_monthly_llm_usage_spending_in_cents
+            }),
+    }))
+}
+
+#[derive(Debug, Deserialize)]
+struct UpdateBillingPreferencesBody {
+    github_user_id: i32,
+    max_monthly_llm_usage_spending_in_cents: i32,
+}
+
+async fn update_billing_preferences(
+    Extension(app): Extension<Arc<AppState>>,
+    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
+) -> Result<Json<BillingPreferencesResponse>> {
+    let user = app
+        .db
+        .get_user_by_github_user_id(body.github_user_id)
+        .await?
+        .ok_or_else(|| anyhow!("user not found"))?;
+
+    let billing_preferences =
+        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
+            app.db
+                .update_billing_preferences(
+                    user.id,
+                    &UpdateBillingPreferencesParams {
+                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
+                            body.max_monthly_llm_usage_spending_in_cents,
+                        ),
+                    },
+                )
+                .await?
+        } else {
+            app.db
+                .create_billing_preferences(
+                    user.id,
+                    &crate::db::CreateBillingPreferencesParams {
+                        max_monthly_llm_usage_spending_in_cents: body
+                            .max_monthly_llm_usage_spending_in_cents,
+                    },
+                )
+                .await?
+        };
+
+    Ok(Json(BillingPreferencesResponse {
+        max_monthly_llm_usage_spending_in_cents: billing_preferences
+            .max_monthly_llm_usage_spending_in_cents,
+    }))
+}
+
 #[derive(Debug, Deserialize)]
 struct ListBillingSubscriptionsParams {
     github_user_id: i32,

crates/collab/src/db.rs 🔗

@@ -42,6 +42,9 @@ pub use tests::TestDb;
 
 pub use ids::*;
 pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
+pub use queries::billing_preferences::{
+    CreateBillingPreferencesParams, UpdateBillingPreferencesParams,
+};
 pub use queries::billing_subscriptions::{
     CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
 };

crates/collab/src/db/ids.rs 🔗

@@ -72,6 +72,7 @@ macro_rules! id_type {
 id_type!(AccessTokenId);
 id_type!(BillingCustomerId);
 id_type!(BillingSubscriptionId);
+id_type!(BillingPreferencesId);
 id_type!(BufferId);
 id_type!(ChannelBufferCollaboratorId);
 id_type!(ChannelChatParticipantId);

crates/collab/src/db/queries.rs 🔗

@@ -2,6 +2,7 @@ use super::*;
 
 pub mod access_tokens;
 pub mod billing_customers;
+pub mod billing_preferences;
 pub mod billing_subscriptions;
 pub mod buffers;
 pub mod channels;

crates/collab/src/db/queries/billing_preferences.rs 🔗

@@ -0,0 +1,75 @@
+use super::*;
+
+#[derive(Debug)]
+pub struct CreateBillingPreferencesParams {
+    pub max_monthly_llm_usage_spending_in_cents: i32,
+}
+
+#[derive(Debug, Default)]
+pub struct UpdateBillingPreferencesParams {
+    pub max_monthly_llm_usage_spending_in_cents: ActiveValue<i32>,
+}
+
+impl Database {
+    /// Returns the billing preferences for the given user, if they exist.
+    pub async fn get_billing_preferences(
+        &self,
+        user_id: UserId,
+    ) -> Result<Option<billing_preference::Model>> {
+        self.transaction(|tx| async move {
+            Ok(billing_preference::Entity::find()
+                .filter(billing_preference::Column::UserId.eq(user_id))
+                .one(&*tx)
+                .await?)
+        })
+        .await
+    }
+
+    /// Creates new billing preferences for the given user.
+    pub async fn create_billing_preferences(
+        &self,
+        user_id: UserId,
+        params: &CreateBillingPreferencesParams,
+    ) -> Result<billing_preference::Model> {
+        self.transaction(|tx| async move {
+            let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
+                user_id: ActiveValue::set(user_id),
+                max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
+                    params.max_monthly_llm_usage_spending_in_cents,
+                ),
+                ..Default::default()
+            })
+            .exec_with_returning(&*tx)
+            .await?;
+
+            Ok(preferences)
+        })
+        .await
+    }
+
+    /// Updates the billing preferences for the given user.
+    pub async fn update_billing_preferences(
+        &self,
+        user_id: UserId,
+        params: &UpdateBillingPreferencesParams,
+    ) -> Result<billing_preference::Model> {
+        self.transaction(|tx| async move {
+            let preferences = billing_preference::Entity::update_many()
+                .set(billing_preference::ActiveModel {
+                    max_monthly_llm_usage_spending_in_cents: params
+                        .max_monthly_llm_usage_spending_in_cents
+                        .clone(),
+                    ..Default::default()
+                })
+                .filter(billing_preference::Column::UserId.eq(user_id))
+                .exec_with_returning(&*tx)
+                .await?;
+
+            Ok(preferences
+                .into_iter()
+                .next()
+                .ok_or_else(|| anyhow!("billing preferences not found"))?)
+        })
+        .await
+    }
+}

crates/collab/src/db/tables.rs 🔗

@@ -1,5 +1,6 @@
 pub mod access_token;
 pub mod billing_customer;
+pub mod billing_preference;
 pub mod billing_subscription;
 pub mod buffer;
 pub mod buffer_operation;

crates/collab/src/db/tables/billing_preference.rs 🔗

@@ -0,0 +1,30 @@
+use crate::db::{BillingPreferencesId, UserId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "billing_preferences")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: BillingPreferencesId,
+    pub created_at: DateTime,
+    pub user_id: UserId,
+    pub max_monthly_llm_usage_spending_in_cents: i32,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::user::Entity",
+        from = "Column::UserId",
+        to = "super::user::Column::Id"
+    )]
+    User,
+}
+
+impl Related<super::user::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::User.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm.rs 🔗

@@ -442,6 +442,12 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
 /// before they have to pay.
 pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(5);
 
+/// The default value to use for maximum spend per month if the user did not
+/// explicitly set a maximum spend.
+///
+/// Used to prevent surprise bills.
+pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
+
 /// The maximum lifetime spending an individual user can reach before being cut off.
 const LIFETIME_SPENDING_LIMIT: Cents = Cents::from_dollars(1_000);