stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::Result;
  4use crate::db::billing_subscription::SubscriptionKind;
  5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  6use anyhow::{Context as _, anyhow};
  7use chrono::Utc;
  8use collections::HashMap;
  9use serde::{Deserialize, Serialize};
 10use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
 11use tokio::sync::RwLock;
 12use uuid::Uuid;
 13
 14pub struct StripeBilling {
 15    state: RwLock<StripeBillingState>,
 16    client: Arc<stripe::Client>,
 17}
 18
 19#[derive(Default)]
 20struct StripeBillingState {
 21    meters_by_event_name: HashMap<String, StripeMeter>,
 22    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 23    prices_by_lookup_key: HashMap<String, stripe::Price>,
 24}
 25
 26impl StripeBilling {
 27    pub fn new(client: Arc<stripe::Client>) -> Self {
 28        Self {
 29            client,
 30            state: RwLock::default(),
 31        }
 32    }
 33
 34    pub async fn initialize(&self) -> Result<()> {
 35        log::info!("StripeBilling: initializing");
 36
 37        let mut state = self.state.write().await;
 38
 39        let (meters, prices) = futures::try_join!(
 40            StripeMeter::list(&self.client),
 41            stripe::Price::list(
 42                &self.client,
 43                &stripe::ListPrices {
 44                    limit: Some(100),
 45                    ..Default::default()
 46                }
 47            )
 48        )?;
 49
 50        for meter in meters.data {
 51            state
 52                .meters_by_event_name
 53                .insert(meter.event_name.clone(), meter);
 54        }
 55
 56        for price in prices.data {
 57            if let Some(lookup_key) = price.lookup_key.clone() {
 58                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 59            }
 60
 61            if let Some(recurring) = price.recurring {
 62                if let Some(meter) = recurring.meter {
 63                    state.price_ids_by_meter_id.insert(meter, price.id);
 64                }
 65            }
 66        }
 67
 68        log::info!("StripeBilling: initialized");
 69
 70        Ok(())
 71    }
 72
 73    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
 74        self.find_price_id_by_lookup_key("zed-pro").await
 75    }
 76
 77    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
 78        self.find_price_id_by_lookup_key("zed-free").await
 79    }
 80
 81    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
 82        self.state
 83            .read()
 84            .await
 85            .prices_by_lookup_key
 86            .get(lookup_key)
 87            .map(|price| price.id.clone())
 88            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 89    }
 90
 91    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
 92        self.state
 93            .read()
 94            .await
 95            .prices_by_lookup_key
 96            .get(lookup_key)
 97            .cloned()
 98            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
 99    }
