1use std::sync::Arc;
2
3use pretty_assertions::assert_eq;
4
5use crate::stripe_billing::StripeBilling;
6use crate::stripe_client::{
7 FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
8 StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
9 StripeSubscriptionItemId, UpdateSubscriptionItems,
10};
11
12fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
13 let stripe_client = Arc::new(FakeStripeClient::new());
14 let stripe_billing = StripeBilling::test(stripe_client.clone());
15
16 (stripe_billing, stripe_client)
17}
18
19#[gpui::test]
20async fn test_initialize() {
21 let (stripe_billing, stripe_client) = make_stripe_billing();
22
23 // Add test meters
24 let meter1 = StripeMeter {
25 id: StripeMeterId("meter_1".into()),
26 event_name: "event_1".to_string(),
27 };
28 let meter2 = StripeMeter {
29 id: StripeMeterId("meter_2".into()),
30 event_name: "event_2".to_string(),
31 };
32 stripe_client
33 .meters
34 .lock()
35 .insert(meter1.id.clone(), meter1);
36 stripe_client
37 .meters
38 .lock()
39 .insert(meter2.id.clone(), meter2);
40
41 // Add test prices
42 let price1 = StripePrice {
43 id: StripePriceId("price_1".into()),
44 unit_amount: Some(1_000),
45 lookup_key: Some("zed-pro".to_string()),
46 recurring: None,
47 };
48 let price2 = StripePrice {
49 id: StripePriceId("price_2".into()),
50 unit_amount: Some(0),
51 lookup_key: Some("zed-free".to_string()),
52 recurring: None,
53 };
54 let price3 = StripePrice {
55 id: StripePriceId("price_3".into()),
56 unit_amount: Some(500),
57 lookup_key: None,
58 recurring: Some(StripePriceRecurring {
59 meter: Some("meter_1".to_string()),
60 }),
61 };
62 stripe_client
63 .prices
64 .lock()
65 .insert(price1.id.clone(), price1);
66 stripe_client
67 .prices
68 .lock()
69 .insert(price2.id.clone(), price2);
70 stripe_client
71 .prices
72 .lock()
73 .insert(price3.id.clone(), price3);
74
75 // Initialize the billing system
76 stripe_billing.initialize().await.unwrap();
77
78 // Verify that prices can be found by lookup key
79 let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
80 assert_eq!(zed_pro_price_id.to_string(), "price_1");
81
82 let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
83 assert_eq!(zed_free_price_id.to_string(), "price_2");
84
85 // Verify that a price can be found by lookup key
86 let zed_pro_price = stripe_billing
87 .find_price_by_lookup_key("zed-pro")
88 .await
89 .unwrap();
90 assert_eq!(zed_pro_price.id.to_string(), "price_1");
91 assert_eq!(zed_pro_price.unit_amount, Some(1_000));
92
93 // Verify that finding a non-existent lookup key returns an error
94 let result = stripe_billing
95 .find_price_by_lookup_key("non-existent")
96 .await;
97 assert!(result.is_err());
98}
99
100#[gpui::test]
101async fn test_find_or_create_customer_by_email() {
102 let (stripe_billing, stripe_client) = make_stripe_billing();
103
104 // Create a customer with an email that doesn't yet correspond to a customer.
105 {
106 let email = "user@example.com";
107
108 let customer_id = stripe_billing
109 .find_or_create_customer_by_email(Some(email))
110 .await
111 .unwrap();
112
113 let customer = stripe_client
114 .customers
115 .lock()
116 .get(&customer_id)
117 .unwrap()
118 .clone();
119 assert_eq!(customer.email.as_deref(), Some(email));
120 }
121
122 // Create a customer with an email that corresponds to an existing customer.
123 {
124 let email = "user2@example.com";
125
126 let existing_customer_id = stripe_billing
127 .find_or_create_customer_by_email(Some(email))
128 .await
129 .unwrap();
130
131 let customer_id = stripe_billing
132 .find_or_create_customer_by_email(Some(email))
133 .await
134 .unwrap();
135 assert_eq!(customer_id, existing_customer_id);
136
137 let customer = stripe_client
138 .customers
139 .lock()
140 .get(&customer_id)
141 .unwrap()
142 .clone();
143 assert_eq!(customer.email.as_deref(), Some(email));
144 }
145}
146
147#[gpui::test]
148async fn test_subscribe_to_price() {
149 let (stripe_billing, stripe_client) = make_stripe_billing();
150
151 let price = StripePrice {
152 id: StripePriceId("price_test".into()),
153 unit_amount: Some(2000),
154 lookup_key: Some("test-price".to_string()),
155 recurring: None,
156 };
157 stripe_client
158 .prices
159 .lock()
160 .insert(price.id.clone(), price.clone());
161
162 let subscription = StripeSubscription {
163 id: StripeSubscriptionId("sub_test".into()),
164 items: vec![],
165 };
166 stripe_client
167 .subscriptions
168 .lock()
169 .insert(subscription.id.clone(), subscription.clone());
170
171 stripe_billing
172 .subscribe_to_price(&subscription.id, &price)
173 .await
174 .unwrap();
175
176 let update_subscription_calls = stripe_client
177 .update_subscription_calls
178 .lock()
179 .iter()
180 .map(|(id, params)| (id.clone(), params.clone()))
181 .collect::<Vec<_>>();
182 assert_eq!(update_subscription_calls.len(), 1);
183 assert_eq!(update_subscription_calls[0].0, subscription.id);
184 assert_eq!(
185 update_subscription_calls[0].1.items,
186 Some(vec![UpdateSubscriptionItems {
187 price: Some(price.id.clone())
188 }])
189 );
190
191 // Subscribing to a price that is already on the subscription is a no-op.
192 {
193 let subscription = StripeSubscription {
194 id: StripeSubscriptionId("sub_test".into()),
195 items: vec![StripeSubscriptionItem {
196 id: StripeSubscriptionItemId("si_test".into()),
197 price: Some(price.clone()),
198 }],
199 };
200 stripe_client
201 .subscriptions
202 .lock()
203 .insert(subscription.id.clone(), subscription.clone());
204
205 stripe_billing
206 .subscribe_to_price(&subscription.id, &price)
207 .await
208 .unwrap();
209
210 assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
211 }
212}
213
214#[gpui::test]
215async fn test_bill_model_request_usage() {
216 let (stripe_billing, stripe_client) = make_stripe_billing();
217
218 let customer_id = StripeCustomerId("cus_test".into());
219
220 stripe_billing
221 .bill_model_request_usage(&customer_id, "some_model/requests", 73)
222 .await
223 .unwrap();
224
225 let create_meter_event_calls = stripe_client
226 .create_meter_event_calls
227 .lock()
228 .iter()
229 .cloned()
230 .collect::<Vec<_>>();
231 assert_eq!(create_meter_event_calls.len(), 1);
232 assert!(
233 create_meter_event_calls[0]
234 .identifier
235 .starts_with("model_requests/")
236 );
237 assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
238 assert_eq!(
239 create_meter_event_calls[0].event_name.as_ref(),
240 "some_model/requests"
241 );
242 assert_eq!(create_meter_event_calls[0].value, 73);
243}