diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 64cbe8422f413dd1a38b9fbb9cb4ceaa63b5f74f..028b3e8b43bd96ccbd7d0a203c9761add36d91e5 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -26,6 +26,7 @@ use crate::api::events::SnowflakeRow; use crate::db::billing_subscription::{ StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, }; +use crate::llm::db::subscription_usage_meter::CompletionMode; use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT}; use crate::rpc::{ResultExt as _, Server}; use crate::{AppState, Cents, Error, Result}; @@ -1372,6 +1373,9 @@ async fn sync_model_request_usage_with_stripe( let claude_3_7_sonnet = stripe_billing .find_price_by_lookup_key("claude-3-7-sonnet-requests") .await?; + let claude_3_7_sonnet_max = stripe_billing + .find_price_by_lookup_key("claude-3-7-sonnet-requests-max") + .await?; for (usage_meter, usage) in usage_meters { maybe!(async { @@ -1397,7 +1401,12 @@ async fn sync_model_request_usage_with_stripe( let (price_id, meter_event_name) = match model.name.as_str() { "claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"), - "claude-3-7-sonnet" => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"), + "claude-3-7-sonnet" => match usage_meter.mode { + CompletionMode::Normal => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"), + CompletionMode::Max => { + (&claude_3_7_sonnet_max.id, "claude_3_7_sonnet/requests/max") + } + }, model_name => { bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") } diff --git a/crates/collab/src/llm/db/tables/subscription_usage_meter.rs b/crates/collab/src/llm/db/tables/subscription_usage_meter.rs index a7241e8f95c9bcf500c3363c451bd8c873bd8880..02ed5c0877942e40dc46f9217e50d9a2b5f180b2 100644 --- a/crates/collab/src/llm/db/tables/subscription_usage_meter.rs +++ b/crates/collab/src/llm/db/tables/subscription_usage_meter.rs @@ -1,4 +1,5 @@ use sea_orm::entity::prelude::*; +use serde::Serialize; use crate::llm::db::ModelId; @@ -9,6 +10,7 @@ pub struct Model { pub id: i32, pub subscription_usage_id: i32, pub model_id: ModelId, + pub mode: CompletionMode, pub requests: i32, } @@ -41,3 +43,13 @@ impl Related for Entity { } impl ActiveModelBehavior for ActiveModel {} + +#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)] +#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] +#[serde(rename_all = "snake_case")] +pub enum CompletionMode { + #[sea_orm(string_value = "normal")] + Normal, + #[sea_orm(string_value = "max")] + Max, +}