stripe_billing_tests.rs

  1use std::sync::Arc;
  2
  3use chrono::{Duration, Utc};
  4use pretty_assertions::assert_eq;
  5
  6use crate::stripe_billing::StripeBilling;
  7use crate::stripe_client::{
  8    FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
  9    StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
 10    StripeSubscriptionItemId, UpdateSubscriptionItems,
 11};
 12
 13fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
 14    let stripe_client = Arc::new(FakeStripeClient::new());
 15    let stripe_billing = StripeBilling::test(stripe_client.clone());
 16
 17    (stripe_billing, stripe_client)
 18}
 19
 20#[gpui::test]
 21async fn test_initialize() {
 22    let (stripe_billing, stripe_client) = make_stripe_billing();
 23
 24    // Add test meters
 25    let meter1 = StripeMeter {
 26        id: StripeMeterId("meter_1".into()),
 27        event_name: "event_1".to_string(),
 28    };
 29    let meter2 = StripeMeter {
 30        id: StripeMeterId("meter_2".into()),
 31        event_name: "event_2".to_string(),
 32    };
 33    stripe_client
 34        .meters
 35        .lock()
 36        .insert(meter1.id.clone(), meter1);
 37    stripe_client
 38        .meters
 39        .lock()
 40        .insert(meter2.id.clone(), meter2);
 41
 42    // Add test prices
 43    let price1 = StripePrice {
 44        id: StripePriceId("price_1".into()),
 45        unit_amount: Some(1_000),
 46        lookup_key: Some("zed-pro".to_string()),
 47        recurring: None,
 48    };
 49    let price2 = StripePrice {
 50        id: StripePriceId("price_2".into()),
 51        unit_amount: Some(0),
 52        lookup_key: Some("zed-free".to_string()),
 53        recurring: None,
 54    };
 55    let price3 = StripePrice {
 56        id: StripePriceId("price_3".into()),
 57        unit_amount: Some(500),
 58        lookup_key: None,
 59        recurring: Some(StripePriceRecurring {
 60            meter: Some("meter_1".to_string()),
 61        }),
 62    };
 63    stripe_client
 64        .prices
 65        .lock()
 66        .insert(price1.id.clone(), price1);
 67    stripe_client
 68        .prices
 69        .lock()
 70        .insert(price2.id.clone(), price2);
 71    stripe_client
 72        .prices
 73        .lock()
 74        .insert(price3.id.clone(), price3);
 75
 76    // Initialize the billing system
 77    stripe_billing.initialize().await.unwrap();
 78
 79    // Verify that prices can be found by lookup key
 80    let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
 81    assert_eq!(zed_pro_price_id.to_string(), "price_1");
 82
 83    let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
 84    assert_eq!(zed_free_price_id.to_string(), "price_2");
 85
 86    // Verify that a price can be found by lookup key
 87    let zed_pro_price = stripe_billing
 88        .find_price_by_lookup_key("zed-pro")
 89        .await
 90        .unwrap();
 91    assert_eq!(zed_pro_price.id.to_string(), "price_1");
 92    assert_eq!(zed_pro_price.unit_amount, Some(1_000));
 93
 94    // Verify that finding a non-existent lookup key returns an error
 95    let result = stripe_billing
 96        .find_price_by_lookup_key("non-existent")
 97        .await;
 98    assert!(result.is_err());
 99}
