stripe_billing.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, anyhow};
  4use chrono::Utc;
  5use collections::HashMap;
  6use serde::{Deserialize, Serialize};
  7use stripe::SubscriptionStatus;
  8use tokio::sync::RwLock;
  9use uuid::Uuid;
 10
 11use crate::Result;
 12use crate::db::billing_subscription::SubscriptionKind;
 13use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
 14use crate::stripe_client::{
 15    RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
 16    StripeSubscription, StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
 17    UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
 18    UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
 19};
 20
 21pub struct StripeBilling {
 22    state: RwLock<StripeBillingState>,
 23    real_client: Arc<stripe::Client>,
 24    client: Arc<dyn StripeClient>,
 25}
 26
 27#[derive(Default)]
 28struct StripeBillingState {
 29    meters_by_event_name: HashMap<String, StripeMeter>,
 30    price_ids_by_meter_id: HashMap<String, StripePriceId>,
 31    prices_by_lookup_key: HashMap<String, StripePrice>,
 32}
 33
 34impl StripeBilling {
 35    pub fn new(client: Arc<stripe::Client>) -> Self {
 36        Self {
 37            client: Arc::new(RealStripeClient::new(client.clone())),
 38            real_client: client,
 39            state: RwLock::default(),
 40        }
 41    }
 42
 43    #[cfg(test)]
 44    pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
 45        Self {
 46            // This is just temporary until we can remove all usages of the real Stripe client.
 47            real_client: Arc::new(stripe::Client::new("sk_test")),
 48            client,
 49            state: RwLock::default(),
 50        }
 51    }
 52
 53    pub async fn initialize(&self) -> Result<()> {
 54        log::info!("StripeBilling: initializing");
 55
 56        let mut state = self.state.write().await;
 57
 58        let (meters, prices) =
 59            futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
 60
 61        for meter in meters {
 62            state
 63                .meters_by_event_name
 64                .insert(meter.event_name.clone(), meter);
 65        }
 66
 67        for price in prices {
 68            if let Some(lookup_key) = price.lookup_key.clone() {
 69                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 70            }
 71
 72            if let Some(recurring) = price.recurring {
 73                if let Some(meter) = recurring.meter {
 74                    state.price_ids_by_meter_id.insert(meter, price.id);
 75                }
 76            }
 77        }
 78
 79        log::info!("StripeBilling: initialized");
 80
 81        Ok(())
 82    }
 83
 84    pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
 85        self.find_price_id_by_lookup_key("zed-pro").await
 86    }
 87
 88    pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
 89        self.find_price_id_by_lookup_key("zed-free").await
 90    }
 91
 92    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
 93        self.state
 94            .read()
 95            .await
 96            .prices_by_lookup_key
 97            .get(lookup_key)
 98            .map(|price| price.id.clone())
 99            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
