stripe_billing.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, anyhow};
  4use chrono::Utc;
  5use collections::HashMap;
  6use serde::{Deserialize, Serialize};
  7use stripe::SubscriptionStatus;
  8use tokio::sync::RwLock;
  9use uuid::Uuid;
 10
 11use crate::Result;
 12use crate::db::billing_subscription::SubscriptionKind;
 13use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 14use crate::stripe_client::{
 15    RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
 16};
 17
 18pub struct StripeBilling {
 19    state: RwLock<StripeBillingState>,
 20    real_client: Arc<stripe::Client>,
 21    client: Arc<dyn StripeClient>,
 22}
 23
 24#[derive(Default)]
 25struct StripeBillingState {
 26    meters_by_event_name: HashMap<String, StripeMeter>,
 27    price_ids_by_meter_id: HashMap<String, StripePriceId>,
 28    prices_by_lookup_key: HashMap<String, StripePrice>,
 29}
 30
 31impl StripeBilling {
 32    pub fn new(client: Arc<stripe::Client>) -> Self {
 33        Self {
 34            client: Arc::new(RealStripeClient::new(client.clone())),
 35            real_client: client,
 36            state: RwLock::default(),
 37        }
 38    }
 39
 40    #[cfg(test)]
 41    pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
 42        Self {
 43            // This is just temporary until we can remove all usages of the real Stripe client.
 44            real_client: Arc::new(stripe::Client::new("sk_test")),
 45            client,
 46            state: RwLock::default(),
 47        }
 48    }
 49
 50    pub async fn initialize(&self) -> Result<()> {
 51        log::info!("StripeBilling: initializing");
 52
 53        let mut state = self.state.write().await;
 54
 55        let (meters, prices) =
 56            futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
 57
 58        for meter in meters {
 59            state
 60                .meters_by_event_name
 61                .insert(meter.event_name.clone(), meter);
 62        }
 63
 64        for price in prices {
 65            if let Some(lookup_key) = price.lookup_key.clone() {
 66                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 67            }
 68
 69            if let Some(recurring) = price.recurring {
 70                if let Some(meter) = recurring.meter {
 71                    state.price_ids_by_meter_id.insert(meter, price.id);
 72                }
 73            }
 74        }
 75
 76        log::info!("StripeBilling: initialized");
 77
 78        Ok(())
 79    }
 80
 81    pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
 82        self.find_price_id_by_lookup_key("zed-pro").await
 83    }
 84
 85    pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
 86        self.find_price_id_by_lookup_key("zed-free").await
 87    }
 88
 89    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
 90        self.state
 91            .read()
 92            .await
 93            .prices_by_lookup_key
 94            .get(lookup_key)
 95            .map(|price| price.id.clone())
 96            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 97    }
 98
 99    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
