stripe_billing_tests.rs

  1use std::sync::Arc;
  2
  3use chrono::{Duration, Utc};
  4use pretty_assertions::assert_eq;
  5
  6use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  7use crate::stripe_billing::StripeBilling;
  8use crate::stripe_client::{
  9    FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
 10    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
 11    StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
 12    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
 13    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
 14    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
 15};
 16
 17fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
 18    let stripe_client = Arc::new(FakeStripeClient::new());
 19    let stripe_billing = StripeBilling::test(stripe_client.clone());
 20
 21    (stripe_billing, stripe_client)
 22}
 23
 24#[gpui::test]
 25async fn test_initialize() {
 26    let (stripe_billing, stripe_client) = make_stripe_billing();
 27
 28    // Add test meters
 29    let meter1 = StripeMeter {
 30        id: StripeMeterId("meter_1".into()),
 31        event_name: "event_1".to_string(),
 32    };
 33    let meter2 = StripeMeter {
 34        id: StripeMeterId("meter_2".into()),
 35        event_name: "event_2".to_string(),
 36    };
 37    stripe_client
 38        .meters
 39        .lock()
 40        .insert(meter1.id.clone(), meter1);
 41    stripe_client
 42        .meters
 43        .lock()
 44        .insert(meter2.id.clone(), meter2);
 45
 46    // Add test prices
 47    let price1 = StripePrice {
 48        id: StripePriceId("price_1".into()),
 49        unit_amount: Some(1_000),
 50        lookup_key: Some("zed-pro".to_string()),
 51        recurring: None,
 52    };
 53    let price2 = StripePrice {
 54        id: StripePriceId("price_2".into()),
 55        unit_amount: Some(0),
 56        lookup_key: Some("zed-free".to_string()),
 57        recurring: None,
 58    };
 59    let price3 = StripePrice {
 60        id: StripePriceId("price_3".into()),
 61        unit_amount: Some(500),
 62        lookup_key: None,
 63        recurring: Some(StripePriceRecurring {
 64            meter: Some("meter_1".to_string()),
 65        }),
 66    };
 67    stripe_client
 68        .prices
 69        .lock()
 70        .insert(price1.id.clone(), price1);
 71    stripe_client
 72        .prices
 73        .lock()
 74        .insert(price2.id.clone(), price2);
 75    stripe_client
 76        .prices
 77        .lock()
 78        .insert(price3.id.clone(), price3);
 79
 80    // Initialize the billing system
 81    stripe_billing.initialize().await.unwrap();
 82
 83    // Verify that prices can be found by lookup key
 84    let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
 85    assert_eq!(zed_pro_price_id.to_string(), "price_1");
 86
 87    let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
 88    assert_eq!(zed_free_price_id.to_string(), "price_2");
 89
 90    // Verify that a price can be found by lookup key
 91    let zed_pro_price = stripe_billing
 92        .find_price_by_lookup_key("zed-pro")
 93        .await
 94        .unwrap();
 95    assert_eq!(zed_pro_price.id.to_string(), "price_1");
 96    assert_eq!(zed_pro_price.unit_amount, Some(1_000));
 97
 98    // Verify that finding a non-existent lookup key returns an error
 99    let result = stripe_billing
100        .find_price_by_lookup_key("non-existent")
101        .await;
102    assert!(result.is_err());
103}
104
105#[gpui::test]
106async fn test_find_or_create_customer_by_email() {
107    let (stripe_billing, stripe_client) = make_stripe_billing();
108
109    // Create a customer with an email that doesn't yet correspond to a customer.
110    {
111        let email = "user@example.com";
112
113        let customer_id = stripe_billing
114            .find_or_create_customer_by_email(Some(email))
115            .await
116            .unwrap();
117
118        let customer = stripe_client
119            .customers
120            .lock()
121            .get(&customer_id)
122            .unwrap()
123            .clone();
124        assert_eq!(customer.email.as_deref(), Some(email));
125    }
126
127    // Create a customer with an email that corresponds to an existing customer.
128    {
129        let email = "user2@example.com";
130
131        let existing_customer_id = stripe_billing
132            .find_or_create_customer_by_email(Some(email))
133            .await
134            .unwrap();
135
136        let customer_id = stripe_billing
137            .find_or_create_customer_by_email(Some(email))
138            .await
139            .unwrap();
140        assert_eq!(customer_id, existing_customer_id);
141
142        let customer = stripe_client
143            .customers
144            .lock()
145            .get(&customer_id)
146            .unwrap()
147            .clone();
148        assert_eq!(customer.email.as_deref(), Some(email));
149    }
150}
151
152#[gpui::test]
153async fn test_subscribe_to_price() {
154    let (stripe_billing, stripe_client) = make_stripe_billing();
155
156    let price = StripePrice {
157        id: StripePriceId("price_test".into()),
158        unit_amount: Some(2000),
159        lookup_key: Some("test-price".to_string()),
160        recurring: None,
161    };
162    stripe_client
163        .prices
164        .lock()
165        .insert(price.id.clone(), price.clone());
166
167    let now = Utc::now();
168    let subscription = StripeSubscription {
169        id: StripeSubscriptionId("sub_test".into()),
170        customer: StripeCustomerId("cus_test".into()),
171        status: stripe::SubscriptionStatus::Active,
172        current_period_start: now.timestamp(),
173        current_period_end: (now + Duration::days(30)).timestamp(),
174        items: vec![],
175        cancel_at: None,
176        cancellation_details: None,
177    };
178    stripe_client
179        .subscriptions
180        .lock()
181        .insert(subscription.id.clone(), subscription.clone());
182
183    stripe_billing
184        .subscribe_to_price(&subscription.id, &price)
185        .await
186        .unwrap();
187
188    let update_subscription_calls = stripe_client
189        .update_subscription_calls
190        .lock()
191        .iter()
192        .map(|(id, params)| (id.clone(), params.clone()))
193        .collect::<Vec<_>>();
194    assert_eq!(update_subscription_calls.len(), 1);
195    assert_eq!(update_subscription_calls[0].0, subscription.id);
196    assert_eq!(
197        update_subscription_calls[0].1.items,
198        Some(vec![UpdateSubscriptionItems {
199            price: Some(price.id.clone())
200        }])
201    );
202
203    // Subscribing to a price that is already on the subscription is a no-op.
204    {
205        let now = Utc::now();
206        let subscription = StripeSubscription {
207            id: StripeSubscriptionId("sub_test".into()),
208            customer: StripeCustomerId("cus_test".into()),
209            status: stripe::SubscriptionStatus::Active,
210            current_period_start: now.timestamp(),
211            current_period_end: (now + Duration::days(30)).timestamp(),
212            items: vec![StripeSubscriptionItem {
213                id: StripeSubscriptionItemId("si_test".into()),
214                price: Some(price.clone()),
215            }],
216            cancel_at: None,
217            cancellation_details: None,
218        };
219        stripe_client
220            .subscriptions
221            .lock()
222            .insert(subscription.id.clone(), subscription.clone());
223
224        stripe_billing
225            .subscribe_to_price(&subscription.id, &price)
226            .await
227            .unwrap();
228
229        assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
230    }
231}
232
233#[gpui::test]
234async fn test_subscribe_to_zed_free() {
235    let (stripe_billing, stripe_client) = make_stripe_billing();
236
237    let zed_pro_price = StripePrice {
238        id: StripePriceId("price_1".into()),
239        unit_amount: Some(0),
240        lookup_key: Some("zed-pro".to_string()),
241        recurring: None,
242    };
243    stripe_client
244        .prices
245        .lock()
246        .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
247    let zed_free_price = StripePrice {
248        id: StripePriceId("price_2".into()),
249        unit_amount: Some(0),
250        lookup_key: Some("zed-free".to_string()),
251        recurring: None,
252    };
253    stripe_client
254        .prices
255        .lock()
256        .insert(zed_free_price.id.clone(), zed_free_price.clone());
257
258    stripe_billing.initialize().await.unwrap();
259
260    // Customer is subscribed to Zed Free when not already subscribed to a plan.
261    {
262        let customer_id = StripeCustomerId("cus_no_plan".into());
263
264        let subscription = stripe_billing
265            .subscribe_to_zed_free(customer_id)
266            .await
267            .unwrap();
268
269        assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
270    }
271
272    // Customer is not subscribed to Zed Free when they already have an active subscription.
273    {
274        let customer_id = StripeCustomerId("cus_active_subscription".into());
275
276        let now = Utc::now();
277        let existing_subscription = StripeSubscription {
278            id: StripeSubscriptionId("sub_existing_active".into()),
279            customer: customer_id.clone(),
280            status: stripe::SubscriptionStatus::Active,
281            current_period_start: now.timestamp(),
282            current_period_end: (now + Duration::days(30)).timestamp(),
283            items: vec![StripeSubscriptionItem {
284                id: StripeSubscriptionItemId("si_test".into()),
285                price: Some(zed_pro_price.clone()),
286            }],
287            cancel_at: None,
288            cancellation_details: None,
289        };
290        stripe_client.subscriptions.lock().insert(
291            existing_subscription.id.clone(),
292            existing_subscription.clone(),
293        );
294
295        let subscription = stripe_billing
296            .subscribe_to_zed_free(customer_id)
297            .await
298            .unwrap();
299
300        assert_eq!(subscription, existing_subscription);
301    }
302
303    // Customer is not subscribed to Zed Free when they already have a trial subscription.
304    {
305        let customer_id = StripeCustomerId("cus_trial_subscription".into());
306
307        let now = Utc::now();
308        let existing_subscription = StripeSubscription {
309            id: StripeSubscriptionId("sub_existing_trial".into()),
310            customer: customer_id.clone(),
311            status: stripe::SubscriptionStatus::Trialing,
312            current_period_start: now.timestamp(),
313            current_period_end: (now + Duration::days(14)).timestamp(),
314            items: vec![StripeSubscriptionItem {
315                id: StripeSubscriptionItemId("si_test".into()),
316                price: Some(zed_pro_price.clone()),
317            }],
318            cancel_at: None,
319            cancellation_details: None,
320        };
321        stripe_client.subscriptions.lock().insert(
322            existing_subscription.id.clone(),
323            existing_subscription.clone(),
324        );
325
326        let subscription = stripe_billing
327            .subscribe_to_zed_free(customer_id)
328            .await
329            .unwrap();
330
331        assert_eq!(subscription, existing_subscription);
332    }
333}
334
335#[gpui::test]
336async fn test_bill_model_request_usage() {
337    let (stripe_billing, stripe_client) = make_stripe_billing();
338
339    let customer_id = StripeCustomerId("cus_test".into());
340
341    stripe_billing
342        .bill_model_request_usage(&customer_id, "some_model/requests", 73)
343        .await
344        .unwrap();
345
346    let create_meter_event_calls = stripe_client
347        .create_meter_event_calls
348        .lock()
349        .iter()
350        .cloned()
351        .collect::<Vec<_>>();
352    assert_eq!(create_meter_event_calls.len(), 1);
353    assert!(
354        create_meter_event_calls[0]
355            .identifier
356            .starts_with("model_requests/")
357    );
358    assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
359    assert_eq!(
360        create_meter_event_calls[0].event_name.as_ref(),
361        "some_model/requests"
362    );
363    assert_eq!(create_meter_event_calls[0].value, 73);
364}
365
366#[gpui::test]
367async fn test_checkout_with_zed_pro() {
368    let (stripe_billing, stripe_client) = make_stripe_billing();
369
370    let customer_id = StripeCustomerId("cus_test".into());
371    let github_login = "zeduser1";
372    let success_url = "https://example.com/success";
373
374    // It returns an error when the Zed Pro price doesn't exist.
375    {
376        let result = stripe_billing
377            .checkout_with_zed_pro(&customer_id, github_login, success_url)
378            .await;
379
380        assert!(result.is_err());
381        assert_eq!(
382            result.err().unwrap().to_string(),
383            r#"no price ID found for "zed-pro""#
384        );
385    }
386
387    // Successful checkout.
388    {
389        let price = StripePrice {
390            id: StripePriceId("price_1".into()),
391            unit_amount: Some(2000),
392            lookup_key: Some("zed-pro".to_string()),
393            recurring: None,
394        };
395        stripe_client
396            .prices
397            .lock()
398            .insert(price.id.clone(), price.clone());
399
400        stripe_billing.initialize().await.unwrap();
401
402        let checkout_url = stripe_billing
403            .checkout_with_zed_pro(&customer_id, github_login, success_url)
404            .await
405            .unwrap();
406
407        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
408
409        let create_checkout_session_calls = stripe_client
410            .create_checkout_session_calls
411            .lock()
412            .drain(..)
413            .collect::<Vec<_>>();
414        assert_eq!(create_checkout_session_calls.len(), 1);
415        let call = create_checkout_session_calls.into_iter().next().unwrap();
416        assert_eq!(call.customer, Some(customer_id));
417        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
418        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
419        assert_eq!(
420            call.line_items,
421            Some(vec![StripeCreateCheckoutSessionLineItems {
422                price: Some(price.id.to_string()),
423                quantity: Some(1)
424            }])
425        );
426        assert_eq!(call.payment_method_collection, None);
427        assert_eq!(call.subscription_data, None);
428        assert_eq!(call.success_url.as_deref(), Some(success_url));
429    }
430}
431
432#[gpui::test]
433async fn test_checkout_with_zed_pro_trial() {
434    let (stripe_billing, stripe_client) = make_stripe_billing();
435
436    let customer_id = StripeCustomerId("cus_test".into());
437    let github_login = "zeduser1";
438    let success_url = "https://example.com/success";
439
440    // It returns an error when the Zed Pro price doesn't exist.
441    {
442        let result = stripe_billing
443            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
444            .await;
445
446        assert!(result.is_err());
447        assert_eq!(
448            result.err().unwrap().to_string(),
449            r#"no price ID found for "zed-pro""#
450        );
451    }
452
453    let price = StripePrice {
454        id: StripePriceId("price_1".into()),
455        unit_amount: Some(2000),
456        lookup_key: Some("zed-pro".to_string()),
457        recurring: None,
458    };
459    stripe_client
460        .prices
461        .lock()
462        .insert(price.id.clone(), price.clone());
463
464    stripe_billing.initialize().await.unwrap();
465
466    // Successful checkout.
467    {
468        let checkout_url = stripe_billing
469            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
470            .await
471            .unwrap();
472
473        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
474
475        let create_checkout_session_calls = stripe_client
476            .create_checkout_session_calls
477            .lock()
478            .drain(..)
479            .collect::<Vec<_>>();
480        assert_eq!(create_checkout_session_calls.len(), 1);
481        let call = create_checkout_session_calls.into_iter().next().unwrap();
482        assert_eq!(call.customer.as_ref(), Some(&customer_id));
483        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
484        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
485        assert_eq!(
486            call.line_items,
487            Some(vec![StripeCreateCheckoutSessionLineItems {
488                price: Some(price.id.to_string()),
489                quantity: Some(1)
490            }])
491        );
492        assert_eq!(
493            call.payment_method_collection,
494            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
495        );
496        assert_eq!(
497            call.subscription_data,
498            Some(StripeCreateCheckoutSessionSubscriptionData {
499                trial_period_days: Some(14),
500                trial_settings: Some(StripeSubscriptionTrialSettings {
501                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
502                        missing_payment_method:
503                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
504                    },
505                }),
506                metadata: None,
507            })
508        );
509        assert_eq!(call.success_url.as_deref(), Some(success_url));
510    }
511
512    // Successful checkout with extended trial.
513    {
514        let checkout_url = stripe_billing
515            .checkout_with_zed_pro_trial(
516                &customer_id,
517                github_login,
518                vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
519                success_url,
520            )
521            .await
522            .unwrap();
523
524        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
525
526        let create_checkout_session_calls = stripe_client
527            .create_checkout_session_calls
528            .lock()
529            .drain(..)
530            .collect::<Vec<_>>();
531        assert_eq!(create_checkout_session_calls.len(), 1);
532        let call = create_checkout_session_calls.into_iter().next().unwrap();
533        assert_eq!(call.customer, Some(customer_id));
534        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
535        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
536        assert_eq!(
537            call.line_items,
538            Some(vec![StripeCreateCheckoutSessionLineItems {
539                price: Some(price.id.to_string()),
540                quantity: Some(1)
541            }])
542        );
543        assert_eq!(
544            call.payment_method_collection,
545            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
546        );
547        assert_eq!(
548            call.subscription_data,
549            Some(StripeCreateCheckoutSessionSubscriptionData {
550                trial_period_days: Some(60),
551                trial_settings: Some(StripeSubscriptionTrialSettings {
552                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
553                        missing_payment_method:
554                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
555                    },
556                }),
557                metadata: Some(std::collections::HashMap::from_iter([(
558                    "promo_feature_flag".into(),
559                    AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
560                )])),
561            })
562        );
563        assert_eq!(call.success_url.as_deref(), Some(success_url));
564    }
565}