stripe_billing.rs

  1use std::sync::Arc;
  2
  3use anyhow::anyhow;
  4use chrono::Utc;
  5use collections::HashMap;
  6use stripe::SubscriptionStatus;
  7use tokio::sync::RwLock;
  8use uuid::Uuid;
  9
 10use crate::Result;
 11use crate::db::billing_subscription::SubscriptionKind;
 12use crate::stripe_client::{
 13    RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams,
 14    StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
 15    StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
 16    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
 17    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
 18    UpdateSubscriptionParams,
 19};
 20
 21pub struct StripeBilling {
 22    state: RwLock<StripeBillingState>,
 23    client: Arc<dyn StripeClient>,
 24}
 25
 26#[derive(Default)]
 27struct StripeBillingState {
 28    prices_by_lookup_key: HashMap<String, StripePrice>,
 29}
 30
 31impl StripeBilling {
 32    pub fn new(client: Arc<stripe::Client>) -> Self {
 33        Self {
 34            client: Arc::new(RealStripeClient::new(client.clone())),
 35            state: RwLock::default(),
 36        }
 37    }
 38
 39    #[cfg(test)]
 40    pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
 41        Self {
 42            client,
 43            state: RwLock::default(),
 44        }
 45    }
 46
 47    pub fn client(&self) -> &Arc<dyn StripeClient> {
 48        &self.client
 49    }
 50
 51    pub async fn initialize(&self) -> Result<()> {
 52        log::info!("StripeBilling: initializing");
 53
 54        let mut state = self.state.write().await;
 55
 56        let prices = self.client.list_prices().await?;
 57
 58        for price in prices {
 59            if let Some(lookup_key) = price.lookup_key.clone() {
 60                state.prices_by_lookup_key.insert(lookup_key, price);
 61            }
 62        }
 63
 64        log::info!("StripeBilling: initialized");
 65
 66        Ok(())
 67    }
 68
 69    pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
 70        self.find_price_id_by_lookup_key("zed-pro").await
 71    }
 72
 73    pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
 74        self.find_price_id_by_lookup_key("zed-free").await
 75    }
 76
 77    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
 78        self.state
 79            .read()
 80            .await
 81            .prices_by_lookup_key
 82            .get(lookup_key)
 83            .map(|price| price.id.clone())
 84            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 85    }
 86
 87    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
 88        self.state
 89            .read()
 90            .await
 91            .prices_by_lookup_key
 92            .get(lookup_key)
 93            .cloned()
 94            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
 95    }
 96
 97    pub async fn determine_subscription_kind(
 98        &self,
 99        subscription: &StripeSubscription,
100    ) -> Option<SubscriptionKind> {
101        let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
102        let zed_free_price_id = self.zed_free_price_id().await.ok()?;
103
104        subscription.items.iter().find_map(|item| {
105            let price = item.price.as_ref()?;
106
107            if price.id == zed_pro_price_id {
108                Some(if subscription.status == SubscriptionStatus::Trialing {
109                    SubscriptionKind::ZedProTrial
110                } else {
111                    SubscriptionKind::ZedPro
112                })
113            } else if price.id == zed_free_price_id {
114                Some(SubscriptionKind::ZedFree)
115            } else {
116                None
117            }
118        })
119    }
120
121    /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
122    /// not already exist.
123    ///
124    /// Always returns a new Stripe customer if the email address is `None`.
125    pub async fn find_or_create_customer_by_email(
126        &self,
127        email_address: Option<&str>,
128    ) -> Result<StripeCustomerId> {
129        let existing_customer = if let Some(email) = email_address {
130            let customers = self.client.list_customers_by_email(email).await?;
131
132            customers.first().cloned()
133        } else {
134            None
135        };
136
137        let customer_id = if let Some(existing_customer) = existing_customer {
138            existing_customer.id
139        } else {
140            let customer = self
141                .client
142                .create_customer(crate::stripe_client::CreateCustomerParams {
143                    email: email_address,
144                })
145                .await?;
146
147            customer.id
148        };
149
150        Ok(customer_id)
151    }
152
153    pub async fn subscribe_to_price(
154        &self,
155        subscription_id: &StripeSubscriptionId,
156        price: &StripePrice,
157    ) -> Result<()> {
158        let subscription = self.client.get_subscription(subscription_id).await?;
159
160        if subscription_contains_price(&subscription, &price.id) {
161            return Ok(());
162        }
163
164        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
165
166        let price_per_unit = price.unit_amount.unwrap_or_default();
167        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
168
169        self.client
170            .update_subscription(
171                subscription_id,
172                UpdateSubscriptionParams {
173                    items: Some(vec![UpdateSubscriptionItems {
174                        price: Some(price.id.clone()),
175                    }]),
176                    trial_settings: Some(StripeSubscriptionTrialSettings {
177                        end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
178                            missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
179                        },
180                    }),
181                },
182            )
183            .await?;
184
185        Ok(())
186    }
187
188    pub async fn bill_model_request_usage(
189        &self,
190        customer_id: &StripeCustomerId,
191        event_name: &str,
192        requests: i32,
193    ) -> Result<()> {
194        let timestamp = Utc::now().timestamp();
195        let idempotency_key = Uuid::new_v4();
196
197        self.client
198            .create_meter_event(StripeCreateMeterEventParams {
199                identifier: &format!("model_requests/{}", idempotency_key),
200                event_name,
201                payload: StripeCreateMeterEventPayload {
202                    value: requests as u64,
203                    stripe_customer_id: customer_id,
204                },
205                timestamp: Some(timestamp),
206            })
207            .await?;
208
209        Ok(())
210    }
211
212    pub async fn subscribe_to_zed_free(
213        &self,
214        customer_id: StripeCustomerId,
215    ) -> Result<StripeSubscription> {
216        let zed_free_price_id = self.zed_free_price_id().await?;
217
218        let existing_subscriptions = self
219            .client
220            .list_subscriptions_for_customer(&customer_id)
221            .await?;
222
223        let existing_active_subscription =
224            existing_subscriptions.into_iter().find(|subscription| {
225                subscription.status == SubscriptionStatus::Active
226                    || subscription.status == SubscriptionStatus::Trialing
227            });
228        if let Some(subscription) = existing_active_subscription {
229            return Ok(subscription);
230        }
231
232        let params = StripeCreateSubscriptionParams {
233            customer: customer_id,
234            items: vec![StripeCreateSubscriptionItems {
235                price: Some(zed_free_price_id),
236                quantity: Some(1),
237            }],
238            automatic_tax: Some(StripeAutomaticTax { enabled: true }),
239        };
240
241        let subscription = self.client.create_subscription(params).await?;
242
243        Ok(subscription)
244    }
245}
246
247fn subscription_contains_price(
248    subscription: &StripeSubscription,
249    price_id: &StripePriceId,
250) -> bool {
251    subscription.items.iter().any(|item| {
252        item.price
253            .as_ref()
254            .map_or(false, |price| price.id == *price_id)
255    })
256}