stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::Result;
  4use crate::db::billing_subscription::SubscriptionKind;
  5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  6use anyhow::{Context as _, anyhow};
  7use chrono::Utc;
  8use collections::HashMap;
  9use serde::{Deserialize, Serialize};
 10use stripe::{PriceId, SubscriptionStatus};
 11use tokio::sync::RwLock;
 12use uuid::Uuid;
 13
 14pub struct StripeBilling {
 15    state: RwLock<StripeBillingState>,
 16    client: Arc<stripe::Client>,
 17}
 18
 19#[derive(Default)]
 20struct StripeBillingState {
 21    meters_by_event_name: HashMap<String, StripeMeter>,
 22    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 23    prices_by_lookup_key: HashMap<String, stripe::Price>,
 24}
 25
 26impl StripeBilling {
 27    pub fn new(client: Arc<stripe::Client>) -> Self {
 28        Self {
 29            client,
 30            state: RwLock::default(),
 31        }
 32    }
 33
 34    pub async fn initialize(&self) -> Result<()> {
 35        log::info!("StripeBilling: initializing");
 36
 37        let mut state = self.state.write().await;
 38
 39        let (meters, prices) = futures::try_join!(
 40            StripeMeter::list(&self.client),
 41            stripe::Price::list(
 42                &self.client,
 43                &stripe::ListPrices {
 44                    limit: Some(100),
 45                    ..Default::default()
 46                }
 47            )
 48        )?;
 49
 50        for meter in meters.data {
 51            state
 52                .meters_by_event_name
 53                .insert(meter.event_name.clone(), meter);
 54        }
 55
 56        for price in prices.data {
 57            if let Some(lookup_key) = price.lookup_key.clone() {
 58                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 59            }
 60
 61            if let Some(recurring) = price.recurring {
 62                if let Some(meter) = recurring.meter {
 63                    state.price_ids_by_meter_id.insert(meter, price.id);
 64                }
 65            }
 66        }
 67
 68        log::info!("StripeBilling: initialized");
 69
 70        Ok(())
 71    }
 72
 73    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
 74        self.find_price_id_by_lookup_key("zed-pro").await
 75    }
 76
 77    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
 78        self.find_price_id_by_lookup_key("zed-free").await
 79    }
 80
 81    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
 82        self.state
 83            .read()
 84            .await
 85            .prices_by_lookup_key
 86            .get(lookup_key)
 87            .map(|price| price.id.clone())
 88            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 89    }
 90
 91    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
 92        self.state
 93            .read()
 94            .await
 95            .prices_by_lookup_key
 96            .get(lookup_key)
 97            .cloned()
 98            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
 99    }
