stripe_billing_tests.rs

  1use std::sync::Arc;
  2
  3use chrono::{Duration, Utc};
  4use pretty_assertions::assert_eq;
  5
  6use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  7use crate::stripe_billing::StripeBilling;
  8use crate::stripe_client::{
  9    FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
 10    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
 11    StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
 12    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
 13    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
 14    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
 15};
 16
 17fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
 18    let stripe_client = Arc::new(FakeStripeClient::new());
 19    let stripe_billing = StripeBilling::test(stripe_client.clone());
 20
 21    (stripe_billing, stripe_client)
 22}
 23
 24#[gpui::test]
 25async fn test_initialize() {
 26    let (stripe_billing, stripe_client) = make_stripe_billing();
 27
 28    // Add test meters
 29    let meter1 = StripeMeter {
 30        id: StripeMeterId("meter_1".into()),
 31        event_name: "event_1".to_string(),
 32    };
 33    let meter2 = StripeMeter {
 34        id: StripeMeterId("meter_2".into()),
 35        event_name: "event_2".to_string(),
 36    };
 37    stripe_client
 38        .meters
 39        .lock()
 40        .insert(meter1.id.clone(), meter1);
 41    stripe_client
 42        .meters
 43        .lock()
 44        .insert(meter2.id.clone(), meter2);
 45
 46    // Add test prices
 47    let price1 = StripePrice {
 48        id: StripePriceId("price_1".into()),
 49        unit_amount: Some(1_000),
 50        lookup_key: Some("zed-pro".to_string()),
 51        recurring: None,
 52    };
 53    let price2 = StripePrice {
 54        id: StripePriceId("price_2".into()),
 55        unit_amount: Some(0),
 56        lookup_key: Some("zed-free".to_string()),
 57        recurring: None,
 58    };
 59    let price3 = StripePrice {
 60        id: StripePriceId("price_3".into()),
 61        unit_amount: Some(500),
 62        lookup_key: None,
 63        recurring: Some(StripePriceRecurring {
 64            meter: Some("meter_1".to_string()),
 65        }),
 66    };
 67    stripe_client
 68        .prices
 69        .lock()
 70        .insert(price1.id.clone(), price1);
 71    stripe_client
 72        .prices
 73        .lock()
 74        .insert(price2.id.clone(), price2);
 75    stripe_client
 76        .prices
 77        .lock()
 78        .insert(price3.id.clone(), price3);
 79
 80    // Initialize the billing system
 81    stripe_billing.initialize().await.unwrap();
 82
 83    // Verify that prices can be found by lookup key
 84    let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
 85    assert_eq!(zed_pro_price_id.to_string(), "price_1");
 86
 87    let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
 88    assert_eq!(zed_free_price_id.to_string(), "price_2");
 89
 90    // Verify that a price can be found by lookup key
 91    let zed_pro_price = stripe_billing
 92        .find_price_by_lookup_key("zed-pro")
 93        .await
 94        .unwrap();
 95    assert_eq!(zed_pro_price.id.to_string(), "price_1");
 96    assert_eq!(zed_pro_price.unit_amount, Some(1_000));
 97
 98    // Verify that finding a non-existent lookup key returns an error
 99    let result = stripe_billing
100        .find_price_by_lookup_key("non-existent")
101        .await;
102    assert!(result.is_err());
103}
104
105#[gpui::test]
106async fn test_find_or_create_customer_by_email() {
107    let (stripe_billing, stripe_client) = make_stripe_billing();
108
109    // Create a customer with an email that doesn't yet correspond to a customer.
110    {
111        let email = "user@example.com";
112
113        let customer_id = stripe_billing
114            .find_or_create_customer_by_email(Some(email))
115            .await
116            .unwrap();
117
118        let customer = stripe_client
119            .customers
120            .lock()
121            .get(&customer_id)
122            .unwrap()
123            .clone();
124        assert_eq!(customer.email.as_deref(), Some(email));
125    }
126
127    // Create a customer with an email that corresponds to an existing customer.
128    {
129        let email = "user2@example.com";
130
131        let existing_customer_id = stripe_billing
132            .find_or_create_customer_by_email(Some(email))
133            .await
134            .unwrap();
135
136        let customer_id = stripe_billing
137            .find_or_create_customer_by_email(Some(email))
138            .await
139            .unwrap();
140        assert_eq!(customer_id, existing_customer_id);
141
142        let customer = stripe_client
143            .customers
144            .lock()
145            .get(&customer_id)
146            .unwrap()
147            .clone();
148        assert_eq!(customer.email.as_deref(), Some(email));
149    }
150}
151
152#[gpui::test]
153async fn test_subscribe_to_price() {
154    let (stripe_billing, stripe_client) = make_stripe_billing();
155
156    let price = StripePrice {
157        id: StripePriceId("price_test".into()),
158        unit_amount: Some(2000),
159        lookup_key: Some("test-price".to_string()),
160        recurring: None,
161    };
162    stripe_client
163        .prices
164        .lock()
165        .insert(price.id.clone(), price.clone());
166
167    let now = Utc::now();
168    let subscription = StripeSubscription {
169        id: StripeSubscriptionId("sub_test".into()),
170        customer: StripeCustomerId("cus_test".into()),
171        status: stripe::SubscriptionStatus::Active,
172        current_period_start: now.timestamp(),
173        current_period_end: (now + Duration::days(30)).timestamp(),
174        items: vec![],
175    };
176    stripe_client
177        .subscriptions
178        .lock()
179        .insert(subscription.id.clone(), subscription.clone());
180
181    stripe_billing
182        .subscribe_to_price(&subscription.id, &price)
183        .await
184        .unwrap();
185
186    let update_subscription_calls = stripe_client
187        .update_subscription_calls
188        .lock()
189        .iter()
190        .map(|(id, params)| (id.clone(), params.clone()))
191        .collect::<Vec<_>>();
192    assert_eq!(update_subscription_calls.len(), 1);
193    assert_eq!(update_subscription_calls[0].0, subscription.id);
194    assert_eq!(
195        update_subscription_calls[0].1.items,
196        Some(vec![UpdateSubscriptionItems {
197            price: Some(price.id.clone())
198        }])
199    );
200
201    // Subscribing to a price that is already on the subscription is a no-op.
202    {
203        let now = Utc::now();
204        let subscription = StripeSubscription {
205            id: StripeSubscriptionId("sub_test".into()),
206            customer: StripeCustomerId("cus_test".into()),
207            status: stripe::SubscriptionStatus::Active,
208            current_period_start: now.timestamp(),
209            current_period_end: (now + Duration::days(30)).timestamp(),
210            items: vec![StripeSubscriptionItem {
211                id: StripeSubscriptionItemId("si_test".into()),
212                price: Some(price.clone()),
213            }],
214        };
215        stripe_client
216            .subscriptions
217            .lock()
218            .insert(subscription.id.clone(), subscription.clone());
219
220        stripe_billing
221            .subscribe_to_price(&subscription.id, &price)
222            .await
223            .unwrap();
224
225        assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
226    }
227}
228
229#[gpui::test]
230async fn test_subscribe_to_zed_free() {
231    let (stripe_billing, stripe_client) = make_stripe_billing();
232
233    let zed_pro_price = StripePrice {
234        id: StripePriceId("price_1".into()),
235        unit_amount: Some(0),
236        lookup_key: Some("zed-pro".to_string()),
237        recurring: None,
238    };
239    stripe_client
240        .prices
241        .lock()
242        .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
243    let zed_free_price = StripePrice {
244        id: StripePriceId("price_2".into()),
245        unit_amount: Some(0),
246        lookup_key: Some("zed-free".to_string()),
247        recurring: None,
248    };
249    stripe_client
250        .prices
251        .lock()
252        .insert(zed_free_price.id.clone(), zed_free_price.clone());
253
254    stripe_billing.initialize().await.unwrap();
255
256    // Customer is subscribed to Zed Free when not already subscribed to a plan.
257    {
258        let customer_id = StripeCustomerId("cus_no_plan".into());
259
260        let subscription = stripe_billing
261            .subscribe_to_zed_free(customer_id)
262            .await
263            .unwrap();
264
265        assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
266    }
267
268    // Customer is not subscribed to Zed Free when they already have an active subscription.
269    {
270        let customer_id = StripeCustomerId("cus_active_subscription".into());
271
272        let now = Utc::now();
273        let existing_subscription = StripeSubscription {
274            id: StripeSubscriptionId("sub_existing_active".into()),
275            customer: customer_id.clone(),
276            status: stripe::SubscriptionStatus::Active,
277            current_period_start: now.timestamp(),
278            current_period_end: (now + Duration::days(30)).timestamp(),
279            items: vec![StripeSubscriptionItem {
280                id: StripeSubscriptionItemId("si_test".into()),
281                price: Some(zed_pro_price.clone()),
282            }],
283        };
284        stripe_client.subscriptions.lock().insert(
285            existing_subscription.id.clone(),
286            existing_subscription.clone(),
287        );
288
289        let subscription = stripe_billing
290            .subscribe_to_zed_free(customer_id)
291            .await
292            .unwrap();
293
294        assert_eq!(subscription, existing_subscription);
295    }
296
297    // Customer is not subscribed to Zed Free when they already have a trial subscription.
298    {
299        let customer_id = StripeCustomerId("cus_trial_subscription".into());
300
301        let now = Utc::now();
302        let existing_subscription = StripeSubscription {
303            id: StripeSubscriptionId("sub_existing_trial".into()),
304            customer: customer_id.clone(),
305            status: stripe::SubscriptionStatus::Trialing,
306            current_period_start: now.timestamp(),
307            current_period_end: (now + Duration::days(14)).timestamp(),
308            items: vec![StripeSubscriptionItem {
309                id: StripeSubscriptionItemId("si_test".into()),
310                price: Some(zed_pro_price.clone()),
311            }],
312        };
313        stripe_client.subscriptions.lock().insert(
314            existing_subscription.id.clone(),
315            existing_subscription.clone(),
316        );
317
318        let subscription = stripe_billing
319            .subscribe_to_zed_free(customer_id)
320            .await
321            .unwrap();
322
323        assert_eq!(subscription, existing_subscription);
324    }
325}
326
327#[gpui::test]
328async fn test_bill_model_request_usage() {
329    let (stripe_billing, stripe_client) = make_stripe_billing();
330
331    let customer_id = StripeCustomerId("cus_test".into());
332
333    stripe_billing
334        .bill_model_request_usage(&customer_id, "some_model/requests", 73)
335        .await
336        .unwrap();
337
338    let create_meter_event_calls = stripe_client
339        .create_meter_event_calls
340        .lock()
341        .iter()
342        .cloned()
343        .collect::<Vec<_>>();
344    assert_eq!(create_meter_event_calls.len(), 1);
345    assert!(
346        create_meter_event_calls[0]
347            .identifier
348            .starts_with("model_requests/")
349    );
350    assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
351    assert_eq!(
352        create_meter_event_calls[0].event_name.as_ref(),
353        "some_model/requests"
354    );
355    assert_eq!(create_meter_event_calls[0].value, 73);
356}
357
358#[gpui::test]
359async fn test_checkout_with_zed_pro() {
360    let (stripe_billing, stripe_client) = make_stripe_billing();
361
362    let customer_id = StripeCustomerId("cus_test".into());
363    let github_login = "zeduser1";
364    let success_url = "https://example.com/success";
365
366    // It returns an error when the Zed Pro price doesn't exist.
367    {
368        let result = stripe_billing
369            .checkout_with_zed_pro(&customer_id, github_login, success_url)
370            .await;
371
372        assert!(result.is_err());
373        assert_eq!(
374            result.err().unwrap().to_string(),
375            r#"no price ID found for "zed-pro""#
376        );
377    }
378
379    // Successful checkout.
380    {
381        let price = StripePrice {
382            id: StripePriceId("price_1".into()),
383            unit_amount: Some(2000),
384            lookup_key: Some("zed-pro".to_string()),
385            recurring: None,
386        };
387        stripe_client
388            .prices
389            .lock()
390            .insert(price.id.clone(), price.clone());
391
392        stripe_billing.initialize().await.unwrap();
393
394        let checkout_url = stripe_billing
395            .checkout_with_zed_pro(&customer_id, github_login, success_url)
396            .await
397            .unwrap();
398
399        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
400
401        let create_checkout_session_calls = stripe_client
402            .create_checkout_session_calls
403            .lock()
404            .drain(..)
405            .collect::<Vec<_>>();
406        assert_eq!(create_checkout_session_calls.len(), 1);
407        let call = create_checkout_session_calls.into_iter().next().unwrap();
408        assert_eq!(call.customer, Some(customer_id));
409        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
410        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
411        assert_eq!(
412            call.line_items,
413            Some(vec![StripeCreateCheckoutSessionLineItems {
414                price: Some(price.id.to_string()),
415                quantity: Some(1)
416            }])
417        );
418        assert_eq!(call.payment_method_collection, None);
419        assert_eq!(call.subscription_data, None);
420        assert_eq!(call.success_url.as_deref(), Some(success_url));
421    }
422}
423
424#[gpui::test]
425async fn test_checkout_with_zed_pro_trial() {
426    let (stripe_billing, stripe_client) = make_stripe_billing();
427
428    let customer_id = StripeCustomerId("cus_test".into());
429    let github_login = "zeduser1";
430    let success_url = "https://example.com/success";
431
432    // It returns an error when the Zed Pro price doesn't exist.
433    {
434        let result = stripe_billing
435            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
436            .await;
437
438        assert!(result.is_err());
439        assert_eq!(
440            result.err().unwrap().to_string(),
441            r#"no price ID found for "zed-pro""#
442        );
443    }
444
445    let price = StripePrice {
446        id: StripePriceId("price_1".into()),
447        unit_amount: Some(2000),
448        lookup_key: Some("zed-pro".to_string()),
449        recurring: None,
450    };
451    stripe_client
452        .prices
453        .lock()
454        .insert(price.id.clone(), price.clone());
455
456    stripe_billing.initialize().await.unwrap();
457
458    // Successful checkout.
459    {
460        let checkout_url = stripe_billing
461            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
462            .await
463            .unwrap();
464
465        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
466
467        let create_checkout_session_calls = stripe_client
468            .create_checkout_session_calls
469            .lock()
470            .drain(..)
471            .collect::<Vec<_>>();
472        assert_eq!(create_checkout_session_calls.len(), 1);
473        let call = create_checkout_session_calls.into_iter().next().unwrap();
474        assert_eq!(call.customer.as_ref(), Some(&customer_id));
475        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
476        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
477        assert_eq!(
478            call.line_items,
479            Some(vec![StripeCreateCheckoutSessionLineItems {
480                price: Some(price.id.to_string()),
481                quantity: Some(1)
482            }])
483        );
484        assert_eq!(
485            call.payment_method_collection,
486            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
487        );
488        assert_eq!(
489            call.subscription_data,
490            Some(StripeCreateCheckoutSessionSubscriptionData {
491                trial_period_days: Some(14),
492                trial_settings: Some(StripeSubscriptionTrialSettings {
493                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
494                        missing_payment_method:
495                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
496                    },
497                }),
498                metadata: None,
499            })
500        );
501        assert_eq!(call.success_url.as_deref(), Some(success_url));
502    }
503
504    // Successful checkout with extended trial.
505    {
506        let checkout_url = stripe_billing
507            .checkout_with_zed_pro_trial(
508                &customer_id,
509                github_login,
510                vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
511                success_url,
512            )
513            .await
514            .unwrap();
515
516        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
517
518        let create_checkout_session_calls = stripe_client
519            .create_checkout_session_calls
520            .lock()
521            .drain(..)
522            .collect::<Vec<_>>();
523        assert_eq!(create_checkout_session_calls.len(), 1);
524        let call = create_checkout_session_calls.into_iter().next().unwrap();
525        assert_eq!(call.customer, Some(customer_id));
526        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
527        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
528        assert_eq!(
529            call.line_items,
530            Some(vec![StripeCreateCheckoutSessionLineItems {
531                price: Some(price.id.to_string()),
532                quantity: Some(1)
533            }])
534        );
535        assert_eq!(
536            call.payment_method_collection,
537            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
538        );
539        assert_eq!(
540            call.subscription_data,
541            Some(StripeCreateCheckoutSessionSubscriptionData {
542                trial_period_days: Some(60),
543                trial_settings: Some(StripeSubscriptionTrialSettings {
544                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
545                        missing_payment_method:
546                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
547                    },
548                }),
549                metadata: Some(std::collections::HashMap::from_iter([(
550                    "promo_feature_flag".into(),
551                    AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
552                )])),
553            })
554        );
555        assert_eq!(call.success_url.as_deref(), Some(success_url));
556    }
557}