collab: Use `StripeClient` for creating model usage meter events (#31633)

Marshall Bowers created

This PR updates the `StripeBilling::bill_model_request_usage` method to
use the `StripeClient` trait.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                      | 15 +-
crates/collab/src/stripe_billing.rs                   | 64 +-----------
crates/collab/src/stripe_client.rs                    | 20 +++
crates/collab/src/stripe_client/fake_stripe_client.rs | 31 +++++
crates/collab/src/stripe_client/real_stripe_client.rs | 29 +++++
crates/collab/src/tests/stripe_billing_tests.rs       | 37 ++++++
6 files changed, 119 insertions(+), 77 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -29,6 +29,7 @@ use crate::db::billing_subscription::{
 use crate::llm::db::subscription_usage_meter::CompletionMode;
 use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
 use crate::rpc::{ResultExt as _, Server};
+use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId};
 use crate::{AppState, Error, Result};
 use crate::{db::UserId, llm::db::LlmDatabase};
 use crate::{
@@ -1545,14 +1546,10 @@ async fn sync_model_request_usage_with_stripe(
                 );
             };
 
-            let stripe_customer_id = billing_customer
-                .stripe_customer_id
-                .parse::<stripe::CustomerId>()
-                .context("failed to parse Stripe customer ID from database")?;
-            let stripe_subscription_id = billing_subscription
-                .stripe_subscription_id
-                .parse::<stripe::SubscriptionId>()
-                .context("failed to parse Stripe subscription ID from database")?;
+            let stripe_customer_id =
+                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
+            let stripe_subscription_id =
+                StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
 
             let model = llm_db.model_by_id(usage_meter.model_id)?;
 
@@ -1578,7 +1575,7 @@ async fn sync_model_request_usage_with_stripe(
             };
 
             stripe_billing
-                .subscribe_to_price(&stripe_subscription_id.into(), price)
+                .subscribe_to_price(&stripe_subscription_id, price)
                 .await?;
             stripe_billing
                 .bill_model_request_usage(

crates/collab/src/stripe_billing.rs 🔗

@@ -3,7 +3,6 @@ use std::sync::Arc;
 use anyhow::{Context as _, anyhow};
 use chrono::Utc;
 use collections::HashMap;
-use serde::{Deserialize, Serialize};
 use stripe::SubscriptionStatus;
 use tokio::sync::RwLock;
 use uuid::Uuid;
@@ -12,8 +11,9 @@ use crate::Result;
 use crate::db::billing_subscription::SubscriptionKind;
 use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 use crate::stripe_client::{
-    RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
-    StripeSubscription, StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
+    RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload,
+    StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
+    StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
     UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
     UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
 };
@@ -204,16 +204,15 @@ impl StripeBilling {
 
     pub async fn bill_model_request_usage(
         &self,
-        customer_id: &stripe::CustomerId,
+        customer_id: &StripeCustomerId,
         event_name: &str,
         requests: i32,
     ) -> Result<()> {
         let timestamp = Utc::now().timestamp();
         let idempotency_key = Uuid::new_v4();
 
-        StripeMeterEvent::create(
-            &self.real_client,
-            StripeCreateMeterEventParams {
+        self.client
+            .create_meter_event(StripeCreateMeterEventParams {
                 identifier: &format!("model_requests/{}", idempotency_key),
                 event_name,
                 payload: StripeCreateMeterEventPayload {
@@ -221,9 +220,8 @@ impl StripeBilling {
                     stripe_customer_id: customer_id,
                 },
                 timestamp: Some(timestamp),
-            },
-        )
-        .await?;
+            })
+            .await?;
 
         Ok(())
     }
@@ -371,52 +369,6 @@ impl StripeBilling {
     }
 }
 
-#[derive(Deserialize)]
-struct StripeMeterEvent {
-    identifier: String,
-}
-
-impl StripeMeterEvent {
-    pub async fn create(
-        client: &stripe::Client,
-        params: StripeCreateMeterEventParams<'_>,
-    ) -> Result<Self, stripe::StripeError> {
-        let identifier = params.identifier;
-        match client.post_form("/billing/meter_events", params).await {
-            Ok(event) => Ok(event),
-            Err(stripe::StripeError::Stripe(error)) => {
-                if error.http_status == 400
-                    && error
-                        .message
-                        .as_ref()
-                        .map_or(false, |message| message.contains(identifier))
-                {
-                    Ok(Self {
-                        identifier: identifier.to_string(),
-                    })
-                } else {
-                    Err(stripe::StripeError::Stripe(error))
-                }
-            }
-            Err(error) => Err(error),
-        }
-    }
-}
-
-#[derive(Serialize)]
-struct StripeCreateMeterEventParams<'a> {
-    identifier: &'a str,
-    event_name: &'a str,
-    payload: StripeCreateMeterEventPayload<'a>,
-    timestamp: Option<i64>,
-}
-
-#[derive(Serialize)]
-struct StripeCreateMeterEventPayload<'a> {
-    value: u64,
-    stripe_customer_id: &'a stripe::CustomerId,
-}
-
 fn subscription_contains_price(
     subscription: &StripeSubscription,
     price_id: &StripePriceId,

crates/collab/src/stripe_client.rs 🔗

@@ -10,9 +10,9 @@ use async_trait::async_trait;
 #[cfg(test)]
 pub use fake_stripe_client::*;
 pub use real_stripe_client::*;
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
 
-#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
 pub struct StripeCustomerId(pub Arc<str>);
 
 #[derive(Debug, Clone)]
@@ -97,6 +97,20 @@ pub struct StripeMeter {
     pub event_name: String,
 }
 
+#[derive(Debug, Serialize)]
+pub struct StripeCreateMeterEventParams<'a> {
+    pub identifier: &'a str,
+    pub event_name: &'a str,
+    pub payload: StripeCreateMeterEventPayload<'a>,
+    pub timestamp: Option<i64>,
+}
+
+#[derive(Debug, Serialize)]
+pub struct StripeCreateMeterEventPayload<'a> {
+    pub value: u64,
+    pub stripe_customer_id: &'a StripeCustomerId,
+}
+
 #[async_trait]
 pub trait StripeClient: Send + Sync {
     async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
@@ -117,4 +131,6 @@ pub trait StripeClient: Send + Sync {
     async fn list_prices(&self) -> Result<Vec<StripePrice>>;
 
     async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
+
+    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
 }

crates/collab/src/stripe_client/fake_stripe_client.rs 🔗

@@ -7,11 +7,20 @@ use parking_lot::Mutex;
 use uuid::Uuid;
 
 use crate::stripe_client::{
-    CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
-    StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
-    UpdateSubscriptionParams,
+    CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
+    StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
+    StripeSubscriptionId, UpdateSubscriptionParams,
 };
 
+#[derive(Debug, Clone)]
+pub struct StripeCreateMeterEventCall {
+    pub identifier: Arc<str>,
+    pub event_name: Arc<str>,
+    pub value: u64,
+    pub stripe_customer_id: StripeCustomerId,
+    pub timestamp: Option<i64>,
+}
+
 pub struct FakeStripeClient {
     pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
     pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
@@ -19,6 +28,7 @@ pub struct FakeStripeClient {
         Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
     pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
     pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
+    pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
 }
 
 impl FakeStripeClient {
@@ -29,6 +39,7 @@ impl FakeStripeClient {
             update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
             prices: Arc::new(Mutex::new(HashMap::default())),
             meters: Arc::new(Mutex::new(HashMap::default())),
+            create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
         }
     }
 }
@@ -94,4 +105,18 @@ impl StripeClient for FakeStripeClient {
 
         Ok(meters)
     }
+
+    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
+        self.create_meter_event_calls
+            .lock()
+            .push(StripeCreateMeterEventCall {
+                identifier: params.identifier.into(),
+                event_name: params.event_name.into(),
+                value: params.payload.value,
+                stripe_customer_id: params.payload.stripe_customer_id.clone(),
+                timestamp: params.timestamp,
+            });
+
+        Ok(())
+    }
 }

