stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::Result;
  4use crate::db::billing_subscription::SubscriptionKind;
  5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  6use anyhow::{Context as _, anyhow};
  7use chrono::Utc;
  8use collections::HashMap;
  9use serde::{Deserialize, Serialize};
 10use stripe::{PriceId, SubscriptionStatus};
 11use tokio::sync::RwLock;
 12use uuid::Uuid;
 13
 14pub struct StripeBilling {
 15    state: RwLock<StripeBillingState>,
 16    client: Arc<stripe::Client>,
 17}
 18
 19#[derive(Default)]
 20struct StripeBillingState {
 21    meters_by_event_name: HashMap<String, StripeMeter>,
 22    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 23    prices_by_lookup_key: HashMap<String, stripe::Price>,
 24}
 25
 26impl StripeBilling {
 27    pub fn new(client: Arc<stripe::Client>) -> Self {
 28        Self {
 29            client,
 30            state: RwLock::default(),
 31        }
 32    }
 33
 34    pub async fn initialize(&self) -> Result<()> {
 35        log::info!("StripeBilling: initializing");
 36
 37        let mut state = self.state.write().await;
 38
 39        let (meters, prices) = futures::try_join!(
 40            StripeMeter::list(&self.client),
 41            stripe::Price::list(
 42                &self.client,
 43                &stripe::ListPrices {
 44                    limit: Some(100),
 45                    ..Default::default()
 46                }
 47            )
 48        )?;
 49
 50        for meter in meters.data {
 51            state
 52                .meters_by_event_name
 53                .insert(meter.event_name.clone(), meter);
 54        }
 55
 56        for price in prices.data {
 57            if let Some(lookup_key) = price.lookup_key.clone() {
 58                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 59            }
 60
 61            if let Some(recurring) = price.recurring {
 62                if let Some(meter) = recurring.meter {
 63                    state.price_ids_by_meter_id.insert(meter, price.id);
 64                }
 65            }
 66        }
 67
 68        log::info!("StripeBilling: initialized");
 69
 70        Ok(())
 71    }
 72
 73    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
 74        self.find_price_id_by_lookup_key("zed-pro").await
 75    }
 76
 77    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
 78        self.find_price_id_by_lookup_key("zed-free").await
 79    }
 80
 81    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
 82        self.state
 83            .read()
 84            .await
 85            .prices_by_lookup_key
 86            .get(lookup_key)
 87            .map(|price| price.id.clone())
 88            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 89    }
 90
 91    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
 92        self.state
 93            .read()
 94            .await
 95            .prices_by_lookup_key
 96            .get(lookup_key)
 97            .cloned()
 98            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
 99    }
