stripe_billing.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, anyhow};
  4use chrono::Utc;
  5use collections::HashMap;
  6use stripe::SubscriptionStatus;
  7use tokio::sync::RwLock;
  8use uuid::Uuid;
  9
 10use crate::Result;
 11use crate::db::billing_subscription::SubscriptionKind;
 12use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 13use crate::stripe_client::{
 14    RealStripeClient, StripeAutomaticTax, StripeBillingAddressCollection,
 15    StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
 16    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
 17    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
 18    StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
 19    StripeCustomerId, StripeCustomerUpdate, StripeCustomerUpdateAddress, StripeCustomerUpdateName,
 20    StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
 21    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
 22    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection,
 23    UpdateSubscriptionItems, UpdateSubscriptionParams,
 24};
 25
 26pub struct StripeBilling {
 27    state: RwLock<StripeBillingState>,
 28    client: Arc<dyn StripeClient>,
 29}
 30
 31#[derive(Default)]
 32struct StripeBillingState {
 33    prices_by_lookup_key: HashMap<String, StripePrice>,
 34}
 35
 36impl StripeBilling {
 37    pub fn new(client: Arc<stripe::Client>) -> Self {
 38        Self {
 39            client: Arc::new(RealStripeClient::new(client.clone())),
 40            state: RwLock::default(),
 41        }
 42    }
 43
 44    #[cfg(test)]
 45    pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
 46        Self {
 47            client,
 48            state: RwLock::default(),
 49        }
 50    }
 51
 52    pub fn client(&self) -> &Arc<dyn StripeClient> {
 53        &self.client
 54    }
 55
 56    pub async fn initialize(&self) -> Result<()> {
 57        log::info!("StripeBilling: initializing");
 58
 59        let mut state = self.state.write().await;
 60
 61        let prices = self.client.list_prices().await?;
 62
 63        for price in prices {
 64            if let Some(lookup_key) = price.lookup_key.clone() {
 65                state.prices_by_lookup_key.insert(lookup_key, price);
 66            }
 67        }
 68
 69        log::info!("StripeBilling: initialized");
 70
 71        Ok(())
 72    }
 73
 74    pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
 75        self.find_price_id_by_lookup_key("zed-pro").await
 76    }
 77
 78    pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
 79        self.find_price_id_by_lookup_key("zed-free").await
 80    }
 81
 82    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
 83        self.state
 84            .read()
 85            .await
 86            .prices_by_lookup_key
 87            .get(lookup_key)
 88            .map(|price| price.id.clone())
 89            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 90    }
 91
 92    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
 93        self.state
 94            .read()
 95            .await
 96            .prices_by_lookup_key
 97            .get(lookup_key)
 98            .cloned()
 99            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
