1use std::sync::Arc;
2
3use anyhow::{Context as _, anyhow};
4use chrono::Utc;
5use collections::HashMap;
6use serde::{Deserialize, Serialize};
7use stripe::SubscriptionStatus;
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use crate::Result;
12use crate::db::billing_subscription::SubscriptionKind;
13use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
14use crate::stripe_client::{
15 RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
16};
17
18pub struct StripeBilling {
19 state: RwLock<StripeBillingState>,
20 real_client: Arc<stripe::Client>,
21 client: Arc<dyn StripeClient>,
22}
23
24#[derive(Default)]
25struct StripeBillingState {
26 meters_by_event_name: HashMap<String, StripeMeter>,
27 price_ids_by_meter_id: HashMap<String, StripePriceId>,
28 prices_by_lookup_key: HashMap<String, StripePrice>,
29}
30
31impl StripeBilling {
32 pub fn new(client: Arc<stripe::Client>) -> Self {
33 Self {
34 client: Arc::new(RealStripeClient::new(client.clone())),
35 real_client: client,
36 state: RwLock::default(),
37 }
38 }
39
40 #[cfg(test)]
41 pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
42 Self {
43 // This is just temporary until we can remove all usages of the real Stripe client.
44 real_client: Arc::new(stripe::Client::new("sk_test")),
45 client,
46 state: RwLock::default(),
47 }
48 }
49
50 pub async fn initialize(&self) -> Result<()> {
51 log::info!("StripeBilling: initializing");
52
53 let mut state = self.state.write().await;
54
55 let (meters, prices) =
56 futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
57
58 for meter in meters {
59 state
60 .meters_by_event_name
61 .insert(meter.event_name.clone(), meter);
62 }
63
64 for price in prices {
65 if let Some(lookup_key) = price.lookup_key.clone() {
66 state.prices_by_lookup_key.insert(lookup_key, price.clone());
67 }
68
69 if let Some(recurring) = price.recurring {
70 if let Some(meter) = recurring.meter {
71 state.price_ids_by_meter_id.insert(meter, price.id);
72 }
73 }
74 }
75
76 log::info!("StripeBilling: initialized");
77
78 Ok(())
79 }
80
81 pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
82 self.find_price_id_by_lookup_key("zed-pro").await
83 }
84
85 pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
86 self.find_price_id_by_lookup_key("zed-free").await
87 }
88
89 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
90 self.state
91 .read()
92 .await
93 .prices_by_lookup_key
94 .get(lookup_key)
95 .map(|price| price.id.clone())
96 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
97 }
98
99 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
100 self.state
101 .read()
102 .await
103 .prices_by_lookup_key
104 .get(lookup_key)
105 .cloned()
106 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
107 }
108
109 pub async fn determine_subscription_kind(
110 &self,
111 subscription: &stripe::Subscription,
112 ) -> Option<SubscriptionKind> {
113 let zed_pro_price_id: stripe::PriceId =
114 self.zed_pro_price_id().await.ok()?.try_into().ok()?;
115 let zed_free_price_id: stripe::PriceId =
116 self.zed_free_price_id().await.ok()?.try_into().ok()?;
117
118 subscription.items.data.iter().find_map(|item| {
119 let price = item.price.as_ref()?;
120
121 if price.id == zed_pro_price_id {
122 Some(if subscription.status == SubscriptionStatus::Trialing {
123 SubscriptionKind::ZedProTrial
124 } else {
125 SubscriptionKind::ZedPro
126 })
127 } else if price.id == zed_free_price_id {
128 Some(SubscriptionKind::ZedFree)
129 } else {
130 None
131 }
132 })
133 }
134
135 /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
136 /// not already exist.
137 ///
138 /// Always returns a new Stripe customer if the email address is `None`.
139 pub async fn find_or_create_customer_by_email(
140 &self,
141 email_address: Option<&str>,
142 ) -> Result<StripeCustomerId> {
143 let existing_customer = if let Some(email) = email_address {
144 let customers = self.client.list_customers_by_email(email).await?;
145
146 customers.first().cloned()
147 } else {
148 None
149 };
150
151 let customer_id = if let Some(existing_customer) = existing_customer {
152 existing_customer.id
153 } else {
154 let customer = self
155 .client
156 .create_customer(crate::stripe_client::CreateCustomerParams {
157 email: email_address,
158 })
159 .await?;
160
161 customer.id
162 };
163
164 Ok(customer_id)
165 }
166
167 pub async fn subscribe_to_price(
168 &self,
169 subscription_id: &stripe::SubscriptionId,
170 price: &StripePrice,
171 ) -> Result<()> {
172 let subscription =
173 stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
174
175 let price_id = price.id.clone().try_into()?;
176 if subscription_contains_price(&subscription, &price_id) {
177 return Ok(());
178 }
179
180 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
181
182 let price_per_unit = price.unit_amount.unwrap_or_default();
183 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
184
185 stripe::Subscription::update(
186 &self.real_client,
187 subscription_id,
188 stripe::UpdateSubscription {
189 items: Some(vec![stripe::UpdateSubscriptionItems {
190 price: Some(price.id.to_string()),
191 ..Default::default()
192 }]),
193 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
194 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
195 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
196 },
197 }),
198 ..Default::default()
199 },
200 )
201 .await?;
202
203 Ok(())
204 }
205
206 pub async fn bill_model_request_usage(
207 &self,
208 customer_id: &stripe::CustomerId,
209 event_name: &str,
210 requests: i32,
211 ) -> Result<()> {
212 let timestamp = Utc::now().timestamp();
213 let idempotency_key = Uuid::new_v4();
214
215 StripeMeterEvent::create(
216 &self.real_client,
217 StripeCreateMeterEventParams {
218 identifier: &format!("model_requests/{}", idempotency_key),
219 event_name,
220 payload: StripeCreateMeterEventPayload {
221 value: requests as u64,
222 stripe_customer_id: customer_id,
223 },
224 timestamp: Some(timestamp),
225 },
226 )
227 .await?;
228
229 Ok(())
230 }
231
232 pub async fn checkout_with_zed_pro(
233 &self,
234 customer_id: stripe::CustomerId,
235 github_login: &str,
236 success_url: &str,
237 ) -> Result<String> {
238 let zed_pro_price_id = self.zed_pro_price_id().await?;
239
240 let mut params = stripe::CreateCheckoutSession::new();
241 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
242 params.customer = Some(customer_id);
243 params.client_reference_id = Some(github_login);
244 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
245 price: Some(zed_pro_price_id.to_string()),
246 quantity: Some(1),
247 ..Default::default()
248 }]);
249 params.success_url = Some(success_url);
250
251 let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
252 Ok(session.url.context("no checkout session URL")?)
253 }
254
255 pub async fn checkout_with_zed_pro_trial(
256 &self,
257 customer_id: stripe::CustomerId,
258 github_login: &str,
259 feature_flags: Vec<String>,
260 success_url: &str,
261 ) -> Result<String> {
262 let zed_pro_price_id = self.zed_pro_price_id().await?;
263
264 let eligible_for_extended_trial = feature_flags
265 .iter()
266 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
267
268 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
269
270 let mut subscription_metadata = std::collections::HashMap::new();
271 if eligible_for_extended_trial {
272 subscription_metadata.insert(
273 "promo_feature_flag".to_string(),
274 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
275 );
276 }
277
278 let mut params = stripe::CreateCheckoutSession::new();
279 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
280 trial_period_days: Some(trial_period_days),
281 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
282 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
283 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
284 }
285 }),
286 metadata: if !subscription_metadata.is_empty() {
287 Some(subscription_metadata)
288 } else {
289 None
290 },
291 ..Default::default()
292 });
293 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
294 params.payment_method_collection =
295 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
296 params.customer = Some(customer_id);
297 params.client_reference_id = Some(github_login);
298 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
299 price: Some(zed_pro_price_id.to_string()),
300 quantity: Some(1),
301 ..Default::default()
302 }]);
303 params.success_url = Some(success_url);
304
305 let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
306 Ok(session.url.context("no checkout session URL")?)
307 }
308
309 pub async fn subscribe_to_zed_free(
310 &self,
311 customer_id: stripe::CustomerId,
312 ) -> Result<stripe::Subscription> {
313 let zed_free_price_id = self.zed_free_price_id().await?;
314
315 let existing_subscriptions = stripe::Subscription::list(
316 &self.real_client,
317 &stripe::ListSubscriptions {
318 customer: Some(customer_id.clone()),
319 status: None,
320 ..Default::default()
321 },
322 )
323 .await?;
324
325 let existing_active_subscription =
326 existing_subscriptions
327 .data
328 .into_iter()
329 .find(|subscription| {
330 subscription.status == SubscriptionStatus::Active
331 || subscription.status == SubscriptionStatus::Trialing
332 });
333 if let Some(subscription) = existing_active_subscription {
334 return Ok(subscription);
335 }
336
337 let mut params = stripe::CreateSubscription::new(customer_id);
338 params.items = Some(vec![stripe::CreateSubscriptionItems {
339 price: Some(zed_free_price_id.to_string()),
340 quantity: Some(1),
341 ..Default::default()
342 }]);
343
344 let subscription = stripe::Subscription::create(&self.real_client, params).await?;
345
346 Ok(subscription)
347 }
348
349 pub async fn checkout_with_zed_free(
350 &self,
351 customer_id: stripe::CustomerId,
352 github_login: &str,
353 success_url: &str,
354 ) -> Result<String> {
355 let zed_free_price_id = self.zed_free_price_id().await?;
356
357 let mut params = stripe::CreateCheckoutSession::new();
358 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
359 params.payment_method_collection =
360 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
361 params.customer = Some(customer_id);
362 params.client_reference_id = Some(github_login);
363 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
364 price: Some(zed_free_price_id.to_string()),
365 quantity: Some(1),
366 ..Default::default()
367 }]);
368 params.success_url = Some(success_url);
369
370 let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
371 Ok(session.url.context("no checkout session URL")?)
372 }
373}
374
375#[derive(Deserialize)]
376struct StripeMeterEvent {
377 identifier: String,
378}
379
380impl StripeMeterEvent {
381 pub async fn create(
382 client: &stripe::Client,
383 params: StripeCreateMeterEventParams<'_>,
384 ) -> Result<Self, stripe::StripeError> {
385 let identifier = params.identifier;
386 match client.post_form("/billing/meter_events", params).await {
387 Ok(event) => Ok(event),
388 Err(stripe::StripeError::Stripe(error)) => {
389 if error.http_status == 400
390 && error
391 .message
392 .as_ref()
393 .map_or(false, |message| message.contains(identifier))
394 {
395 Ok(Self {
396 identifier: identifier.to_string(),
397 })
398 } else {
399 Err(stripe::StripeError::Stripe(error))
400 }
401 }
402 Err(error) => Err(error),
403 }
404 }
405}
406
407#[derive(Serialize)]
408struct StripeCreateMeterEventParams<'a> {
409 identifier: &'a str,
410 event_name: &'a str,
411 payload: StripeCreateMeterEventPayload<'a>,
412 timestamp: Option<i64>,
413}
414
415#[derive(Serialize)]
416struct StripeCreateMeterEventPayload<'a> {
417 value: u64,
418 stripe_customer_id: &'a stripe::CustomerId,
419}
420
421fn subscription_contains_price(
422 subscription: &stripe::Subscription,
423 price_id: &stripe::PriceId,
424) -> bool {
425 subscription.items.data.iter().any(|item| {
426 item.price
427 .as_ref()
428 .map_or(false, |price| price.id == *price_id)
429 })
430}