1use std::sync::Arc;
2
3use chrono::{Duration, Utc};
4use pretty_assertions::assert_eq;
5
6use crate::stripe_billing::StripeBilling;
7use crate::stripe_client::{
8 FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
9 StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
10 StripeSubscriptionItemId, UpdateSubscriptionItems,
11};
12
13fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
14 let stripe_client = Arc::new(FakeStripeClient::new());
15 let stripe_billing = StripeBilling::test(stripe_client.clone());
16
17 (stripe_billing, stripe_client)
18}
19
20#[gpui::test]
21async fn test_initialize() {
22 let (stripe_billing, stripe_client) = make_stripe_billing();
23
24 // Add test meters
25 let meter1 = StripeMeter {
26 id: StripeMeterId("meter_1".into()),
27 event_name: "event_1".to_string(),
28 };
29 let meter2 = StripeMeter {
30 id: StripeMeterId("meter_2".into()),
31 event_name: "event_2".to_string(),
32 };
33 stripe_client
34 .meters
35 .lock()
36 .insert(meter1.id.clone(), meter1);
37 stripe_client
38 .meters
39 .lock()
40 .insert(meter2.id.clone(), meter2);
41
42 // Add test prices
43 let price1 = StripePrice {
44 id: StripePriceId("price_1".into()),
45 unit_amount: Some(1_000),
46 lookup_key: Some("zed-pro".to_string()),
47 recurring: None,
48 };
49 let price2 = StripePrice {
50 id: StripePriceId("price_2".into()),
51 unit_amount: Some(0),
52 lookup_key: Some("zed-free".to_string()),
53 recurring: None,
54 };
55 let price3 = StripePrice {
56 id: StripePriceId("price_3".into()),
57 unit_amount: Some(500),
58 lookup_key: None,
59 recurring: Some(StripePriceRecurring {
60 meter: Some("meter_1".to_string()),
61 }),
62 };
63 stripe_client
64 .prices
65 .lock()
66 .insert(price1.id.clone(), price1);
67 stripe_client
68 .prices
69 .lock()
70 .insert(price2.id.clone(), price2);
71 stripe_client
72 .prices
73 .lock()
74 .insert(price3.id.clone(), price3);
75
76 // Initialize the billing system
77 stripe_billing.initialize().await.unwrap();
78
79 // Verify that prices can be found by lookup key
80 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
81 assert_eq!(zed_pro_price_id.to_string(), "price_1");
82
83 let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
84 assert_eq!(zed_free_price_id.to_string(), "price_2");
85
86 // Verify that a price can be found by lookup key
87 let zed_pro_price = stripe_billing
88 .find_price_by_lookup_key("zed-pro")
89 .await
90 .unwrap();
91 assert_eq!(zed_pro_price.id.to_string(), "price_1");
92 assert_eq!(zed_pro_price.unit_amount, Some(1_000));
93
94 // Verify that finding a non-existent lookup key returns an error
95 let result = stripe_billing
96 .find_price_by_lookup_key("non-existent")
97 .await;
98 assert!(result.is_err());
99}
100
101#[gpui::test]
102async fn test_find_or_create_customer_by_email() {
103 let (stripe_billing, stripe_client) = make_stripe_billing();
104
105 // Create a customer with an email that doesn't yet correspond to a customer.
106 {
107 let email = "user@example.com";
108
109 let customer_id = stripe_billing
110 .find_or_create_customer_by_email(Some(email))
111 .await
112 .unwrap();
113
114 let customer = stripe_client
115 .customers
116 .lock()
117 .get(&customer_id)
118 .unwrap()
119 .clone();
120 assert_eq!(customer.email.as_deref(), Some(email));
121 }
122
123 // Create a customer with an email that corresponds to an existing customer.
124 {
125 let email = "user2@example.com";
126
127 let existing_customer_id = stripe_billing
128 .find_or_create_customer_by_email(Some(email))
129 .await
130 .unwrap();
131
132 let customer_id = stripe_billing
133 .find_or_create_customer_by_email(Some(email))
134 .await
135 .unwrap();
136 assert_eq!(customer_id, existing_customer_id);
137
138 let customer = stripe_client
139 .customers
140 .lock()
141 .get(&customer_id)
142 .unwrap()
143 .clone();
144 assert_eq!(customer.email.as_deref(), Some(email));
145 }
146}
147
148#[gpui::test]
149async fn test_subscribe_to_price() {
150 let (stripe_billing, stripe_client) = make_stripe_billing();
151
152 let price = StripePrice {
153 id: StripePriceId("price_test".into()),
154 unit_amount: Some(2000),
155 lookup_key: Some("test-price".to_string()),
156 recurring: None,
157 };
158 stripe_client
159 .prices
160 .lock()
161 .insert(price.id.clone(), price.clone());
162
163 let now = Utc::now();
164 let subscription = StripeSubscription {
165 id: StripeSubscriptionId("sub_test".into()),
166 customer: StripeCustomerId("cus_test".into()),
167 status: stripe::SubscriptionStatus::Active,
168 current_period_start: now.timestamp(),
169 current_period_end: (now + Duration::days(30)).timestamp(),
170 items: vec![],
171 cancel_at: None,
172 cancellation_details: None,
173 };
174 stripe_client
175 .subscriptions
176 .lock()
177 .insert(subscription.id.clone(), subscription.clone());
178
179 stripe_billing
180 .subscribe_to_price(&subscription.id, &price)
181 .await
182 .unwrap();
183
184 let update_subscription_calls = stripe_client
185 .update_subscription_calls
186 .lock()
187 .iter()
188 .map(|(id, params)| (id.clone(), params.clone()))
189 .collect::<Vec<_>>();
190 assert_eq!(update_subscription_calls.len(), 1);
191 assert_eq!(update_subscription_calls[0].0, subscription.id);
192 assert_eq!(
193 update_subscription_calls[0].1.items,
194 Some(vec![UpdateSubscriptionItems {
195 price: Some(price.id.clone())
196 }])
197 );
198
199 // Subscribing to a price that is already on the subscription is a no-op.
200 {
201 let now = Utc::now();
202 let subscription = StripeSubscription {
203 id: StripeSubscriptionId("sub_test".into()),
204 customer: StripeCustomerId("cus_test".into()),
205 status: stripe::SubscriptionStatus::Active,
206 current_period_start: now.timestamp(),
207 current_period_end: (now + Duration::days(30)).timestamp(),
208 items: vec![StripeSubscriptionItem {
209 id: StripeSubscriptionItemId("si_test".into()),
210 price: Some(price.clone()),
211 }],
212 cancel_at: None,
213 cancellation_details: None,
214 };
215 stripe_client
216 .subscriptions
217 .lock()
218 .insert(subscription.id.clone(), subscription.clone());
219
220 stripe_billing
221 .subscribe_to_price(&subscription.id, &price)
222 .await
223 .unwrap();
224
225 assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
226 }
227}
228
229#[gpui::test]
230async fn test_subscribe_to_zed_free() {
231 let (stripe_billing, stripe_client) = make_stripe_billing();
232
233 let zed_pro_price = StripePrice {
234 id: StripePriceId("price_1".into()),
235 unit_amount: Some(0),
236 lookup_key: Some("zed-pro".to_string()),
237 recurring: None,
238 };
239 stripe_client
240 .prices
241 .lock()
242 .insert(zed_pro_price.id.clone(), zed_pro_price.clone());
243 let zed_free_price = StripePrice {
244 id: StripePriceId("price_2".into()),
245 unit_amount: Some(0),
246 lookup_key: Some("zed-free".to_string()),
247 recurring: None,
248 };
249 stripe_client
250 .prices
251 .lock()
252 .insert(zed_free_price.id.clone(), zed_free_price.clone());
253
254 stripe_billing.initialize().await.unwrap();
255
256 // Customer is subscribed to Zed Free when not already subscribed to a plan.
257 {
258 let customer_id = StripeCustomerId("cus_no_plan".into());
259
260 let subscription = stripe_billing
261 .subscribe_to_zed_free(customer_id)
262 .await
263 .unwrap();
264
265 assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
266 }
267
268 // Customer is not subscribed to Zed Free when they already have an active subscription.
269 {
270 let customer_id = StripeCustomerId("cus_active_subscription".into());
271
272 let now = Utc::now();
273 let existing_subscription = StripeSubscription {
274 id: StripeSubscriptionId("sub_existing_active".into()),
275 customer: customer_id.clone(),
276 status: stripe::SubscriptionStatus::Active,
277 current_period_start: now.timestamp(),
278 current_period_end: (now + Duration::days(30)).timestamp(),
279 items: vec![StripeSubscriptionItem {
280 id: StripeSubscriptionItemId("si_test".into()),
281 price: Some(zed_pro_price.clone()),
282 }],
283 cancel_at: None,
284 cancellation_details: None,
285 };
286 stripe_client.subscriptions.lock().insert(
287 existing_subscription.id.clone(),
288 existing_subscription.clone(),
289 );
290
291 let subscription = stripe_billing
292 .subscribe_to_zed_free(customer_id)
293 .await
294 .unwrap();
295
296 assert_eq!(subscription, existing_subscription);
297 }
298
299 // Customer is not subscribed to Zed Free when they already have a trial subscription.
300 {
301 let customer_id = StripeCustomerId("cus_trial_subscription".into());
302
303 let now = Utc::now();
304 let existing_subscription = StripeSubscription {
305 id: StripeSubscriptionId("sub_existing_trial".into()),
306 customer: customer_id.clone(),
307 status: stripe::SubscriptionStatus::Trialing,
308 current_period_start: now.timestamp(),
309 current_period_end: (now + Duration::days(14)).timestamp(),
310 items: vec![StripeSubscriptionItem {
311 id: StripeSubscriptionItemId("si_test".into()),
312 price: Some(zed_pro_price.clone()),
313 }],
314 cancel_at: None,
315 cancellation_details: None,
316 };
317 stripe_client.subscriptions.lock().insert(
318 existing_subscription.id.clone(),
319 existing_subscription.clone(),
320 );
321
322 let subscription = stripe_billing
323 .subscribe_to_zed_free(customer_id)
324 .await
325 .unwrap();
326
327 assert_eq!(subscription, existing_subscription);
328 }
329}
330
331#[gpui::test]
332async fn test_bill_model_request_usage() {
333 let (stripe_billing, stripe_client) = make_stripe_billing();
334
335 let customer_id = StripeCustomerId("cus_test".into());
336
337 stripe_billing
338 .bill_model_request_usage(&customer_id, "some_model/requests", 73)
339 .await
340 .unwrap();
341
342 let create_meter_event_calls = stripe_client
343 .create_meter_event_calls
344 .lock()
345 .iter()
346 .cloned()
347 .collect::<Vec<_>>();
348 assert_eq!(create_meter_event_calls.len(), 1);
349 assert!(
350 create_meter_event_calls[0]
351 .identifier
352 .starts_with("model_requests/")
353 );
354 assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
355 assert_eq!(
356 create_meter_event_calls[0].event_name.as_ref(),
357 "some_model/requests"
358 );
359 assert_eq!(create_meter_event_calls[0].value, 73);
360}