Detailed changes
@@ -20,6 +20,7 @@ test-support = ["sqlite"]
[dependencies]
anyhow.workspace = true
async-stripe.workspace = true
+async-trait.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" }
@@ -344,6 +344,7 @@ async fn create_billing_subscription(
stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?
+ .try_into()?
};
let success_url = format!(
@@ -9,6 +9,7 @@ pub mod migrations;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
+pub mod stripe_client;
pub mod user_backfiller;
#[cfg(test)]
@@ -4039,7 +4039,8 @@ async fn get_llm_api_token(
} else {
let customer_id = stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
- .await?;
+ .await?
+ .try_into()?;
find_or_create_billing_customer(
&session.app_state,
@@ -1,19 +1,22 @@
use std::sync::Arc;
-use crate::Result;
-use crate::db::billing_subscription::SubscriptionKind;
-use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
-use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
+use stripe::{PriceId, SubscriptionStatus};
use tokio::sync::RwLock;
use uuid::Uuid;
+use crate::Result;
+use crate::db::billing_subscription::SubscriptionKind;
+use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
+use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
+
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
- client: Arc<stripe::Client>,
+ real_client: Arc<stripe::Client>,
+ client: Arc<dyn StripeClient>,
}
#[derive(Default)]
@@ -26,6 +29,17 @@ struct StripeBillingState {
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
+ client: Arc::new(RealStripeClient::new(client.clone())),
+ real_client: client,
+ state: RwLock::default(),
+ }
+ }
+
+ #[cfg(test)]
+ pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
+ Self {
+ // This is just temporary until we can remove all usages of the real Stripe client.
+ real_client: Arc::new(stripe::Client::new("sk_test")),
client,
state: RwLock::default(),
}
@@ -37,9 +51,9 @@ impl StripeBilling {
let mut state = self.state.write().await;
let (meters, prices) = futures::try_join!(
- StripeMeter::list(&self.client),
+ StripeMeter::list(&self.real_client),
stripe::Price::list(
- &self.client,
+ &self.real_client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
@@ -129,18 +143,11 @@ impl StripeBilling {
pub async fn find_or_create_customer_by_email(
&self,
email_address: Option<&str>,
- ) -> Result<CustomerId> {
+ ) -> Result<StripeCustomerId> {
let existing_customer = if let Some(email) = email_address {
- let customers = Customer::list(
- &self.client,
- &stripe::ListCustomers {
- email: Some(email),
- ..Default::default()
- },
- )
- .await?;
+ let customers = self.client.list_customers_by_email(email).await?;
- customers.data.first().cloned()
+ customers.first().cloned()
} else {
None
};
@@ -148,14 +155,12 @@ impl StripeBilling {
let customer_id = if let Some(existing_customer) = existing_customer {
existing_customer.id
} else {
- let customer = Customer::create(
- &self.client,
- CreateCustomer {
+ let customer = self
+ .client
+ .create_customer(crate::stripe_client::CreateCustomerParams {
email: email_address,
- ..Default::default()
- },
- )
- .await?;
+ })
+ .await?;
customer.id
};
@@ -169,7 +174,7 @@ impl StripeBilling {
price: &stripe::Price,
) -> Result<()> {
let subscription =
- stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
+ stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, &price.id) {
return Ok(());
@@ -181,7 +186,7 @@ impl StripeBilling {
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
stripe::Subscription::update(
- &self.client,
+ &self.real_client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
@@ -211,7 +216,7 @@ impl StripeBilling {
let idempotency_key = Uuid::new_v4();
StripeMeterEvent::create(
- &self.client,
+ &self.real_client,
StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
@@ -246,7 +251,7 @@ impl StripeBilling {
}]);
params.success_url = Some(success_url);
- let session = stripe::CheckoutSession::create(&self.client, params).await?;
+ let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
@@ -300,7 +305,7 @@ impl StripeBilling {
}]);
params.success_url = Some(success_url);
- let session = stripe::CheckoutSession::create(&self.client, params).await?;
+ let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
@@ -311,7 +316,7 @@ impl StripeBilling {
let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = stripe::Subscription::list(
- &self.client,
+ &self.real_client,
&stripe::ListSubscriptions {
customer: Some(customer_id.clone()),
status: None,
@@ -339,7 +344,7 @@ impl StripeBilling {
..Default::default()
}]);
- let subscription = stripe::Subscription::create(&self.client, params).await?;
+ let subscription = stripe::Subscription::create(&self.real_client, params).await?;
Ok(subscription)
}
@@ -365,7 +370,7 @@ impl StripeBilling {
}]);
params.success_url = Some(success_url);
- let session = stripe::CheckoutSession::create(&self.client, params).await?;
+ let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
}
@@ -0,0 +1,33 @@
+#[cfg(test)]
+mod fake_stripe_client;
+mod real_stripe_client;
+
+use std::sync::Arc;
+
+use anyhow::Result;
+use async_trait::async_trait;
+
+#[cfg(test)]
+pub use fake_stripe_client::*;
+pub use real_stripe_client::*;
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone)]
+pub struct StripeCustomerId(pub Arc<str>);
+
+#[derive(Debug, Clone)]
+pub struct StripeCustomer {
+ pub id: StripeCustomerId,
+ pub email: Option<String>,
+}
+
+#[derive(Debug)]
+pub struct CreateCustomerParams<'a> {
+ pub email: Option<&'a str>,
+}
+
+#[async_trait]
+pub trait StripeClient: Send + Sync {
+ async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
+
+ async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
+}
@@ -0,0 +1,47 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use async_trait::async_trait;
+use collections::HashMap;
+use parking_lot::Mutex;
+use uuid::Uuid;
+
+use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
+
+pub struct FakeStripeClient {
+ pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
+}
+
+impl FakeStripeClient {
+ pub fn new() -> Self {
+ Self {
+ customers: Arc::new(Mutex::new(HashMap::default())),
+ }
+ }
+}
+
+#[async_trait]
+impl StripeClient for FakeStripeClient {
+ async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
+ Ok(self
+ .customers
+ .lock()
+ .values()
+ .filter(|customer| customer.email.as_deref() == Some(email))
+ .cloned()
+ .collect())
+ }
+
+ async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
+ let customer = StripeCustomer {
+ id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
+ email: params.email.map(|email| email.to_string()),
+ };
+
+ self.customers
+ .lock()
+ .insert(customer.id.clone(), customer.clone());
+
+ Ok(customer)
+ }
+}
@@ -0,0 +1,74 @@
+use std::str::FromStr as _;
+use std::sync::Arc;
+
+use anyhow::{Context as _, Result};
+use async_trait::async_trait;
+use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers};
+
+use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
+
+pub struct RealStripeClient {
+ client: Arc<stripe::Client>,
+}
+
+impl RealStripeClient {
+ pub fn new(client: Arc<stripe::Client>) -> Self {
+ Self { client }
+ }
+}
+
+#[async_trait]
+impl StripeClient for RealStripeClient {
+ async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
+ let response = Customer::list(
+ &self.client,
+ &ListCustomers {
+ email: Some(email),
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ Ok(response
+ .data
+ .into_iter()
+ .map(StripeCustomer::from)
+ .collect())
+ }
+
+ async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
+ let customer = Customer::create(
+ &self.client,
+ CreateCustomer {
+ email: params.email,
+ ..Default::default()
+ },
+ )
+ .await?;
+
+ Ok(StripeCustomer::from(customer))
+ }
+}
+
+impl From<CustomerId> for StripeCustomerId {
+ fn from(value: CustomerId) -> Self {
+ Self(value.as_str().into())
+ }
+}
+
+impl TryFrom<StripeCustomerId> for CustomerId {
+ type Error = anyhow::Error;
+
+ fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
+ Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
+ }
+}
+
+impl From<Customer> for StripeCustomer {
+ fn from(value: Customer) -> Self {
+ StripeCustomer {
+ id: value.id.into(),
+ email: value.email,
+ }
+ }
+}
@@ -18,6 +18,7 @@ mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
mod remote_editing_collaboration_tests;
+mod stripe_billing_tests;
mod test_server;
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
@@ -0,0 +1,60 @@
+use std::sync::Arc;
+
+use pretty_assertions::assert_eq;
+
+use crate::stripe_billing::StripeBilling;
+use crate::stripe_client::FakeStripeClient;
+
+fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
+ let stripe_client = Arc::new(FakeStripeClient::new());
+ let stripe_billing = StripeBilling::test(stripe_client.clone());
+
+ (stripe_billing, stripe_client)
+}
+
+#[gpui::test]
+async fn test_find_or_create_customer_by_email() {
+ let (stripe_billing, stripe_client) = make_stripe_billing();
+
+ // Create a customer with an email that doesn't yet correspond to a customer.
+ {
+ let email = "user@example.com";
+
+ let customer_id = stripe_billing
+ .find_or_create_customer_by_email(Some(email))
+ .await
+ .unwrap();
+
+ let customer = stripe_client
+ .customers
+ .lock()
+ .get(&customer_id)
+ .unwrap()
+ .clone();
+ assert_eq!(customer.email.as_deref(), Some(email));
+ }
+
+ // Create a customer with an email that corresponds to an existing customer.
+ {
+ let email = "user2@example.com";
+
+ let existing_customer_id = stripe_billing
+ .find_or_create_customer_by_email(Some(email))
+ .await
+ .unwrap();
+
+ let customer_id = stripe_billing
+ .find_or_create_customer_by_email(Some(email))
+ .await
+ .unwrap();
+ assert_eq!(customer_id, existing_customer_id);
+
+ let customer = stripe_client
+ .customers
+ .lock()
+ .get(&customer_id)
+ .unwrap()
+ .clone();
+ assert_eq!(customer.email.as_deref(), Some(email));
+ }
+}