1use std::sync::Arc;
2
3use chrono::{Duration, Utc};
4use pretty_assertions::assert_eq;
5
6use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
7use crate::stripe_billing::StripeBilling;
8use crate::stripe_client::{
9 FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
10 StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
11 StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
12 StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
13 StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
14 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
15};
16
17fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
18 let stripe_client = Arc::new(FakeStripeClient::new());
19 let stripe_billing = StripeBilling::test(stripe_client.clone());
20
21 (stripe_billing, stripe_client)
22}
23
24#[gpui::test]
25async fn test_initialize() {
26 let (stripe_billing, stripe_client) = make_stripe_billing();
27
28 // Add test meters
29 let meter1 = StripeMeter {
30 id: StripeMeterId("meter_1".into()),
31 event_name: "event_1".to_string(),
32 };
33 let meter2 = StripeMeter {
34 id: StripeMeterId("meter_2".into()),
35 event_name: "event_2".to_string(),
36 };
37 stripe_client
38 .meters
39 .lock()
40 .insert(meter1.id.clone(), meter1);
41 stripe_client
42 .meters
43 .lock()
44 .insert(meter2.id.clone(), meter2);
45
46 // Add test prices
47 let price1 = StripePrice {
48 id: StripePriceId("price_1".into()),
49 unit_amount: Some(1_000),
50 lookup_key: Some("zed-pro".to_string()),
51 recurring: None,
52 };
53 let price2 = StripePrice {
54 id: StripePriceId("price_2".into()),
55 unit_amount: Some(0),
56 lookup_key: Some("zed-free".to_string()),
57 recurring: None,
58 };
59 let price3 = StripePrice {
60 id: StripePriceId("price_3".into()),
61 unit_amount: Some(500),
62 lookup_key: None,
63 recurring: Some(StripePriceRecurring {
64 meter: Some("meter_1".to_string()),
65 }),
66 };
67 stripe_client
68 .prices
69 .lock()
70 .insert(price1.id.clone(), price1);
71 stripe_client
72 .prices
73 .lock()
74 .insert(price2.id.clone(), price2);
75 stripe_client
76 .prices
77 .lock()
78 .insert(price3.id.clone(), price3);
79
80 // Initialize the billing system
81 stripe_billing.initialize().await.unwrap();
82
83 // Verify that prices can be found by lookup key
84 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
85 assert_eq!(zed_pro_price_id.to_string(), "price_1");
86
87 let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
88 assert_eq!(zed_free_price_id.to_string(), "price_2");
89
90 // Verify that a price can be found by lookup key
91 let zed_pro_price = stripe_billing
92 .find_price_by_lookup_key("zed-pro")
93 .await
94 .unwrap();
95 assert_eq!(zed_pro_price.id.to_string(), "price_1");
96 assert_eq!(zed_pro_price.unit_amount, Some(1_000));
97
98 // Verify that finding a non-existent lookup key returns an error
99 let result = stripe_billing
100 .find_price_by_lookup_key("non-existent")
101 .await;
102 assert!(result.is_err());
103}
104
105#[gpui::test]
106async fn test_find_or_create_customer_by_email() {
107 let (stripe_billing, stripe_client) = make_stripe_billing();
108
109 // Create a customer with an email that doesn't yet correspond to a customer.
110 {
111 let email = "user@example.com";
112
113 let customer_id = stripe_billing
114 .find_or_create_customer_by_email(Some(email))
115 .await
116 .unwrap();
117
118 let customer = stripe_client
119 .customers
120 .lock()
121 .get(&customer_id)
122 .unwrap()
123 .clone();
124 assert_eq!(customer.email.as_deref(), Some(email));
125 }
126
127 // Create a customer with an email that corresponds to an existing customer.
128 {
129 let email = "user2@example.com";
130
131 let existing_customer_id = stripe_billing
132 .find_or_create_customer_by_email(Some(email))
133 .await
134 .unwrap();
135
136 let customer_id = stripe_billing
137 .find_or_create_customer_by_email(Some(email))
138 .await
139 .unwrap();
140 assert_eq!(customer_id, existing_customer_id);
141
142 let customer = stripe_client
143 .customers
144 .lock()
145 .get(&customer_id)
146 .unwrap()
147 .clone();
148 assert_eq!(customer.email.as_deref(), Some(email));
149 }
150}
151
152#[gpui::test]
153async fn test_subscribe_to_price() {
154 let (stripe_billing, stripe_client) = make_stripe_billing();
155
156 let price = StripePrice {
157 id: StripePriceId("price_test".into()),
158 unit_amount: Some(2000),
159 lookup_key: Some("test-price".to_string()),
160 recurring: None,
161 };
162 stripe_client
163 .prices
164 .lock()
165 .insert(price.id.clone(), price.clone());
166
167 let now = Utc::now();
168 let subscription = StripeSubscription {
169 id: StripeSubscriptionId("sub_test".into()),
170 customer: StripeCustomerId("cus_test".into()),
171 status: stripe::SubscriptionStatus::Active,
172 current_period_start: now.timestamp(),
173 current_period_end: (now + Duration::days(30)).timestamp(),
174 items: vec![],
175 cancel_at: None,
176 cancellation_details: None,
177 };
178 stripe_client
179 .subscriptions
180 .lock()
181 .insert(subscription.id.clone(), subscription.clone());
182
183 stripe_billing
184 .subscribe_to_price(&subscription.id, &price)
185 .await
186 .unwrap();
187
188 let update_subscription_calls = stripe_client
189 .update_subscription_calls
190 .lock()
191 .iter()
192 .map(|(id, params)| (id.clone(), params.clone()))
193 .collect::<Vec<_>>();
194 assert_eq!(update_subscription_calls.len(), 1);
195 assert_eq!(update_subscription_calls[0].0, subscription.id);
196 assert_eq!(
197 update_subscription_calls[0].1.items,
198 Some(vec![UpdateSubscriptionItems {
199 price: Some(price.id.clone())
200 }])
201 );
202
203 // Subscribing to a price that is already on the subscription is a no-op.
204 {
205 let now = Utc::now();
206 let subscription = StripeSubscription {
207 id: StripeSubscriptionId("sub_test".into()),
208 customer: StripeCustomerId("cus_test".into()),
209 status: stripe::SubscriptionStatus::Active,
210 current_period_start: now.timestamp(),
211 current_period_end: (now + Duration::days(30)).timestamp(),
212 items: vec![StripeSubscriptionItem {
213 id: StripeSubscriptionItemId("si_test".into()),
214 price: Some(price.clone()),
215 }],
216 cancel_at: None,
217 cancellation_details: None,
218 };
219 stripe_client
220 .subscriptions
221 .lock()
222 .insert(subscription.id.clone(), subscription.clone());
223
224 stripe_billing
225 .subscribe_to_price(&subscription.id, &price)
226 .await
227 .unwrap();
228
229 assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
230 }
231}
232
233#[gpui::test]
234async fn test_subscribe_to_zed_free() {
235 let (stripe_billing, stripe_client) = make_stripe_billing();
236
237 let zed_pro_price = StripePrice {
238 id: StripePriceId("price_1".into()),
239 unit_amount: Some(0),
240 lookup_key: Some("zed-pro".to_string()),
241 recurring: None,
242 };
243 stripe_client
244 .prices
245 .lock()
246 .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
247 let zed_free_price = StripePrice {
248 id: StripePriceId("price_2".into()),
249 unit_amount: Some(0),
250 lookup_key: Some("zed-free".to_string()),
251 recurring: None,
252 };
253 stripe_client
254 .prices
255 .lock()
256 .insert(zed_free_price.id.clone(), zed_free_price.clone());
257
258 stripe_billing.initialize().await.unwrap();
259
260 // Customer is subscribed to Zed Free when not already subscribed to a plan.
261 {
262 let customer_id = StripeCustomerId("cus_no_plan".into());
263
264 let subscription = stripe_billing
265 .subscribe_to_zed_free(customer_id)
266 .await
267 .unwrap();
268
269 assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
270 }
271
272 // Customer is not subscribed to Zed Free when they already have an active subscription.
273 {
274 let customer_id = StripeCustomerId("cus_active_subscription".into());
275
276 let now = Utc::now();
277 let existing_subscription = StripeSubscription {
278 id: StripeSubscriptionId("sub_existing_active".into()),
279 customer: customer_id.clone(),
280 status: stripe::SubscriptionStatus::Active,
281 current_period_start: now.timestamp(),
282 current_period_end: (now + Duration::days(30)).timestamp(),
283 items: vec![StripeSubscriptionItem {
284 id: StripeSubscriptionItemId("si_test".into()),
285 price: Some(zed_pro_price.clone()),
286 }],
287 cancel_at: None,
288 cancellation_details: None,
289 };
290 stripe_client.subscriptions.lock().insert(
291 existing_subscription.id.clone(),
292 existing_subscription.clone(),
293 );
294
295 let subscription = stripe_billing
296 .subscribe_to_zed_free(customer_id)
297 .await
298 .unwrap();
299
300 assert_eq!(subscription, existing_subscription);
301 }
302
303 // Customer is not subscribed to Zed Free when they already have a trial subscription.
304 {
305 let customer_id = StripeCustomerId("cus_trial_subscription".into());
306
307 let now = Utc::now();
308 let existing_subscription = StripeSubscription {
309 id: StripeSubscriptionId("sub_existing_trial".into()),
310 customer: customer_id.clone(),
311 status: stripe::SubscriptionStatus::Trialing,
312 current_period_start: now.timestamp(),
313 current_period_end: (now + Duration::days(14)).timestamp(),
314 items: vec![StripeSubscriptionItem {
315 id: StripeSubscriptionItemId("si_test".into()),
316 price: Some(zed_pro_price.clone()),
317 }],
318 cancel_at: None,
319 cancellation_details: None,
320 };
321 stripe_client.subscriptions.lock().insert(
322 existing_subscription.id.clone(),
323 existing_subscription.clone(),
324 );
325
326 let subscription = stripe_billing
327 .subscribe_to_zed_free(customer_id)
328 .await
329 .unwrap();
330
331 assert_eq!(subscription, existing_subscription);
332 }
333}
334
335#[gpui::test]
336async fn test_bill_model_request_usage() {
337 let (stripe_billing, stripe_client) = make_stripe_billing();
338
339 let customer_id = StripeCustomerId("cus_test".into());
340
341 stripe_billing
342 .bill_model_request_usage(&customer_id, "some_model/requests", 73)
343 .await
344 .unwrap();
345
346 let create_meter_event_calls = stripe_client
347 .create_meter_event_calls
348 .lock()
349 .iter()
350 .cloned()
351 .collect::<Vec<_>>();
352 assert_eq!(create_meter_event_calls.len(), 1);
353 assert!(
354 create_meter_event_calls[0]
355 .identifier
356 .starts_with("model_requests/")
357 );
358 assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
359 assert_eq!(
360 create_meter_event_calls[0].event_name.as_ref(),
361 "some_model/requests"
362 );
363 assert_eq!(create_meter_event_calls[0].value, 73);
364}
365
366#[gpui::test]
367async fn test_checkout_with_zed_pro() {
368 let (stripe_billing, stripe_client) = make_stripe_billing();
369
370 let customer_id = StripeCustomerId("cus_test".into());
371 let github_login = "zeduser1";
372 let success_url = "https://example.com/success";
373
374 // It returns an error when the Zed Pro price doesn't exist.
375 {
376 let result = stripe_billing
377 .checkout_with_zed_pro(&customer_id, github_login, success_url)
378 .await;
379
380 assert!(result.is_err());
381 assert_eq!(
382 result.err().unwrap().to_string(),
383 r#"no price ID found for "zed-pro""#
384 );
385 }
386
387 // Successful checkout.
388 {
389 let price = StripePrice {
390 id: StripePriceId("price_1".into()),
391 unit_amount: Some(2000),
392 lookup_key: Some("zed-pro".to_string()),
393 recurring: None,
394 };
395 stripe_client
396 .prices
397 .lock()
398 .insert(price.id.clone(), price.clone());
399
400 stripe_billing.initialize().await.unwrap();
401
402 let checkout_url = stripe_billing
403 .checkout_with_zed_pro(&customer_id, github_login, success_url)
404 .await
405 .unwrap();
406
407 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
408
409 let create_checkout_session_calls = stripe_client
410 .create_checkout_session_calls
411 .lock()
412 .drain(..)
413 .collect::<Vec<_>>();
414 assert_eq!(create_checkout_session_calls.len(), 1);
415 let call = create_checkout_session_calls.into_iter().next().unwrap();
416 assert_eq!(call.customer, Some(customer_id));
417 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
418 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
419 assert_eq!(
420 call.line_items,
421 Some(vec![StripeCreateCheckoutSessionLineItems {
422 price: Some(price.id.to_string()),
423 quantity: Some(1)
424 }])
425 );
426 assert_eq!(call.payment_method_collection, None);
427 assert_eq!(call.subscription_data, None);
428 assert_eq!(call.success_url.as_deref(), Some(success_url));
429 }
430}
431
432#[gpui::test]
433async fn test_checkout_with_zed_pro_trial() {
434 let (stripe_billing, stripe_client) = make_stripe_billing();
435
436 let customer_id = StripeCustomerId("cus_test".into());
437 let github_login = "zeduser1";
438 let success_url = "https://example.com/success";
439
440 // It returns an error when the Zed Pro price doesn't exist.
441 {
442 let result = stripe_billing
443 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
444 .await;
445
446 assert!(result.is_err());
447 assert_eq!(
448 result.err().unwrap().to_string(),
449 r#"no price ID found for "zed-pro""#
450 );
451 }
452
453 let price = StripePrice {
454 id: StripePriceId("price_1".into()),
455 unit_amount: Some(2000),
456 lookup_key: Some("zed-pro".to_string()),
457 recurring: None,
458 };
459 stripe_client
460 .prices
461 .lock()
462 .insert(price.id.clone(), price.clone());
463
464 stripe_billing.initialize().await.unwrap();
465
466 // Successful checkout.
467 {
468 let checkout_url = stripe_billing
469 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
470 .await
471 .unwrap();
472
473 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
474
475 let create_checkout_session_calls = stripe_client
476 .create_checkout_session_calls
477 .lock()
478 .drain(..)
479 .collect::<Vec<_>>();
480 assert_eq!(create_checkout_session_calls.len(), 1);
481 let call = create_checkout_session_calls.into_iter().next().unwrap();
482 assert_eq!(call.customer.as_ref(), Some(&customer_id));
483 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
484 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
485 assert_eq!(
486 call.line_items,
487 Some(vec![StripeCreateCheckoutSessionLineItems {
488 price: Some(price.id.to_string()),
489 quantity: Some(1)
490 }])
491 );
492 assert_eq!(
493 call.payment_method_collection,
494 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
495 );
496 assert_eq!(
497 call.subscription_data,
498 Some(StripeCreateCheckoutSessionSubscriptionData {
499 trial_period_days: Some(14),
500 trial_settings: Some(StripeSubscriptionTrialSettings {
501 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
502 missing_payment_method:
503 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
504 },
505 }),
506 metadata: None,
507 })
508 );
509 assert_eq!(call.success_url.as_deref(), Some(success_url));
510 }
511
512 // Successful checkout with extended trial.
513 {
514 let checkout_url = stripe_billing
515 .checkout_with_zed_pro_trial(
516 &customer_id,
517 github_login,
518 vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
519 success_url,
520 )
521 .await
522 .unwrap();
523
524 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
525
526 let create_checkout_session_calls = stripe_client
527 .create_checkout_session_calls
528 .lock()
529 .drain(..)
530 .collect::<Vec<_>>();
531 assert_eq!(create_checkout_session_calls.len(), 1);
532 let call = create_checkout_session_calls.into_iter().next().unwrap();
533 assert_eq!(call.customer, Some(customer_id));
534 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
535 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
536 assert_eq!(
537 call.line_items,
538 Some(vec![StripeCreateCheckoutSessionLineItems {
539 price: Some(price.id.to_string()),
540 quantity: Some(1)
541 }])
542 );
543 assert_eq!(
544 call.payment_method_collection,
545 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
546 );
547 assert_eq!(
548 call.subscription_data,
549 Some(StripeCreateCheckoutSessionSubscriptionData {
550 trial_period_days: Some(60),
551 trial_settings: Some(StripeSubscriptionTrialSettings {
552 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
553 missing_payment_method:
554 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
555 },
556 }),
557 metadata: Some(std::collections::HashMap::from_iter([(
558 "promo_feature_flag".into(),
559 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
560 )])),
561 })
562 );
563 assert_eq!(call.success_url.as_deref(), Some(success_url));
564 }
565}