stripe_billing_tests.rs

  1use std::sync::Arc;
  2
  3use chrono::{Duration, Utc};
  4use pretty_assertions::assert_eq;
  5
  6use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  7use crate::stripe_billing::StripeBilling;
  8use crate::stripe_client::{
  9    FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode,
 10    StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems,
 11    StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeCustomerUpdate,
 12    StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeMeter, StripeMeterId, StripePrice,
 13    StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
 14    StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
 15    StripeSubscriptionTrialSettingsEndBehavior,
 16    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
 17};
 18
 19fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
 20    let stripe_client = Arc::new(FakeStripeClient::new());
 21    let stripe_billing = StripeBilling::test(stripe_client.clone());
 22
 23    (stripe_billing, stripe_client)
 24}
 25
 26#[gpui::test]
 27async fn test_initialize() {
 28    let (stripe_billing, stripe_client) = make_stripe_billing();
 29
 30    // Add test meters
 31    let meter1 = StripeMeter {
 32        id: StripeMeterId("meter_1".into()),
 33        event_name: "event_1".to_string(),
 34    };
 35    let meter2 = StripeMeter {
 36        id: StripeMeterId("meter_2".into()),
 37        event_name: "event_2".to_string(),
 38    };
 39    stripe_client
 40        .meters
 41        .lock()
 42        .insert(meter1.id.clone(), meter1);
 43    stripe_client
 44        .meters
 45        .lock()
 46        .insert(meter2.id.clone(), meter2);
 47
 48    // Add test prices
 49    let price1 = StripePrice {
 50        id: StripePriceId("price_1".into()),
 51        unit_amount: Some(1_000),
 52        lookup_key: Some("zed-pro".to_string()),
 53        recurring: None,
 54    };
 55    let price2 = StripePrice {
 56        id: StripePriceId("price_2".into()),
 57        unit_amount: Some(0),
 58        lookup_key: Some("zed-free".to_string()),
 59        recurring: None,
 60    };
 61    let price3 = StripePrice {
 62        id: StripePriceId("price_3".into()),
 63        unit_amount: Some(500),
 64        lookup_key: None,
 65        recurring: Some(StripePriceRecurring {
 66            meter: Some("meter_1".to_string()),
 67        }),
 68    };
 69    stripe_client
 70        .prices
 71        .lock()
 72        .insert(price1.id.clone(), price1);
 73    stripe_client
 74        .prices
 75        .lock()
 76        .insert(price2.id.clone(), price2);
 77    stripe_client
 78        .prices
 79        .lock()
 80        .insert(price3.id.clone(), price3);
 81
 82    // Initialize the billing system
 83    stripe_billing.initialize().await.unwrap();
 84
 85    // Verify that prices can be found by lookup key
 86    let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
 87    assert_eq!(zed_pro_price_id.to_string(), "price_1");
 88
 89    let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
 90    assert_eq!(zed_free_price_id.to_string(), "price_2");
 91
 92    // Verify that a price can be found by lookup key
 93    let zed_pro_price = stripe_billing
 94        .find_price_by_lookup_key("zed-pro")
 95        .await
 96        .unwrap();
 97    assert_eq!(zed_pro_price.id.to_string(), "price_1");
 98    assert_eq!(zed_pro_price.unit_amount, Some(1_000));
 99
100    // Verify that finding a non-existent lookup key returns an error
101    let result = stripe_billing
102        .find_price_by_lookup_key("non-existent")
103        .await;
104    assert!(result.is_err());
105}
106
107#[gpui::test]
108async fn test_find_or_create_customer_by_email() {
109    let (stripe_billing, stripe_client) = make_stripe_billing();
110
111    // Create a customer with an email that doesn't yet correspond to a customer.
112    {
113        let email = "user@example.com";
114
115        let customer_id = stripe_billing
116            .find_or_create_customer_by_email(Some(email))
117            .await
118            .unwrap();
119
120        let customer = stripe_client
121            .customers
122            .lock()
123            .get(&customer_id)
124            .unwrap()
125            .clone();
126        assert_eq!(customer.email.as_deref(), Some(email));
127    }
128
129    // Create a customer with an email that corresponds to an existing customer.
130    {
131        let email = "user2@example.com";
132
133        let existing_customer_id = stripe_billing
134            .find_or_create_customer_by_email(Some(email))
135            .await
136            .unwrap();
137
138        let customer_id = stripe_billing
139            .find_or_create_customer_by_email(Some(email))
140            .await
141            .unwrap();
142        assert_eq!(customer_id, existing_customer_id);
143
144        let customer = stripe_client
145            .customers
146            .lock()
147            .get(&customer_id)
148            .unwrap()
149            .clone();
150        assert_eq!(customer.email.as_deref(), Some(email));
151    }
152}
153
154#[gpui::test]
155async fn test_subscribe_to_price() {
156    let (stripe_billing, stripe_client) = make_stripe_billing();
157
158    let price = StripePrice {
159        id: StripePriceId("price_test".into()),
160        unit_amount: Some(2000),
161        lookup_key: Some("test-price".to_string()),
162        recurring: None,
163    };
164    stripe_client
165        .prices
166        .lock()
167        .insert(price.id.clone(), price.clone());
168
169    let now = Utc::now();
170    let subscription = StripeSubscription {
171        id: StripeSubscriptionId("sub_test".into()),
172        customer: StripeCustomerId("cus_test".into()),
173        status: stripe::SubscriptionStatus::Active,
174        current_period_start: now.timestamp(),
175        current_period_end: (now + Duration::days(30)).timestamp(),
176        items: vec![],
177        cancel_at: None,
178        cancellation_details: None,
179    };
180    stripe_client
181        .subscriptions
182        .lock()
183        .insert(subscription.id.clone(), subscription.clone());
184
185    stripe_billing
186        .subscribe_to_price(&subscription.id, &price)
187        .await
188        .unwrap();
189
190    let update_subscription_calls = stripe_client
191        .update_subscription_calls
192        .lock()
193        .iter()
194        .map(|(id, params)| (id.clone(), params.clone()))
195        .collect::<Vec<_>>();
196    assert_eq!(update_subscription_calls.len(), 1);
197    assert_eq!(update_subscription_calls[0].0, subscription.id);
198    assert_eq!(
199        update_subscription_calls[0].1.items,
200        Some(vec![UpdateSubscriptionItems {
201            price: Some(price.id.clone())
202        }])
203    );
204
205    // Subscribing to a price that is already on the subscription is a no-op.
206    {
207        let now = Utc::now();
208        let subscription = StripeSubscription {
209            id: StripeSubscriptionId("sub_test".into()),
210            customer: StripeCustomerId("cus_test".into()),
211            status: stripe::SubscriptionStatus::Active,
212            current_period_start: now.timestamp(),
213            current_period_end: (now + Duration::days(30)).timestamp(),
214            items: vec![StripeSubscriptionItem {
215                id: StripeSubscriptionItemId("si_test".into()),
216                price: Some(price.clone()),
217            }],
218            cancel_at: None,
219            cancellation_details: None,
220        };
221        stripe_client
222            .subscriptions
223            .lock()
224            .insert(subscription.id.clone(), subscription.clone());
225
226        stripe_billing
227            .subscribe_to_price(&subscription.id, &price)
228            .await
229            .unwrap();
230
231        assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
232    }
233}
234
235#[gpui::test]
236async fn test_subscribe_to_zed_free() {
237    let (stripe_billing, stripe_client) = make_stripe_billing();
238
239    let zed_pro_price = StripePrice {
240        id: StripePriceId("price_1".into()),
241        unit_amount: Some(0),
242        lookup_key: Some("zed-pro".to_string()),
243        recurring: None,
244    };
245    stripe_client
246        .prices
247        .lock()
248        .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
249    let zed_free_price = StripePrice {
250        id: StripePriceId("price_2".into()),
251        unit_amount: Some(0),
252        lookup_key: Some("zed-free".to_string()),
253        recurring: None,
254    };
255    stripe_client
256        .prices
257        .lock()
258        .insert(zed_free_price.id.clone(), zed_free_price.clone());
259
260    stripe_billing.initialize().await.unwrap();
261
262    // Customer is subscribed to Zed Free when not already subscribed to a plan.
263    {
264        let customer_id = StripeCustomerId("cus_no_plan".into());
265
266        let subscription = stripe_billing
267            .subscribe_to_zed_free(customer_id)
268            .await
269            .unwrap();
270
271        assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
272    }
273
274    // Customer is not subscribed to Zed Free when they already have an active subscription.
275    {
276        let customer_id = StripeCustomerId("cus_active_subscription".into());
277
278        let now = Utc::now();
279        let existing_subscription = StripeSubscription {
280            id: StripeSubscriptionId("sub_existing_active".into()),
281            customer: customer_id.clone(),
282            status: stripe::SubscriptionStatus::Active,
283            current_period_start: now.timestamp(),
284            current_period_end: (now + Duration::days(30)).timestamp(),
285            items: vec![StripeSubscriptionItem {
286                id: StripeSubscriptionItemId("si_test".into()),
287                price: Some(zed_pro_price.clone()),
288            }],
289            cancel_at: None,
290            cancellation_details: None,
291        };
292        stripe_client.subscriptions.lock().insert(
293            existing_subscription.id.clone(),
294            existing_subscription.clone(),
295        );
296
297        let subscription = stripe_billing
298            .subscribe_to_zed_free(customer_id)
299            .await
300            .unwrap();
301
302        assert_eq!(subscription, existing_subscription);
303    }
304
305    // Customer is not subscribed to Zed Free when they already have a trial subscription.
306    {
307        let customer_id = StripeCustomerId("cus_trial_subscription".into());
308
309        let now = Utc::now();
310        let existing_subscription = StripeSubscription {
311            id: StripeSubscriptionId("sub_existing_trial".into()),
312            customer: customer_id.clone(),
313            status: stripe::SubscriptionStatus::Trialing,
314            current_period_start: now.timestamp(),
315            current_period_end: (now + Duration::days(14)).timestamp(),
316            items: vec![StripeSubscriptionItem {
317                id: StripeSubscriptionItemId("si_test".into()),
318                price: Some(zed_pro_price.clone()),
319            }],
320            cancel_at: None,
321            cancellation_details: None,
322        };
323        stripe_client.subscriptions.lock().insert(
324            existing_subscription.id.clone(),
325            existing_subscription.clone(),
326        );
327
328        let subscription = stripe_billing
329            .subscribe_to_zed_free(customer_id)
330            .await
331            .unwrap();
332
333        assert_eq!(subscription, existing_subscription);
334    }
335}
336
337#[gpui::test]
338async fn test_bill_model_request_usage() {
339    let (stripe_billing, stripe_client) = make_stripe_billing();
340
341    let customer_id = StripeCustomerId("cus_test".into());
342
343    stripe_billing
344        .bill_model_request_usage(&customer_id, "some_model/requests", 73)
345        .await
346        .unwrap();
347
348    let create_meter_event_calls = stripe_client
349        .create_meter_event_calls
350        .lock()
351        .iter()
352        .cloned()
353        .collect::<Vec<_>>();
354    assert_eq!(create_meter_event_calls.len(), 1);
355    assert!(
356        create_meter_event_calls[0]
357            .identifier
358            .starts_with("model_requests/")
359    );
360    assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
361    assert_eq!(
362        create_meter_event_calls[0].event_name.as_ref(),
363        "some_model/requests"
364    );
365    assert_eq!(create_meter_event_calls[0].value, 73);
366}
367
368#[gpui::test]
369async fn test_checkout_with_zed_pro() {
370    let (stripe_billing, stripe_client) = make_stripe_billing();
371
372    let customer_id = StripeCustomerId("cus_test".into());
373    let github_login = "zeduser1";
374    let success_url = "https://example.com/success";
375
376    // It returns an error when the Zed Pro price doesn't exist.
377    {
378        let result = stripe_billing
379            .checkout_with_zed_pro(&customer_id, github_login, success_url)
380            .await;
381
382        assert!(result.is_err());
383        assert_eq!(
384            result.err().unwrap().to_string(),
385            r#"no price ID found for "zed-pro""#
386        );
387    }
388
389    // Successful checkout.
390    {
391        let price = StripePrice {
392            id: StripePriceId("price_1".into()),
393            unit_amount: Some(2000),
394            lookup_key: Some("zed-pro".to_string()),
395            recurring: None,
396        };
397        stripe_client
398            .prices
399            .lock()
400            .insert(price.id.clone(), price.clone());
401
402        stripe_billing.initialize().await.unwrap();
403
404        let checkout_url = stripe_billing
405            .checkout_with_zed_pro(&customer_id, github_login, success_url)
406            .await
407            .unwrap();
408
409        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
410
411        let create_checkout_session_calls = stripe_client
412            .create_checkout_session_calls
413            .lock()
414            .drain(..)
415            .collect::<Vec<_>>();
416        assert_eq!(create_checkout_session_calls.len(), 1);
417        let call = create_checkout_session_calls.into_iter().next().unwrap();
418        assert_eq!(call.customer, Some(customer_id));
419        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
420        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
421        assert_eq!(
422            call.line_items,
423            Some(vec![StripeCreateCheckoutSessionLineItems {
424                price: Some(price.id.to_string()),
425                quantity: Some(1)
426            }])
427        );
428        assert_eq!(call.payment_method_collection, None);
429        assert_eq!(call.subscription_data, None);
430        assert_eq!(call.success_url.as_deref(), Some(success_url));
431        assert_eq!(
432            call.billing_address_collection,
433            Some(StripeBillingAddressCollection::Required)
434        );
435        assert_eq!(
436            call.customer_update,
437            Some(StripeCustomerUpdate {
438                address: Some(StripeCustomerUpdateAddress::Auto),
439                name: Some(StripeCustomerUpdateName::Auto),
440                shipping: None,
441            })
442        );
443    }
444}
445
446#[gpui::test]
447async fn test_checkout_with_zed_pro_trial() {
448    let (stripe_billing, stripe_client) = make_stripe_billing();
449
450    let customer_id = StripeCustomerId("cus_test".into());
451    let github_login = "zeduser1";
452    let success_url = "https://example.com/success";
453
454    // It returns an error when the Zed Pro price doesn't exist.
455    {
456        let result = stripe_billing
457            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
458            .await;
459
460        assert!(result.is_err());
461        assert_eq!(
462            result.err().unwrap().to_string(),
463            r#"no price ID found for "zed-pro""#
464        );
465    }
466
467    let price = StripePrice {
468        id: StripePriceId("price_1".into()),
469        unit_amount: Some(2000),
470        lookup_key: Some("zed-pro".to_string()),
471        recurring: None,
472    };
473    stripe_client
474        .prices
475        .lock()
476        .insert(price.id.clone(), price.clone());
477
478    stripe_billing.initialize().await.unwrap();
479
480    // Successful checkout.
481    {
482        let checkout_url = stripe_billing
483            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
484            .await
485            .unwrap();
486
487        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
488
489        let create_checkout_session_calls = stripe_client
490            .create_checkout_session_calls
491            .lock()
492            .drain(..)
493            .collect::<Vec<_>>();
494        assert_eq!(create_checkout_session_calls.len(), 1);
495        let call = create_checkout_session_calls.into_iter().next().unwrap();
496        assert_eq!(call.customer.as_ref(), Some(&customer_id));
497        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
498        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
499        assert_eq!(
500            call.line_items,
501            Some(vec![StripeCreateCheckoutSessionLineItems {
502                price: Some(price.id.to_string()),
503                quantity: Some(1)
504            }])
505        );
506        assert_eq!(
507            call.payment_method_collection,
508            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
509        );
510        assert_eq!(
511            call.subscription_data,
512            Some(StripeCreateCheckoutSessionSubscriptionData {
513                trial_period_days: Some(14),
514                trial_settings: Some(StripeSubscriptionTrialSettings {
515                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
516                        missing_payment_method:
517                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
518                    },
519                }),
520                metadata: None,
521            })
522        );
523        assert_eq!(call.success_url.as_deref(), Some(success_url));
524        assert_eq!(
525            call.billing_address_collection,
526            Some(StripeBillingAddressCollection::Required)
527        );
528        assert_eq!(
529            call.customer_update,
530            Some(StripeCustomerUpdate {
531                address: Some(StripeCustomerUpdateAddress::Auto),
532                name: Some(StripeCustomerUpdateName::Auto),
533                shipping: None,
534            })
535        );
536    }
537
538    // Successful checkout with extended trial.
539    {
540        let checkout_url = stripe_billing
541            .checkout_with_zed_pro_trial(
542                &customer_id,
543                github_login,
544                vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
545                success_url,
546            )
547            .await
548            .unwrap();
549
550        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
551
552        let create_checkout_session_calls = stripe_client
553            .create_checkout_session_calls
554            .lock()
555            .drain(..)
556            .collect::<Vec<_>>();
557        assert_eq!(create_checkout_session_calls.len(), 1);
558        let call = create_checkout_session_calls.into_iter().next().unwrap();
559        assert_eq!(call.customer, Some(customer_id));
560        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
561        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
562        assert_eq!(
563            call.line_items,
564            Some(vec![StripeCreateCheckoutSessionLineItems {
565                price: Some(price.id.to_string()),
566                quantity: Some(1)
567            }])
568        );
569        assert_eq!(
570            call.payment_method_collection,
571            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
572        );
573        assert_eq!(
574            call.subscription_data,
575            Some(StripeCreateCheckoutSessionSubscriptionData {
576                trial_period_days: Some(60),
577                trial_settings: Some(StripeSubscriptionTrialSettings {
578                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
579                        missing_payment_method:
580                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
581                    },
582                }),
583                metadata: Some(std::collections::HashMap::from_iter([(
584                    "promo_feature_flag".into(),
585                    AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
586                )])),
587            })
588        );
589        assert_eq!(call.success_url.as_deref(), Some(success_url));
590        assert_eq!(
591            call.billing_address_collection,
592            Some(StripeBillingAddressCollection::Required)
593        );
594        assert_eq!(
595            call.customer_update,
596            Some(StripeCustomerUpdate {
597                address: Some(StripeCustomerUpdateAddress::Auto),
598                name: Some(StripeCustomerUpdateName::Auto),
599                shipping: None,
600            })
601        );
602    }
603}