collab: Use `StripeClient` to retrieve prices and meters from Stripe (#31624)

Marshall Bowers created

This PR updates `StripeBilling` to use the `StripeClient` trait to
retrieve prices and meters from Stripe instead of using the
`stripe::Client` directly.

Release Notes:

- N/A

Change summary

crates/collab/src/api/billing.rs                      |  6 
crates/collab/src/stripe_billing.rs                   | 63 +++------
crates/collab/src/stripe_client.rs                    | 32 ++++
crates/collab/src/stripe_client/fake_stripe_client.rs | 21 +++
crates/collab/src/stripe_client/real_stripe_client.rs | 70 ++++++++++
crates/collab/src/tests/stripe_billing_tests.rs       | 85 ++++++++++++
6 files changed, 228 insertions(+), 49 deletions(-)

Detailed changes

crates/collab/src/api/billing.rs 🔗

@@ -499,8 +499,10 @@ async fn manage_billing_subscription(
     let flow = match body.intent {
         ManageSubscriptionIntent::ManageSubscription => None,
         ManageSubscriptionIntent::UpgradeToPro => {
-            let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
-            let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
+            let zed_pro_price_id: stripe::PriceId =
+                stripe_billing.zed_pro_price_id().await?.try_into()?;
+            let zed_free_price_id: stripe::PriceId =
+                stripe_billing.zed_free_price_id().await?.try_into()?;
 
             let stripe_subscription =
                 Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;

crates/collab/src/stripe_billing.rs 🔗

@@ -4,14 +4,16 @@ use anyhow::{Context as _, anyhow};
 use chrono::Utc;
 use collections::HashMap;
 use serde::{Deserialize, Serialize};
-use stripe::{PriceId, SubscriptionStatus};
+use stripe::SubscriptionStatus;
 use tokio::sync::RwLock;
 use uuid::Uuid;
 
 use crate::Result;
 use crate::db::billing_subscription::SubscriptionKind;
 use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
-use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
+use crate::stripe_client::{
+    RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
+};
 
 pub struct StripeBilling {
     state: RwLock<StripeBillingState>,
@@ -22,8 +24,8 @@ pub struct StripeBilling {
 #[derive(Default)]
 struct StripeBillingState {
     meters_by_event_name: HashMap<String, StripeMeter>,
-    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
-    prices_by_lookup_key: HashMap<String, stripe::Price>,
+    price_ids_by_meter_id: HashMap<String, StripePriceId>,
+    prices_by_lookup_key: HashMap<String, StripePrice>,
 }
 
 impl StripeBilling {
@@ -50,24 +52,16 @@ impl StripeBilling {
 
         let mut state = self.state.write().await;
 
-        let (meters, prices) = futures::try_join!(
-            StripeMeter::list(&self.real_client),
-            stripe::Price::list(
-                &self.real_client,
-                &stripe::ListPrices {
-                    limit: Some(100),
-                    ..Default::default()
-                }
-            )
-        )?;
+        let (meters, prices) =
+            futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
 
-        for meter in meters.data {
+        for meter in meters {
             state
                 .meters_by_event_name
                 .insert(meter.event_name.clone(), meter);
         }
 
-        for price in prices.data {
+        for price in prices {
             if let Some(lookup_key) = price.lookup_key.clone() {
                 state.prices_by_lookup_key.insert(lookup_key, price.clone());
             }
@@ -84,15 +78,15 @@ impl StripeBilling {
         Ok(())
     }
 
-    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
+    pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
         self.find_price_id_by_lookup_key("zed-pro").await
     }
 
-    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
+    pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
         self.find_price_id_by_lookup_key("zed-free").await
     }
 
-    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
+    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
         self.state
             .read()
             .await
@@ -102,7 +96,7 @@ impl StripeBilling {
             .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
     }
 
-    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
+    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
         self.state
             .read()
             .await
@@ -116,8 +110,10 @@ impl StripeBilling {
         &self,
         subscription: &stripe::Subscription,
     ) -> Option<SubscriptionKind> {
-        let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
-        let zed_free_price_id = self.zed_free_price_id().await.ok()?;
+        let zed_pro_price_id: stripe::PriceId =
+            self.zed_pro_price_id().await.ok()?.try_into().ok()?;
+        let zed_free_price_id: stripe::PriceId =
+            self.zed_free_price_id().await.ok()?.try_into().ok()?;
 
         subscription.items.data.iter().find_map(|item| {
             let price = item.price.as_ref()?;
@@ -171,12 +167,13 @@ impl StripeBilling {
     pub async fn subscribe_to_price(
         &self,
         subscription_id: &stripe::SubscriptionId,
-        price: &stripe::Price,
+        price: &StripePrice,
     ) -> Result<()> {
         let subscription =
             stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
 
-        if subscription_contains_price(&subscription, &price.id) {
+        let price_id = price.id.clone().try_into()?;
+        if subscription_contains_price(&subscription, &price_id) {
             return Ok(());
         }
 
@@ -375,24 +372,6 @@ impl StripeBilling {
     }
 }
 
-#[derive(Clone, Deserialize)]
-struct StripeMeter {
-    id: String,
-    event_name: String,
-}
-
-impl StripeMeter {
-    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
-        #[derive(Serialize)]
-        struct Params {
-            #[serde(skip_serializing_if = "Option::is_none")]
-            limit: Option<u64>,
-        }
-
-        client.get_query("/billing/meters", Params { limit: Some(100) })
-    }
-}
-
 #[derive(Deserialize)]
 struct StripeMeterEvent {
     identifier: String,

crates/collab/src/stripe_client.rs 🔗

@@ -10,8 +10,9 @@ use async_trait::async_trait;
 #[cfg(test)]
 pub use fake_stripe_client::*;
 pub use real_stripe_client::*;
+use serde::Deserialize;
 
-#[derive(Debug, PartialEq, Eq, Hash, Clone)]
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
 pub struct StripeCustomerId(pub Arc<str>);
 
 #[derive(Debug, Clone)]
@@ -25,9 +26,38 @@ pub struct CreateCustomerParams<'a> {
     pub email: Option<&'a str>,
 }
 
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
+pub struct StripePriceId(pub Arc<str>);
+
+#[derive(Debug, Clone)]
+pub struct StripePrice {
+    pub id: StripePriceId,
+    pub unit_amount: Option<i64>,
+    pub lookup_key: Option<String>,
+    pub recurring: Option<StripePriceRecurring>,
+}
+
+#[derive(Debug, Clone)]
+pub struct StripePriceRecurring {
+    pub meter: Option<String>,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
+pub struct StripeMeterId(pub Arc<str>);
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct StripeMeter {
+    pub id: StripeMeterId,
+    pub event_name: String,
+}
+
 #[async_trait]
 pub trait StripeClient: Send + Sync {
     async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
 
     async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
+
+    async fn list_prices(&self) -> Result<Vec<StripePrice>>;
+
+    async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
 }

crates/collab/src/stripe_client/fake_stripe_client.rs 🔗

@@ -6,16 +6,23 @@ use collections::HashMap;
 use parking_lot::Mutex;
 use uuid::Uuid;
 
-use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
+use crate::stripe_client::{
+    CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
+    StripeMeterId, StripePrice, StripePriceId,
+};
 
 pub struct FakeStripeClient {
     pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
+    pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
+    pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
 }
 
 impl FakeStripeClient {
     pub fn new() -> Self {
         Self {
             customers: Arc::new(Mutex::new(HashMap::default())),
+            prices: Arc::new(Mutex::new(HashMap::default())),
+            meters: Arc::new(Mutex::new(HashMap::default())),
         }
     }
 }
@@ -44,4 +51,16 @@ impl StripeClient for FakeStripeClient {
 
         Ok(customer)
     }
+
+    async fn list_prices(&self) -> Result<Vec<StripePrice>> {
+        let prices = self.prices.lock().values().cloned().collect();
+
+        Ok(prices)
+    }
+
+    async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
+        let meters = self.meters.lock().values().cloned().collect();
+
+        Ok(meters)
+    }
 }

crates/collab/src/stripe_client/real_stripe_client.rs 🔗

@@ -3,9 +3,13 @@ use std::sync::Arc;
 
 use anyhow::{Context as _, Result};
 use async_trait::async_trait;
-use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers};
+use serde::Serialize;
+use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring};
 
-use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
+use crate::stripe_client::{
+    CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
+    StripePriceId, StripePriceRecurring,
+};
 
 pub struct RealStripeClient {
     client: Arc<stripe::Client>,
@@ -48,6 +52,37 @@ impl StripeClient for RealStripeClient {
 
         Ok(StripeCustomer::from(customer))
     }
+
+    async fn list_prices(&self) -> Result<Vec<StripePrice>> {
+        let response = stripe::Price::list(
+            &self.client,
+            &stripe::ListPrices {
+                limit: Some(100),
+                ..Default::default()
+            },
+        )
+        .await?;
+
+        Ok(response.data.into_iter().map(StripePrice::from).collect())
+    }
+
+    async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
+        #[derive(Serialize)]
+        struct Params {
+            #[serde(skip_serializing_if = "Option::is_none")]
+            limit: Option<u64>,
+        }
+
+        let response = self
+            .client
+            .get_query::<stripe::List<StripeMeter>, _>(
+                "/billing/meters",
+                Params { limit: Some(100) },
+            )
+            .await?;
+
+        Ok(response.data)
+    }
 }
 
 impl From<CustomerId> for StripeCustomerId {
@@ -72,3 +107,34 @@ impl From<Customer> for StripeCustomer {
         }
     }
 }
+
+impl From<PriceId> for StripePriceId {
+    fn from(value: PriceId) -> Self {
+        Self(value.as_str().into())
+    }
+}
+
+impl TryFrom<StripePriceId> for PriceId {
+    type Error = anyhow::Error;
+
+    fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
+        Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
+    }
+}
+
+impl From<Price> for StripePrice {
+    fn from(value: Price) -> Self {
+        Self {
+            id: value.id.into(),
+            unit_amount: value.unit_amount,
+            lookup_key: value.lookup_key,
+            recurring: value.recurring.map(StripePriceRecurring::from),
+        }
+    }
+}
+
+impl From<Recurring> for StripePriceRecurring {
+    fn from(value: Recurring) -> Self {
+        Self { meter: value.meter }
+    }
+}

crates/collab/src/tests/stripe_billing_tests.rs 🔗

@@ -3,7 +3,9 @@ use std::sync::Arc;
 use pretty_assertions::assert_eq;
 
 use crate::stripe_billing::StripeBilling;
-use crate::stripe_client::FakeStripeClient;
+use crate::stripe_client::{
+    FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
+};
 
 fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
     let stripe_client = Arc::new(FakeStripeClient::new());
@@ -12,6 +14,87 @@ fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
     (stripe_billing, stripe_client)
 }
 
+#[gpui::test]
+async fn test_initialize() {
+    let (stripe_billing, stripe_client) = make_stripe_billing();
+
+    // Add test meters
+    let meter1 = StripeMeter {
+        id: StripeMeterId("meter_1".into()),
+        event_name: "event_1".to_string(),
+    };
+    let meter2 = StripeMeter {
+        id: StripeMeterId("meter_2".into()),
+        event_name: "event_2".to_string(),
+    };
+    stripe_client
+        .meters
+        .lock()
+        .insert(meter1.id.clone(), meter1);
+    stripe_client
+        .meters
+        .lock()
+        .insert(meter2.id.clone(), meter2);
+
+    // Add test prices
+    let price1 = StripePrice {
+        id: StripePriceId("price_1".into()),
+        unit_amount: Some(1_000),
+        lookup_key: Some("zed-pro".to_string()),
+        recurring: None,
+    };
+    let price2 = StripePrice {
+        id: StripePriceId("price_2".into()),
+        unit_amount: Some(0),
+        lookup_key: Some("zed-free".to_string()),
+        recurring: None,
+    };
+    let price3 = StripePrice {
+        id: StripePriceId("price_3".into()),
+        unit_amount: Some(500),
+        lookup_key: None,
+        recurring: Some(StripePriceRecurring {
+            meter: Some("meter_1".to_string()),
+        }),
+    };
+    stripe_client
+        .prices
+        .lock()
+        .insert(price1.id.clone(), price1);
+    stripe_client
+        .prices
+        .lock()
+        .insert(price2.id.clone(), price2);
+    stripe_client
+        .prices
+        .lock()
+        .insert(price3.id.clone(), price3);
+
+    // Initialize the billing system
+    stripe_billing.initialize().await.unwrap();
+
+    // Verify that prices can be found by lookup key
+    let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
+    assert_eq!(zed_pro_price_id.to_string(), "price_1");
+
+    let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
+    assert_eq!(zed_free_price_id.to_string(), "price_2");
+
+    // Verify that a price can be found by lookup key
+    let zed_pro_price = stripe_billing
+        .find_price_by_lookup_key("zed-pro")
+        .await
+        .unwrap();
+    assert_eq!(zed_pro_price.id.to_string(), "price_1");
+    assert_eq!(zed_pro_price.unit_amount, Some(1_000));
+
+    // Verify that finding a non-existent lookup key returns an error
+    let result = stripe_billing
+        .find_price_by_lookup_key("non-existent")
+        .await;
+    assert!(result.is_err());
+}
+
 #[gpui::test]
 async fn test_find_or_create_customer_by_email() {
     let (stripe_billing, stripe_client) = make_stripe_billing();