crates/collab/src/stripe_client/real_stripe_client.rs 🔗

@@ -1,7 +1,7 @@
 use std::str::FromStr as _;
 use std::sync::Arc;
 
-use anyhow::{Context as _, Result};
+use anyhow::{Context as _, Result, anyhow};
 use async_trait::async_trait;
 use serde::Serialize;
 use stripe::{
@@ -12,9 +12,10 @@ use stripe::{
 };
 
 use crate::stripe_client::{
-    CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
-    StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
-    StripeSubscriptionItem, StripeSubscriptionItemId, UpdateSubscriptionParams,
+    CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
+    StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
+    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
+    UpdateSubscriptionParams,
 };
 
 pub struct RealStripeClient {
@@ -129,6 +130,26 @@ impl StripeClient for RealStripeClient {
 
         Ok(response.data)
     }
+
+    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
+        let identifier = params.identifier;
+        match self.client.post_form("/billing/meter_events", params).await {
+            Ok(event) => Ok(event),
+            Err(stripe::StripeError::Stripe(error)) => {
+                if error.http_status == 400
+                    && error
+                        .message
+                        .as_ref()
+                        .map_or(false, |message| message.contains(identifier))
+                {
+                    Ok(())
+                } else {
+                    Err(anyhow!(stripe::StripeError::Stripe(error)))
+                }
+            }
+            Err(error) => Err(anyhow!(error)),
+        }
+    }
 }
 
 impl From<CustomerId> for StripeCustomerId {

crates/collab/src/tests/stripe_billing_tests.rs 🔗

@@ -4,9 +4,9 @@ use pretty_assertions::assert_eq;
 
 use crate::stripe_billing::StripeBilling;
 use crate::stripe_client::{
-    FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
-    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
-    UpdateSubscriptionItems,
+    FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
+    StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
+    StripeSubscriptionItemId, UpdateSubscriptionItems,
 };
 
 fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
@@ -210,3 +210,34 @@ async fn test_subscribe_to_price() {
         assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
     }
 }
+
+#[gpui::test]
+async fn test_bill_model_request_usage() {
+    let (stripe_billing, stripe_client) = make_stripe_billing();
+
+    let customer_id = StripeCustomerId("cus_test".into());
+
+    stripe_billing
+        .bill_model_request_usage(&customer_id, "some_model/requests", 73)
+        .await
+        .unwrap();
+
+    let create_meter_event_calls = stripe_client
+        .create_meter_event_calls
+        .lock()
+        .iter()
+        .cloned()
+        .collect::<Vec<_>>();
+    assert_eq!(create_meter_event_calls.len(), 1);
+    assert!(
+        create_meter_event_calls[0]
+            .identifier
+            .starts_with("model_requests/")
+    );
+    assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
+    assert_eq!(
+        create_meter_event_calls[0].event_name.as_ref(),
+        "some_model/requests"
+    );
+    assert_eq!(create_meter_event_calls[0].value, 73);
+}