1use std::sync::Arc;
2
3use chrono::{Duration, Utc};
4use pretty_assertions::assert_eq;
5
6use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
7use crate::stripe_billing::StripeBilling;
8use crate::stripe_client::{
9 FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode,
10 StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems,
11 StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeMeter, StripeMeterId,
12 StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
13 StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
14 StripeSubscriptionTrialSettingsEndBehavior,
15 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
16};
17
18fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
19 let stripe_client = Arc::new(FakeStripeClient::new());
20 let stripe_billing = StripeBilling::test(stripe_client.clone());
21
22 (stripe_billing, stripe_client)
23}
24
25#[gpui::test]
26async fn test_initialize() {
27 let (stripe_billing, stripe_client) = make_stripe_billing();
28
29 // Add test meters
30 let meter1 = StripeMeter {
31 id: StripeMeterId("meter_1".into()),
32 event_name: "event_1".to_string(),
33 };
34 let meter2 = StripeMeter {
35 id: StripeMeterId("meter_2".into()),
36 event_name: "event_2".to_string(),
37 };
38 stripe_client
39 .meters
40 .lock()
41 .insert(meter1.id.clone(), meter1);
42 stripe_client
43 .meters
44 .lock()
45 .insert(meter2.id.clone(), meter2);
46
47 // Add test prices
48 let price1 = StripePrice {
49 id: StripePriceId("price_1".into()),
50 unit_amount: Some(1_000),
51 lookup_key: Some("zed-pro".to_string()),
52 recurring: None,
53 };
54 let price2 = StripePrice {
55 id: StripePriceId("price_2".into()),
56 unit_amount: Some(0),
57 lookup_key: Some("zed-free".to_string()),
58 recurring: None,
59 };
60 let price3 = StripePrice {
61 id: StripePriceId("price_3".into()),
62 unit_amount: Some(500),
63 lookup_key: None,
64 recurring: Some(StripePriceRecurring {
65 meter: Some("meter_1".to_string()),
66 }),
67 };
68 stripe_client
69 .prices
70 .lock()
71 .insert(price1.id.clone(), price1);
72 stripe_client
73 .prices
74 .lock()
75 .insert(price2.id.clone(), price2);
76 stripe_client
77 .prices
78 .lock()
79 .insert(price3.id.clone(), price3);
80
81 // Initialize the billing system
82 stripe_billing.initialize().await.unwrap();
83
84 // Verify that prices can be found by lookup key
85 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
86 assert_eq!(zed_pro_price_id.to_string(), "price_1");
87
88 let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
89 assert_eq!(zed_free_price_id.to_string(), "price_2");
90
91 // Verify that a price can be found by lookup key
92 let zed_pro_price = stripe_billing
93 .find_price_by_lookup_key("zed-pro")
94 .await
95 .unwrap();
96 assert_eq!(zed_pro_price.id.to_string(), "price_1");
97 assert_eq!(zed_pro_price.unit_amount, Some(1_000));
98
99 // Verify that finding a non-existent lookup key returns an error
100 let result = stripe_billing
101 .find_price_by_lookup_key("non-existent")
102 .await;
103 assert!(result.is_err());
104}
105
106#[gpui::test]
107async fn test_find_or_create_customer_by_email() {
108 let (stripe_billing, stripe_client) = make_stripe_billing();
109
110 // Create a customer with an email that doesn't yet correspond to a customer.
111 {
112 let email = "user@example.com";
113
114 let customer_id = stripe_billing
115 .find_or_create_customer_by_email(Some(email))
116 .await
117 .unwrap();
118
119 let customer = stripe_client
120 .customers
121 .lock()
122 .get(&customer_id)
123 .unwrap()
124 .clone();
125 assert_eq!(customer.email.as_deref(), Some(email));
126 }
127
128 // Create a customer with an email that corresponds to an existing customer.
129 {
130 let email = "user2@example.com";
131
132 let existing_customer_id = stripe_billing
133 .find_or_create_customer_by_email(Some(email))
134 .await
135 .unwrap();
136
137 let customer_id = stripe_billing
138 .find_or_create_customer_by_email(Some(email))
139 .await
140 .unwrap();
141 assert_eq!(customer_id, existing_customer_id);
142
143 let customer = stripe_client
144 .customers
145 .lock()
146 .get(&customer_id)
147 .unwrap()
148 .clone();
149 assert_eq!(customer.email.as_deref(), Some(email));
150 }
151}
152
153#[gpui::test]
154async fn test_subscribe_to_price() {
155 let (stripe_billing, stripe_client) = make_stripe_billing();
156
157 let price = StripePrice {
158 id: StripePriceId("price_test".into()),
159 unit_amount: Some(2000),
160 lookup_key: Some("test-price".to_string()),
161 recurring: None,
162 };
163 stripe_client
164 .prices
165 .lock()
166 .insert(price.id.clone(), price.clone());
167
168 let now = Utc::now();
169 let subscription = StripeSubscription {
170 id: StripeSubscriptionId("sub_test".into()),
171 customer: StripeCustomerId("cus_test".into()),
172 status: stripe::SubscriptionStatus::Active,
173 current_period_start: now.timestamp(),
174 current_period_end: (now + Duration::days(30)).timestamp(),
175 items: vec![],
176 cancel_at: None,
177 cancellation_details: None,
178 };
179 stripe_client
180 .subscriptions
181 .lock()
182 .insert(subscription.id.clone(), subscription.clone());
183
184 stripe_billing
185 .subscribe_to_price(&subscription.id, &price)
186 .await
187 .unwrap();
188
189 let update_subscription_calls = stripe_client
190 .update_subscription_calls
191 .lock()
192 .iter()
193 .map(|(id, params)| (id.clone(), params.clone()))
194 .collect::<Vec<_>>();
195 assert_eq!(update_subscription_calls.len(), 1);
196 assert_eq!(update_subscription_calls[0].0, subscription.id);
197 assert_eq!(
198 update_subscription_calls[0].1.items,
199 Some(vec![UpdateSubscriptionItems {
200 price: Some(price.id.clone())
201 }])
202 );
203
204 // Subscribing to a price that is already on the subscription is a no-op.
205 {
206 let now = Utc::now();
207 let subscription = StripeSubscription {
208 id: StripeSubscriptionId("sub_test".into()),
209 customer: StripeCustomerId("cus_test".into()),
210 status: stripe::SubscriptionStatus::Active,
211 current_period_start: now.timestamp(),
212 current_period_end: (now + Duration::days(30)).timestamp(),
213 items: vec![StripeSubscriptionItem {
214 id: StripeSubscriptionItemId("si_test".into()),
215 price: Some(price.clone()),
216 }],
217 cancel_at: None,
218 cancellation_details: None,
219 };
220 stripe_client
221 .subscriptions
222 .lock()
223 .insert(subscription.id.clone(), subscription.clone());
224
225 stripe_billing
226 .subscribe_to_price(&subscription.id, &price)
227 .await
228 .unwrap();
229
230 assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
231 }
232}
233
234#[gpui::test]
235async fn test_subscribe_to_zed_free() {
236 let (stripe_billing, stripe_client) = make_stripe_billing();
237
238 let zed_pro_price = StripePrice {
239 id: StripePriceId("price_1".into()),
240 unit_amount: Some(0),
241 lookup_key: Some("zed-pro".to_string()),
242 recurring: None,
243 };
244 stripe_client
245 .prices
246 .lock()
247 .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
248 let zed_free_price = StripePrice {
249 id: StripePriceId("price_2".into()),
250 unit_amount: Some(0),
251 lookup_key: Some("zed-free".to_string()),
252 recurring: None,
253 };
254 stripe_client
255 .prices
256 .lock()
257 .insert(zed_free_price.id.clone(), zed_free_price.clone());
258
259 stripe_billing.initialize().await.unwrap();
260
261 // Customer is subscribed to Zed Free when not already subscribed to a plan.
262 {
263 let customer_id = StripeCustomerId("cus_no_plan".into());
264
265 let subscription = stripe_billing
266 .subscribe_to_zed_free(customer_id)
267 .await
268 .unwrap();
269
270 assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
271 }
272
273 // Customer is not subscribed to Zed Free when they already have an active subscription.
274 {
275 let customer_id = StripeCustomerId("cus_active_subscription".into());
276
277 let now = Utc::now();
278 let existing_subscription = StripeSubscription {
279 id: StripeSubscriptionId("sub_existing_active".into()),
280 customer: customer_id.clone(),
281 status: stripe::SubscriptionStatus::Active,
282 current_period_start: now.timestamp(),
283 current_period_end: (now + Duration::days(30)).timestamp(),
284 items: vec![StripeSubscriptionItem {
285 id: StripeSubscriptionItemId("si_test".into()),
286 price: Some(zed_pro_price.clone()),
287 }],
288 cancel_at: None,
289 cancellation_details: None,
290 };
291 stripe_client.subscriptions.lock().insert(
292 existing_subscription.id.clone(),
293 existing_subscription.clone(),
294 );
295
296 let subscription = stripe_billing
297 .subscribe_to_zed_free(customer_id)
298 .await
299 .unwrap();
300
301 assert_eq!(subscription, existing_subscription);
302 }
303
304 // Customer is not subscribed to Zed Free when they already have a trial subscription.
305 {
306 let customer_id = StripeCustomerId("cus_trial_subscription".into());
307
308 let now = Utc::now();
309 let existing_subscription = StripeSubscription {
310 id: StripeSubscriptionId("sub_existing_trial".into()),
311 customer: customer_id.clone(),
312 status: stripe::SubscriptionStatus::Trialing,
313 current_period_start: now.timestamp(),
314 current_period_end: (now + Duration::days(14)).timestamp(),
315 items: vec![StripeSubscriptionItem {
316 id: StripeSubscriptionItemId("si_test".into()),
317 price: Some(zed_pro_price.clone()),
318 }],
319 cancel_at: None,
320 cancellation_details: None,
321 };
322 stripe_client.subscriptions.lock().insert(
323 existing_subscription.id.clone(),
324 existing_subscription.clone(),
325 );
326
327 let subscription = stripe_billing
328 .subscribe_to_zed_free(customer_id)
329 .await
330 .unwrap();
331
332 assert_eq!(subscription, existing_subscription);
333 }
334}
335
336#[gpui::test]
337async fn test_bill_model_request_usage() {
338 let (stripe_billing, stripe_client) = make_stripe_billing();
339
340 let customer_id = StripeCustomerId("cus_test".into());
341
342 stripe_billing
343 .bill_model_request_usage(&customer_id, "some_model/requests", 73)
344 .await
345 .unwrap();
346
347 let create_meter_event_calls = stripe_client
348 .create_meter_event_calls
349 .lock()
350 .iter()
351 .cloned()
352 .collect::<Vec<_>>();
353 assert_eq!(create_meter_event_calls.len(), 1);
354 assert!(
355 create_meter_event_calls[0]
356 .identifier
357 .starts_with("model_requests/")
358 );
359 assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
360 assert_eq!(
361 create_meter_event_calls[0].event_name.as_ref(),
362 "some_model/requests"
363 );
364 assert_eq!(create_meter_event_calls[0].value, 73);
365}
366
367#[gpui::test]
368async fn test_checkout_with_zed_pro() {
369 let (stripe_billing, stripe_client) = make_stripe_billing();
370
371 let customer_id = StripeCustomerId("cus_test".into());
372 let github_login = "zeduser1";
373 let success_url = "https://example.com/success";
374
375 // It returns an error when the Zed Pro price doesn't exist.
376 {
377 let result = stripe_billing
378 .checkout_with_zed_pro(&customer_id, github_login, success_url)
379 .await;
380
381 assert!(result.is_err());
382 assert_eq!(
383 result.err().unwrap().to_string(),
384 r#"no price ID found for "zed-pro""#
385 );
386 }
387
388 // Successful checkout.
389 {
390 let price = StripePrice {
391 id: StripePriceId("price_1".into()),
392 unit_amount: Some(2000),
393 lookup_key: Some("zed-pro".to_string()),
394 recurring: None,
395 };
396 stripe_client
397 .prices
398 .lock()
399 .insert(price.id.clone(), price.clone());
400
401 stripe_billing.initialize().await.unwrap();
402
403 let checkout_url = stripe_billing
404 .checkout_with_zed_pro(&customer_id, github_login, success_url)
405 .await
406 .unwrap();
407
408 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
409
410 let create_checkout_session_calls = stripe_client
411 .create_checkout_session_calls
412 .lock()
413 .drain(..)
414 .collect::<Vec<_>>();
415 assert_eq!(create_checkout_session_calls.len(), 1);
416 let call = create_checkout_session_calls.into_iter().next().unwrap();
417 assert_eq!(call.customer, Some(customer_id));
418 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
419 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
420 assert_eq!(
421 call.line_items,
422 Some(vec![StripeCreateCheckoutSessionLineItems {
423 price: Some(price.id.to_string()),
424 quantity: Some(1)
425 }])
426 );
427 assert_eq!(call.payment_method_collection, None);
428 assert_eq!(call.subscription_data, None);
429 assert_eq!(call.success_url.as_deref(), Some(success_url));
430 assert_eq!(
431 call.billing_address_collection,
432 Some(StripeBillingAddressCollection::Required)
433 );
434 }
435}
436
437#[gpui::test]
438async fn test_checkout_with_zed_pro_trial() {
439 let (stripe_billing, stripe_client) = make_stripe_billing();
440
441 let customer_id = StripeCustomerId("cus_test".into());
442 let github_login = "zeduser1";
443 let success_url = "https://example.com/success";
444
445 // It returns an error when the Zed Pro price doesn't exist.
446 {
447 let result = stripe_billing
448 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
449 .await;
450
451 assert!(result.is_err());
452 assert_eq!(
453 result.err().unwrap().to_string(),
454 r#"no price ID found for "zed-pro""#
455 );
456 }
457
458 let price = StripePrice {
459 id: StripePriceId("price_1".into()),
460 unit_amount: Some(2000),
461 lookup_key: Some("zed-pro".to_string()),
462 recurring: None,
463 };
464 stripe_client
465 .prices
466 .lock()
467 .insert(price.id.clone(), price.clone());
468
469 stripe_billing.initialize().await.unwrap();
470
471 // Successful checkout.
472 {
473 let checkout_url = stripe_billing
474 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
475 .await
476 .unwrap();
477
478 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
479
480 let create_checkout_session_calls = stripe_client
481 .create_checkout_session_calls
482 .lock()
483 .drain(..)
484 .collect::<Vec<_>>();
485 assert_eq!(create_checkout_session_calls.len(), 1);
486 let call = create_checkout_session_calls.into_iter().next().unwrap();
487 assert_eq!(call.customer.as_ref(), Some(&customer_id));
488 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
489 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
490 assert_eq!(
491 call.line_items,
492 Some(vec![StripeCreateCheckoutSessionLineItems {
493 price: Some(price.id.to_string()),
494 quantity: Some(1)
495 }])
496 );
497 assert_eq!(
498 call.payment_method_collection,
499 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
500 );
501 assert_eq!(
502 call.subscription_data,
503 Some(StripeCreateCheckoutSessionSubscriptionData {
504 trial_period_days: Some(14),
505 trial_settings: Some(StripeSubscriptionTrialSettings {
506 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
507 missing_payment_method:
508 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
509 },
510 }),
511 metadata: None,
512 })
513 );
514 assert_eq!(call.success_url.as_deref(), Some(success_url));
515 assert_eq!(
516 call.billing_address_collection,
517 Some(StripeBillingAddressCollection::Required)
518 );
519 }
520
521 // Successful checkout with extended trial.
522 {
523 let checkout_url = stripe_billing
524 .checkout_with_zed_pro_trial(
525 &customer_id,
526 github_login,
527 vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
528 success_url,
529 )
530 .await
531 .unwrap();
532
533 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
534
535 let create_checkout_session_calls = stripe_client
536 .create_checkout_session_calls
537 .lock()
538 .drain(..)
539 .collect::<Vec<_>>();
540 assert_eq!(create_checkout_session_calls.len(), 1);
541 let call = create_checkout_session_calls.into_iter().next().unwrap();
542 assert_eq!(call.customer, Some(customer_id));
543 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
544 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
545 assert_eq!(
546 call.line_items,
547 Some(vec![StripeCreateCheckoutSessionLineItems {
548 price: Some(price.id.to_string()),
549 quantity: Some(1)
550 }])
551 );
552 assert_eq!(
553 call.payment_method_collection,
554 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
555 );
556 assert_eq!(
557 call.subscription_data,
558 Some(StripeCreateCheckoutSessionSubscriptionData {
559 trial_period_days: Some(60),
560 trial_settings: Some(StripeSubscriptionTrialSettings {
561 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
562 missing_payment_method:
563 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
564 },
565 }),
566 metadata: Some(std::collections::HashMap::from_iter([(
567 "promo_feature_flag".into(),
568 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
569 )])),
570 })
571 );
572 assert_eq!(call.success_url.as_deref(), Some(success_url));
573 assert_eq!(
574 call.billing_address_collection,
575 Some(StripeBillingAddressCollection::Required)
576 );
577 }
578}