stripe_billing.rs

  1use std::sync::Arc;
  2
  3use anyhow::anyhow;
  4use collections::HashMap;
  5use stripe::SubscriptionStatus;
  6use tokio::sync::RwLock;
  7
  8use crate::Result;
  9use crate::stripe_client::{
 10    RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems,
 11    StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId,
 12    StripeSubscription,
 13};
 14
 15pub struct StripeBilling {
 16    state: RwLock<StripeBillingState>,
 17    client: Arc<dyn StripeClient>,
 18}
 19
 20#[derive(Default)]
 21struct StripeBillingState {
 22    prices_by_lookup_key: HashMap<String, StripePrice>,
 23}
 24
 25impl StripeBilling {
 26    pub fn new(client: Arc<stripe::Client>) -> Self {
 27        Self {
 28            client: Arc::new(RealStripeClient::new(client.clone())),
 29            state: RwLock::default(),
 30        }
 31    }
 32
 33    #[cfg(test)]
 34    pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
 35        Self {
 36            client,
 37            state: RwLock::default(),
 38        }
 39    }
 40
 41    pub fn client(&self) -> &Arc<dyn StripeClient> {
 42        &self.client
 43    }
 44
 45    pub async fn initialize(&self) -> Result<()> {
 46        log::info!("StripeBilling: initializing");
 47
 48        let mut state = self.state.write().await;
 49
 50        let prices = self.client.list_prices().await?;
 51
 52        for price in prices {
 53            if let Some(lookup_key) = price.lookup_key.clone() {
 54                state.prices_by_lookup_key.insert(lookup_key, price);
 55            }
 56        }
 57
 58        log::info!("StripeBilling: initialized");
 59
 60        Ok(())
 61    }
 62
 63    pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
 64        self.find_price_id_by_lookup_key("zed-pro").await
 65    }
 66
 67    pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
 68        self.find_price_id_by_lookup_key("zed-free").await
 69    }
 70
 71    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
 72        self.state
 73            .read()
 74            .await
 75            .prices_by_lookup_key
 76            .get(lookup_key)
 77            .map(|price| price.id.clone())
 78            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 79    }
 80
 81    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
 82        self.state
 83            .read()
 84            .await
 85            .prices_by_lookup_key
 86            .get(lookup_key)
 87            .cloned()
 88            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
 89    }
 90
 91    /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
 92    /// not already exist.
 93    ///
 94    /// Always returns a new Stripe customer if the email address is `None`.
 95    pub async fn find_or_create_customer_by_email(
 96        &self,
 97        email_address: Option<&str>,
 98    ) -> Result<StripeCustomerId> {
 99        let existing_customer = if let Some(email) = email_address {
100            let customers = self.client.list_customers_by_email(email).await?;
101
102            customers.first().cloned()
103        } else {
104            None
105        };
106
107        let customer_id = if let Some(existing_customer) = existing_customer {
108            existing_customer.id
109        } else {
110            let customer = self
111                .client
112                .create_customer(crate::stripe_client::CreateCustomerParams {
113                    email: email_address,
114                })
115                .await?;
116
117            customer.id
118        };
119
120        Ok(customer_id)
121    }
122
123    pub async fn subscribe_to_zed_free(
124        &self,
125        customer_id: StripeCustomerId,
126    ) -> Result<StripeSubscription> {
127        let zed_free_price_id = self.zed_free_price_id().await?;
128
129        let existing_subscriptions = self
130            .client
131            .list_subscriptions_for_customer(&customer_id)
132            .await?;
133
134        let existing_active_subscription =
135            existing_subscriptions.into_iter().find(|subscription| {
136                subscription.status == SubscriptionStatus::Active
137                    || subscription.status == SubscriptionStatus::Trialing
138            });
139        if let Some(subscription) = existing_active_subscription {
140            return Ok(subscription);
141        }
142
143        let params = StripeCreateSubscriptionParams {
144            customer: customer_id,
145            items: vec![StripeCreateSubscriptionItems {
146                price: Some(zed_free_price_id),
147                quantity: Some(1),
148            }],
149            automatic_tax: Some(StripeAutomaticTax { enabled: true }),
150        };
151
152        let subscription = self.client.create_subscription(params).await?;
153
154        Ok(subscription)
155    }
156}