100
101    pub async fn determine_subscription_kind(
102        &self,
103        subscription: &stripe::Subscription,
104    ) -> Option<SubscriptionKind> {
105        let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
106        let zed_free_price_id = self.zed_free_price_id().await.ok()?;
107
108        subscription.items.data.iter().find_map(|item| {
109            let price = item.price.as_ref()?;
110
111            if price.id == zed_pro_price_id {
112                Some(if subscription.status == SubscriptionStatus::Trialing {
113                    SubscriptionKind::ZedProTrial
114                } else {
115                    SubscriptionKind::ZedPro
116                })
117            } else if price.id == zed_free_price_id {
118                Some(SubscriptionKind::ZedFree)
119            } else {
120                None
121            }
122        })
123    }
124
125    /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
126    /// not already exist.
127    ///
128    /// Always returns a new Stripe customer if the email address is `None`.
129    pub async fn find_or_create_customer_by_email(
130        &self,
131        email_address: Option<&str>,
132    ) -> Result<CustomerId> {
133        let existing_customer = if let Some(email) = email_address {
134            let customers = Customer::list(
135                &self.client,
136                &stripe::ListCustomers {
137                    email: Some(email),
138                    ..Default::default()
139                },
140            )
141            .await?;
142
143            customers.data.first().cloned()
144        } else {
145            None
146        };
147
148        let customer_id = if let Some(existing_customer) = existing_customer {
149            existing_customer.id
150        } else {
151            let customer = Customer::create(
152                &self.client,
153                CreateCustomer {
154                    email: email_address,
155                    ..Default::default()
156                },
157            )
158            .await?;
159
160            customer.id
161        };
162
163        Ok(customer_id)
164    }
165
166    pub async fn subscribe_to_price(
167        &self,
168        subscription_id: &stripe::SubscriptionId,
169        price: &stripe::Price,
170    ) -> Result<()> {
171        let subscription =
172            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
173
174        if subscription_contains_price(&subscription, &price.id) {
175            return Ok(());
176        }
177
178        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
179
180        let price_per_unit = price.unit_amount.unwrap_or_default();
181        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
182
183        stripe::Subscription::update(
184            &self.client,
185            subscription_id,
186            stripe::UpdateSubscription {
187                items: Some(vec![stripe::UpdateSubscriptionItems {
188                    price: Some(price.id.to_string()),
189                    ..Default::default()
190                }]),
191                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
192                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
193                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
194                    },
195                }),
196                ..Default::default()
197            },
198        )
199        .await?;
200
201        Ok(())
202    }
203
204    pub async fn bill_model_request_usage(
205        &self,
206        customer_id: &stripe::CustomerId,
207        event_name: &str,
208        requests: i32,
209    ) -> Result<()> {
210        let timestamp = Utc::now().timestamp();
211        let idempotency_key = Uuid::new_v4();
212
213        StripeMeterEvent::create(
214            &self.client,
215            StripeCreateMeterEventParams {
216                identifier: &format!("model_requests/{}", idempotency_key),
217                event_name,
218                payload: StripeCreateMeterEventPayload {
219                    value: requests as u64,
220                    stripe_customer_id: customer_id,
221                },
222                timestamp: Some(timestamp),
223            },
224        )
225        .await?;
226
227        Ok(())
228    }
229
230    pub async fn checkout_with_zed_pro(
231        &self,
232        customer_id: stripe::CustomerId,
233        github_login: &str,
234        success_url: &str,
235    ) -> Result<String> {
236        let zed_pro_price_id = self.zed_pro_price_id().await?;
237
238        let mut params = stripe::CreateCheckoutSession::new();
239        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
240        params.customer = Some(customer_id);
241        params.client_reference_id = Some(github_login);
242        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
243            price: Some(zed_pro_price_id.to_string()),
244            quantity: Some(1),
245            ..Default::default()
246        }]);
247        params.success_url = Some(success_url);
248
249        let session = stripe::CheckoutSession::create(&self.client, params).await?;
250        Ok(session.url.context("no checkout session URL")?)
251    }
252
253    pub async fn checkout_with_zed_pro_trial(
254        &self,
255        customer_id: stripe::CustomerId,
256        github_login: &str,
257        feature_flags: Vec<String>,
258        success_url: &str,
259    ) -> Result<String> {
260        let zed_pro_price_id = self.zed_pro_price_id().await?;
261
262        let eligible_for_extended_trial = feature_flags
263            .iter()
264            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
265
266        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
267
268        let mut subscription_metadata = std::collections::HashMap::new();
269        if eligible_for_extended_trial {
270            subscription_metadata.insert(
271                "promo_feature_flag".to_string(),
272                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
273            );
274        }
275
276        let mut params = stripe::CreateCheckoutSession::new();
277        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
278            trial_period_days: Some(trial_period_days),
279            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
280                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
281                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
282                }
283            }),
284            metadata: if !subscription_metadata.is_empty() {
285                Some(subscription_metadata)
286            } else {
287                None
288            },
289            ..Default::default()
290        });
291        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
292        params.payment_method_collection =
293            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
294        params.customer = Some(customer_id);
295        params.client_reference_id = Some(github_login);
296        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
297            price: Some(zed_pro_price_id.to_string()),
298            quantity: Some(1),
299            ..Default::default()
300        }]);
301        params.success_url = Some(success_url);
302
303        let session = stripe::CheckoutSession::create(&self.client, params).await?;
304        Ok(session.url.context("no checkout session URL")?)
305    }
306
307    pub async fn subscribe_to_zed_free(
308        &self,
309        customer_id: stripe::CustomerId,
310    ) -> Result<stripe::Subscription> {
311        let zed_free_price_id = self.zed_free_price_id().await?;
312
313        let existing_subscriptions = stripe::Subscription::list(
314            &self.client,
315            &stripe::ListSubscriptions {
316                customer: Some(customer_id.clone()),
317                status: None,
318                ..Default::default()
319            },
320        )
321        .await?;
322
323        let existing_active_subscription =
324            existing_subscriptions
325                .data
326                .into_iter()
327                .find(|subscription| {
328                    subscription.status == SubscriptionStatus::Active
329                        || subscription.status == SubscriptionStatus::Trialing
330                });
331        if let Some(subscription) = existing_active_subscription {
332            return Ok(subscription);
333        }
334
335        let mut params = stripe::CreateSubscription::new(customer_id);
336        params.items = Some(vec![stripe::CreateSubscriptionItems {
337            price: Some(zed_free_price_id.to_string()),
338            quantity: Some(1),
339            ..Default::default()
340        }]);
341
342        let subscription = stripe::Subscription::create(&self.client, params).await?;
343
344        Ok(subscription)
345    }
346
347    pub async fn checkout_with_zed_free(
348        &self,
349        customer_id: stripe::CustomerId,
350        github_login: &str,
351        success_url: &str,
352    ) -> Result<String> {
353        let zed_free_price_id = self.zed_free_price_id().await?;
354
355        let mut params = stripe::CreateCheckoutSession::new();
356        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
357        params.payment_method_collection =
358            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
359        params.customer = Some(customer_id);
360        params.client_reference_id = Some(github_login);
361        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
362            price: Some(zed_free_price_id.to_string()),
363            quantity: Some(1),
364            ..Default::default()
365        }]);
366        params.success_url = Some(success_url);
367
368        let session = stripe::CheckoutSession::create(&self.client, params).await?;
369        Ok(session.url.context("no checkout session URL")?)
370    }
371}
372
373#[derive(Clone, Deserialize)]
374struct StripeMeter {
375    id: String,
376    event_name: String,
377}
378
379impl StripeMeter {
380    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
381        #[derive(Serialize)]
382        struct Params {
383            #[serde(skip_serializing_if = "Option::is_none")]
384            limit: Option<u64>,
385        }
386
387        client.get_query("/billing/meters", Params { limit: Some(100) })
388    }
389}
390
391#[derive(Deserialize)]
392struct StripeMeterEvent {
393    identifier: String,
394}
395
396impl StripeMeterEvent {
397    pub async fn create(
398        client: &stripe::Client,
399        params: StripeCreateMeterEventParams<'_>,
400    ) -> Result<Self, stripe::StripeError> {
401        let identifier = params.identifier;
402        match client.post_form("/billing/meter_events", params).await {
403            Ok(event) => Ok(event),
404            Err(stripe::StripeError::Stripe(error)) => {
405                if error.http_status == 400
406                    && error
407                        .message
408                        .as_ref()
409                        .map_or(false, |message| message.contains(identifier))
410                {
411                    Ok(Self {
412                        identifier: identifier.to_string(),
413                    })
414                } else {
415                    Err(stripe::StripeError::Stripe(error))
416                }
417            }
418            Err(error) => Err(error),
419        }
420    }
421}
422
423#[derive(Serialize)]
424struct StripeCreateMeterEventParams<'a> {
425    identifier: &'a str,
426    event_name: &'a str,
427    payload: StripeCreateMeterEventPayload<'a>,
428    timestamp: Option<i64>,
429}
430
431#[derive(Serialize)]
432struct StripeCreateMeterEventPayload<'a> {
433    value: u64,
434    stripe_customer_id: &'a stripe::CustomerId,
435}
436
437fn subscription_contains_price(
438    subscription: &stripe::Subscription,
439    price_id: &stripe::PriceId,
440) -> bool {
441    subscription.items.data.iter().any(|item| {
442        item.price
443            .as_ref()
444            .map_or(false, |price| price.id == *price_id)
445    })
446}