1use std::sync::Arc;
2
3use crate::Result;
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
6use anyhow::{Context as _, anyhow};
7use chrono::Utc;
8use collections::HashMap;
9use serde::{Deserialize, Serialize};
10use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct StripeBilling {
15 state: RwLock<StripeBillingState>,
16 client: Arc<stripe::Client>,
17}
18
19#[derive(Default)]
20struct StripeBillingState {
21 meters_by_event_name: HashMap<String, StripeMeter>,
22 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
23 prices_by_lookup_key: HashMap<String, stripe::Price>,
24}
25
26impl StripeBilling {
27 pub fn new(client: Arc<stripe::Client>) -> Self {
28 Self {
29 client,
30 state: RwLock::default(),
31 }
32 }
33
34 pub async fn initialize(&self) -> Result<()> {
35 log::info!("StripeBilling: initializing");
36
37 let mut state = self.state.write().await;
38
39 let (meters, prices) = futures::try_join!(
40 StripeMeter::list(&self.client),
41 stripe::Price::list(
42 &self.client,
43 &stripe::ListPrices {
44 limit: Some(100),
45 ..Default::default()
46 }
47 )
48 )?;
49
50 for meter in meters.data {
51 state
52 .meters_by_event_name
53 .insert(meter.event_name.clone(), meter);
54 }
55
56 for price in prices.data {
57 if let Some(lookup_key) = price.lookup_key.clone() {
58 state.prices_by_lookup_key.insert(lookup_key, price.clone());
59 }
60
61 if let Some(recurring) = price.recurring {
62 if let Some(meter) = recurring.meter {
63 state.price_ids_by_meter_id.insert(meter, price.id);
64 }
65 }
66 }
67
68 log::info!("StripeBilling: initialized");
69
70 Ok(())
71 }
72
73 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
74 self.find_price_id_by_lookup_key("zed-pro").await
75 }
76
77 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
78 self.find_price_id_by_lookup_key("zed-free").await
79 }
80
81 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
82 self.state
83 .read()
84 .await
85 .prices_by_lookup_key
86 .get(lookup_key)
87 .map(|price| price.id.clone())
88 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
89 }
90
91 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
92 self.state
93 .read()
94 .await
95 .prices_by_lookup_key
96 .get(lookup_key)
97 .cloned()
98 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
99 }
100
101 pub async fn determine_subscription_kind(
102 &self,
103 subscription: &stripe::Subscription,
104 ) -> Option<SubscriptionKind> {
105 let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
106 let zed_free_price_id = self.zed_free_price_id().await.ok()?;
107
108 subscription.items.data.iter().find_map(|item| {
109 let price = item.price.as_ref()?;
110
111 if price.id == zed_pro_price_id {
112 Some(if subscription.status == SubscriptionStatus::Trialing {
113 SubscriptionKind::ZedProTrial
114 } else {
115 SubscriptionKind::ZedPro
116 })
117 } else if price.id == zed_free_price_id {
118 Some(SubscriptionKind::ZedFree)
119 } else {
120 None
121 }
122 })
123 }
124
125 /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
126 /// not already exist.
127 ///
128 /// Always returns a new Stripe customer if the email address is `None`.
129 pub async fn find_or_create_customer_by_email(
130 &self,
131 email_address: Option<&str>,
132 ) -> Result<CustomerId> {
133 let existing_customer = if let Some(email) = email_address {
134 let customers = Customer::list(
135 &self.client,
136 &stripe::ListCustomers {
137 email: Some(email),
138 ..Default::default()
139 },
140 )
141 .await?;
142
143 customers.data.first().cloned()
144 } else {
145 None
146 };
147
148 let customer_id = if let Some(existing_customer) = existing_customer {
149 existing_customer.id
150 } else {
151 let customer = Customer::create(
152 &self.client,
153 CreateCustomer {
154 email: email_address,
155 ..Default::default()
156 },
157 )
158 .await?;
159
160 customer.id
161 };
162
163 Ok(customer_id)
164 }
165
166 pub async fn subscribe_to_price(
167 &self,
168 subscription_id: &stripe::SubscriptionId,
169 price: &stripe::Price,
170 ) -> Result<()> {
171 let subscription =
172 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
173
174 if subscription_contains_price(&subscription, &price.id) {
175 return Ok(());
176 }
177
178 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
179
180 let price_per_unit = price.unit_amount.unwrap_or_default();
181 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
182
183 stripe::Subscription::update(
184 &self.client,
185 subscription_id,
186 stripe::UpdateSubscription {
187 items: Some(vec![stripe::UpdateSubscriptionItems {
188 price: Some(price.id.to_string()),
189 ..Default::default()
190 }]),
191 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
192 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
193 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
194 },
195 }),
196 ..Default::default()
197 },
198 )
199 .await?;
200
201 Ok(())
202 }
203
204 pub async fn bill_model_request_usage(
205 &self,
206 customer_id: &stripe::CustomerId,
207 event_name: &str,
208 requests: i32,
209 ) -> Result<()> {
210 let timestamp = Utc::now().timestamp();
211 let idempotency_key = Uuid::new_v4();
212
213 StripeMeterEvent::create(
214 &self.client,
215 StripeCreateMeterEventParams {
216 identifier: &format!("model_requests/{}", idempotency_key),
217 event_name,
218 payload: StripeCreateMeterEventPayload {
219 value: requests as u64,
220 stripe_customer_id: customer_id,
221 },
222 timestamp: Some(timestamp),
223 },
224 )
225 .await?;
226
227 Ok(())
228 }
229
230 pub async fn checkout_with_zed_pro(
231 &self,
232 customer_id: stripe::CustomerId,
233 github_login: &str,
234 success_url: &str,
235 ) -> Result<String> {
236 let zed_pro_price_id = self.zed_pro_price_id().await?;
237
238 let mut params = stripe::CreateCheckoutSession::new();
239 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
240 params.customer = Some(customer_id);
241 params.client_reference_id = Some(github_login);
242 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
243 price: Some(zed_pro_price_id.to_string()),
244 quantity: Some(1),
245 ..Default::default()
246 }]);
247 params.success_url = Some(success_url);
248
249 let session = stripe::CheckoutSession::create(&self.client, params).await?;
250 Ok(session.url.context("no checkout session URL")?)
251 }
252
253 pub async fn checkout_with_zed_pro_trial(
254 &self,
255 customer_id: stripe::CustomerId,
256 github_login: &str,
257 feature_flags: Vec<String>,
258 success_url: &str,
259 ) -> Result<String> {
260 let zed_pro_price_id = self.zed_pro_price_id().await?;
261
262 let eligible_for_extended_trial = feature_flags
263 .iter()
264 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
265
266 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
267
268 let mut subscription_metadata = std::collections::HashMap::new();
269 if eligible_for_extended_trial {
270 subscription_metadata.insert(
271 "promo_feature_flag".to_string(),
272 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
273 );
274 }
275
276 let mut params = stripe::CreateCheckoutSession::new();
277 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
278 trial_period_days: Some(trial_period_days),
279 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
280 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
281 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
282 }
283 }),
284 metadata: if !subscription_metadata.is_empty() {
285 Some(subscription_metadata)
286 } else {
287 None
288 },
289 ..Default::default()
290 });
291 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
292 params.payment_method_collection =
293 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
294 params.customer = Some(customer_id);
295 params.client_reference_id = Some(github_login);
296 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
297 price: Some(zed_pro_price_id.to_string()),
298 quantity: Some(1),
299 ..Default::default()
300 }]);
301 params.success_url = Some(success_url);
302
303 let session = stripe::CheckoutSession::create(&self.client, params).await?;
304 Ok(session.url.context("no checkout session URL")?)
305 }
306
307 pub async fn subscribe_to_zed_free(
308 &self,
309 customer_id: stripe::CustomerId,
310 ) -> Result<stripe::Subscription> {
311 let zed_free_price_id = self.zed_free_price_id().await?;
312
313 let existing_subscriptions = stripe::Subscription::list(
314 &self.client,
315 &stripe::ListSubscriptions {
316 customer: Some(customer_id.clone()),
317 status: None,
318 ..Default::default()
319 },
320 )
321 .await?;
322
323 let existing_active_subscription =
324 existing_subscriptions
325 .data
326 .into_iter()
327 .find(|subscription| {
328 subscription.status == SubscriptionStatus::Active
329 || subscription.status == SubscriptionStatus::Trialing
330 });
331 if let Some(subscription) = existing_active_subscription {
332 return Ok(subscription);
333 }
334
335 let mut params = stripe::CreateSubscription::new(customer_id);
336 params.items = Some(vec![stripe::CreateSubscriptionItems {
337 price: Some(zed_free_price_id.to_string()),
338 quantity: Some(1),
339 ..Default::default()
340 }]);
341
342 let subscription = stripe::Subscription::create(&self.client, params).await?;
343
344 Ok(subscription)
345 }
346
347 pub async fn checkout_with_zed_free(
348 &self,
349 customer_id: stripe::CustomerId,
350 github_login: &str,
351 success_url: &str,
352 ) -> Result<String> {
353 let zed_free_price_id = self.zed_free_price_id().await?;
354
355 let mut params = stripe::CreateCheckoutSession::new();
356 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
357 params.payment_method_collection =
358 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
359 params.customer = Some(customer_id);
360 params.client_reference_id = Some(github_login);
361 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
362 price: Some(zed_free_price_id.to_string()),
363 quantity: Some(1),
364 ..Default::default()
365 }]);
366 params.success_url = Some(success_url);
367
368 let session = stripe::CheckoutSession::create(&self.client, params).await?;
369 Ok(session.url.context("no checkout session URL")?)
370 }
371}
372
373#[derive(Clone, Deserialize)]
374struct StripeMeter {
375 id: String,
376 event_name: String,
377}
378
379impl StripeMeter {
380 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
381 #[derive(Serialize)]
382 struct Params {
383 #[serde(skip_serializing_if = "Option::is_none")]
384 limit: Option<u64>,
385 }
386
387 client.get_query("/billing/meters", Params { limit: Some(100) })
388 }
389}
390
391#[derive(Deserialize)]
392struct StripeMeterEvent {
393 identifier: String,
394}
395
396impl StripeMeterEvent {
397 pub async fn create(
398 client: &stripe::Client,
399 params: StripeCreateMeterEventParams<'_>,
400 ) -> Result<Self, stripe::StripeError> {
401 let identifier = params.identifier;
402 match client.post_form("/billing/meter_events", params).await {
403 Ok(event) => Ok(event),
404 Err(stripe::StripeError::Stripe(error)) => {
405 if error.http_status == 400
406 && error
407 .message
408 .as_ref()
409 .map_or(false, |message| message.contains(identifier))
410 {
411 Ok(Self {
412 identifier: identifier.to_string(),
413 })
414 } else {
415 Err(stripe::StripeError::Stripe(error))
416 }
417 }
418 Err(error) => Err(error),
419 }
420 }
421}
422
423#[derive(Serialize)]
424struct StripeCreateMeterEventParams<'a> {
425 identifier: &'a str,
426 event_name: &'a str,
427 payload: StripeCreateMeterEventPayload<'a>,
428 timestamp: Option<i64>,
429}
430
431#[derive(Serialize)]
432struct StripeCreateMeterEventPayload<'a> {
433 value: u64,
434 stripe_customer_id: &'a stripe::CustomerId,
435}
436
437fn subscription_contains_price(
438 subscription: &stripe::Subscription,
439 price_id: &stripe::PriceId,
440) -> bool {
441 subscription.items.data.iter().any(|item| {
442 item.price
443 .as_ref()
444 .map_or(false, |price| price.id == *price_id)
445 })
446}