stripe_billing_tests.rs

  1use std::sync::Arc;
  2
  3use pretty_assertions::assert_eq;
  4
  5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  6use crate::stripe_billing::StripeBilling;
  7use crate::stripe_client::{
  8    FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
  9    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
 10    StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
 11    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
 12    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
 13    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
 14};
 15
 16fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
 17    let stripe_client = Arc::new(FakeStripeClient::new());
 18    let stripe_billing = StripeBilling::test(stripe_client.clone());
 19
 20    (stripe_billing, stripe_client)
 21}
 22
 23#[gpui::test]
 24async fn test_initialize() {
 25    let (stripe_billing, stripe_client) = make_stripe_billing();
 26
 27    // Add test meters
 28    let meter1 = StripeMeter {
 29        id: StripeMeterId("meter_1".into()),
 30        event_name: "event_1".to_string(),
 31    };
 32    let meter2 = StripeMeter {
 33        id: StripeMeterId("meter_2".into()),
 34        event_name: "event_2".to_string(),
 35    };
 36    stripe_client
 37        .meters
 38        .lock()
 39        .insert(meter1.id.clone(), meter1);
 40    stripe_client
 41        .meters
 42        .lock()
 43        .insert(meter2.id.clone(), meter2);
 44
 45    // Add test prices
 46    let price1 = StripePrice {
 47        id: StripePriceId("price_1".into()),
 48        unit_amount: Some(1_000),
 49        lookup_key: Some("zed-pro".to_string()),
 50        recurring: None,
 51    };
 52    let price2 = StripePrice {
 53        id: StripePriceId("price_2".into()),
 54        unit_amount: Some(0),
 55        lookup_key: Some("zed-free".to_string()),
 56        recurring: None,
 57    };
 58    let price3 = StripePrice {
 59        id: StripePriceId("price_3".into()),
 60        unit_amount: Some(500),
 61        lookup_key: None,
 62        recurring: Some(StripePriceRecurring {
 63            meter: Some("meter_1".to_string()),
 64        }),
 65    };
 66    stripe_client
 67        .prices
 68        .lock()
 69        .insert(price1.id.clone(), price1);
 70    stripe_client
 71        .prices
 72        .lock()
 73        .insert(price2.id.clone(), price2);
 74    stripe_client
 75        .prices
 76        .lock()
 77        .insert(price3.id.clone(), price3);
 78
 79    // Initialize the billing system
 80    stripe_billing.initialize().await.unwrap();
 81
 82    // Verify that prices can be found by lookup key
 83    let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
 84    assert_eq!(zed_pro_price_id.to_string(), "price_1");
 85
 86    let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
 87    assert_eq!(zed_free_price_id.to_string(), "price_2");
 88
 89    // Verify that a price can be found by lookup key
 90    let zed_pro_price = stripe_billing
 91        .find_price_by_lookup_key("zed-pro")
 92        .await
 93        .unwrap();
 94    assert_eq!(zed_pro_price.id.to_string(), "price_1");
 95    assert_eq!(zed_pro_price.unit_amount, Some(1_000));
 96
 97    // Verify that finding a non-existent lookup key returns an error
 98    let result = stripe_billing
 99        .find_price_by_lookup_key("non-existent")
100        .await;
101    assert!(result.is_err());
102}
103
104#[gpui::test]
105async fn test_find_or_create_customer_by_email() {
106    let (stripe_billing, stripe_client) = make_stripe_billing();
107
108    // Create a customer with an email that doesn't yet correspond to a customer.
109    {
110        let email = "user@example.com";
111
112        let customer_id = stripe_billing
113            .find_or_create_customer_by_email(Some(email))
114            .await
115            .unwrap();
116
117        let customer = stripe_client
118            .customers
119            .lock()
120            .get(&customer_id)
121            .unwrap()
122            .clone();
123        assert_eq!(customer.email.as_deref(), Some(email));
124    }
125
126    // Create a customer with an email that corresponds to an existing customer.
127    {
128        let email = "user2@example.com";
129
130        let existing_customer_id = stripe_billing
131            .find_or_create_customer_by_email(Some(email))
132            .await
133            .unwrap();
134
135        let customer_id = stripe_billing
136            .find_or_create_customer_by_email(Some(email))
137            .await
138            .unwrap();
139        assert_eq!(customer_id, existing_customer_id);
140
141        let customer = stripe_client
142            .customers
143            .lock()
144            .get(&customer_id)
145            .unwrap()
146            .clone();
147        assert_eq!(customer.email.as_deref(), Some(email));
148    }
149}
150
151#[gpui::test]
152async fn test_subscribe_to_price() {
153    let (stripe_billing, stripe_client) = make_stripe_billing();
154
155    let price = StripePrice {
156        id: StripePriceId("price_test".into()),
157        unit_amount: Some(2000),
158        lookup_key: Some("test-price".to_string()),
159        recurring: None,
160    };
161    stripe_client
162        .prices
163        .lock()
164        .insert(price.id.clone(), price.clone());
165
166    let subscription = StripeSubscription {
167        id: StripeSubscriptionId("sub_test".into()),
168        items: vec![],
169    };
170    stripe_client
171        .subscriptions
172        .lock()
173        .insert(subscription.id.clone(), subscription.clone());
174
175    stripe_billing
176        .subscribe_to_price(&subscription.id, &price)
177        .await
178        .unwrap();
179
180    let update_subscription_calls = stripe_client
181        .update_subscription_calls
182        .lock()
183        .iter()
184        .map(|(id, params)| (id.clone(), params.clone()))
185        .collect::<Vec<_>>();
186    assert_eq!(update_subscription_calls.len(), 1);
187    assert_eq!(update_subscription_calls[0].0, subscription.id);
188    assert_eq!(
189        update_subscription_calls[0].1.items,
190        Some(vec![UpdateSubscriptionItems {
191            price: Some(price.id.clone())
192        }])
193    );
194
195    // Subscribing to a price that is already on the subscription is a no-op.
196    {
197        let subscription = StripeSubscription {
198            id: StripeSubscriptionId("sub_test".into()),
199            items: vec![StripeSubscriptionItem {
200                id: StripeSubscriptionItemId("si_test".into()),
201                price: Some(price.clone()),
202            }],
203        };
204        stripe_client
205            .subscriptions
206            .lock()
207            .insert(subscription.id.clone(), subscription.clone());
208
209        stripe_billing
210            .subscribe_to_price(&subscription.id, &price)
211            .await
212            .unwrap();
213
214        assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
215    }
216}
217
218#[gpui::test]
219async fn test_bill_model_request_usage() {
220    let (stripe_billing, stripe_client) = make_stripe_billing();
221
222    let customer_id = StripeCustomerId("cus_test".into());
223
224    stripe_billing
225        .bill_model_request_usage(&customer_id, "some_model/requests", 73)
226        .await
227        .unwrap();
228
229    let create_meter_event_calls = stripe_client
230        .create_meter_event_calls
231        .lock()
232        .iter()
233        .cloned()
234        .collect::<Vec<_>>();
235    assert_eq!(create_meter_event_calls.len(), 1);
236    assert!(
237        create_meter_event_calls[0]
238            .identifier
239            .starts_with("model_requests/")
240    );
241    assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
242    assert_eq!(
243        create_meter_event_calls[0].event_name.as_ref(),
244        "some_model/requests"
245    );
246    assert_eq!(create_meter_event_calls[0].value, 73);
247}
248
249#[gpui::test]
250async fn test_checkout_with_zed_pro() {
251    let (stripe_billing, stripe_client) = make_stripe_billing();
252
253    let customer_id = StripeCustomerId("cus_test".into());
254    let github_login = "zeduser1";
255    let success_url = "https://example.com/success";
256
257    // It returns an error when the Zed Pro price doesn't exist.
258    {
259        let result = stripe_billing
260            .checkout_with_zed_pro(&customer_id, github_login, success_url)
261            .await;
262
263        assert!(result.is_err());
264        assert_eq!(
265            result.err().unwrap().to_string(),
266            r#"no price ID found for "zed-pro""#
267        );
268    }
269
270    // Successful checkout.
271    {
272        let price = StripePrice {
273            id: StripePriceId("price_1".into()),
274            unit_amount: Some(2000),
275            lookup_key: Some("zed-pro".to_string()),
276            recurring: None,
277        };
278        stripe_client
279            .prices
280            .lock()
281            .insert(price.id.clone(), price.clone());
282
283        stripe_billing.initialize().await.unwrap();
284
285        let checkout_url = stripe_billing
286            .checkout_with_zed_pro(&customer_id, github_login, success_url)
287            .await
288            .unwrap();
289
290        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
291
292        let create_checkout_session_calls = stripe_client
293            .create_checkout_session_calls
294            .lock()
295            .drain(..)
296            .collect::<Vec<_>>();
297        assert_eq!(create_checkout_session_calls.len(), 1);
298        let call = create_checkout_session_calls.into_iter().next().unwrap();
299        assert_eq!(call.customer, Some(customer_id));
300        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
301        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
302        assert_eq!(
303            call.line_items,
304            Some(vec![StripeCreateCheckoutSessionLineItems {
305                price: Some(price.id.to_string()),
306                quantity: Some(1)
307            }])
308        );
309        assert_eq!(call.payment_method_collection, None);
310        assert_eq!(call.subscription_data, None);
311        assert_eq!(call.success_url.as_deref(), Some(success_url));
312    }
313}
314
315#[gpui::test]
316async fn test_checkout_with_zed_pro_trial() {
317    let (stripe_billing, stripe_client) = make_stripe_billing();
318
319    let customer_id = StripeCustomerId("cus_test".into());
320    let github_login = "zeduser1";
321    let success_url = "https://example.com/success";
322
323    // It returns an error when the Zed Pro price doesn't exist.
324    {
325        let result = stripe_billing
326            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
327            .await;
328
329        assert!(result.is_err());
330        assert_eq!(
331            result.err().unwrap().to_string(),
332            r#"no price ID found for "zed-pro""#
333        );
334    }
335
336    let price = StripePrice {
337        id: StripePriceId("price_1".into()),
338        unit_amount: Some(2000),
339        lookup_key: Some("zed-pro".to_string()),
340        recurring: None,
341    };
342    stripe_client
343        .prices
344        .lock()
345        .insert(price.id.clone(), price.clone());
346
347    stripe_billing.initialize().await.unwrap();
348
349    // Successful checkout.
350    {
351        let checkout_url = stripe_billing
352            .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
353            .await
354            .unwrap();
355
356        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
357
358        let create_checkout_session_calls = stripe_client
359            .create_checkout_session_calls
360            .lock()
361            .drain(..)
362            .collect::<Vec<_>>();
363        assert_eq!(create_checkout_session_calls.len(), 1);
364        let call = create_checkout_session_calls.into_iter().next().unwrap();
365        assert_eq!(call.customer.as_ref(), Some(&customer_id));
366        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
367        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
368        assert_eq!(
369            call.line_items,
370            Some(vec![StripeCreateCheckoutSessionLineItems {
371                price: Some(price.id.to_string()),
372                quantity: Some(1)
373            }])
374        );
375        assert_eq!(
376            call.payment_method_collection,
377            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
378        );
379        assert_eq!(
380            call.subscription_data,
381            Some(StripeCreateCheckoutSessionSubscriptionData {
382                trial_period_days: Some(14),
383                trial_settings: Some(StripeSubscriptionTrialSettings {
384                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
385                        missing_payment_method:
386                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
387                    },
388                }),
389                metadata: None,
390            })
391        );
392        assert_eq!(call.success_url.as_deref(), Some(success_url));
393    }
394
395    // Successful checkout with extended trial.
396    {
397        let checkout_url = stripe_billing
398            .checkout_with_zed_pro_trial(
399                &customer_id,
400                github_login,
401                vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
402                success_url,
403            )
404            .await
405            .unwrap();
406
407        assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
408
409        let create_checkout_session_calls = stripe_client
410            .create_checkout_session_calls
411            .lock()
412            .drain(..)
413            .collect::<Vec<_>>();
414        assert_eq!(create_checkout_session_calls.len(), 1);
415        let call = create_checkout_session_calls.into_iter().next().unwrap();
416        assert_eq!(call.customer, Some(customer_id));
417        assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
418        assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
419        assert_eq!(
420            call.line_items,
421            Some(vec![StripeCreateCheckoutSessionLineItems {
422                price: Some(price.id.to_string()),
423                quantity: Some(1)
424            }])
425        );
426        assert_eq!(
427            call.payment_method_collection,
428            Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
429        );
430        assert_eq!(
431            call.subscription_data,
432            Some(StripeCreateCheckoutSessionSubscriptionData {
433                trial_period_days: Some(60),
434                trial_settings: Some(StripeSubscriptionTrialSettings {
435                    end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
436                        missing_payment_method:
437                            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
438                    },
439                }),
440                metadata: Some(std::collections::HashMap::from_iter([(
441                    "promo_feature_flag".into(),
442                    AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
443                )])),
444            })
445        );
446        assert_eq!(call.success_url.as_deref(), Some(success_url));
447    }
448}