100    }
101
102    pub async fn determine_subscription_kind(
103        &self,
104        subscription: &StripeSubscription,
105    ) -> Option<SubscriptionKind> {
106        let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
107        let zed_free_price_id = self.zed_free_price_id().await.ok()?;
108
109        subscription.items.iter().find_map(|item| {
110            let price = item.price.as_ref()?;
111
112            if price.id == zed_pro_price_id {
113                Some(if subscription.status == SubscriptionStatus::Trialing {
114                    SubscriptionKind::ZedProTrial
115                } else {
116                    SubscriptionKind::ZedPro
117                })
118            } else if price.id == zed_free_price_id {
119                Some(SubscriptionKind::ZedFree)
120            } else {
121                None
122            }
123        })
124    }
125
126    /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
127    /// not already exist.
128    ///
129    /// Always returns a new Stripe customer if the email address is `None`.
130    pub async fn find_or_create_customer_by_email(
131        &self,
132        email_address: Option<&str>,
133    ) -> Result<StripeCustomerId> {
134        let existing_customer = if let Some(email) = email_address {
135            let customers = self.client.list_customers_by_email(email).await?;
136
137            customers.first().cloned()
138        } else {
139            None
140        };
141
142        let customer_id = if let Some(existing_customer) = existing_customer {
143            existing_customer.id
144        } else {
145            let customer = self
146                .client
147                .create_customer(crate::stripe_client::CreateCustomerParams {
148                    email: email_address,
149                })
150                .await?;
151
152            customer.id
153        };
154
155        Ok(customer_id)
156    }
157
158    pub async fn subscribe_to_price(
159        &self,
160        subscription_id: &StripeSubscriptionId,
161        price: &StripePrice,
162    ) -> Result<()> {
163        let subscription = self.client.get_subscription(subscription_id).await?;
164
165        if subscription_contains_price(&subscription, &price.id) {
166            return Ok(());
167        }
168
169        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
170
171        let price_per_unit = price.unit_amount.unwrap_or_default();
172        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
173
174        self.client
175            .update_subscription(
176                subscription_id,
177                UpdateSubscriptionParams {
178                    items: Some(vec![UpdateSubscriptionItems {
179                        price: Some(price.id.clone()),
180                    }]),
181                    trial_settings: Some(StripeSubscriptionTrialSettings {
182                        end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
183                            missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
184                        },
185                    }),
186                },
187            )
188            .await?;
189
190        Ok(())
191    }
192
193    pub async fn bill_model_request_usage(
194        &self,
195        customer_id: &StripeCustomerId,
196        event_name: &str,
197        requests: i32,
198    ) -> Result<()> {
199        let timestamp = Utc::now().timestamp();
200        let idempotency_key = Uuid::new_v4();
201
202        self.client
203            .create_meter_event(StripeCreateMeterEventParams {
204                identifier: &format!("model_requests/{}", idempotency_key),
205                event_name,
206                payload: StripeCreateMeterEventPayload {
207                    value: requests as u64,
208                    stripe_customer_id: customer_id,
209                },
210                timestamp: Some(timestamp),
211            })
212            .await?;
213
214        Ok(())
215    }
216
217    pub async fn checkout_with_zed_pro(
218        &self,
219        customer_id: &StripeCustomerId,
220        github_login: &str,
221        success_url: &str,
222    ) -> Result<String> {
223        let zed_pro_price_id = self.zed_pro_price_id().await?;
224
225        let mut params = StripeCreateCheckoutSessionParams::default();
226        params.mode = Some(StripeCheckoutSessionMode::Subscription);
227        params.customer = Some(customer_id);
228        params.client_reference_id = Some(github_login);
229        params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
230            price: Some(zed_pro_price_id.to_string()),
231            quantity: Some(1),
232        }]);
233        params.success_url = Some(success_url);
234        params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
235        params.customer_update = Some(StripeCustomerUpdate {
236            address: Some(StripeCustomerUpdateAddress::Auto),
237            name: Some(StripeCustomerUpdateName::Auto),
238            shipping: None,
239        });
240        params.tax_id_collection = Some(StripeTaxIdCollection { enabled: true });
241
242        let session = self.client.create_checkout_session(params).await?;
243        Ok(session.url.context("no checkout session URL")?)
244    }
245
246    pub async fn checkout_with_zed_pro_trial(
247        &self,
248        customer_id: &StripeCustomerId,
249        github_login: &str,
250        feature_flags: Vec<String>,
251        success_url: &str,
252    ) -> Result<String> {
253        let zed_pro_price_id = self.zed_pro_price_id().await?;
254
255        let eligible_for_extended_trial = feature_flags
256            .iter()
257            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
258
259        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
260
261        let mut subscription_metadata = std::collections::HashMap::new();
262        if eligible_for_extended_trial {
263            subscription_metadata.insert(
264                "promo_feature_flag".to_string(),
265                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
266            );
267        }
268
269        let mut params = StripeCreateCheckoutSessionParams::default();
270        params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
271            trial_period_days: Some(trial_period_days),
272            trial_settings: Some(StripeSubscriptionTrialSettings {
273                end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
274                    missing_payment_method:
275                        StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
276                },
277            }),
278            metadata: if !subscription_metadata.is_empty() {
279                Some(subscription_metadata)
280            } else {
281                None
282            },
283        });
284        params.mode = Some(StripeCheckoutSessionMode::Subscription);
285        params.payment_method_collection =
286            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
287        params.customer = Some(customer_id);
288        params.client_reference_id = Some(github_login);
289        params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
290            price: Some(zed_pro_price_id.to_string()),
291            quantity: Some(1),
292        }]);
293        params.success_url = Some(success_url);
294        params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
295        params.customer_update = Some(StripeCustomerUpdate {
296            address: Some(StripeCustomerUpdateAddress::Auto),
297            name: Some(StripeCustomerUpdateName::Auto),
298            shipping: None,
299        });
300        params.tax_id_collection = Some(StripeTaxIdCollection { enabled: true });
301
302        let session = self.client.create_checkout_session(params).await?;
303        Ok(session.url.context("no checkout session URL")?)
304    }
305
306    pub async fn subscribe_to_zed_free(
307        &self,
308        customer_id: StripeCustomerId,
309    ) -> Result<StripeSubscription> {
310        let zed_free_price_id = self.zed_free_price_id().await?;
311
312        let existing_subscriptions = self
313            .client
314            .list_subscriptions_for_customer(&customer_id)
315            .await?;
316
317        let existing_active_subscription =
318            existing_subscriptions.into_iter().find(|subscription| {
319                subscription.status == SubscriptionStatus::Active
320                    || subscription.status == SubscriptionStatus::Trialing
321            });
322        if let Some(subscription) = existing_active_subscription {
323            return Ok(subscription);
324        }
325
326        let params = StripeCreateSubscriptionParams {
327            customer: customer_id,
328            items: vec![StripeCreateSubscriptionItems {
329                price: Some(zed_free_price_id),
330                quantity: Some(1),
331            }],
332            automatic_tax: Some(StripeAutomaticTax { enabled: true }),
333        };
334
335        let subscription = self.client.create_subscription(params).await?;
336
337        Ok(subscription)
338    }
339}
340
341fn subscription_contains_price(
342    subscription: &StripeSubscription,
343    price_id: &StripePriceId,
344) -> bool {
345    subscription.items.iter().any(|item| {
346        item.price
347            .as_ref()
348            .map_or(false, |price| price.id == *price_id)
349    })
350}