100
101    pub async fn determine_subscription_kind(
102        &self,
103        subscription: &stripe::Subscription,
104    ) -> Option<SubscriptionKind> {
105        let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
106        let zed_free_price_id = self.zed_free_price_id().await.ok()?;
107
108        subscription.items.data.iter().find_map(|item| {
109            let price = item.price.as_ref()?;
110
111            if price.id == zed_pro_price_id {
112                Some(if subscription.status == SubscriptionStatus::Trialing {
113                    SubscriptionKind::ZedProTrial
114                } else {
115                    SubscriptionKind::ZedPro
116                })
117            } else if price.id == zed_free_price_id {
118                Some(SubscriptionKind::ZedFree)
119            } else {
120                None
121            }
122        })
123    }
124
125    pub async fn subscribe_to_price(
126        &self,
127        subscription_id: &stripe::SubscriptionId,
128        price: &stripe::Price,
129    ) -> Result<()> {
130        let subscription =
131            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
132
133        if subscription_contains_price(&subscription, &price.id) {
134            return Ok(());
135        }
136
137        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
138
139        let price_per_unit = price.unit_amount.unwrap_or_default();
140        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
141
142        stripe::Subscription::update(
143            &self.client,
144            subscription_id,
145            stripe::UpdateSubscription {
146                items: Some(vec![stripe::UpdateSubscriptionItems {
147                    price: Some(price.id.to_string()),
148                    ..Default::default()
149                }]),
150                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
151                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
152                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
153                    },
154                }),
155                ..Default::default()
156            },
157        )
158        .await?;
159
160        Ok(())
161    }
162
163    pub async fn bill_model_request_usage(
164        &self,
165        customer_id: &stripe::CustomerId,
166        event_name: &str,
167        requests: i32,
168    ) -> Result<()> {
169        let timestamp = Utc::now().timestamp();
170        let idempotency_key = Uuid::new_v4();
171
172        StripeMeterEvent::create(
173            &self.client,
174            StripeCreateMeterEventParams {
175                identifier: &format!("model_requests/{}", idempotency_key),
176                event_name,
177                payload: StripeCreateMeterEventPayload {
178                    value: requests as u64,
179                    stripe_customer_id: customer_id,
180                },
181                timestamp: Some(timestamp),
182            },
183        )
184        .await?;
185
186        Ok(())
187    }
188
189    pub async fn checkout_with_zed_pro(
190        &self,
191        customer_id: stripe::CustomerId,
192        github_login: &str,
193        success_url: &str,
194    ) -> Result<String> {
195        let zed_pro_price_id = self.zed_pro_price_id().await?;
196
197        let mut params = stripe::CreateCheckoutSession::new();
198        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
199        params.customer = Some(customer_id);
200        params.client_reference_id = Some(github_login);
201        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
202            price: Some(zed_pro_price_id.to_string()),
203            quantity: Some(1),
204            ..Default::default()
205        }]);
206        params.success_url = Some(success_url);
207
208        let session = stripe::CheckoutSession::create(&self.client, params).await?;
209        Ok(session.url.context("no checkout session URL")?)
210    }
211
212    pub async fn checkout_with_zed_pro_trial(
213        &self,
214        customer_id: stripe::CustomerId,
215        github_login: &str,
216        feature_flags: Vec<String>,
217        success_url: &str,
218    ) -> Result<String> {
219        let zed_pro_price_id = self.zed_pro_price_id().await?;
220
221        let eligible_for_extended_trial = feature_flags
222            .iter()
223            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
224
225        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
226
227        let mut subscription_metadata = std::collections::HashMap::new();
228        if eligible_for_extended_trial {
229            subscription_metadata.insert(
230                "promo_feature_flag".to_string(),
231                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
232            );
233        }
234
235        let mut params = stripe::CreateCheckoutSession::new();
236        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
237            trial_period_days: Some(trial_period_days),
238            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
239                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
240                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
241                }
242            }),
243            metadata: if !subscription_metadata.is_empty() {
244                Some(subscription_metadata)
245            } else {
246                None
247            },
248            ..Default::default()
249        });
250        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
251        params.payment_method_collection =
252            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
253        params.customer = Some(customer_id);
254        params.client_reference_id = Some(github_login);
255        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
256            price: Some(zed_pro_price_id.to_string()),
257            quantity: Some(1),
258            ..Default::default()
259        }]);
260        params.success_url = Some(success_url);
261
262        let session = stripe::CheckoutSession::create(&self.client, params).await?;
263        Ok(session.url.context("no checkout session URL")?)
264    }
265
266    pub async fn subscribe_to_zed_free(
267        &self,
268        customer_id: stripe::CustomerId,
269    ) -> Result<stripe::Subscription> {
270        let zed_free_price_id = self.zed_free_price_id().await?;
271
272        let mut params = stripe::CreateSubscription::new(customer_id);
273        params.items = Some(vec![stripe::CreateSubscriptionItems {
274            price: Some(zed_free_price_id.to_string()),
275            quantity: Some(1),
276            ..Default::default()
277        }]);
278
279        let subscription = stripe::Subscription::create(&self.client, params).await?;
280
281        Ok(subscription)
282    }
283
284    pub async fn checkout_with_zed_free(
285        &self,
286        customer_id: stripe::CustomerId,
287        github_login: &str,
288        success_url: &str,
289    ) -> Result<String> {
290        let zed_free_price_id = self.zed_free_price_id().await?;
291
292        let mut params = stripe::CreateCheckoutSession::new();
293        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
294        params.payment_method_collection =
295            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
296        params.customer = Some(customer_id);
297        params.client_reference_id = Some(github_login);
298        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
299            price: Some(zed_free_price_id.to_string()),
300            quantity: Some(1),
301            ..Default::default()
302        }]);
303        params.success_url = Some(success_url);
304
305        let session = stripe::CheckoutSession::create(&self.client, params).await?;
306        Ok(session.url.context("no checkout session URL")?)
307    }
308}
309
310#[derive(Clone, Deserialize)]
311struct StripeMeter {
312    id: String,
313    event_name: String,
314}
315
316impl StripeMeter {
317    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
318        #[derive(Serialize)]
319        struct Params {
320            #[serde(skip_serializing_if = "Option::is_none")]
321            limit: Option<u64>,
322        }
323
324        client.get_query("/billing/meters", Params { limit: Some(100) })
325    }
326}
327
328#[derive(Deserialize)]
329struct StripeMeterEvent {
330    identifier: String,
331}
332
333impl StripeMeterEvent {
334    pub async fn create(
335        client: &stripe::Client,
336        params: StripeCreateMeterEventParams<'_>,
337    ) -> Result<Self, stripe::StripeError> {
338        let identifier = params.identifier;
339        match client.post_form("/billing/meter_events", params).await {
340            Ok(event) => Ok(event),
341            Err(stripe::StripeError::Stripe(error)) => {
342                if error.http_status == 400
343                    && error
344                        .message
345                        .as_ref()
346                        .map_or(false, |message| message.contains(identifier))
347                {
348                    Ok(Self {
349                        identifier: identifier.to_string(),
350                    })
351                } else {
352                    Err(stripe::StripeError::Stripe(error))
353                }
354            }
355            Err(error) => Err(error),
356        }
357    }
358}
359
360#[derive(Serialize)]
361struct StripeCreateMeterEventParams<'a> {
362    identifier: &'a str,
363    event_name: &'a str,
364    payload: StripeCreateMeterEventPayload<'a>,
365    timestamp: Option<i64>,
366}
367
368#[derive(Serialize)]
369struct StripeCreateMeterEventPayload<'a> {
370    value: u64,
371    stripe_customer_id: &'a stripe::CustomerId,
372}
373
374fn subscription_contains_price(
375    subscription: &stripe::Subscription,
376    price_id: &stripe::PriceId,
377) -> bool {
378    subscription.items.data.iter().any(|item| {
379        item.price
380            .as_ref()
381            .map_or(false, |price| price.id == *price_id)
382    })
383}