1use std::sync::Arc;
2
3use pretty_assertions::assert_eq;
4
5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
6use crate::stripe_billing::StripeBilling;
7use crate::stripe_client::{
8 FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
9 StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
10 StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
11 StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
12 StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
13 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
14};
15
16fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
17 let stripe_client = Arc::new(FakeStripeClient::new());
18 let stripe_billing = StripeBilling::test(stripe_client.clone());
19
20 (stripe_billing, stripe_client)
21}
22
23#[gpui::test]
24async fn test_initialize() {
25 let (stripe_billing, stripe_client) = make_stripe_billing();
26
27 // Add test meters
28 let meter1 = StripeMeter {
29 id: StripeMeterId("meter_1".into()),
30 event_name: "event_1".to_string(),
31 };
32 let meter2 = StripeMeter {
33 id: StripeMeterId("meter_2".into()),
34 event_name: "event_2".to_string(),
35 };
36 stripe_client
37 .meters
38 .lock()
39 .insert(meter1.id.clone(), meter1);
40 stripe_client
41 .meters
42 .lock()
43 .insert(meter2.id.clone(), meter2);
44
45 // Add test prices
46 let price1 = StripePrice {
47 id: StripePriceId("price_1".into()),
48 unit_amount: Some(1_000),
49 lookup_key: Some("zed-pro".to_string()),
50 recurring: None,
51 };
52 let price2 = StripePrice {
53 id: StripePriceId("price_2".into()),
54 unit_amount: Some(0),
55 lookup_key: Some("zed-free".to_string()),
56 recurring: None,
57 };
58 let price3 = StripePrice {
59 id: StripePriceId("price_3".into()),
60 unit_amount: Some(500),
61 lookup_key: None,
62 recurring: Some(StripePriceRecurring {
63 meter: Some("meter_1".to_string()),
64 }),
65 };
66 stripe_client
67 .prices
68 .lock()
69 .insert(price1.id.clone(), price1);
70 stripe_client
71 .prices
72 .lock()
73 .insert(price2.id.clone(), price2);
74 stripe_client
75 .prices
76 .lock()
77 .insert(price3.id.clone(), price3);
78
79 // Initialize the billing system
80 stripe_billing.initialize().await.unwrap();
81
82 // Verify that prices can be found by lookup key
83 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
84 assert_eq!(zed_pro_price_id.to_string(), "price_1");
85
86 let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
87 assert_eq!(zed_free_price_id.to_string(), "price_2");
88
89 // Verify that a price can be found by lookup key
90 let zed_pro_price = stripe_billing
91 .find_price_by_lookup_key("zed-pro")
92 .await
93 .unwrap();
94 assert_eq!(zed_pro_price.id.to_string(), "price_1");
95 assert_eq!(zed_pro_price.unit_amount, Some(1_000));
96
97 // Verify that finding a non-existent lookup key returns an error
98 let result = stripe_billing
99 .find_price_by_lookup_key("non-existent")
100 .await;
101 assert!(result.is_err());
102}
103
104#[gpui::test]
105async fn test_find_or_create_customer_by_email() {
106 let (stripe_billing, stripe_client) = make_stripe_billing();
107
108 // Create a customer with an email that doesn't yet correspond to a customer.
109 {
110 let email = "user@example.com";
111
112 let customer_id = stripe_billing
113 .find_or_create_customer_by_email(Some(email))
114 .await
115 .unwrap();
116
117 let customer = stripe_client
118 .customers
119 .lock()
120 .get(&customer_id)
121 .unwrap()
122 .clone();
123 assert_eq!(customer.email.as_deref(), Some(email));
124 }
125
126 // Create a customer with an email that corresponds to an existing customer.
127 {
128 let email = "user2@example.com";
129
130 let existing_customer_id = stripe_billing
131 .find_or_create_customer_by_email(Some(email))
132 .await
133 .unwrap();
134
135 let customer_id = stripe_billing
136 .find_or_create_customer_by_email(Some(email))
137 .await
138 .unwrap();
139 assert_eq!(customer_id, existing_customer_id);
140
141 let customer = stripe_client
142 .customers
143 .lock()
144 .get(&customer_id)
145 .unwrap()
146 .clone();
147 assert_eq!(customer.email.as_deref(), Some(email));
148 }
149}
150
151#[gpui::test]
152async fn test_subscribe_to_price() {
153 let (stripe_billing, stripe_client) = make_stripe_billing();
154
155 let price = StripePrice {
156 id: StripePriceId("price_test".into()),
157 unit_amount: Some(2000),
158 lookup_key: Some("test-price".to_string()),
159 recurring: None,
160 };
161 stripe_client
162 .prices
163 .lock()
164 .insert(price.id.clone(), price.clone());
165
166 let subscription = StripeSubscription {
167 id: StripeSubscriptionId("sub_test".into()),
168 items: vec![],
169 };
170 stripe_client
171 .subscriptions
172 .lock()
173 .insert(subscription.id.clone(), subscription.clone());
174
175 stripe_billing
176 .subscribe_to_price(&subscription.id, &price)
177 .await
178 .unwrap();
179
180 let update_subscription_calls = stripe_client
181 .update_subscription_calls
182 .lock()
183 .iter()
184 .map(|(id, params)| (id.clone(), params.clone()))
185 .collect::<Vec<_>>();
186 assert_eq!(update_subscription_calls.len(), 1);
187 assert_eq!(update_subscription_calls[0].0, subscription.id);
188 assert_eq!(
189 update_subscription_calls[0].1.items,
190 Some(vec![UpdateSubscriptionItems {
191 price: Some(price.id.clone())
192 }])
193 );
194
195 // Subscribing to a price that is already on the subscription is a no-op.
196 {
197 let subscription = StripeSubscription {
198 id: StripeSubscriptionId("sub_test".into()),
199 items: vec![StripeSubscriptionItem {
200 id: StripeSubscriptionItemId("si_test".into()),
201 price: Some(price.clone()),
202 }],
203 };
204 stripe_client
205 .subscriptions
206 .lock()
207 .insert(subscription.id.clone(), subscription.clone());
208
209 stripe_billing
210 .subscribe_to_price(&subscription.id, &price)
211 .await
212 .unwrap();
213
214 assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
215 }
216}
217
218#[gpui::test]
219async fn test_bill_model_request_usage() {
220 let (stripe_billing, stripe_client) = make_stripe_billing();
221
222 let customer_id = StripeCustomerId("cus_test".into());
223
224 stripe_billing
225 .bill_model_request_usage(&customer_id, "some_model/requests", 73)
226 .await
227 .unwrap();
228
229 let create_meter_event_calls = stripe_client
230 .create_meter_event_calls
231 .lock()
232 .iter()
233 .cloned()
234 .collect::<Vec<_>>();
235 assert_eq!(create_meter_event_calls.len(), 1);
236 assert!(
237 create_meter_event_calls[0]
238 .identifier
239 .starts_with("model_requests/")
240 );
241 assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
242 assert_eq!(
243 create_meter_event_calls[0].event_name.as_ref(),
244 "some_model/requests"
245 );
246 assert_eq!(create_meter_event_calls[0].value, 73);
247}
248
249#[gpui::test]
250async fn test_checkout_with_zed_pro() {
251 let (stripe_billing, stripe_client) = make_stripe_billing();
252
253 let customer_id = StripeCustomerId("cus_test".into());
254 let github_login = "zeduser1";
255 let success_url = "https://example.com/success";
256
257 // It returns an error when the Zed Pro price doesn't exist.
258 {
259 let result = stripe_billing
260 .checkout_with_zed_pro(&customer_id, github_login, success_url)
261 .await;
262
263 assert!(result.is_err());
264 assert_eq!(
265 result.err().unwrap().to_string(),
266 r#"no price ID found for "zed-pro""#
267 );
268 }
269
270 // Successful checkout.
271 {
272 let price = StripePrice {
273 id: StripePriceId("price_1".into()),
274 unit_amount: Some(2000),
275 lookup_key: Some("zed-pro".to_string()),
276 recurring: None,
277 };
278 stripe_client
279 .prices
280 .lock()
281 .insert(price.id.clone(), price.clone());
282
283 stripe_billing.initialize().await.unwrap();
284
285 let checkout_url = stripe_billing
286 .checkout_with_zed_pro(&customer_id, github_login, success_url)
287 .await
288 .unwrap();
289
290 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
291
292 let create_checkout_session_calls = stripe_client
293 .create_checkout_session_calls
294 .lock()
295 .drain(..)
296 .collect::<Vec<_>>();
297 assert_eq!(create_checkout_session_calls.len(), 1);
298 let call = create_checkout_session_calls.into_iter().next().unwrap();
299 assert_eq!(call.customer, Some(customer_id));
300 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
301 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
302 assert_eq!(
303 call.line_items,
304 Some(vec![StripeCreateCheckoutSessionLineItems {
305 price: Some(price.id.to_string()),
306 quantity: Some(1)
307 }])
308 );
309 assert_eq!(call.payment_method_collection, None);
310 assert_eq!(call.subscription_data, None);
311 assert_eq!(call.success_url.as_deref(), Some(success_url));
312 }
313}
314
315#[gpui::test]
316async fn test_checkout_with_zed_pro_trial() {
317 let (stripe_billing, stripe_client) = make_stripe_billing();
318
319 let customer_id = StripeCustomerId("cus_test".into());
320 let github_login = "zeduser1";
321 let success_url = "https://example.com/success";
322
323 // It returns an error when the Zed Pro price doesn't exist.
324 {
325 let result = stripe_billing
326 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
327 .await;
328
329 assert!(result.is_err());
330 assert_eq!(
331 result.err().unwrap().to_string(),
332 r#"no price ID found for "zed-pro""#
333 );
334 }
335
336 let price = StripePrice {
337 id: StripePriceId("price_1".into()),
338 unit_amount: Some(2000),
339 lookup_key: Some("zed-pro".to_string()),
340 recurring: None,
341 };
342 stripe_client
343 .prices
344 .lock()
345 .insert(price.id.clone(), price.clone());
346
347 stripe_billing.initialize().await.unwrap();
348
349 // Successful checkout.
350 {
351 let checkout_url = stripe_billing
352 .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
353 .await
354 .unwrap();
355
356 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
357
358 let create_checkout_session_calls = stripe_client
359 .create_checkout_session_calls
360 .lock()
361 .drain(..)
362 .collect::<Vec<_>>();
363 assert_eq!(create_checkout_session_calls.len(), 1);
364 let call = create_checkout_session_calls.into_iter().next().unwrap();
365 assert_eq!(call.customer.as_ref(), Some(&customer_id));
366 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
367 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
368 assert_eq!(
369 call.line_items,
370 Some(vec![StripeCreateCheckoutSessionLineItems {
371 price: Some(price.id.to_string()),
372 quantity: Some(1)
373 }])
374 );
375 assert_eq!(
376 call.payment_method_collection,
377 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
378 );
379 assert_eq!(
380 call.subscription_data,
381 Some(StripeCreateCheckoutSessionSubscriptionData {
382 trial_period_days: Some(14),
383 trial_settings: Some(StripeSubscriptionTrialSettings {
384 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
385 missing_payment_method:
386 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
387 },
388 }),
389 metadata: None,
390 })
391 );
392 assert_eq!(call.success_url.as_deref(), Some(success_url));
393 }
394
395 // Successful checkout with extended trial.
396 {
397 let checkout_url = stripe_billing
398 .checkout_with_zed_pro_trial(
399 &customer_id,
400 github_login,
401 vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
402 success_url,
403 )
404 .await
405 .unwrap();
406
407 assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
408
409 let create_checkout_session_calls = stripe_client
410 .create_checkout_session_calls
411 .lock()
412 .drain(..)
413 .collect::<Vec<_>>();
414 assert_eq!(create_checkout_session_calls.len(), 1);
415 let call = create_checkout_session_calls.into_iter().next().unwrap();
416 assert_eq!(call.customer, Some(customer_id));
417 assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
418 assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
419 assert_eq!(
420 call.line_items,
421 Some(vec![StripeCreateCheckoutSessionLineItems {
422 price: Some(price.id.to_string()),
423 quantity: Some(1)
424 }])
425 );
426 assert_eq!(
427 call.payment_method_collection,
428 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
429 );
430 assert_eq!(
431 call.subscription_data,
432 Some(StripeCreateCheckoutSessionSubscriptionData {
433 trial_period_days: Some(60),
434 trial_settings: Some(StripeSubscriptionTrialSettings {
435 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
436 missing_payment_method:
437 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
438 },
439 }),
440 metadata: Some(std::collections::HashMap::from_iter([(
441 "promo_feature_flag".into(),
442 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
443 )])),
444 })
445 );
446 assert_eq!(call.success_url.as_deref(), Some(success_url));
447 }
448}