100    }
101
102    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
103        self.state
104            .read()
105            .await
106            .prices_by_lookup_key
107            .get(lookup_key)
108            .cloned()
109            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
110    }
111
112    pub async fn determine_subscription_kind(
113        &self,
114        subscription: &stripe::Subscription,
115    ) -> Option<SubscriptionKind> {
116        let zed_pro_price_id: stripe::PriceId =
117            self.zed_pro_price_id().await.ok()?.try_into().ok()?;
118        let zed_free_price_id: stripe::PriceId =
119            self.zed_free_price_id().await.ok()?.try_into().ok()?;
120
121        subscription.items.data.iter().find_map(|item| {
122            let price = item.price.as_ref()?;
123
124            if price.id == zed_pro_price_id {
125                Some(if subscription.status == SubscriptionStatus::Trialing {
126                    SubscriptionKind::ZedProTrial
127                } else {
128                    SubscriptionKind::ZedPro
129                })
130            } else if price.id == zed_free_price_id {
131                Some(SubscriptionKind::ZedFree)
132            } else {
133                None
134            }
135        })
136    }
137
138    /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
139    /// not already exist.
140    ///
141    /// Always returns a new Stripe customer if the email address is `None`.
142    pub async fn find_or_create_customer_by_email(
143        &self,
144        email_address: Option<&str>,
145    ) -> Result<StripeCustomerId> {
146        let existing_customer = if let Some(email) = email_address {
147            let customers = self.client.list_customers_by_email(email).await?;
148
149            customers.first().cloned()
150        } else {
151            None
152        };
153
154        let customer_id = if let Some(existing_customer) = existing_customer {
155            existing_customer.id
156        } else {
157            let customer = self
158                .client
159                .create_customer(crate::stripe_client::CreateCustomerParams {
160                    email: email_address,
161                })
162                .await?;
163
164            customer.id
165        };
166
167        Ok(customer_id)
168    }
169
170    pub async fn subscribe_to_price(
171        &self,
172        subscription_id: &StripeSubscriptionId,
173        price: &StripePrice,
174    ) -> Result<()> {
175        let subscription = self.client.get_subscription(subscription_id).await?;
176
177        if subscription_contains_price(&subscription, &price.id) {
178            return Ok(());
179        }
180
181        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
182
183        let price_per_unit = price.unit_amount.unwrap_or_default();
184        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
185
186        self.client
187            .update_subscription(
188                subscription_id,
189                UpdateSubscriptionParams {
190                    items: Some(vec![UpdateSubscriptionItems {
191                        price: Some(price.id.clone()),
192                    }]),
193                    trial_settings: Some(UpdateSubscriptionTrialSettings {
194                        end_behavior: UpdateSubscriptionTrialSettingsEndBehavior {
195                            missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
196                        },
197                    }),
198                },
199            )
200            .await?;
201
202        Ok(())
203    }
204
205    pub async fn bill_model_request_usage(
206        &self,
207        customer_id: &stripe::CustomerId,
208        event_name: &str,
209        requests: i32,
210    ) -> Result<()> {
211        let timestamp = Utc::now().timestamp();
212        let idempotency_key = Uuid::new_v4();
213
214        StripeMeterEvent::create(
215            &self.real_client,
216            StripeCreateMeterEventParams {
217                identifier: &format!("model_requests/{}", idempotency_key),
218                event_name,
219                payload: StripeCreateMeterEventPayload {
220                    value: requests as u64,
221                    stripe_customer_id: customer_id,
222                },
223                timestamp: Some(timestamp),
224            },
225        )
226        .await?;
227
228        Ok(())
229    }
230
231    pub async fn checkout_with_zed_pro(
232        &self,
233        customer_id: stripe::CustomerId,
234        github_login: &str,
235        success_url: &str,
236    ) -> Result<String> {
237        let zed_pro_price_id = self.zed_pro_price_id().await?;
238
239        let mut params = stripe::CreateCheckoutSession::new();
240        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
241        params.customer = Some(customer_id);
242        params.client_reference_id = Some(github_login);
243        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
244            price: Some(zed_pro_price_id.to_string()),
245            quantity: Some(1),
246            ..Default::default()
247        }]);
248        params.success_url = Some(success_url);
249
250        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
251        Ok(session.url.context("no checkout session URL")?)
252    }
253
254    pub async fn checkout_with_zed_pro_trial(
255        &self,
256        customer_id: stripe::CustomerId,
257        github_login: &str,
258        feature_flags: Vec<String>,
259        success_url: &str,
260    ) -> Result<String> {
261        let zed_pro_price_id = self.zed_pro_price_id().await?;
262
263        let eligible_for_extended_trial = feature_flags
264            .iter()
265            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
266
267        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
268
269        let mut subscription_metadata = std::collections::HashMap::new();
270        if eligible_for_extended_trial {
271            subscription_metadata.insert(
272                "promo_feature_flag".to_string(),
273                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
274            );
275        }
276
277        let mut params = stripe::CreateCheckoutSession::new();
278        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
279            trial_period_days: Some(trial_period_days),
280            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
281                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
282                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
283                }
284            }),
285            metadata: if !subscription_metadata.is_empty() {
286                Some(subscription_metadata)
287            } else {
288                None
289            },
290            ..Default::default()
291        });
292        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
293        params.payment_method_collection =
294            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
295        params.customer = Some(customer_id);
296        params.client_reference_id = Some(github_login);
297        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
298            price: Some(zed_pro_price_id.to_string()),
299            quantity: Some(1),
300            ..Default::default()
301        }]);
302        params.success_url = Some(success_url);
303
304        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
305        Ok(session.url.context("no checkout session URL")?)
306    }
307
308    pub async fn subscribe_to_zed_free(
309        &self,
310        customer_id: stripe::CustomerId,
311    ) -> Result<stripe::Subscription> {
312        let zed_free_price_id = self.zed_free_price_id().await?;
313
314        let existing_subscriptions = stripe::Subscription::list(
315            &self.real_client,
316            &stripe::ListSubscriptions {
317                customer: Some(customer_id.clone()),
318                status: None,
319                ..Default::default()
320            },
321        )
322        .await?;
323
324        let existing_active_subscription =
325            existing_subscriptions
326                .data
327                .into_iter()
328                .find(|subscription| {
329                    subscription.status == SubscriptionStatus::Active
330                        || subscription.status == SubscriptionStatus::Trialing
331                });
332        if let Some(subscription) = existing_active_subscription {
333            return Ok(subscription);
334        }
335
336        let mut params = stripe::CreateSubscription::new(customer_id);
337        params.items = Some(vec![stripe::CreateSubscriptionItems {
338            price: Some(zed_free_price_id.to_string()),
339            quantity: Some(1),
340            ..Default::default()
341        }]);
342
343        let subscription = stripe::Subscription::create(&self.real_client, params).await?;
344
345        Ok(subscription)
346    }
347
348    pub async fn checkout_with_zed_free(
349        &self,
350        customer_id: stripe::CustomerId,
351        github_login: &str,
352        success_url: &str,
353    ) -> Result<String> {
354        let zed_free_price_id = self.zed_free_price_id().await?;
355
356        let mut params = stripe::CreateCheckoutSession::new();
357        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
358        params.payment_method_collection =
359            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
360        params.customer = Some(customer_id);
361        params.client_reference_id = Some(github_login);
362        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
363            price: Some(zed_free_price_id.to_string()),
364            quantity: Some(1),
365            ..Default::default()
366        }]);
367        params.success_url = Some(success_url);
368
369        let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
370        Ok(session.url.context("no checkout session URL")?)
371    }
372}
373
374#[derive(Deserialize)]
375struct StripeMeterEvent {
376    identifier: String,
377}
378
379impl StripeMeterEvent {
380    pub async fn create(
381        client: &stripe::Client,
382        params: StripeCreateMeterEventParams<'_>,
383    ) -> Result<Self, stripe::StripeError> {
384        let identifier = params.identifier;
385        match client.post_form("/billing/meter_events", params).await {
386            Ok(event) => Ok(event),
387            Err(stripe::StripeError::Stripe(error)) => {
388                if error.http_status == 400
389                    && error
390                        .message
391                        .as_ref()
392                        .map_or(false, |message| message.contains(identifier))
393                {
394                    Ok(Self {
395                        identifier: identifier.to_string(),
396                    })
397                } else {
398                    Err(stripe::StripeError::Stripe(error))
399                }
400            }
401            Err(error) => Err(error),
402        }
403    }
404}
405
406#[derive(Serialize)]
407struct StripeCreateMeterEventParams<'a> {
408    identifier: &'a str,
409    event_name: &'a str,
410    payload: StripeCreateMeterEventPayload<'a>,
411    timestamp: Option<i64>,
412}
413
414#[derive(Serialize)]
415struct StripeCreateMeterEventPayload<'a> {
416    value: u64,
417    stripe_customer_id: &'a stripe::CustomerId,
418}
419
420fn subscription_contains_price(
421    subscription: &StripeSubscription,
422    price_id: &StripePriceId,
423) -> bool {
424    subscription.items.iter().any(|item| {
425        item.price
426            .as_ref()
427            .map_or(false, |price| price.id == *price_id)
428    })
429}