100        self.state
101            .read()
102            .await
103            .prices_by_lookup_key
104            .get(lookup_key)
105            .cloned()
106            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
107    }
108
109    pub async fn determine_subscription_kind(
110        &self,
111        subscription: &stripe::Subscription,
112    ) -> Option<SubscriptionKind> {
113        let zed_pro_price_id: stripe::PriceId =
114            self.zed_pro_price_id().await.ok()?.try_into().ok()?;
115        let zed_free_price_id: stripe::PriceId =
116            self.zed_free_price_id().await.ok()?.try_into().ok()?;
117
118        subscription.items.data.iter().find_map(|item| {
119            let price = item.price.as_ref()?;
120
121            if price.id == zed_pro_price_id {
122                Some(if subscription.status == SubscriptionStatus::Trialing {
123                    SubscriptionKind::ZedProTrial
124                } else {
125                    SubscriptionKind::ZedPro
126                })
127            } else if price.id == zed_free_price_id {
128                Some(SubscriptionKind::ZedFree)
129            } else {
130                None
131            }
132        })
133    }
134
135    /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
136    /// not already exist.
137    ///
138    /// Always returns a new Stripe customer if the email address is `None`.
139    pub async fn find_or_create_customer_by_email(
140        &self,
141        email_address: Option<&str>,
142    ) -> Result<StripeCustomerId> {
143        let existing_customer = if let Some(email) = email_address {
144            let customers = self.client.list_customers_by_email(email).await?;
145
146            customers.first().cloned()
147        } else {
148            None
149        };
150
151        let customer_id = if let Some(existing_customer) = existing_customer {
152            existing_customer.id
153        } else {
154            let customer = self
155                .client
156                .create_customer(crate::stripe_client::CreateCustomerParams {
157                    email: email_address,
158                })
159                .await?;
160
161            customer.id
162        };
163
164        Ok(customer_id)
165    }
166
167    pub async fn subscribe_to_price(
168        &self,
169        subscription_id: &stripe::SubscriptionId,
170        price: &StripePrice,
171    ) -> Result<()> {
172        let subscription =
173            stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
174
175        let price_id = price.id.clone().try_into()?;
176        if subscription_contains_price(&subscription, &price_id) {
177            return Ok(());
178        }
179
180        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
181
182        let price_per_unit = price.unit_amount.unwrap_or_default();
183        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
184
185        stripe::Subscription::update(
186            &self.real_client,
187            subscription_id,
188            stripe::UpdateSubscription {
189                items: Some(vec![stripe::UpdateSubscriptionItems {
190                    price: Some(price.id.to_string()),
191                    ..Default::default()
192                }]),
193                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
194                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
195                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
196                    },
197                }),
198                ..Default::default()
199            },
200        )
201        .await?;
202
203        Ok(())
204    }
205
206    pub async fn bill_model_request_usage(
207        &self,
208        customer_id: &stripe::CustomerId,
209        event_name: &str,
210        requests: i32,
211    ) -> Result<()> {
212        let timestamp = Utc::now().timestamp();
213        let idempotency_key = Uuid::new_v4();
214
215        StripeMeterEvent::create(
216            &self.real_client,
217            StripeCreateMeterEventParams {
218                identifier: &format!("model_requests/{}", idempotency_key),
219                event_name,
220                payload: StripeCreateMeterEventPayload {
221                    value: requests as u64,
222                    stripe_customer_id: customer_id,
223                },
224                timestamp: Some(timestamp),
225            },
226        )
227        .await?;
228
229        Ok(())
230    }
231
232    pub async fn checkout_with_zed_pro(
233        &self,
234        customer_id: stripe::CustomerId,
235        github_login: &str,
236        success_url: &str,
237    ) -> Result<String> {
238        let zed_pro_price_id = self.zed_pro_price_id().await?;
239
240        let mut params = stripe::CreateCheckoutSession::new();
241        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
242        params.customer = Some(customer_id);
243        params.client_reference_id = Some(github_login);
244        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
245            price: Some(zed_pro_price_id.to_string()),
246            quantity: Some(1),
247            ..Default::default()
248        }]);
249        params.success_url = Some(success_url);
250
251        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
252        Ok(session.url.context("no checkout session URL")?)
253    }
254
255    pub async fn checkout_with_zed_pro_trial(
256        &self,
257        customer_id: stripe::CustomerId,
258        github_login: &str,
259        feature_flags: Vec<String>,
260        success_url: &str,
261    ) -> Result<String> {
262        let zed_pro_price_id = self.zed_pro_price_id().await?;
263
264        let eligible_for_extended_trial = feature_flags
265            .iter()
266            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
267
268        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
269
270        let mut subscription_metadata = std::collections::HashMap::new();
271        if eligible_for_extended_trial {
272            subscription_metadata.insert(
273                "promo_feature_flag".to_string(),
274                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
275            );
276        }
277
278        let mut params = stripe::CreateCheckoutSession::new();
279        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
280            trial_period_days: Some(trial_period_days),
281            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
282                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
283                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
284                }
285            }),
286            metadata: if !subscription_metadata.is_empty() {
287                Some(subscription_metadata)
288            } else {
289                None
290            },
291            ..Default::default()
292        });
293        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
294        params.payment_method_collection =
295            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
296        params.customer = Some(customer_id);
297        params.client_reference_id = Some(github_login);
298        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
299            price: Some(zed_pro_price_id.to_string()),
300            quantity: Some(1),
301            ..Default::default()
302        }]);
303        params.success_url = Some(success_url);
304
305        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
306        Ok(session.url.context("no checkout session URL")?)
307    }
308
309    pub async fn subscribe_to_zed_free(
310        &self,
311        customer_id: stripe::CustomerId,
312    ) -> Result<stripe::Subscription> {
313        let zed_free_price_id = self.zed_free_price_id().await?;
314
315        let existing_subscriptions = stripe::Subscription::list(
316            &self.real_client,
317            &stripe::ListSubscriptions {
318                customer: Some(customer_id.clone()),
319                status: None,
320                ..Default::default()
321            },
322        )
323        .await?;
324
325        let existing_active_subscription =
326            existing_subscriptions
327                .data
328                .into_iter()
329                .find(|subscription| {
330                    subscription.status == SubscriptionStatus::Active
331                        || subscription.status == SubscriptionStatus::Trialing
332                });
333        if let Some(subscription) = existing_active_subscription {
334            return Ok(subscription);
335        }
336
337        let mut params = stripe::CreateSubscription::new(customer_id);
338        params.items = Some(vec![stripe::CreateSubscriptionItems {
339            price: Some(zed_free_price_id.to_string()),
340            quantity: Some(1),
341            ..Default::default()
342        }]);
343
344        let subscription = stripe::Subscription::create(&self.real_client, params).await?;
345
346        Ok(subscription)
347    }
348
349    pub async fn checkout_with_zed_free(
350        &self,
351        customer_id: stripe::CustomerId,
352        github_login: &str,
353        success_url: &str,
354    ) -> Result<String> {
355        let zed_free_price_id = self.zed_free_price_id().await?;
356
357        let mut params = stripe::CreateCheckoutSession::new();
358        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
359        params.payment_method_collection =
360            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
361        params.customer = Some(customer_id);
362        params.client_reference_id = Some(github_login);
363        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
364            price: Some(zed_free_price_id.to_string()),
365            quantity: Some(1),
366            ..Default::default()
367        }]);
368        params.success_url = Some(success_url);
369
370        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
371        Ok(session.url.context("no checkout session URL")?)
372    }
373}
374
375#[derive(Deserialize)]
376struct StripeMeterEvent {
377    identifier: String,
378}
379
380impl StripeMeterEvent {
381    pub async fn create(
382        client: &stripe::Client,
383        params: StripeCreateMeterEventParams<'_>,
384    ) -> Result<Self, stripe::StripeError> {
385        let identifier = params.identifier;
386        match client.post_form("/billing/meter_events", params).await {
387            Ok(event) => Ok(event),
388            Err(stripe::StripeError::Stripe(error)) => {
389                if error.http_status == 400
390                    && error
391                        .message
392                        .as_ref()
393                        .map_or(false, |message| message.contains(identifier))
394                {
395                    Ok(Self {
396                        identifier: identifier.to_string(),
397                    })
398                } else {
399                    Err(stripe::StripeError::Stripe(error))
400                }
401            }
402            Err(error) => Err(error),
403        }
404    }
405}
406
407#[derive(Serialize)]
408struct StripeCreateMeterEventParams<'a> {
409    identifier: &'a str,
410    event_name: &'a str,
411    payload: StripeCreateMeterEventPayload<'a>,
412    timestamp: Option<i64>,
413}
414
415#[derive(Serialize)]
416struct StripeCreateMeterEventPayload<'a> {
417    value: u64,
418    stripe_customer_id: &'a stripe::CustomerId,
419}
420
421fn subscription_contains_price(
422    subscription: &stripe::Subscription,
423    price_id: &stripe::PriceId,
424) -> bool {
425    subscription.items.data.iter().any(|item| {
426        item.price
427            .as_ref()
428            .map_or(false, |price| price.id == *price_id)
429    })
430}