stripe_billing.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, anyhow};
  4use chrono::Utc;
  5use collections::HashMap;
  6use serde::{Deserialize, Serialize};
  7use stripe::{PriceId, SubscriptionStatus};
  8use tokio::sync::RwLock;
  9use uuid::Uuid;
 10
 11use crate::Result;
 12use crate::db::billing_subscription::SubscriptionKind;
 13use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 14use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
 15
 16pub struct StripeBilling {
 17    state: RwLock<StripeBillingState>,
 18    real_client: Arc<stripe::Client>,
 19    client: Arc<dyn StripeClient>,
 20}
 21
 22#[derive(Default)]
 23struct StripeBillingState {
 24    meters_by_event_name: HashMap<String, StripeMeter>,
 25    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 26    prices_by_lookup_key: HashMap<String, stripe::Price>,
 27}
 28
 29impl StripeBilling {
 30    pub fn new(client: Arc<stripe::Client>) -> Self {
 31        Self {
 32            client: Arc::new(RealStripeClient::new(client.clone())),
 33            real_client: client,
 34            state: RwLock::default(),
 35        }
 36    }
 37
 38    #[cfg(test)]
 39    pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
 40        Self {
 41            // This is just temporary until we can remove all usages of the real Stripe client.
 42            real_client: Arc::new(stripe::Client::new("sk_test")),
 43            client,
 44            state: RwLock::default(),
 45        }
 46    }
 47
 48    pub async fn initialize(&self) -> Result<()> {
 49        log::info!("StripeBilling: initializing");
 50
 51        let mut state = self.state.write().await;
 52
 53        let (meters, prices) = futures::try_join!(
 54            StripeMeter::list(&self.real_client),
 55            stripe::Price::list(
 56                &self.real_client,
 57                &stripe::ListPrices {
 58                    limit: Some(100),
 59                    ..Default::default()
 60                }
 61            )
 62        )?;
 63
 64        for meter in meters.data {
 65            state
 66                .meters_by_event_name
 67                .insert(meter.event_name.clone(), meter);
 68        }
 69
 70        for price in prices.data {
 71            if let Some(lookup_key) = price.lookup_key.clone() {
 72                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 73            }
 74
 75            if let Some(recurring) = price.recurring {
 76                if let Some(meter) = recurring.meter {
 77                    state.price_ids_by_meter_id.insert(meter, price.id);
 78                }
 79            }
 80        }
 81
 82        log::info!("StripeBilling: initialized");
 83
 84        Ok(())
 85    }
 86
 87    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
 88        self.find_price_id_by_lookup_key("zed-pro").await
 89    }
 90
 91    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
 92        self.find_price_id_by_lookup_key("zed-free").await
 93    }
 94
 95    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
 96        self.state
 97            .read()
 98            .await
 99            .prices_by_lookup_key