100
101#[gpui::test]
102async fn test_find_or_create_customer_by_email() {
103    let (stripe_billing, stripe_client) = make_stripe_billing();
104
105    // Create a customer with an email that doesn't yet correspond to a customer.
106    {
107        let email = "user@example.com";
108
109        let customer_id = stripe_billing
110            .find_or_create_customer_by_email(Some(email))
111            .await
112            .unwrap();
113
114        let customer = stripe_client
115            .customers
116            .lock()
117            .get(&customer_id)
118            .unwrap()
119            .clone();
120        assert_eq!(customer.email.as_deref(), Some(email));
121    }
122
123    // Create a customer with an email that corresponds to an existing customer.
124    {
125        let email = "user2@example.com";
126
127        let existing_customer_id = stripe_billing
128            .find_or_create_customer_by_email(Some(email))
129            .await
130            .unwrap();
131
132        let customer_id = stripe_billing
133            .find_or_create_customer_by_email(Some(email))
134            .await
135            .unwrap();
136        assert_eq!(customer_id, existing_customer_id);
137
138        let customer = stripe_client
139            .customers
140            .lock()
141            .get(&customer_id)
142            .unwrap()
143            .clone();
144        assert_eq!(customer.email.as_deref(), Some(email));
145    }
146}
147
148#[gpui::test]
149async fn test_subscribe_to_price() {
150    let (stripe_billing, stripe_client) = make_stripe_billing();
151
152    let price = StripePrice {
153        id: StripePriceId("price_test".into()),
154        unit_amount: Some(2000),
155        lookup_key: Some("test-price".to_string()),
156        recurring: None,
157    };
158    stripe_client
159        .prices
160        .lock()
161        .insert(price.id.clone(), price.clone());
162
163    let now = Utc::now();
164    let subscription = StripeSubscription {
165        id: StripeSubscriptionId("sub_test".into()),
166        customer: StripeCustomerId("cus_test".into()),
167        status: stripe::SubscriptionStatus::Active,
168        current_period_start: now.timestamp(),
169        current_period_end: (now + Duration::days(30)).timestamp(),
170        items: vec![],
171        cancel_at: None,
172        cancellation_details: None,
173    };
174    stripe_client
175        .subscriptions
176        .lock()
177        .insert(subscription.id.clone(), subscription.clone());
178
179    stripe_billing
180        .subscribe_to_price(&subscription.id, &price)
181        .await
182        .unwrap();
183
184    let update_subscription_calls = stripe_client
185        .update_subscription_calls
186        .lock()
187        .iter()
188        .map(|(id, params)| (id.clone(), params.clone()))
189        .collect::<Vec<_>>();
190    assert_eq!(update_subscription_calls.len(), 1);
191    assert_eq!(update_subscription_calls[0].0, subscription.id);
192    assert_eq!(
193        update_subscription_calls[0].1.items,
194        Some(vec![UpdateSubscriptionItems {
195            price: Some(price.id.clone())
196        }])
197    );
198
199    // Subscribing to a price that is already on the subscription is a no-op.
200    {
201        let now = Utc::now();
202        let subscription = StripeSubscription {
203            id: StripeSubscriptionId("sub_test".into()),
204            customer: StripeCustomerId("cus_test".into()),
205            status: stripe::SubscriptionStatus::Active,
206            current_period_start: now.timestamp(),
207            current_period_end: (now + Duration::days(30)).timestamp(),
208            items: vec![StripeSubscriptionItem {
209                id: StripeSubscriptionItemId("si_test".into()),
210                price: Some(price.clone()),
211            }],
212            cancel_at: None,
213            cancellation_details: None,
214        };
215        stripe_client
216            .subscriptions
217            .lock()
218            .insert(subscription.id.clone(), subscription.clone());
219
220        stripe_billing
221            .subscribe_to_price(&subscription.id, &price)
222            .await
223            .unwrap();
224
225        assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
226    }
227}
228
229#[gpui::test]
230async fn test_subscribe_to_zed_free() {
231    let (stripe_billing, stripe_client) = make_stripe_billing();
232
233    let zed_pro_price = StripePrice {
234        id: StripePriceId("price_1".into()),
235        unit_amount: Some(0),
236        lookup_key: Some("zed-pro".to_string()),
237        recurring: None,
238    };
239    stripe_client
240        .prices
241        .lock()
242        .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
243    let zed_free_price = StripePrice {
244        id: StripePriceId("price_2".into()),
245        unit_amount: Some(0),
246        lookup_key: Some("zed-free".to_string()),
247        recurring: None,
248    };
249    stripe_client
250        .prices
251        .lock()
252        .insert(zed_free_price.id.clone(), zed_free_price.clone());
253
254    stripe_billing.initialize().await.unwrap();
255
256    // Customer is subscribed to Zed Free when not already subscribed to a plan.
257    {
258        let customer_id = StripeCustomerId("cus_no_plan".into());
259
260        let subscription = stripe_billing
261            .subscribe_to_zed_free(customer_id)
262            .await
263            .unwrap();
264
265        assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
266    }
267
268    // Customer is not subscribed to Zed Free when they already have an active subscription.
269    {
270        let customer_id = StripeCustomerId("cus_active_subscription".into());
271
272        let now = Utc::now();
273        let existing_subscription = StripeSubscription {
274            id: StripeSubscriptionId("sub_existing_active".into()),
275            customer: customer_id.clone(),
276            status: stripe::SubscriptionStatus::Active,
277            current_period_start: now.timestamp(),
278            current_period_end: (now + Duration::days(30)).timestamp(),
279            items: vec![StripeSubscriptionItem {
280                id: StripeSubscriptionItemId("si_test".into()),
281                price: Some(zed_pro_price.clone()),
282            }],
283            cancel_at: None,
284            cancellation_details: None,
285        };
286        stripe_client.subscriptions.lock().insert(
287            existing_subscription.id.clone(),
288            existing_subscription.clone(),
289        );
290
291        let subscription = stripe_billing
292            .subscribe_to_zed_free(customer_id)
293            .await
294            .unwrap();
295
296        assert_eq!(subscription, existing_subscription);
297    }
298
299    // Customer is not subscribed to Zed Free when they already have a trial subscription.
300    {
301        let customer_id = StripeCustomerId("cus_trial_subscription".into());
302
303        let now = Utc::now();
304        let existing_subscription = StripeSubscription {
305            id: StripeSubscriptionId("sub_existing_trial".into()),
306            customer: customer_id.clone(),
307            status: stripe::SubscriptionStatus::Trialing,
308            current_period_start: now.timestamp(),
309            current_period_end: (now + Duration::days(14)).timestamp(),
310            items: vec![StripeSubscriptionItem {
311                id: StripeSubscriptionItemId("si_test".into()),
312                price: Some(zed_pro_price.clone()),
313            }],
314            cancel_at: None,
315            cancellation_details: None,
316        };
317        stripe_client.subscriptions.lock().insert(
318            existing_subscription.id.clone(),
319            existing_subscription.clone(),
320        );
321
322        let subscription = stripe_billing
323            .subscribe_to_zed_free(customer_id)
324            .await
325            .unwrap();
326
327        assert_eq!(subscription, existing_subscription);
328    }
329}
330
331#[gpui::test]
332async fn test_bill_model_request_usage() {
333    let (stripe_billing, stripe_client) = make_stripe_billing();
334
335    let customer_id = StripeCustomerId("cus_test".into());
336
337    stripe_billing
338        .bill_model_request_usage(&customer_id, "some_model/requests", 73)
339        .await
340        .unwrap();
341
342    let create_meter_event_calls = stripe_client
343        .create_meter_event_calls
344        .lock()
345        .iter()
346        .cloned()
347        .collect::<Vec<_>>();
348    assert_eq!(create_meter_event_calls.len(), 1);
349    assert!(
350        create_meter_event_calls[0]
351            .identifier
352            .starts_with("model_requests/")
353    );
354    assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
355    assert_eq!(
356        create_meter_event_calls[0].event_name.as_ref(),
357        "some_model/requests"
358    );
359    assert_eq!(create_meter_event_calls[0].value, 73);
360}