1use std::sync::Arc;
2
3use chrono::{Duration, Utc};
4use pretty_assertions::assert_eq;
5
6use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
7use crate::stripe_billing::StripeBilling;
8use crate::stripe_client::{
9 FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
10 StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
11 StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
12 StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
13 StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
14 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
15};
16
17fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
18 let stripe_client = Arc::new(FakeStripeClient::new());
19 let stripe_billing = StripeBilling::test(stripe_client.clone());
20
21 (stripe_billing, stripe_client)
22}
23
24#[gpui::test]
25async fn test_initialize() {
26 let (stripe_billing, stripe_client) = make_stripe_billing();
27
28 // Add test meters
29 let meter1 = StripeMeter {
30 id: StripeMeterId("meter_1".into()),
31 event_name: "event_1".to_string(),
32 };
33 let meter2 = StripeMeter {
34 id: StripeMeterId("meter_2".into()),
35 event_name: "event_2".to_string(),
36 };
37 stripe_client
38 .meters
39 .lock()
40 .insert(meter1.id.clone(), meter1);
41 stripe_client
42 .meters
43 .lock()
44 .insert(meter2.id.clone(), meter2);
45
46 // Add test prices
47 let price1 = StripePrice {
48 id: StripePriceId("price_1".into()),
49 unit_amount: Some(1_000),
50 lookup_key: Some("zed-pro".to_string()),
51 recurring: None,
52 };
53 let price2 = StripePrice {
54 id: StripePriceId("price_2".into()),
55 unit_amount: Some(0),
56 lookup_key: Some("zed-free".to_string()),
57 recurring: None,
58 };
59 let price3 = StripePrice {
60 id: StripePriceId("price_3".into()),
61 unit_amount: Some(500),
62 lookup_key: None,
63 recurring: Some(StripePriceRecurring {
64 meter: Some("meter_1".to_string()),
65 }),
66 };
67 stripe_client
68 .prices
69 .lock()
70 .insert(price1.id.clone(), price1);
71 stripe_client
72 .prices
73 .lock()
74 .insert(price2.id.clone(), price2);
75 stripe_client
76 .prices
77 .lock()
78 .insert(price3.id.clone(), price3);
79
80 // Initialize the billing system
81 stripe_billing.initialize().await.unwrap();
82
83 // Verify that prices can be found by lookup key
84 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
85 assert_eq!(zed_pro_price_id.to_string(), "price_1");
86
87 let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
88 assert_eq!(zed_free_price_id.to_string(), "price_2");
89
90 // Verify that a price can be found by lookup key
91 let zed_pro_price = stripe_billing
92 .find_price_by_lookup_key("zed-pro")
93 .await
94 .unwrap();
95 assert_eq!(zed_pro_price.id.to_string(), "price_1");
96 assert_eq!(zed_pro_price.unit_amount, Some(1_000));
97
98 // Verify that finding a non-existent lookup key returns an error
99 let result = stripe_billing
100 .find_price_by_lookup_key("non-existent")
101 .await;
102 assert!(result.is_err());
103}
104
105#[gpui::test]
106async fn test_find_or_create_customer_by_email() {
107 let (stripe_billing, stripe_client) = make_stripe_billing();
108
109 // Create a customer with an email that doesn't yet correspond to a customer.
110 {
111 let email = "user@example.com";
112
113 let customer_id = stripe_billing
114 .find_or_create_customer_by_email(Some(email))
115 .await
116 .unwrap();
117
118 let customer = stripe_client
119 .customers
120 .lock()
121 .get(&customer_id)
122 .unwrap()
123 .clone();
124 assert_eq!(customer.email.as_deref(), Some(email));
125 }
126
127 // Create a customer with an email that corresponds to an existing customer.
128 {
129 let email = "user2@example.com";
130
131 let existing_customer_id = stripe_billing
132 .find_or_create_customer_by_email(Some(email))
133 .await
134 .unwrap();
135
136 let customer_id = stripe_billing
137 .find_or_create_customer_by_email(Some(email))
138 .await
139 .unwrap();
140 assert_eq!(customer_id, existing_customer_id);
141
142 let customer = stripe_client
143 .customers
144 .lock()
145 .get(&customer_id)
146 .unwrap()
147 .clone();
148 assert_eq!(customer.email.as_deref(), Some(email));
149 }
150}
151
152#[gpui::test]
153async fn test_subscribe_to_price() {
154 let (stripe_billing, stripe_client) = make_stripe_billing();
155
156 let price = StripePrice {
157 id: StripePriceId("price_test".into()),
158 unit_amount: Some(2000),
159 lookup_key: Some("test-price".to_string()),
160 recurring: None,
161 };
162 stripe_client
163 .prices
164 .lock()
165 .insert(price.id.clone(), price.clone());
166
167 let now = Utc::now();
168 let subscription = StripeSubscription {
169 id: StripeSubscriptionId("sub_test".into()),
170 customer: StripeCustomerId("cus_test".into()),
171 status: stripe::SubscriptionStatus::Active,
172 current_period_start: now.timestamp(),
173 current_period_end: (now + Duration::days(30)).timestamp(),
174 items: vec![],
175 };
176 stripe_client
177 .subscriptions
178 .lock()
179 .insert(subscription.id.clone(), subscription.clone());
180
181 stripe_billing
182 .subscribe_to_price(&subscription.id, &price)
183 .await
184 .unwrap();
185
186 let update_subscription_calls = stripe_client
187 .update_subscription_calls
188 .lock()
189 .iter()
190 .map(|(id, params)| (id.clone(), params.clone()))
191 .collect::<Vec<_>>();
192 assert_eq!(update_subscription_calls.len(), 1);
193 assert_eq!(update_subscription_calls[0].0, subscription.id);
194 assert_eq!(
195 update_subscription_calls[0].1.items,
196 Some(vec![UpdateSubscriptionItems {
197 price: Some(price.id.clone())
198 }])
199 );
200
201 // Subscribing to a price that is already on the subscription is a no-op.
202 {
203 let now = Utc::now();
204 let subscription = StripeSubscription {
205 id: StripeSubscriptionId("sub_test".into()),
206 customer: StripeCustomerId("cus_test".into()),
207 status: stripe::SubscriptionStatus::Active,
208 current_period_start: now.timestamp(),
209 current_period_end: (now + Duration::days(30)).timestamp(),
210 items: vec![StripeSubscriptionItem {
211 id: StripeSubscriptionItemId("si_test".into()),
212 price: Some(price.clone()),
213 }],
214 };
215 stripe_client
216 .subscriptions
217 .lock()
218 .insert(subscription.id.clone(), subscription.clone());
219
220 stripe_billing
221 .subscribe_to_price(&subscription.id, &price)
222 .await
223 .unwrap();
224
225 assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
226 }
227}
228
229#[gpui::test]
230async fn test_subscribe_to_zed_free() {
231 let (stripe_billing, stripe_client) = make_stripe_billing();
232
233 let zed_pro_price = StripePrice {
234 id: StripePriceId("price_1".into()),
235 unit_amount: Some(0),
236 lookup_key: Some("zed-pro".to_string()),
237 recurring: None,
238 };
239 stripe_client
240 .prices
241 .lock()
242 .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
243 let zed_free_price = StripePrice {
244 id: StripePriceId("price_2".into()),
245 unit_amount: Some(0),
246 lookup_key: Some("zed-free".to_string()),
247 recurring: None,
248 };
249 stripe_client
250 .prices
251 .lock()
252 .insert(zed_free_price.id.clone(), zed_free_price.clone());
253
254 stripe_billing.initialize().await.unwrap();
255
256 // Customer is subscribed to Zed Free when not already subscribed to a plan.
257 {
258 let customer_id = StripeCustomerId("cus_no_plan".into());
259
260 let subscription = stripe_billing
261 .subscribe_to_zed_free(customer_id)
262 .await
263 .unwrap();
264
265 assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
266 }
267
268 // Customer is not subscribed to Zed Free when they already have an active subscription.
269 {
270 let customer_id = StripeCustomerId("cus_active_subscription".into());
271
272 let now = Utc::now();
273 let existing_subscription = StripeSubscription {
274 id: StripeSubscriptionId("sub_existing_active".into()),
275 customer: customer_id.clone(),
276 status: stripe::SubscriptionStatus::Active,
277 current_period_start: now.timestamp(),
278 current_period_end: (now + Duration::days(30)).timestamp(),
279 items: vec![StripeSubscriptionItem {
280 id: StripeSubscriptionItemId("si_test".into()),
281 price: Some(zed_pro_price.clone()),
282 }],
283 };
284 stripe_client.subscriptions.lock().insert(
285 existing_subscription.id.clone(),
286 existing_subscription.clone(),
287 );
288
289 let subscription = stripe_billing
290 .subscribe_to_zed_free(customer_id)
291 .await
292 .unwrap();
293
294 assert_eq!(subscription, existing_subscription);
295 }
296
297 // Customer is not subscribed to Zed Free when they already have a trial subscription.
298 {
299 let customer_id = StripeCustomerId("cus_trial_subscription".into());
300
301 let now = Utc::now();
302 let existing_subscription = StripeSubscription {
303 id: StripeSubscriptionId("sub_existing_trial".into()),
304 customer: customer_id.clone(),
305 status: stripe::SubscriptionStatus::Trialing,
306 current_period_start: now.timestamp(),
307 current_period_end: (now + Duration::days(14)).timestamp(),
308 items: vec![StripeSubscriptionItem {
309 id: StripeSubscriptionItemId("si_test".into()),
310 price: Some(zed_pro_price.clone()),
311 }],
312 };
313 stripe_client.subscriptions.lock().insert(
314 existing_subscription.id.clone(),
315 existing_subscription.clone(),
316 );
317
318 let subscription = stripe_billing
319 .subscribe_to_zed_free(customer_id)
320 .await
321 .unwrap();
322
323 assert_eq!(subscription, existing_subscription);
324 }
325}
326
327#[gpui::test]
328async fn test_bill_model_request_usage() {
329 let (stripe_billing, stripe_client) = make_stripe_billing();
330
331 let customer_id = StripeCustomerId("cus_test".into());
332
333 stripe_billing
334 .bill_model_request_usage(&customer_id, "some_model/requests", 73)
335 .await
336 .unwrap();
337
338 let create_meter_event_calls = stripe_client
339 .create_meter_event_calls
340 .lock()
341 .iter()
342 .cloned()
343 .collect::<Vec<_>>();
344 assert_eq!(create_meter_event_calls.len(), 1);
345 assert!(
346 create_meter_event_calls[0]
347 .identifier
348 .starts_with("model_requests/")
349 );
350 assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
351 assert_eq!(
352 create_meter_event_calls[0].event_name.as_ref(),
353 "some_model/requests"
354 );
355 assert_eq!(create_meter_event_calls[0].value, 73);
356}
357
358#[gpui::test]
359async fn test_checkout_with_zed_pro() {
360 let (stripe_billing, stripe_client) = make_stripe_billing();
361
362 let customer_id = StripeCustomerId("cus_test".into());
363 let github_login = "zeduser1";
364 let success_url = "https://example.com/success";
365
366 // It returns an error when the Zed Pro price doesn't exist.
367 {
368 let result = stripe_billing
369 .checkout_with_zed_pro(&customer_id, github_login, success_url)
370 .await;
371
372 assert!(result.is_err());
373 assert_eq!(
374 result.err().unwrap().to_string(),
375 r#"no price ID found for "zed-pro""#
376 );
377 }
378
379 // Successful checkout.
380 {
381 let price = StripePrice {
382 id: StripePriceId("price_1".into()),
383 unit_amount: Some(2000),
384 lookup_key: Some("zed-pro".to_string()),
385 recurring: None,
386 };
387 stripe_client
388 .prices
389 .lock()
390 .insert(price.id.clone(), price.clone());
391
392 stripe_billing.initialize().await.unwrap();
393
394 let checkout_url = stripe_billing
395 .checkout_with_zed_pro(&customer_id, github_login, success_url)
396 .await
397 .unwrap();
398
399 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
400
401 let create_checkout_session_calls = stripe_client
402 .create_checkout_session_calls
403 .lock()
404 .drain(..)
405 .collect::<Vec<_>>();
406 assert_eq!(create_checkout_session_calls.len(), 1);
407 let call = create_checkout_session_calls.into_iter().next().unwrap();
408 assert_eq!(call.customer, Some(customer_id));
409 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
410 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
411 assert_eq!(
412 call.line_items,
413 Some(vec![StripeCreateCheckoutSessionLineItems {
414 price: Some(price.id.to_string()),
415 quantity: Some(1)
416 }])
417 );
418 assert_eq!(call.payment_method_collection, None);
419 assert_eq!(call.subscription_data, None);
420 assert_eq!(call.success_url.as_deref(), Some(success_url));
421 }
422}
423
424#[gpui::test]
425async fn test_checkout_with_zed_pro_trial() {
426 let (stripe_billing, stripe_client) = make_stripe_billing();
427
428 let customer_id = StripeCustomerId("cus_test".into());
429 let github_login = "zeduser1";
430 let success_url = "https://example.com/success";
431
432 // It returns an error when the Zed Pro price doesn't exist.
433 {
434 let result = stripe_billing
435 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
436 .await;
437
438 assert!(result.is_err());
439 assert_eq!(
440 result.err().unwrap().to_string(),
441 r#"no price ID found for "zed-pro""#
442 );
443 }
444
445 let price = StripePrice {
446 id: StripePriceId("price_1".into()),
447 unit_amount: Some(2000),
448 lookup_key: Some("zed-pro".to_string()),
449 recurring: None,
450 };
451 stripe_client
452 .prices
453 .lock()
454 .insert(price.id.clone(), price.clone());
455
456 stripe_billing.initialize().await.unwrap();
457
458 // Successful checkout.
459 {
460 let checkout_url = stripe_billing
461 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
462 .await
463 .unwrap();
464
465 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
466
467 let create_checkout_session_calls = stripe_client
468 .create_checkout_session_calls
469 .lock()
470 .drain(..)
471 .collect::<Vec<_>>();
472 assert_eq!(create_checkout_session_calls.len(), 1);
473 let call = create_checkout_session_calls.into_iter().next().unwrap();
474 assert_eq!(call.customer.as_ref(), Some(&customer_id));
475 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
476 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
477 assert_eq!(
478 call.line_items,
479 Some(vec![StripeCreateCheckoutSessionLineItems {
480 price: Some(price.id.to_string()),
481 quantity: Some(1)
482 }])
483 );
484 assert_eq!(
485 call.payment_method_collection,
486 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
487 );
488 assert_eq!(
489 call.subscription_data,
490 Some(StripeCreateCheckoutSessionSubscriptionData {
491 trial_period_days: Some(14),
492 trial_settings: Some(StripeSubscriptionTrialSettings {
493 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
494 missing_payment_method:
495 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
496 },
497 }),
498 metadata: None,
499 })
500 );
501 assert_eq!(call.success_url.as_deref(), Some(success_url));
502 }
503
504 // Successful checkout with extended trial.
505 {
506 let checkout_url = stripe_billing
507 .checkout_with_zed_pro_trial(
508 &customer_id,
509 github_login,
510 vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
511 success_url,
512 )
513 .await
514 .unwrap();
515
516 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
517
518 let create_checkout_session_calls = stripe_client
519 .create_checkout_session_calls
520 .lock()
521 .drain(..)
522 .collect::<Vec<_>>();
523 assert_eq!(create_checkout_session_calls.len(), 1);
524 let call = create_checkout_session_calls.into_iter().next().unwrap();
525 assert_eq!(call.customer, Some(customer_id));
526 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
527 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
528 assert_eq!(
529 call.line_items,
530 Some(vec![StripeCreateCheckoutSessionLineItems {
531 price: Some(price.id.to_string()),
532 quantity: Some(1)
533 }])
534 );
535 assert_eq!(
536 call.payment_method_collection,
537 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
538 );
539 assert_eq!(
540 call.subscription_data,
541 Some(StripeCreateCheckoutSessionSubscriptionData {
542 trial_period_days: Some(60),
543 trial_settings: Some(StripeSubscriptionTrialSettings {
544 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
545 missing_payment_method:
546 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
547 },
548 }),
549 metadata: Some(std::collections::HashMap::from_iter([(
550 "promo_feature_flag".into(),
551 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
552 )])),
553 })
554 );
555 assert_eq!(call.success_url.as_deref(), Some(success_url));
556 }
557}