Detailed changes
@@ -499,8 +499,10 @@ async fn manage_billing_subscription(
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
- let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
- let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
+ let zed_pro_price_id: stripe::PriceId =
+ stripe_billing.zed_pro_price_id().await?.try_into()?;
+ let zed_free_price_id: stripe::PriceId =
+ stripe_billing.zed_free_price_id().await?.try_into()?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
@@ -4,14 +4,16 @@ use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
-use stripe::{PriceId, SubscriptionStatus};
+use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
-use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
+use crate::stripe_client::{
+ RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
+};
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
@@ -22,8 +24,8 @@ pub struct StripeBilling {
#[derive(Default)]
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
- price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
- prices_by_lookup_key: HashMap<String, stripe::Price>,
+ price_ids_by_meter_id: HashMap<String, StripePriceId>,
+ prices_by_lookup_key: HashMap<String, StripePrice>,
}
impl StripeBilling {
@@ -50,24 +52,16 @@ impl StripeBilling {
let mut state = self.state.write().await;
- let (meters, prices) = futures::try_join!(
- StripeMeter::list(&self.real_client),
- stripe::Price::list(
- &self.real_client,
- &stripe::ListPrices {
- limit: Some(100),
- ..Default::default()
- }
- )
- )?;
+ let (meters, prices) =
+ futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
- for meter in meters.data {
+ for meter in meters {
state
.meters_by_event_name
.insert(meter.event_name.clone(), meter);
}
- for price in prices.data {
+ for price in prices {
if let Some(lookup_key) = price.lookup_key.clone() {
state.prices_by_lookup_key.insert(lookup_key, price.clone());
}
@@ -84,15 +78,15 @@ impl StripeBilling {
Ok(())
}
- pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
+ pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-pro").await
}
- pub async fn zed_free_price_id(&self) -> Result<PriceId> {
+ pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-free").await
}
- pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
+ pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
self.state
.read()
.await
@@ -102,7 +96,7 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
- pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
+ pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
self.state
.read()
.await
@@ -116,8 +110,10 @@ impl StripeBilling {
&self,
subscription: &stripe::Subscription,
) -> Option<SubscriptionKind> {
- let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
- let zed_free_price_id = self.zed_free_price_id().await.ok()?;
+ let zed_pro_price_id: stripe::PriceId =
+ self.zed_pro_price_id().await.ok()?.try_into().ok()?;
+ let zed_free_price_id: stripe::PriceId =
+ self.zed_free_price_id().await.ok()?.try_into().ok()?;
subscription.items.data.iter().find_map(|item| {
let price = item.price.as_ref()?;
@@ -171,12 +167,13 @@ impl StripeBilling {
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
- price: &stripe::Price,
+ price: &StripePrice,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
- if subscription_contains_price(&subscription, &price.id) {
+ let price_id = price.id.clone().try_into()?;
+ if subscription_contains_price(&subscription, &price_id) {
return Ok(());
}
@@ -375,24 +372,6 @@ impl StripeBilling {
}
}
-#[derive(Clone, Deserialize)]
-struct StripeMeter {
- id: String,
- event_name: String,
-}
-
-impl StripeMeter {
- pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
- #[derive(Serialize)]
- struct Params {
- #[serde(skip_serializing_if = "Option::is_none")]
- limit: Option<u64>,
- }
-
- client.get_query("/billing/meters", Params { limit: Some(100) })
- }
-}
-
#[derive(Deserialize)]
struct StripeMeterEvent {
identifier: String,
@@ -10,8 +10,9 @@ use async_trait::async_trait;
#[cfg(test)]
pub use fake_stripe_client::*;
pub use real_stripe_client::*;
+use serde::Deserialize;
-#[derive(Debug, PartialEq, Eq, Hash, Clone)]
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeCustomerId(pub Arc<str>);
#[derive(Debug, Clone)]
@@ -25,9 +26,38 @@ pub struct CreateCustomerParams<'a> {
pub email: Option<&'a str>,
}
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
+pub struct StripePriceId(pub Arc<str>);
+
+#[derive(Debug, Clone)]
+pub struct StripePrice {
+ pub id: StripePriceId,
+ pub unit_amount: Option<i64>,
+ pub lookup_key: Option<String>,
+ pub recurring: Option<StripePriceRecurring>,
+}
+
+#[derive(Debug, Clone)]
+pub struct StripePriceRecurring {
+ pub meter: Option<String>,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
+pub struct StripeMeterId(pub Arc<str>);
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct StripeMeter {
+ pub id: StripeMeterId,
+ pub event_name: String,
+}
+
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
+
+ async fn list_prices(&self) -> Result<Vec<StripePrice>>;
+
+ async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
}
@@ -6,16 +6,23 @@ use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
-use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
+use crate::stripe_client::{
+ CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
+ StripeMeterId, StripePrice, StripePriceId,
+};
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
+ pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
+ pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
}
impl FakeStripeClient {
pub fn new() -> Self {
Self {
customers: Arc::new(Mutex::new(HashMap::default())),
+ prices: Arc::new(Mutex::new(HashMap::default())),
+ meters: Arc::new(Mutex::new(HashMap::default())),
}
}
}
@@ -44,4 +51,16 @@ impl StripeClient for FakeStripeClient {
Ok(customer)
}
+
+ async fn list_prices(&self) -> Result<Vec<StripePrice>> {
+ let prices = self.prices.lock().values().cloned().collect();
+
+ Ok(prices)
+ }
+
+ async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
+ let meters = self.meters.lock().values().cloned().collect();
+
+ Ok(meters)
+ }
}
@@ -3,9 +3,13 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
-use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers};
+use serde::Serialize;
+use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring};
-use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
+use crate::stripe_client::{
+ CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
+ StripePriceId, StripePriceRecurring,
+};
pub struct RealStripeClient {
client: Arc<stripe::Client>,
@@ -48,6 +52,37 @@ impl StripeClient for RealStripeClient {
Ok(StripeCustomer::from(customer))
}
+
+ async fn list_prices(&self) -> Result<Vec<StripePrice>> {
+ let response = stripe::Price::list(
+ &self.client,
+ &stripe::ListPrices {
+ limit: Some(100),
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ Ok(response.data.into_iter().map(StripePrice::from).collect())
+ }
+
+ async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
+ #[derive(Serialize)]
+ struct Params {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ limit: Option<u64>,
+ }
+
+ let response = self
+ .client
+ .get_query::<stripe::List<StripeMeter>, _>(
+ "/billing/meters",
+ Params { limit: Some(100) },
+ )
+ .await?;
+
+ Ok(response.data)
+ }
}
impl From<CustomerId> for StripeCustomerId {
@@ -72,3 +107,34 @@ impl From<Customer> for StripeCustomer {
}
}
}
+
+impl From<PriceId> for StripePriceId {
+ fn from(value: PriceId) -> Self {
+ Self(value.as_str().into())
+ }
+}
+
+impl TryFrom<StripePriceId> for PriceId {
+ type Error = anyhow::Error;
+
+ fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
+ Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
+ }
+}
+
+impl From<Price> for StripePrice {
+ fn from(value: Price) -> Self {
+ Self {
+ id: value.id.into(),
+ unit_amount: value.unit_amount,
+ lookup_key: value.lookup_key,
+ recurring: value.recurring.map(StripePriceRecurring::from),
+ }
+ }
+}
+
+impl From<Recurring> for StripePriceRecurring {
+ fn from(value: Recurring) -> Self {
+ Self { meter: value.meter }
+ }
+}
@@ -3,7 +3,9 @@ use std::sync::Arc;
use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
-use crate::stripe_client::FakeStripeClient;
+use crate::stripe_client::{
+ FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
+};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
@@ -12,6 +14,87 @@ fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
(stripe_billing, stripe_client)
}
+#[gpui::test]
+async fn test_initialize() {
+ let (stripe_billing, stripe_client) = make_stripe_billing();
+
+ // Add test meters
+ let meter1 = StripeMeter {
+ id: StripeMeterId("meter_1".into()),
+ event_name: "event_1".to_string(),
+ };
+ let meter2 = StripeMeter {
+ id: StripeMeterId("meter_2".into()),
+ event_name: "event_2".to_string(),
+ };
+ stripe_client
+ .meters
+ .lock()
+ .insert(meter1.id.clone(), meter1);
+ stripe_client
+ .meters
+ .lock()
+ .insert(meter2.id.clone(), meter2);
+
+ // Add test prices
+ let price1 = StripePrice {
+ id: StripePriceId("price_1".into()),
+ unit_amount: Some(1_000),
+ lookup_key: Some("zed-pro".to_string()),
+ recurring: None,
+ };
+ let price2 = StripePrice {
+ id: StripePriceId("price_2".into()),
+ unit_amount: Some(0),
+ lookup_key: Some("zed-free".to_string()),
+ recurring: None,
+ };
+ let price3 = StripePrice {
+ id: StripePriceId("price_3".into()),
+ unit_amount: Some(500),
+ lookup_key: None,
+ recurring: Some(StripePriceRecurring {
+ meter: Some("meter_1".to_string()),
+ }),
+ };
+ stripe_client
+ .prices
+ .lock()
+ .insert(price1.id.clone(), price1);
+ stripe_client
+ .prices
+ .lock()
+ .insert(price2.id.clone(), price2);
+ stripe_client
+ .prices
+ .lock()
+ .insert(price3.id.clone(), price3);
+
+ // Initialize the billing system
+ stripe_billing.initialize().await.unwrap();
+
+ // Verify that prices can be found by lookup key
+ let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
+ assert_eq!(zed_pro_price_id.to_string(), "price_1");
+
+ let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
+ assert_eq!(zed_free_price_id.to_string(), "price_2");
+
+ // Verify that a price can be found by lookup key
+ let zed_pro_price = stripe_billing
+ .find_price_by_lookup_key("zed-pro")
+ .await
+ .unwrap();
+ assert_eq!(zed_pro_price.id.to_string(), "price_1");
+ assert_eq!(zed_pro_price.unit_amount, Some(1_000));
+
+ // Verify that finding a non-existent lookup key returns an error
+ let result = stripe_billing
+ .find_price_by_lookup_key("non-existent")
+ .await;
+ assert!(result.is_err());
+}
+
#[gpui::test]
async fn test_find_or_create_customer_by_email() {
let (stripe_billing, stripe_client) = make_stripe_billing();