1use std::sync::Arc;
2
3use anyhow::{Context as _, anyhow};
4use chrono::Utc;
5use collections::HashMap;
6use serde::{Deserialize, Serialize};
7use stripe::{PriceId, SubscriptionStatus};
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use crate::Result;
12use crate::db::billing_subscription::SubscriptionKind;
13use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
14use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
15
16pub struct StripeBilling {
17 state: RwLock<StripeBillingState>,
18 real_client: Arc<stripe::Client>,
19 client: Arc<dyn StripeClient>,
20}
21
22#[derive(Default)]
23struct StripeBillingState {
24 meters_by_event_name: HashMap<String, StripeMeter>,
25 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
26 prices_by_lookup_key: HashMap<String, stripe::Price>,
27}
28
29impl StripeBilling {
30 pub fn new(client: Arc<stripe::Client>) -> Self {
31 Self {
32 client: Arc::new(RealStripeClient::new(client.clone())),
33 real_client: client,
34 state: RwLock::default(),
35 }
36 }
37
38 #[cfg(test)]
39 pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
40 Self {
41 // This is just temporary until we can remove all usages of the real Stripe client.
42 real_client: Arc::new(stripe::Client::new("sk_test")),
43 client,
44 state: RwLock::default(),
45 }
46 }
47
48 pub async fn initialize(&self) -> Result<()> {
49 log::info!("StripeBilling: initializing");
50
51 let mut state = self.state.write().await;
52
53 let (meters, prices) = futures::try_join!(
54 StripeMeter::list(&self.real_client),
55 stripe::Price::list(
56 &self.real_client,
57 &stripe::ListPrices {
58 limit: Some(100),
59 ..Default::default()
60 }
61 )
62 )?;
63
64 for meter in meters.data {
65 state
66 .meters_by_event_name
67 .insert(meter.event_name.clone(), meter);
68 }
69
70 for price in prices.data {
71 if let Some(lookup_key) = price.lookup_key.clone() {
72 state.prices_by_lookup_key.insert(lookup_key, price.clone());
73 }
74
75 if let Some(recurring) = price.recurring {
76 if let Some(meter) = recurring.meter {
77 state.price_ids_by_meter_id.insert(meter, price.id);
78 }
79 }
80 }
81
82 log::info!("StripeBilling: initialized");
83
84 Ok(())
85 }
86
87 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
88 self.find_price_id_by_lookup_key("zed-pro").await
89 }
90
91 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
92 self.find_price_id_by_lookup_key("zed-free").await
93 }
94
95 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
96 self.state
97 .read()
98 .await
99 .prices_by_lookup_key
100 .get(lookup_key)
101 .map(|price| price.id.clone())
102 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
103 }
104
105 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
106 self.state
107 .read()
108 .await
109 .prices_by_lookup_key
110 .get(lookup_key)
111 .cloned()
112 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
113 }
114
115 pub async fn determine_subscription_kind(
116 &self,
117 subscription: &stripe::Subscription,
118 ) -> Option<SubscriptionKind> {
119 let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
120 let zed_free_price_id = self.zed_free_price_id().await.ok()?;
121
122 subscription.items.data.iter().find_map(|item| {
123 let price = item.price.as_ref()?;
124
125 if price.id == zed_pro_price_id {
126 Some(if subscription.status == SubscriptionStatus::Trialing {
127 SubscriptionKind::ZedProTrial
128 } else {
129 SubscriptionKind::ZedPro
130 })
131 } else if price.id == zed_free_price_id {
132 Some(SubscriptionKind::ZedFree)
133 } else {
134 None
135 }
136 })
137 }
138
139 /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
140 /// not already exist.
141 ///
142 /// Always returns a new Stripe customer if the email address is `None`.
143 pub async fn find_or_create_customer_by_email(
144 &self,
145 email_address: Option<&str>,
146 ) -> Result<StripeCustomerId> {
147 let existing_customer = if let Some(email) = email_address {
148 let customers = self.client.list_customers_by_email(email).await?;
149
150 customers.first().cloned()
151 } else {
152 None
153 };
154
155 let customer_id = if let Some(existing_customer) = existing_customer {
156 existing_customer.id
157 } else {
158 let customer = self
159 .client
160 .create_customer(crate::stripe_client::CreateCustomerParams {
161 email: email_address,
162 })
163 .await?;
164
165 customer.id
166 };
167
168 Ok(customer_id)
169 }
170
171 pub async fn subscribe_to_price(
172 &self,
173 subscription_id: &stripe::SubscriptionId,
174 price: &stripe::Price,
175 ) -> Result<()> {
176 let subscription =
177 stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
178
179 if subscription_contains_price(&subscription, &price.id) {
180 return Ok(());
181 }
182
183 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
184
185 let price_per_unit = price.unit_amount.unwrap_or_default();
186 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
187
188 stripe::Subscription::update(
189 &self.real_client,
190 subscription_id,
191 stripe::UpdateSubscription {
192 items: Some(vec![stripe::UpdateSubscriptionItems {
193 price: Some(price.id.to_string()),
194 ..Default::default()
195 }]),
196 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
197 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
198 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
199 },
200 }),
201 ..Default::default()
202 },
203 )
204 .await?;
205
206 Ok(())
207 }
208
209 pub async fn bill_model_request_usage(
210 &self,
211 customer_id: &stripe::CustomerId,
212 event_name: &str,
213 requests: i32,
214 ) -> Result<()> {
215 let timestamp = Utc::now().timestamp();
216 let idempotency_key = Uuid::new_v4();
217
218 StripeMeterEvent::create(
219 &self.real_client,
220 StripeCreateMeterEventParams {
221 identifier: &format!("model_requests/{}", idempotency_key),
222 event_name,
223 payload: StripeCreateMeterEventPayload {
224 value: requests as u64,
225 stripe_customer_id: customer_id,
226 },
227 timestamp: Some(timestamp),
228 },
229 )
230 .await?;
231
232 Ok(())
233 }
234
235 pub async fn checkout_with_zed_pro(
236 &self,
237 customer_id: stripe::CustomerId,
238 github_login: &str,
239 success_url: &str,
240 ) -> Result<String> {
241 let zed_pro_price_id = self.zed_pro_price_id().await?;
242
243 let mut params = stripe::CreateCheckoutSession::new();
244 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
245 params.customer = Some(customer_id);
246 params.client_reference_id = Some(github_login);
247 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
248 price: Some(zed_pro_price_id.to_string()),
249 quantity: Some(1),
250 ..Default::default()
251 }]);
252 params.success_url = Some(success_url);
253
254 let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
255 Ok(session.url.context("no checkout session URL")?)
256 }
257
258 pub async fn checkout_with_zed_pro_trial(
259 &self,
260 customer_id: stripe::CustomerId,
261 github_login: &str,
262 feature_flags: Vec<String>,
263 success_url: &str,
264 ) -> Result<String> {
265 let zed_pro_price_id = self.zed_pro_price_id().await?;
266
267 let eligible_for_extended_trial = feature_flags
268 .iter()
269 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
270
271 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
272
273 let mut subscription_metadata = std::collections::HashMap::new();
274 if eligible_for_extended_trial {
275 subscription_metadata.insert(
276 "promo_feature_flag".to_string(),
277 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
278 );
279 }
280
281 let mut params = stripe::CreateCheckoutSession::new();
282 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
283 trial_period_days: Some(trial_period_days),
284 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
285 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
286 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
287 }
288 }),
289 metadata: if !subscription_metadata.is_empty() {
290 Some(subscription_metadata)
291 } else {
292 None
293 },
294 ..Default::default()
295 });
296 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
297 params.payment_method_collection =
298 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
299 params.customer = Some(customer_id);
300 params.client_reference_id = Some(github_login);
301 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
302 price: Some(zed_pro_price_id.to_string()),
303 quantity: Some(1),
304 ..Default::default()
305 }]);
306 params.success_url = Some(success_url);
307
308 let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
309 Ok(session.url.context("no checkout session URL")?)
310 }
311
312 pub async fn subscribe_to_zed_free(
313 &self,
314 customer_id: stripe::CustomerId,
315 ) -> Result<stripe::Subscription> {
316 let zed_free_price_id = self.zed_free_price_id().await?;
317
318 let existing_subscriptions = stripe::Subscription::list(
319 &self.real_client,
320 &stripe::ListSubscriptions {
321 customer: Some(customer_id.clone()),
322 status: None,
323 ..Default::default()
324 },
325 )
326 .await?;
327
328 let existing_active_subscription =
329 existing_subscriptions
330 .data
331 .into_iter()
332 .find(|subscription| {
333 subscription.status == SubscriptionStatus::Active
334 || subscription.status == SubscriptionStatus::Trialing
335 });
336 if let Some(subscription) = existing_active_subscription {
337 return Ok(subscription);
338 }
339
340 let mut params = stripe::CreateSubscription::new(customer_id);
341 params.items = Some(vec![stripe::CreateSubscriptionItems {
342 price: Some(zed_free_price_id.to_string()),
343 quantity: Some(1),
344 ..Default::default()
345 }]);
346
347 let subscription = stripe::Subscription::create(&self.real_client, params).await?;
348
349 Ok(subscription)
350 }
351
352 pub async fn checkout_with_zed_free(
353 &self,
354 customer_id: stripe::CustomerId,
355 github_login: &str,
356 success_url: &str,
357 ) -> Result<String> {
358 let zed_free_price_id = self.zed_free_price_id().await?;
359
360 let mut params = stripe::CreateCheckoutSession::new();
361 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
362 params.payment_method_collection =
363 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
364 params.customer = Some(customer_id);
365 params.client_reference_id = Some(github_login);
366 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
367 price: Some(zed_free_price_id.to_string()),
368 quantity: Some(1),
369 ..Default::default()
370 }]);
371 params.success_url = Some(success_url);
372
373 let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
374 Ok(session.url.context("no checkout session URL")?)
375 }
376}
377
378#[derive(Clone, Deserialize)]
379struct StripeMeter {
380 id: String,
381 event_name: String,
382}
383
384impl StripeMeter {
385 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
386 #[derive(Serialize)]
387 struct Params {
388 #[serde(skip_serializing_if = "Option::is_none")]
389 limit: Option<u64>,
390 }
391
392 client.get_query("/billing/meters", Params { limit: Some(100) })
393 }
394}
395
396#[derive(Deserialize)]
397struct StripeMeterEvent {
398 identifier: String,
399}
400
401impl StripeMeterEvent {
402 pub async fn create(
403 client: &stripe::Client,
404 params: StripeCreateMeterEventParams<'_>,
405 ) -> Result<Self, stripe::StripeError> {
406 let identifier = params.identifier;
407 match client.post_form("/billing/meter_events", params).await {
408 Ok(event) => Ok(event),
409 Err(stripe::StripeError::Stripe(error)) => {
410 if error.http_status == 400
411 && error
412 .message
413 .as_ref()
414 .map_or(false, |message| message.contains(identifier))
415 {
416 Ok(Self {
417 identifier: identifier.to_string(),
418 })
419 } else {
420 Err(stripe::StripeError::Stripe(error))
421 }
422 }
423 Err(error) => Err(error),
424 }
425 }
426}
427
428#[derive(Serialize)]
429struct StripeCreateMeterEventParams<'a> {
430 identifier: &'a str,
431 event_name: &'a str,
432 payload: StripeCreateMeterEventPayload<'a>,
433 timestamp: Option<i64>,
434}
435
436#[derive(Serialize)]
437struct StripeCreateMeterEventPayload<'a> {
438 value: u64,
439 stripe_customer_id: &'a stripe::CustomerId,
440}
441
442fn subscription_contains_price(
443 subscription: &stripe::Subscription,
444 price_id: &stripe::PriceId,
445) -> bool {
446 subscription.items.data.iter().any(|item| {
447 item.price
448 .as_ref()
449 .map_or(false, |price| price.id == *price_id)
450 })
451}