1use std::sync::Arc;
2
3use crate::Result;
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
6use anyhow::{Context as _, anyhow};
7use chrono::Utc;
8use collections::HashMap;
9use serde::{Deserialize, Serialize};
10use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct StripeBilling {
15 state: RwLock<StripeBillingState>,
16 client: Arc<stripe::Client>,
17}
18
19#[derive(Default)]
20struct StripeBillingState {
21 meters_by_event_name: HashMap<String, StripeMeter>,
22 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
23 prices_by_lookup_key: HashMap<String, stripe::Price>,
24}
25
26impl StripeBilling {
27 pub fn new(client: Arc<stripe::Client>) -> Self {
28 Self {
29 client,
30 state: RwLock::default(),
31 }
32 }
33
34 pub async fn initialize(&self) -> Result<()> {
35 log::info!("StripeBilling: initializing");
36
37 let mut state = self.state.write().await;
38
39 let (meters, prices) = futures::try_join!(
40 StripeMeter::list(&self.client),
41 stripe::Price::list(
42 &self.client,
43 &stripe::ListPrices {
44 limit: Some(100),
45 ..Default::default()
46 }
47 )
48 )?;
49
50 for meter in meters.data {
51 state
52 .meters_by_event_name
53 .insert(meter.event_name.clone(), meter);
54 }
55
56 for price in prices.data {
57 if let Some(lookup_key) = price.lookup_key.clone() {
58 state.prices_by_lookup_key.insert(lookup_key, price.clone());
59 }
60
61 if let Some(recurring) = price.recurring {
62 if let Some(meter) = recurring.meter {
63 state.price_ids_by_meter_id.insert(meter, price.id);
64 }
65 }
66 }
67
68 log::info!("StripeBilling: initialized");
69
70 Ok(())
71 }
72
73 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
74 self.find_price_id_by_lookup_key("zed-pro").await
75 }
76
77 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
78 self.find_price_id_by_lookup_key("zed-free").await
79 }
80
81 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
82 self.state
83 .read()
84 .await
85 .prices_by_lookup_key
86 .get(lookup_key)
87 .map(|price| price.id.clone())
88 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
89 }
90
91 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
92 self.state
93 .read()
94 .await
95 .prices_by_lookup_key
96 .get(lookup_key)
97 .cloned()
98 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
99 }
100
101 pub async fn determine_subscription_kind(
102 &self,
103 subscription: &stripe::Subscription,
104 ) -> Option<SubscriptionKind> {
105 let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
106 let zed_free_price_id = self.zed_free_price_id().await.ok()?;
107
108 subscription.items.data.iter().find_map(|item| {
109 let price = item.price.as_ref()?;
110
111 if price.id == zed_pro_price_id {
112 Some(if subscription.status == SubscriptionStatus::Trialing {
113 SubscriptionKind::ZedProTrial
114 } else {
115 SubscriptionKind::ZedPro
116 })
117 } else if price.id == zed_free_price_id {
118 Some(SubscriptionKind::ZedFree)
119 } else {
120 None
121 }
122 })
123 }
124
125 /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
126 /// not already exist.
127 ///
128 /// Always returns a new Stripe customer if the email address is `None`.
129 pub async fn find_or_create_customer_by_email(
130 &self,
131 email_address: Option<&str>,
132 ) -> Result<CustomerId> {
133 let existing_customer = if let Some(email) = email_address {
134 let customers = Customer::list(
135 &self.client,
136 &stripe::ListCustomers {
137 email: Some(email),
138 ..Default::default()
139 },
140 )
141 .await?;
142
143 customers.data.first().cloned()
144 } else {
145 None
146 };
147
148 let customer_id = if let Some(existing_customer) = existing_customer {
149 existing_customer.id
150 } else {
151 let customer = Customer::create(
152 &self.client,
153 CreateCustomer {
154 email: email_address,
155 ..Default::default()
156 },
157 )
158 .await?;
159
160 customer.id
161 };
162
163 Ok(customer_id)
164 }
165
166 pub async fn subscribe_to_price(
167 &self,
168 subscription_id: &stripe::SubscriptionId,
169 price: &stripe::Price,
170 ) -> Result<()> {
171 let subscription =
172 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
173
174 if subscription_contains_price(&subscription, &price.id) {
175 return Ok(());
176 }
177
178 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
179
180 let price_per_unit = price.unit_amount.unwrap_or_default();
181 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
182
183 stripe::Subscription::update(
184 &self.client,
185 subscription_id,
186 stripe::UpdateSubscription {
187 items: Some(vec![stripe::UpdateSubscriptionItems {
188 price: Some(price.id.to_string()),
189 ..Default::default()
190 }]),
191 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
192 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
193 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
194 },
195 }),
196 ..Default::default()
197 },
198 )
199 .await?;
200
201 Ok(())
202 }
203
204 pub async fn bill_model_request_usage(
205 &self,
206 customer_id: &stripe::CustomerId,
207 event_name: &str,
208 requests: i32,
209 ) -> Result<()> {
210 let timestamp = Utc::now().timestamp();
211 let idempotency_key = Uuid::new_v4();
212
213 StripeMeterEvent::create(
214 &self.client,
215 StripeCreateMeterEventParams {
216 identifier: &format!("model_requests/{}", idempotency_key),
217 event_name,
218 payload: StripeCreateMeterEventPayload {
219 value: requests as u64,
220 stripe_customer_id: customer_id,
221 },
222 timestamp: Some(timestamp),
223 },
224 )
225 .await?;
226
227 Ok(())
228 }
229
230 pub async fn checkout_with_zed_pro(
231 &self,
232 customer_id: stripe::CustomerId,
233 github_login: &str,
234 success_url: &str,
235 ) -> Result<String> {
236 let zed_pro_price_id = self.zed_pro_price_id().await?;
237
238 let mut params = stripe::CreateCheckoutSession::new();
239 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
240 params.customer = Some(customer_id);
241 params.client_reference_id = Some(github_login);
242 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
243 price: Some(zed_pro_price_id.to_string()),
244 quantity: Some(1),
245 ..Default::default()
246 }]);
247 params.success_url = Some(success_url);
248
249 let session = stripe::CheckoutSession::create(&self.client, params).await?;
250 Ok(session.url.context("no checkout session URL")?)
251 }
252
253 pub async fn checkout_with_zed_pro_trial(
254 &self,
255 customer_id: stripe::CustomerId,
256 github_login: &str,
257 feature_flags: Vec<String>,
258 success_url: &str,
259 ) -> Result<String> {
260 let zed_pro_price_id = self.zed_pro_price_id().await?;
261
262 let eligible_for_extended_trial = feature_flags
263 .iter()
264 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
265
266 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
267
268 let mut subscription_metadata = std::collections::HashMap::new();
269 if eligible_for_extended_trial {
270 subscription_metadata.insert(
271 "promo_feature_flag".to_string(),
272 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
273 );
274 }
275
276 let mut params = stripe::CreateCheckoutSession::new();
277 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
278 trial_period_days: Some(trial_period_days),
279 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
280 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
281 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
282 }
283 }),
284 metadata: if !subscription_metadata.is_empty() {
285 Some(subscription_metadata)
286 } else {
287 None
288 },
289 ..Default::default()
290 });
291 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
292 params.payment_method_collection =
293 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
294 params.customer = Some(customer_id);
295 params.client_reference_id = Some(github_login);
296 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
297 price: Some(zed_pro_price_id.to_string()),
298 quantity: Some(1),
299 ..Default::default()
300 }]);
301 params.success_url = Some(success_url);
302
303 let session = stripe::CheckoutSession::create(&self.client, params).await?;
304 Ok(session.url.context("no checkout session URL")?)
305 }
306
307 pub async fn subscribe_to_zed_free(
308 &self,
309 customer_id: stripe::CustomerId,
310 ) -> Result<stripe::Subscription> {
311 let zed_free_price_id = self.zed_free_price_id().await?;
312
313 let existing_subscriptions = stripe::Subscription::list(
314 &self.client,
315 &stripe::ListSubscriptions {
316 customer: Some(customer_id.clone()),
317 status: None,
318 ..Default::default()
319 },
320 )
321 .await?;
322
323 let existing_zed_free_subscription =
324 existing_subscriptions
325 .data
326 .into_iter()
327 .find(|subscription| {
328 subscription.status == SubscriptionStatus::Active
329 && subscription.items.data.iter().any(|item| {
330 item.price
331 .as_ref()
332 .map_or(false, |price| price.id == zed_free_price_id)
333 })
334 });
335 if let Some(subscription) = existing_zed_free_subscription {
336 return Ok(subscription);
337 }
338
339 let mut params = stripe::CreateSubscription::new(customer_id);
340 params.items = Some(vec![stripe::CreateSubscriptionItems {
341 price: Some(zed_free_price_id.to_string()),
342 quantity: Some(1),
343 ..Default::default()
344 }]);
345
346 let subscription = stripe::Subscription::create(&self.client, params).await?;
347
348 Ok(subscription)
349 }
350
351 pub async fn checkout_with_zed_free(
352 &self,
353 customer_id: stripe::CustomerId,
354 github_login: &str,
355 success_url: &str,
356 ) -> Result<String> {
357 let zed_free_price_id = self.zed_free_price_id().await?;
358
359 let mut params = stripe::CreateCheckoutSession::new();
360 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
361 params.payment_method_collection =
362 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
363 params.customer = Some(customer_id);
364 params.client_reference_id = Some(github_login);
365 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
366 price: Some(zed_free_price_id.to_string()),
367 quantity: Some(1),
368 ..Default::default()
369 }]);
370 params.success_url = Some(success_url);
371
372 let session = stripe::CheckoutSession::create(&self.client, params).await?;
373 Ok(session.url.context("no checkout session URL")?)
374 }
375}
376
377#[derive(Clone, Deserialize)]
378struct StripeMeter {
379 id: String,
380 event_name: String,
381}
382
383impl StripeMeter {
384 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
385 #[derive(Serialize)]
386 struct Params {
387 #[serde(skip_serializing_if = "Option::is_none")]
388 limit: Option<u64>,
389 }
390
391 client.get_query("/billing/meters", Params { limit: Some(100) })
392 }
393}
394
395#[derive(Deserialize)]
396struct StripeMeterEvent {
397 identifier: String,
398}
399
400impl StripeMeterEvent {
401 pub async fn create(
402 client: &stripe::Client,
403 params: StripeCreateMeterEventParams<'_>,
404 ) -> Result<Self, stripe::StripeError> {
405 let identifier = params.identifier;
406 match client.post_form("/billing/meter_events", params).await {
407 Ok(event) => Ok(event),
408 Err(stripe::StripeError::Stripe(error)) => {
409 if error.http_status == 400
410 && error
411 .message
412 .as_ref()
413 .map_or(false, |message| message.contains(identifier))
414 {
415 Ok(Self {
416 identifier: identifier.to_string(),
417 })
418 } else {
419 Err(stripe::StripeError::Stripe(error))
420 }
421 }
422 Err(error) => Err(error),
423 }
424 }
425}
426
427#[derive(Serialize)]
428struct StripeCreateMeterEventParams<'a> {
429 identifier: &'a str,
430 event_name: &'a str,
431 payload: StripeCreateMeterEventPayload<'a>,
432 timestamp: Option<i64>,
433}
434
435#[derive(Serialize)]
436struct StripeCreateMeterEventPayload<'a> {
437 value: u64,
438 stripe_customer_id: &'a stripe::CustomerId,
439}
440
441fn subscription_contains_price(
442 subscription: &stripe::Subscription,
443 price_id: &stripe::PriceId,
444) -> bool {
445 subscription.items.data.iter().any(|item| {
446 item.price
447 .as_ref()
448 .map_or(false, |price| price.id == *price_id)
449 })
450}