100            .get(lookup_key)
101            .map(|price| price.id.clone())
102            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
103    }
104
105    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
106        self.state
107            .read()
108            .await
109            .prices_by_lookup_key
110            .get(lookup_key)
111            .cloned()
112            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
113    }
114
115    pub async fn determine_subscription_kind(
116        &self,
117        subscription: &stripe::Subscription,
118    ) -> Option<SubscriptionKind> {
119        let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
120        let zed_free_price_id = self.zed_free_price_id().await.ok()?;
121
122        subscription.items.data.iter().find_map(|item| {
123            let price = item.price.as_ref()?;
124
125            if price.id == zed_pro_price_id {
126                Some(if subscription.status == SubscriptionStatus::Trialing {
127                    SubscriptionKind::ZedProTrial
128                } else {
129                    SubscriptionKind::ZedPro
130                })
131            } else if price.id == zed_free_price_id {
132                Some(SubscriptionKind::ZedFree)
133            } else {
134                None
135            }
136        })
137    }
138
139    /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
140    /// not already exist.
141    ///
142    /// Always returns a new Stripe customer if the email address is `None`.
143    pub async fn find_or_create_customer_by_email(
144        &self,
145        email_address: Option<&str>,
146    ) -> Result<StripeCustomerId> {
147        let existing_customer = if let Some(email) = email_address {
148            let customers = self.client.list_customers_by_email(email).await?;
149
150            customers.first().cloned()
151        } else {
152            None
153        };
154
155        let customer_id = if let Some(existing_customer) = existing_customer {
156            existing_customer.id
157        } else {
158            let customer = self
159                .client
160                .create_customer(crate::stripe_client::CreateCustomerParams {
161                    email: email_address,
162                })
163                .await?;
164
165            customer.id
166        };
167
168        Ok(customer_id)
169    }
170
171    pub async fn subscribe_to_price(
172        &self,
173        subscription_id: &stripe::SubscriptionId,
174        price: &stripe::Price,
175    ) -> Result<()> {
176        let subscription =
177            stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
178
179        if subscription_contains_price(&subscription, &price.id) {
180            return Ok(());
181        }
182
183        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
184
185        let price_per_unit = price.unit_amount.unwrap_or_default();
186        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
187
188        stripe::Subscription::update(
189            &self.real_client,
190            subscription_id,
191            stripe::UpdateSubscription {
192                items: Some(vec![stripe::UpdateSubscriptionItems {
193                    price: Some(price.id.to_string()),
194                    ..Default::default()
195                }]),
196                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
197                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
198                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
199                    },
200                }),
201                ..Default::default()
202            },
203        )
204        .await?;
205
206        Ok(())
207    }
208
209    pub async fn bill_model_request_usage(
210        &self,
211        customer_id: &stripe::CustomerId,
212        event_name: &str,
213        requests: i32,
214    ) -> Result<()> {
215        let timestamp = Utc::now().timestamp();
216        let idempotency_key = Uuid::new_v4();
217
218        StripeMeterEvent::create(
219            &self.real_client,
220            StripeCreateMeterEventParams {
221                identifier: &format!("model_requests/{}", idempotency_key),
222                event_name,
223                payload: StripeCreateMeterEventPayload {
224                    value: requests as u64,
225                    stripe_customer_id: customer_id,
226                },
227                timestamp: Some(timestamp),
228            },
229        )
230        .await?;
231
232        Ok(())
233    }
234
235    pub async fn checkout_with_zed_pro(
236        &self,
237        customer_id: stripe::CustomerId,
238        github_login: &str,
239        success_url: &str,
240    ) -> Result<String> {
241        let zed_pro_price_id = self.zed_pro_price_id().await?;
242
243        let mut params = stripe::CreateCheckoutSession::new();
244        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
245        params.customer = Some(customer_id);
246        params.client_reference_id = Some(github_login);
247        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
248            price: Some(zed_pro_price_id.to_string()),
249            quantity: Some(1),
250            ..Default::default()
251        }]);
252        params.success_url = Some(success_url);
253
254        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
255        Ok(session.url.context("no checkout session URL")?)
256    }
257
258    pub async fn checkout_with_zed_pro_trial(
259        &self,
260        customer_id: stripe::CustomerId,
261        github_login: &str,
262        feature_flags: Vec<String>,
263        success_url: &str,
264    ) -> Result<String> {
265        let zed_pro_price_id = self.zed_pro_price_id().await?;
266
267        let eligible_for_extended_trial = feature_flags
268            .iter()
269            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
270
271        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
272
273        let mut subscription_metadata = std::collections::HashMap::new();
274        if eligible_for_extended_trial {
275            subscription_metadata.insert(
276                "promo_feature_flag".to_string(),
277                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
278            );
279        }
280
281        let mut params = stripe::CreateCheckoutSession::new();
282        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
283            trial_period_days: Some(trial_period_days),
284            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
285                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
286                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
287                }
288            }),
289            metadata: if !subscription_metadata.is_empty() {
290                Some(subscription_metadata)
291            } else {
292                None
293            },
294            ..Default::default()
295        });
296        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
297        params.payment_method_collection =
298            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
299        params.customer = Some(customer_id);
300        params.client_reference_id = Some(github_login);
301        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
302            price: Some(zed_pro_price_id.to_string()),
303            quantity: Some(1),
304            ..Default::default()
305        }]);
306        params.success_url = Some(success_url);
307
308        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
309        Ok(session.url.context("no checkout session URL")?)
310    }
311
312    pub async fn subscribe_to_zed_free(
313        &self,
314        customer_id: stripe::CustomerId,
315    ) -> Result<stripe::Subscription> {
316        let zed_free_price_id = self.zed_free_price_id().await?;
317
318        let existing_subscriptions = stripe::Subscription::list(
319            &self.real_client,
320            &stripe::ListSubscriptions {
321                customer: Some(customer_id.clone()),
322                status: None,
323                ..Default::default()
324            },
325        )
326        .await?;
327
328        let existing_active_subscription =
329            existing_subscriptions
330                .data
331                .into_iter()
332                .find(|subscription| {
333                    subscription.status == SubscriptionStatus::Active
334                        || subscription.status == SubscriptionStatus::Trialing
335                });
336        if let Some(subscription) = existing_active_subscription {
337            return Ok(subscription);
338        }
339
340        let mut params = stripe::CreateSubscription::new(customer_id);
341        params.items = Some(vec![stripe::CreateSubscriptionItems {
342            price: Some(zed_free_price_id.to_string()),
343            quantity: Some(1),
344            ..Default::default()
345        }]);
346
347        let subscription = stripe::Subscription::create(&self.real_client, params).await?;
348
349        Ok(subscription)
350    }
351
352    pub async fn checkout_with_zed_free(
353        &self,
354        customer_id: stripe::CustomerId,
355        github_login: &str,
356        success_url: &str,
357    ) -> Result<String> {
358        let zed_free_price_id = self.zed_free_price_id().await?;
359
360        let mut params = stripe::CreateCheckoutSession::new();
361        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
362        params.payment_method_collection =
363            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
364        params.customer = Some(customer_id);
365        params.client_reference_id = Some(github_login);
366        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
367            price: Some(zed_free_price_id.to_string()),
368            quantity: Some(1),
369            ..Default::default()
370        }]);
371        params.success_url = Some(success_url);
372
373        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
374        Ok(session.url.context("no checkout session URL")?)
375    }
376}
377
378#[derive(Clone, Deserialize)]
379struct StripeMeter {
380    id: String,
381    event_name: String,
382}
383
384impl StripeMeter {
385    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
386        #[derive(Serialize)]
387        struct Params {
388            #[serde(skip_serializing_if = "Option::is_none")]
389            limit: Option<u64>,
390        }
391
392        client.get_query("/billing/meters", Params { limit: Some(100) })
393    }
394}
395
396#[derive(Deserialize)]
397struct StripeMeterEvent {
398    identifier: String,
399}
400
401impl StripeMeterEvent {
402    pub async fn create(
403        client: &stripe::Client,
404        params: StripeCreateMeterEventParams<'_>,
405    ) -> Result<Self, stripe::StripeError> {
406        let identifier = params.identifier;
407        match client.post_form("/billing/meter_events", params).await {
408            Ok(event) => Ok(event),
409            Err(stripe::StripeError::Stripe(error)) => {
410                if error.http_status == 400
411                    && error
412                        .message
413                        .as_ref()
414                        .map_or(false, |message| message.contains(identifier))
415                {
416                    Ok(Self {
417                        identifier: identifier.to_string(),
418                    })
419                } else {
420                    Err(stripe::StripeError::Stripe(error))
421                }
422            }
423            Err(error) => Err(error),
424        }
425    }
426}
427
428#[derive(Serialize)]
429struct StripeCreateMeterEventParams<'a> {
430    identifier: &'a str,
431    event_name: &'a str,
432    payload: StripeCreateMeterEventPayload<'a>,
433    timestamp: Option<i64>,
434}
435
436#[derive(Serialize)]
437struct StripeCreateMeterEventPayload<'a> {
438    value: u64,
439    stripe_customer_id: &'a stripe::CustomerId,
440}
441
442fn subscription_contains_price(
443    subscription: &stripe::Subscription,
444    price_id: &stripe::PriceId,
445) -> bool {
446    subscription.items.data.iter().any(|item| {
447        item.price
448            .as_ref()
449            .map_or(false, |price| price.id == *price_id)
450    })
451}