1use std::sync::Arc;
2
3use chrono::{Duration, Utc};
4use pretty_assertions::assert_eq;
5
6use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
7use crate::stripe_billing::StripeBilling;
8use crate::stripe_client::{
9 FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode,
10 StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems,
11 StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeCustomerUpdate,
12 StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeMeter, StripeMeterId, StripePrice,
13 StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
14 StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
15 StripeSubscriptionTrialSettingsEndBehavior,
16 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
17};
18
19fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
20 let stripe_client = Arc::new(FakeStripeClient::new());
21 let stripe_billing = StripeBilling::test(stripe_client.clone());
22
23 (stripe_billing, stripe_client)
24}
25
26#[gpui::test]
27async fn test_initialize() {
28 let (stripe_billing, stripe_client) = make_stripe_billing();
29
30 // Add test meters
31 let meter1 = StripeMeter {
32 id: StripeMeterId("meter_1".into()),
33 event_name: "event_1".to_string(),
34 };
35 let meter2 = StripeMeter {
36 id: StripeMeterId("meter_2".into()),
37 event_name: "event_2".to_string(),
38 };
39 stripe_client
40 .meters
41 .lock()
42 .insert(meter1.id.clone(), meter1);
43 stripe_client
44 .meters
45 .lock()
46 .insert(meter2.id.clone(), meter2);
47
48 // Add test prices
49 let price1 = StripePrice {
50 id: StripePriceId("price_1".into()),
51 unit_amount: Some(1_000),
52 lookup_key: Some("zed-pro".to_string()),
53 recurring: None,
54 };
55 let price2 = StripePrice {
56 id: StripePriceId("price_2".into()),
57 unit_amount: Some(0),
58 lookup_key: Some("zed-free".to_string()),
59 recurring: None,
60 };
61 let price3 = StripePrice {
62 id: StripePriceId("price_3".into()),
63 unit_amount: Some(500),
64 lookup_key: None,
65 recurring: Some(StripePriceRecurring {
66 meter: Some("meter_1".to_string()),
67 }),
68 };
69 stripe_client
70 .prices
71 .lock()
72 .insert(price1.id.clone(), price1);
73 stripe_client
74 .prices
75 .lock()
76 .insert(price2.id.clone(), price2);
77 stripe_client
78 .prices
79 .lock()
80 .insert(price3.id.clone(), price3);
81
82 // Initialize the billing system
83 stripe_billing.initialize().await.unwrap();
84
85 // Verify that prices can be found by lookup key
86 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
87 assert_eq!(zed_pro_price_id.to_string(), "price_1");
88
89 let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
90 assert_eq!(zed_free_price_id.to_string(), "price_2");
91
92 // Verify that a price can be found by lookup key
93 let zed_pro_price = stripe_billing
94 .find_price_by_lookup_key("zed-pro")
95 .await
96 .unwrap();
97 assert_eq!(zed_pro_price.id.to_string(), "price_1");
98 assert_eq!(zed_pro_price.unit_amount, Some(1_000));
99
100 // Verify that finding a non-existent lookup key returns an error
101 let result = stripe_billing
102 .find_price_by_lookup_key("non-existent")
103 .await;
104 assert!(result.is_err());
105}
106
107#[gpui::test]
108async fn test_find_or_create_customer_by_email() {
109 let (stripe_billing, stripe_client) = make_stripe_billing();
110
111 // Create a customer with an email that doesn't yet correspond to a customer.
112 {
113 let email = "user@example.com";
114
115 let customer_id = stripe_billing
116 .find_or_create_customer_by_email(Some(email))
117 .await
118 .unwrap();
119
120 let customer = stripe_client
121 .customers
122 .lock()
123 .get(&customer_id)
124 .unwrap()
125 .clone();
126 assert_eq!(customer.email.as_deref(), Some(email));
127 }
128
129 // Create a customer with an email that corresponds to an existing customer.
130 {
131 let email = "user2@example.com";
132
133 let existing_customer_id = stripe_billing
134 .find_or_create_customer_by_email(Some(email))
135 .await
136 .unwrap();
137
138 let customer_id = stripe_billing
139 .find_or_create_customer_by_email(Some(email))
140 .await
141 .unwrap();
142 assert_eq!(customer_id, existing_customer_id);
143
144 let customer = stripe_client
145 .customers
146 .lock()
147 .get(&customer_id)
148 .unwrap()
149 .clone();
150 assert_eq!(customer.email.as_deref(), Some(email));
151 }
152}
153
154#[gpui::test]
155async fn test_subscribe_to_price() {
156 let (stripe_billing, stripe_client) = make_stripe_billing();
157
158 let price = StripePrice {
159 id: StripePriceId("price_test".into()),
160 unit_amount: Some(2000),
161 lookup_key: Some("test-price".to_string()),
162 recurring: None,
163 };
164 stripe_client
165 .prices
166 .lock()
167 .insert(price.id.clone(), price.clone());
168
169 let now = Utc::now();
170 let subscription = StripeSubscription {
171 id: StripeSubscriptionId("sub_test".into()),
172 customer: StripeCustomerId("cus_test".into()),
173 status: stripe::SubscriptionStatus::Active,
174 current_period_start: now.timestamp(),
175 current_period_end: (now + Duration::days(30)).timestamp(),
176 items: vec![],
177 cancel_at: None,
178 cancellation_details: None,
179 };
180 stripe_client
181 .subscriptions
182 .lock()
183 .insert(subscription.id.clone(), subscription.clone());
184
185 stripe_billing
186 .subscribe_to_price(&subscription.id, &price)
187 .await
188 .unwrap();
189
190 let update_subscription_calls = stripe_client
191 .update_subscription_calls
192 .lock()
193 .iter()
194 .map(|(id, params)| (id.clone(), params.clone()))
195 .collect::<Vec<_>>();
196 assert_eq!(update_subscription_calls.len(), 1);
197 assert_eq!(update_subscription_calls[0].0, subscription.id);
198 assert_eq!(
199 update_subscription_calls[0].1.items,
200 Some(vec![UpdateSubscriptionItems {
201 price: Some(price.id.clone())
202 }])
203 );
204
205 // Subscribing to a price that is already on the subscription is a no-op.
206 {
207 let now = Utc::now();
208 let subscription = StripeSubscription {
209 id: StripeSubscriptionId("sub_test".into()),
210 customer: StripeCustomerId("cus_test".into()),
211 status: stripe::SubscriptionStatus::Active,
212 current_period_start: now.timestamp(),
213 current_period_end: (now + Duration::days(30)).timestamp(),
214 items: vec![StripeSubscriptionItem {
215 id: StripeSubscriptionItemId("si_test".into()),
216 price: Some(price.clone()),
217 }],
218 cancel_at: None,
219 cancellation_details: None,
220 };
221 stripe_client
222 .subscriptions
223 .lock()
224 .insert(subscription.id.clone(), subscription.clone());
225
226 stripe_billing
227 .subscribe_to_price(&subscription.id, &price)
228 .await
229 .unwrap();
230
231 assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
232 }
233}
234
235#[gpui::test]
236async fn test_subscribe_to_zed_free() {
237 let (stripe_billing, stripe_client) = make_stripe_billing();
238
239 let zed_pro_price = StripePrice {
240 id: StripePriceId("price_1".into()),
241 unit_amount: Some(0),
242 lookup_key: Some("zed-pro".to_string()),
243 recurring: None,
244 };
245 stripe_client
246 .prices
247 .lock()
248 .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
249 let zed_free_price = StripePrice {
250 id: StripePriceId("price_2".into()),
251 unit_amount: Some(0),
252 lookup_key: Some("zed-free".to_string()),
253 recurring: None,
254 };
255 stripe_client
256 .prices
257 .lock()
258 .insert(zed_free_price.id.clone(), zed_free_price.clone());
259
260 stripe_billing.initialize().await.unwrap();
261
262 // Customer is subscribed to Zed Free when not already subscribed to a plan.
263 {
264 let customer_id = StripeCustomerId("cus_no_plan".into());
265
266 let subscription = stripe_billing
267 .subscribe_to_zed_free(customer_id)
268 .await
269 .unwrap();
270
271 assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
272 }
273
274 // Customer is not subscribed to Zed Free when they already have an active subscription.
275 {
276 let customer_id = StripeCustomerId("cus_active_subscription".into());
277
278 let now = Utc::now();
279 let existing_subscription = StripeSubscription {
280 id: StripeSubscriptionId("sub_existing_active".into()),
281 customer: customer_id.clone(),
282 status: stripe::SubscriptionStatus::Active,
283 current_period_start: now.timestamp(),
284 current_period_end: (now + Duration::days(30)).timestamp(),
285 items: vec![StripeSubscriptionItem {
286 id: StripeSubscriptionItemId("si_test".into()),
287 price: Some(zed_pro_price.clone()),
288 }],
289 cancel_at: None,
290 cancellation_details: None,
291 };
292 stripe_client.subscriptions.lock().insert(
293 existing_subscription.id.clone(),
294 existing_subscription.clone(),
295 );
296
297 let subscription = stripe_billing
298 .subscribe_to_zed_free(customer_id)
299 .await
300 .unwrap();
301
302 assert_eq!(subscription, existing_subscription);
303 }
304
305 // Customer is not subscribed to Zed Free when they already have a trial subscription.
306 {
307 let customer_id = StripeCustomerId("cus_trial_subscription".into());
308
309 let now = Utc::now();
310 let existing_subscription = StripeSubscription {
311 id: StripeSubscriptionId("sub_existing_trial".into()),
312 customer: customer_id.clone(),
313 status: stripe::SubscriptionStatus::Trialing,
314 current_period_start: now.timestamp(),
315 current_period_end: (now + Duration::days(14)).timestamp(),
316 items: vec![StripeSubscriptionItem {
317 id: StripeSubscriptionItemId("si_test".into()),
318 price: Some(zed_pro_price.clone()),
319 }],
320 cancel_at: None,
321 cancellation_details: None,
322 };
323 stripe_client.subscriptions.lock().insert(
324 existing_subscription.id.clone(),
325 existing_subscription.clone(),
326 );
327
328 let subscription = stripe_billing
329 .subscribe_to_zed_free(customer_id)
330 .await
331 .unwrap();
332
333 assert_eq!(subscription, existing_subscription);
334 }
335}
336
337#[gpui::test]
338async fn test_bill_model_request_usage() {
339 let (stripe_billing, stripe_client) = make_stripe_billing();
340
341 let customer_id = StripeCustomerId("cus_test".into());
342
343 stripe_billing
344 .bill_model_request_usage(&customer_id, "some_model/requests", 73)
345 .await
346 .unwrap();
347
348 let create_meter_event_calls = stripe_client
349 .create_meter_event_calls
350 .lock()
351 .iter()
352 .cloned()
353 .collect::<Vec<_>>();
354 assert_eq!(create_meter_event_calls.len(), 1);
355 assert!(
356 create_meter_event_calls[0]
357 .identifier
358 .starts_with("model_requests/")
359 );
360 assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
361 assert_eq!(
362 create_meter_event_calls[0].event_name.as_ref(),
363 "some_model/requests"
364 );
365 assert_eq!(create_meter_event_calls[0].value, 73);
366}
367
368#[gpui::test]
369async fn test_checkout_with_zed_pro() {
370 let (stripe_billing, stripe_client) = make_stripe_billing();
371
372 let customer_id = StripeCustomerId("cus_test".into());
373 let github_login = "zeduser1";
374 let success_url = "https://example.com/success";
375
376 // It returns an error when the Zed Pro price doesn't exist.
377 {
378 let result = stripe_billing
379 .checkout_with_zed_pro(&customer_id, github_login, success_url)
380 .await;
381
382 assert!(result.is_err());
383 assert_eq!(
384 result.err().unwrap().to_string(),
385 r#"no price ID found for "zed-pro""#
386 );
387 }
388
389 // Successful checkout.
390 {
391 let price = StripePrice {
392 id: StripePriceId("price_1".into()),
393 unit_amount: Some(2000),
394 lookup_key: Some("zed-pro".to_string()),
395 recurring: None,
396 };
397 stripe_client
398 .prices
399 .lock()
400 .insert(price.id.clone(), price.clone());
401
402 stripe_billing.initialize().await.unwrap();
403
404 let checkout_url = stripe_billing
405 .checkout_with_zed_pro(&customer_id, github_login, success_url)
406 .await
407 .unwrap();
408
409 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
410
411 let create_checkout_session_calls = stripe_client
412 .create_checkout_session_calls
413 .lock()
414 .drain(..)
415 .collect::<Vec<_>>();
416 assert_eq!(create_checkout_session_calls.len(), 1);
417 let call = create_checkout_session_calls.into_iter().next().unwrap();
418 assert_eq!(call.customer, Some(customer_id));
419 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
420 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
421 assert_eq!(
422 call.line_items,
423 Some(vec![StripeCreateCheckoutSessionLineItems {
424 price: Some(price.id.to_string()),
425 quantity: Some(1)
426 }])
427 );
428 assert_eq!(call.payment_method_collection, None);
429 assert_eq!(call.subscription_data, None);
430 assert_eq!(call.success_url.as_deref(), Some(success_url));
431 assert_eq!(
432 call.billing_address_collection,
433 Some(StripeBillingAddressCollection::Required)
434 );
435 assert_eq!(
436 call.customer_update,
437 Some(StripeCustomerUpdate {
438 address: Some(StripeCustomerUpdateAddress::Auto),
439 name: Some(StripeCustomerUpdateName::Auto),
440 shipping: None,
441 })
442 );
443 }
444}
445
446#[gpui::test]
447async fn test_checkout_with_zed_pro_trial() {
448 let (stripe_billing, stripe_client) = make_stripe_billing();
449
450 let customer_id = StripeCustomerId("cus_test".into());
451 let github_login = "zeduser1";
452 let success_url = "https://example.com/success";
453
454 // It returns an error when the Zed Pro price doesn't exist.
455 {
456 let result = stripe_billing
457 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
458 .await;
459
460 assert!(result.is_err());
461 assert_eq!(
462 result.err().unwrap().to_string(),
463 r#"no price ID found for "zed-pro""#
464 );
465 }
466
467 let price = StripePrice {
468 id: StripePriceId("price_1".into()),
469 unit_amount: Some(2000),
470 lookup_key: Some("zed-pro".to_string()),
471 recurring: None,
472 };
473 stripe_client
474 .prices
475 .lock()
476 .insert(price.id.clone(), price.clone());
477
478 stripe_billing.initialize().await.unwrap();
479
480 // Successful checkout.
481 {
482 let checkout_url = stripe_billing
483 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
484 .await
485 .unwrap();
486
487 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
488
489 let create_checkout_session_calls = stripe_client
490 .create_checkout_session_calls
491 .lock()
492 .drain(..)
493 .collect::<Vec<_>>();
494 assert_eq!(create_checkout_session_calls.len(), 1);
495 let call = create_checkout_session_calls.into_iter().next().unwrap();
496 assert_eq!(call.customer.as_ref(), Some(&customer_id));
497 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
498 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
499 assert_eq!(
500 call.line_items,
501 Some(vec![StripeCreateCheckoutSessionLineItems {
502 price: Some(price.id.to_string()),
503 quantity: Some(1)
504 }])
505 );
506 assert_eq!(
507 call.payment_method_collection,
508 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
509 );
510 assert_eq!(
511 call.subscription_data,
512 Some(StripeCreateCheckoutSessionSubscriptionData {
513 trial_period_days: Some(14),
514 trial_settings: Some(StripeSubscriptionTrialSettings {
515 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
516 missing_payment_method:
517 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
518 },
519 }),
520 metadata: None,
521 })
522 );
523 assert_eq!(call.success_url.as_deref(), Some(success_url));
524 assert_eq!(
525 call.billing_address_collection,
526 Some(StripeBillingAddressCollection::Required)
527 );
528 assert_eq!(
529 call.customer_update,
530 Some(StripeCustomerUpdate {
531 address: Some(StripeCustomerUpdateAddress::Auto),
532 name: Some(StripeCustomerUpdateName::Auto),
533 shipping: None,
534 })
535 );
536 }
537
538 // Successful checkout with extended trial.
539 {
540 let checkout_url = stripe_billing
541 .checkout_with_zed_pro_trial(
542 &customer_id,
543 github_login,
544 vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
545 success_url,
546 )
547 .await
548 .unwrap();
549
550 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
551
552 let create_checkout_session_calls = stripe_client
553 .create_checkout_session_calls
554 .lock()
555 .drain(..)
556 .collect::<Vec<_>>();
557 assert_eq!(create_checkout_session_calls.len(), 1);
558 let call = create_checkout_session_calls.into_iter().next().unwrap();
559 assert_eq!(call.customer, Some(customer_id));
560 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
561 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
562 assert_eq!(
563 call.line_items,
564 Some(vec![StripeCreateCheckoutSessionLineItems {
565 price: Some(price.id.to_string()),
566 quantity: Some(1)
567 }])
568 );
569 assert_eq!(
570 call.payment_method_collection,
571 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
572 );
573 assert_eq!(
574 call.subscription_data,
575 Some(StripeCreateCheckoutSessionSubscriptionData {
576 trial_period_days: Some(60),
577 trial_settings: Some(StripeSubscriptionTrialSettings {
578 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
579 missing_payment_method:
580 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
581 },
582 }),
583 metadata: Some(std::collections::HashMap::from_iter([(
584 "promo_feature_flag".into(),
585 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
586 )])),
587 })
588 );
589 assert_eq!(call.success_url.as_deref(), Some(success_url));
590 assert_eq!(
591 call.billing_address_collection,
592 Some(StripeBillingAddressCollection::Required)
593 );
594 assert_eq!(
595 call.customer_update,
596 Some(StripeCustomerUpdate {
597 address: Some(StripeCustomerUpdateAddress::Auto),
598 name: Some(StripeCustomerUpdateName::Auto),
599 shipping: None,
600 })
601 );
602 }
603}