100
101    pub async fn determine_subscription_kind(
102        &self,
103        subscription: &stripe::Subscription,
104    ) -> Option<SubscriptionKind> {
105        let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
106        let zed_free_price_id = self.zed_free_price_id().await.ok()?;
107
108        subscription.items.data.iter().find_map(|item| {
109            let price = item.price.as_ref()?;
110
111            if price.id == zed_pro_price_id {
112                Some(if subscription.status == SubscriptionStatus::Trialing {
113                    SubscriptionKind::ZedProTrial
114                } else {
115                    SubscriptionKind::ZedPro
116                })
117            } else if price.id == zed_free_price_id {
118                Some(SubscriptionKind::ZedFree)
119            } else {
120                None
121            }
122        })
123    }
124
125    pub async fn subscribe_to_price(
126        &self,
127        subscription_id: &stripe::SubscriptionId,
128        price: &stripe::Price,
129    ) -> Result<()> {
130        let subscription =
131            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
132
133        if subscription_contains_price(&subscription, &price.id) {
134            return Ok(());
135        }
136
137        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
138
139        let price_per_unit = price.unit_amount.unwrap_or_default();
140        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
141
142        stripe::Subscription::update(
143            &self.client,
144            subscription_id,
145            stripe::UpdateSubscription {
146                items: Some(vec![stripe::UpdateSubscriptionItems {
147                    price: Some(price.id.to_string()),
148                    ..Default::default()
149                }]),
150                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
151                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
152                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
153                    },
154                }),
155                ..Default::default()
156            },
157        )
158        .await?;
159
160        Ok(())
161    }
162
163    pub async fn bill_model_request_usage(
164        &self,
165        customer_id: &stripe::CustomerId,
166        event_name: &str,
167        requests: i32,
168    ) -> Result<()> {
169        let timestamp = Utc::now().timestamp();
170        let idempotency_key = Uuid::new_v4();
171
172        StripeMeterEvent::create(
173            &self.client,
174            StripeCreateMeterEventParams {
175                identifier: &format!("model_requests/{}", idempotency_key),
176                event_name,
177                payload: StripeCreateMeterEventPayload {
178                    value: requests as u64,
179                    stripe_customer_id: customer_id,
180                },
181                timestamp: Some(timestamp),
182            },
183        )
184        .await?;
185
186        Ok(())
187    }
188
189    pub async fn checkout_with_zed_pro(
190        &self,
191        customer_id: stripe::CustomerId,
192        github_login: &str,
193        success_url: &str,
194    ) -> Result<String> {
195        let zed_pro_price_id = self.zed_pro_price_id().await?;
196
197        let mut params = stripe::CreateCheckoutSession::new();
198        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
199        params.customer = Some(customer_id);
200        params.client_reference_id = Some(github_login);
201        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
202            price: Some(zed_pro_price_id.to_string()),
203            quantity: Some(1),
204            ..Default::default()
205        }]);
206        params.success_url = Some(success_url);
207
208        let session = stripe::CheckoutSession::create(&self.client, params).await?;
209        Ok(session.url.context("no checkout session URL")?)
210    }
211
212    pub async fn checkout_with_zed_pro_trial(
213        &self,
214        customer_id: stripe::CustomerId,
215        github_login: &str,
216        feature_flags: Vec<String>,
217        success_url: &str,
218    ) -> Result<String> {
219        let zed_pro_price_id = self.zed_pro_price_id().await?;
220
221        let eligible_for_extended_trial = feature_flags
222            .iter()
223            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
224
225        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
226
227        let mut subscription_metadata = std::collections::HashMap::new();
228        if eligible_for_extended_trial {
229            subscription_metadata.insert(
230                "promo_feature_flag".to_string(),
231                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
232            );
233        }
234
235        let mut params = stripe::CreateCheckoutSession::new();
236        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
237            trial_period_days: Some(trial_period_days),
238            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
239                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
240                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
241                }
242            }),
243            metadata: if !subscription_metadata.is_empty() {
244                Some(subscription_metadata)
245            } else {
246                None
247            },
248            ..Default::default()
249        });
250        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
251        params.payment_method_collection =
252            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
253        params.customer = Some(customer_id);
254        params.client_reference_id = Some(github_login);
255        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
256            price: Some(zed_pro_price_id.to_string()),
257            quantity: Some(1),
258            ..Default::default()
259        }]);
260        params.success_url = Some(success_url);
261
262        let session = stripe::CheckoutSession::create(&self.client, params).await?;
263        Ok(session.url.context("no checkout session URL")?)
264    }
265
266    pub async fn checkout_with_zed_free(
267        &self,
268        customer_id: stripe::CustomerId,
269        github_login: &str,
270        success_url: &str,
271    ) -> Result<String> {
272        let zed_free_price_id = self.zed_free_price_id().await?;
273
274        let mut params = stripe::CreateCheckoutSession::new();
275        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
276        params.payment_method_collection =
277            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
278        params.customer = Some(customer_id);
279        params.client_reference_id = Some(github_login);
280        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
281            price: Some(zed_free_price_id.to_string()),
282            quantity: Some(1),
283            ..Default::default()
284        }]);
285        params.success_url = Some(success_url);
286
287        let session = stripe::CheckoutSession::create(&self.client, params).await?;
288        Ok(session.url.context("no checkout session URL")?)
289    }
290}
291
292#[derive(Clone, Deserialize)]
293struct StripeMeter {
294    id: String,
295    event_name: String,
296}
297
298impl StripeMeter {
299    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
300        #[derive(Serialize)]
301        struct Params {
302            #[serde(skip_serializing_if = "Option::is_none")]
303            limit: Option<u64>,
304        }
305
306        client.get_query("/billing/meters", Params { limit: Some(100) })
307    }
308}
309
310#[derive(Deserialize)]
311struct StripeMeterEvent {
312    identifier: String,
313}
314
315impl StripeMeterEvent {
316    pub async fn create(
317        client: &stripe::Client,
318        params: StripeCreateMeterEventParams<'_>,
319    ) -> Result<Self, stripe::StripeError> {
320        let identifier = params.identifier;
321        match client.post_form("/billing/meter_events", params).await {
322            Ok(event) => Ok(event),
323            Err(stripe::StripeError::Stripe(error)) => {
324                if error.http_status == 400
325                    && error
326                        .message
327                        .as_ref()
328                        .map_or(false, |message| message.contains(identifier))
329                {
330                    Ok(Self {
331                        identifier: identifier.to_string(),
332                    })
333                } else {
334                    Err(stripe::StripeError::Stripe(error))
335                }
336            }
337            Err(error) => Err(error),
338        }
339    }
340}
341
342#[derive(Serialize)]
343struct StripeCreateMeterEventParams<'a> {
344    identifier: &'a str,
345    event_name: &'a str,
346    payload: StripeCreateMeterEventPayload<'a>,
347    timestamp: Option<i64>,
348}
349
350#[derive(Serialize)]
351struct StripeCreateMeterEventPayload<'a> {
352    value: u64,
353    stripe_customer_id: &'a stripe::CustomerId,
354}
355
356fn subscription_contains_price(
357    subscription: &stripe::Subscription,
358    price_id: &stripe::PriceId,
359) -> bool {
360    subscription.items.data.iter().any(|item| {
361        item.price
362            .as_ref()
363            .map_or(false, |price| price.id == *price_id)
364    })
365}