stripe_billing_tests.rs

  1use std::sync::Arc;
  2
  3use chrono::{Duration, Utc};
  4use pretty_assertions::assert_eq;
  5
  6use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  7use crate::stripe_billing::StripeBilling;
  8use crate::stripe_client::{
  9    FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode,
 10    StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems,
 11    StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeMeter, StripeMeterId,
 12    StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
 13    StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
 14    StripeSubscriptionTrialSettingsEndBehavior,
 15    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
 16};
 17
 18fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
 19    let stripe_client = Arc::new(FakeStripeClient::new());
 20    let stripe_billing = StripeBilling::test(stripe_client.clone());
 21
 22    (stripe_billing, stripe_client)
 23}
 24
 25#[gpui::test]
 26async fn test_initialize() {
 27    let (stripe_billing, stripe_client) = make_stripe_billing();
 28
 29    // Add test meters
 30    let meter1 = StripeMeter {
 31        id: StripeMeterId("meter_1".into()),
 32        event_name: "event_1".to_string(),
 33    };
 34    let meter2 = StripeMeter {
 35        id: StripeMeterId("meter_2".into()),
 36        event_name: "event_2".to_string(),
 37    };
 38    stripe_client
 39        .meters
 40        .lock()
 41        .insert(meter1.id.clone(), meter1);
 42    stripe_client
 43        .meters
 44        .lock()
 45        .insert(meter2.id.clone(), meter2);
 46
 47    // Add test prices
 48    let price1 = StripePrice {
 49        id: StripePriceId("price_1".into()),
 50        unit_amount: Some(1_000),
 51        lookup_key: Some("zed-pro".to_string()),
 52        recurring: None,
 53    };
 54    let price2 = StripePrice {
 55        id: StripePriceId("price_2".into()),
 56        unit_amount: Some(0),
 57        lookup_key: Some("zed-free".to_string()),
 58        recurring: None,
 59    };
 60    let price3 = StripePrice {
 61        id: StripePriceId("price_3".into()),
 62        unit_amount: Some(500),
 63        lookup_key: None,
 64        recurring: Some(StripePriceRecurring {
 65            meter: Some("meter_1".to_string()),
 66        }),
 67    };
 68    stripe_client
 69        .prices
 70        .lock()
 71        .insert(price1.id.clone(), price1);
 72    stripe_client
 73        .prices
 74        .lock()
 75        .insert(price2.id.clone(), price2);
 76    stripe_client
 77        .prices
 78        .lock()
 79        .insert(price3.id.clone(), price3);
 80
 81    // Initialize the billing system
 82    stripe_billing.initialize().await.unwrap();
 83
 84    // Verify that prices can be found by lookup key
 85    let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
 86    assert_eq!(zed_pro_price_id.to_string(), "price_1");
 87
 88    let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
 89    assert_eq!(zed_free_price_id.to_string(), "price_2");
 90
 91    // Verify that a price can be found by lookup key
 92    let zed_pro_price = stripe_billing
 93        .find_price_by_lookup_key("zed-pro")
 94        .await
 95        .unwrap();
 96    assert_eq!(zed_pro_price.id.to_string(), "price_1");
 97    assert_eq!(zed_pro_price.unit_amount, Some(1_000));
 98
 99    // Verify that finding a non-existent lookup key returns an error
100    let result = stripe_billing
101        .find_price_by_lookup_key("non-existent")
102        .await;
103    assert!(result.is_err());
104}
105
106#[gpui::test]
107async fn test_find_or_create_customer_by_email() {
108    let (stripe_billing, stripe_client) = make_stripe_billing();
109
110    // Create a customer with an email that doesn't yet correspond to a customer.
111    {
112        let email = "user@example.com";
113
114        let customer_id = stripe_billing
115            .find_or_create_customer_by_email(Some(email))
116            .await
117            .unwrap();
118
119        let customer = stripe_client
120            .customers
121            .lock()
122            .get(&customer_id)
123            .unwrap()
124            .clone();
125        assert_eq!(customer.email.as_deref(), Some(email));
126    }
127
128    // Create a customer with an email that corresponds to an existing customer.
129    {
130        let email = "user2@example.com";
131
132        let existing_customer_id = stripe_billing
133            .find_or_create_customer_by_email(Some(email))
134            .await
135            .unwrap();
136
137        let customer_id = stripe_billing
138            .find_or_create_customer_by_email(Some(email))
139            .await
140            .unwrap();
141        assert_eq!(customer_id, existing_customer_id);
142
143        let customer = stripe_client
144            .customers
145            .lock()
146            .get(&customer_id)
147            .unwrap()
148            .clone();
149        assert_eq!(customer.email.as_deref(), Some(email));
150    }
151}
152
153#[gpui::test]
154async fn test_subscribe_to_price() {
155    let (stripe_billing, stripe_client) = make_stripe_billing();
156
157    let price = StripePrice {
158        id: StripePriceId("price_test".into()),
159        unit_amount: Some(2000),
160        lookup_key: Some("test-price".to_string()),
161        recurring: None,
162    };
163    stripe_client
164        .prices
165        .lock()
166        .insert(price.id.clone(), price.clone());
167
168    let now = Utc::now();
169    let subscription = StripeSubscription {
170        id: StripeSubscriptionId("sub_test".into()),
171        customer: StripeCustomerId("cus_test".into()),
172        status: stripe::SubscriptionStatus::Active,
173        current_period_start: now.timestamp(),
174        current_period_end: (now + Duration::days(30)).timestamp(),
175        items: vec![],
176        cancel_at: None,
177        cancellation_details: None,
178    };
179    stripe_client
180        .subscriptions
181        .lock()
182        .insert(subscription.id.clone(), subscription.clone());
183
184    stripe_billing
185        .subscribe_to_price(&subscription.id, &price)
186        .await
187        .unwrap();
188
189    let update_subscription_calls = stripe_client
190        .update_subscription_calls
191        .lock()
192        .iter()
193        .map(|(id, params)| (id.clone(), params.clone()))
194        .collect::<Vec<_>>();
195    assert_eq!(update_subscription_calls.len(), 1);
196    assert_eq!(update_subscription_calls[0].0, subscription.id);
197    assert_eq!(
198        update_subscription_calls[0].1.items,
199        Some(vec![UpdateSubscriptionItems {
200            price: Some(price.id.clone())
201        }])
202    );
203
204    // Subscribing to a price that is already on the subscription is a no-op.
205    {
206        let now = Utc::now();
207        let subscription = StripeSubscription {
208            id: StripeSubscriptionId("sub_test".into()),
209            customer: StripeCustomerId("cus_test".into()),
210            status: stripe::SubscriptionStatus::Active,
211            current_period_start: now.timestamp(),
212            current_period_end: (now + Duration::days(30)).timestamp(),
213            items: vec![StripeSubscriptionItem {
214                id: StripeSubscriptionItemId("si_test".into()),
215                price: Some(price.clone()),
216            }],
217            cancel_at: None,
218            cancellation_details: None,
219        };
220        stripe_client
221            .subscriptions
222            .lock()
223            .insert(subscription.id.clone(), subscription.clone());
224
225        stripe_billing
226            .subscribe_to_price(&subscription.id, &price)
227            .await
228            .unwrap();
229
230        assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
231    }
232}
233
234#[gpui::test]
235async fn test_subscribe_to_zed_free() {
236    let (stripe_billing, stripe_client) = make_stripe_billing();
237
238    let zed_pro_price = StripePrice {
239        id: StripePriceId("price_1".into()),
240        unit_amount: Some(0),
241        lookup_key: Some("zed-pro".to_string()),
242        recurring: None,
243    };
244    stripe_client
245        .prices
246        .lock()
247        .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
248    let zed_free_price = StripePrice {
249        id: StripePriceId("price_2".into()),
250        unit_amount: Some(0),
251        lookup_key: Some("zed-free".to_string()),
252        recurring: None,
253    };
254    stripe_client
255        .prices
256        .lock()
257        .insert(zed_free_price.id.clone(), zed_free_price.clone());
258
259    stripe_billing.initialize().await.unwrap();
260
261    // Customer is subscribed to Zed Free when not already subscribed to a plan.
262    {
263        let customer_id = StripeCustomerId("cus_no_plan".into());
264
265        let subscription = stripe_billing
266            .subscribe_to_zed_free(customer_id)
267            .await
268            .unwrap();
269
270        assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
271    }
272
273    // Customer is not subscribed to Zed Free when they already have an active subscription.
274    {
275        let customer_id = StripeCustomerId("cus_active_subscription".into());
276
277        let now = Utc::now();
278        let existing_subscription = StripeSubscription {
279            id: StripeSubscriptionId("sub_existing_active".into()),
280            customer: customer_id.clone(),
281            status: stripe::SubscriptionStatus::Active,
282            current_period_start: now.timestamp(),
283            current_period_end: (now + Duration::days(30)).timestamp(),
284            items: vec![StripeSubscriptionItem {
285                id: StripeSubscriptionItemId("si_test".into()),
286                price: Some(zed_pro_price.clone()),
287            }],
288            cancel_at: None,
289            cancellation_details: None,
290        };
291        stripe_client.subscriptions.lock().insert(
292            existing_subscription.id.clone(),
293            existing_subscription.clone(),
294        );
295
296        let subscription = stripe_billing
297            .subscribe_to_zed_free(customer_id)
298            .await
299            .unwrap();
300
301        assert_eq!(subscription, existing_subscription);
302    }
303
304    // Customer is not subscribed to Zed Free when they already have a trial subscription.
305    {
306        let customer_id = StripeCustomerId("cus_trial_subscription".into());
307
308        let now = Utc::now();
309        let existing_subscription = StripeSubscription {
310            id: StripeSubscriptionId("sub_existing_trial".into()),
311            customer: customer_id.clone(),
312            status: stripe::SubscriptionStatus::Trialing,
313            current_period_start: now.timestamp(),
314            current_period_end: (now + Duration::days(14)).timestamp(),
315            items: vec![StripeSubscriptionItem {
316                id: StripeSubscriptionItemId("si_test".into()),
317                price: Some(zed_pro_price.clone()),
318            }],
319            cancel_at: None,
320            cancellation_details: None,
321        };
322        stripe_client.subscriptions.lock().insert(
323            existing_subscription.id.clone(),
324            existing_subscription.clone(),
325        );
326
327        let subscription = stripe_billing
328            .subscribe_to_zed_free(customer_id)
329            .await
330            .unwrap();
331
332        assert_eq!(subscription, existing_subscription);
333    }
334}
335
336#[gpui::test]
337async fn test_bill_model_request_usage() {
338    let (stripe_billing, stripe_client) = make_stripe_billing();
339
340    let customer_id = StripeCustomerId("cus_test".into());
341
342    stripe_billing
343        .bill_model_request_usage(&customer_id, "some_model/requests", 73)
344        .await
345        .unwrap();
346
347    let create_meter_event_calls = stripe_client
348        .create_meter_event_calls
349        .lock()
350        .iter()
351        .cloned()
352        .collect::<Vec<_>>();
353    assert_eq!(create_meter_event_calls.len(), 1);
354    assert!(
355        create_meter_event_calls[0]
356            .identifier
357            .starts_with("model_requests/")
358    );
359    assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
360    assert_eq!(
361        create_meter_event_calls[0].event_name.as_ref(),
362        "some_model/requests"
363    );
364    assert_eq!(create_meter_event_calls[0].value, 73);
365}
366
367#[gpui::test]
368async fn test_checkout_with_zed_pro() {
369    let (stripe_billing, stripe_client) = make_stripe_billing();
370
371    let customer_id = StripeCustomerId("cus_test".into());
372    let github_login = "zeduser1";
373    let success_url = "https://example.com/success";
374
375    // It returns an error when the Zed Pro price doesn't exist.
376    {
377        let result = stripe_billing
378            .checkout_with_zed_pro(&customer_id, github_login, success_url)
379            .await;
380
381        assert!(result.is_err());
382        assert_eq!(
383            result.err().unwrap().to_string(),
384            r#"no price ID found for "zed-pro""#
385        );
386    }
387
388    // Successful checkout.
389    {
390        let price = StripePrice {
391            id: StripePriceId("price_1".into()),
392            unit_amount: Some(2000),
393            lookup_key: Some("zed-pro".to_string()),
394            recurring: None,
395        };
396        stripe_client
397            .prices
398            .lock()
399            .insert(price.id.clone(), price.clone());
400
401        stripe_billing.initialize().await.unwrap();
402
403        let checkout_url = stripe_billing
404            .checkout_with_zed_pro(&customer_id, github_login, success_url)
405            .await
406            .unwrap();
407
408        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
409
410        let create_checkout_session_calls = stripe_client
411            .create_checkout_session_calls
412            .lock()
413            .drain(..)
414            .collect::<Vec<_>>();
415        assert_eq!(create_checkout_session_calls.len(), 1);
416        let call = create_checkout_session_calls.into_iter().next().unwrap();
417        assert_eq!(call.customer, Some(customer_id));
418        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
419        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
420        assert_eq!(
421            call.line_items,
422            Some(vec![StripeCreateCheckoutSessionLineItems {
423                price: Some(price.id.to_string()),
424                quantity: Some(1)
425            }])
426        );
427        assert_eq!(call.payment_method_collection, None);
428        assert_eq!(call.subscription_data, None);
429        assert_eq!(call.success_url.as_deref(), Some(success_url));
430        assert_eq!(
431            call.billing_address_collection,
432            Some(StripeBillingAddressCollection::Required)
433        );
434    }
435}
436
437#[gpui::test]
438async fn test_checkout_with_zed_pro_trial() {
439    let (stripe_billing, stripe_client) = make_stripe_billing();
440
441    let customer_id = StripeCustomerId("cus_test".into());
442    let github_login = "zeduser1";
443    let success_url = "https://example.com/success";
444
445    // It returns an error when the Zed Pro price doesn't exist.
446    {
447        let result = stripe_billing
448            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
449            .await;
450
451        assert!(result.is_err());
452        assert_eq!(
453            result.err().unwrap().to_string(),
454            r#"no price ID found for "zed-pro""#
455        );
456    }
457
458    let price = StripePrice {
459        id: StripePriceId("price_1".into()),
460        unit_amount: Some(2000),
461        lookup_key: Some("zed-pro".to_string()),
462        recurring: None,
463    };
464    stripe_client
465        .prices
466        .lock()
467        .insert(price.id.clone(), price.clone());
468
469    stripe_billing.initialize().await.unwrap();
470
471    // Successful checkout.
472    {
473        let checkout_url = stripe_billing
474            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
475            .await
476            .unwrap();
477
478        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
479
480        let create_checkout_session_calls = stripe_client
481            .create_checkout_session_calls
482            .lock()
483            .drain(..)
484            .collect::<Vec<_>>();
485        assert_eq!(create_checkout_session_calls.len(), 1);
486        let call = create_checkout_session_calls.into_iter().next().unwrap();
487        assert_eq!(call.customer.as_ref(), Some(&customer_id));
488        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
489        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
490        assert_eq!(
491            call.line_items,
492            Some(vec![StripeCreateCheckoutSessionLineItems {
493                price: Some(price.id.to_string()),
494                quantity: Some(1)
495            }])
496        );
497        assert_eq!(
498            call.payment_method_collection,
499            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
500        );
501        assert_eq!(
502            call.subscription_data,
503            Some(StripeCreateCheckoutSessionSubscriptionData {
504                trial_period_days: Some(14),
505                trial_settings: Some(StripeSubscriptionTrialSettings {
506                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
507                        missing_payment_method:
508                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
509                    },
510                }),
511                metadata: None,
512            })
513        );
514        assert_eq!(call.success_url.as_deref(), Some(success_url));
515        assert_eq!(
516            call.billing_address_collection,
517            Some(StripeBillingAddressCollection::Required)
518        );
519    }
520
521    // Successful checkout with extended trial.
522    {
523        let checkout_url = stripe_billing
524            .checkout_with_zed_pro_trial(
525                &customer_id,
526                github_login,
527                vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
528                success_url,
529            )
530            .await
531            .unwrap();
532
533        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
534
535        let create_checkout_session_calls = stripe_client
536            .create_checkout_session_calls
537            .lock()
538            .drain(..)
539            .collect::<Vec<_>>();
540        assert_eq!(create_checkout_session_calls.len(), 1);
541        let call = create_checkout_session_calls.into_iter().next().unwrap();
542        assert_eq!(call.customer, Some(customer_id));
543        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
544        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
545        assert_eq!(
546            call.line_items,
547            Some(vec![StripeCreateCheckoutSessionLineItems {
548                price: Some(price.id.to_string()),
549                quantity: Some(1)
550            }])
551        );
552        assert_eq!(
553            call.payment_method_collection,
554            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
555        );
556        assert_eq!(
557            call.subscription_data,
558            Some(StripeCreateCheckoutSessionSubscriptionData {
559                trial_period_days: Some(60),
560                trial_settings: Some(StripeSubscriptionTrialSettings {
561                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
562                        missing_payment_method:
563                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
564                    },
565                }),
566                metadata: Some(std::collections::HashMap::from_iter([(
567                    "promo_feature_flag".into(),
568                    AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
569                )])),
570            })
571        );
572        assert_eq!(call.success_url.as_deref(), Some(success_url));
573        assert_eq!(
574            call.billing_address_collection,
575            Some(StripeBillingAddressCollection::Required)
576        );
577    }
578}