1use std::sync::Arc;
2
3use crate::Result;
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
6use anyhow::{Context as _, anyhow};
7use chrono::Utc;
8use collections::HashMap;
9use serde::{Deserialize, Serialize};
10use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct StripeBilling {
15 state: RwLock<StripeBillingState>,
16 client: Arc<stripe::Client>,
17}
18
19#[derive(Default)]
20struct StripeBillingState {
21 meters_by_event_name: HashMap<String, StripeMeter>,
22 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
23 prices_by_lookup_key: HashMap<String, stripe::Price>,
24}
25
26impl StripeBilling {
27 pub fn new(client: Arc<stripe::Client>) -> Self {
28 Self {
29 client,
30 state: RwLock::default(),
31 }
32 }
33
34 pub async fn initialize(&self) -> Result<()> {
35 log::info!("StripeBilling: initializing");
36
37 let mut state = self.state.write().await;
38
39 let (meters, prices) = futures::try_join!(
40 StripeMeter::list(&self.client),
41 stripe::Price::list(
42 &self.client,
43 &stripe::ListPrices {
44 limit: Some(100),
45 ..Default::default()
46 }
47 )
48 )?;
49
50 for meter in meters.data {
51 state
52 .meters_by_event_name
53 .insert(meter.event_name.clone(), meter);
54 }
55
56 for price in prices.data {
57 if let Some(lookup_key) = price.lookup_key.clone() {
58 state.prices_by_lookup_key.insert(lookup_key, price.clone());
59 }
60
61 if let Some(recurring) = price.recurring {
62 if let Some(meter) = recurring.meter {
63 state.price_ids_by_meter_id.insert(meter, price.id);
64 }
65 }
66 }
67
68 log::info!("StripeBilling: initialized");
69
70 Ok(())
71 }
72
73 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
74 self.find_price_id_by_lookup_key("zed-pro").await
75 }
76
77 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
78 self.find_price_id_by_lookup_key("zed-free").await
79 }
80
81 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
82 self.state
83 .read()
84 .await
85 .prices_by_lookup_key
86 .get(lookup_key)
87 .map(|price| price.id.clone())
88 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
89 }
90
91 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
92 self.state
93 .read()
94 .await
95 .prices_by_lookup_key
96 .get(lookup_key)
97 .cloned()
98 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
99 }
100
101 pub async fn determine_subscription_kind(
102 &self,
103 subscription: &stripe::Subscription,
104 ) -> Option<SubscriptionKind> {
105 let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
106 let zed_free_price_id = self.zed_free_price_id().await.ok()?;
107
108 subscription.items.data.iter().find_map(|item| {
109 let price = item.price.as_ref()?;
110
111 if price.id == zed_pro_price_id {
112 Some(if subscription.status == SubscriptionStatus::Trialing {
113 SubscriptionKind::ZedProTrial
114 } else {
115 SubscriptionKind::ZedPro
116 })
117 } else if price.id == zed_free_price_id {
118 Some(SubscriptionKind::ZedFree)
119 } else {
120 None
121 }
122 })
123 }
124
125 /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
126 /// not already exist.
127 ///
128 /// Always returns a new Stripe customer if the email address is `None`.
129 pub async fn find_or_create_customer_by_email(
130 &self,
131 email_address: Option<&str>,
132 ) -> Result<CustomerId> {
133 let existing_customer = if let Some(email) = email_address {
134 let customers = Customer::list(
135 &self.client,
136 &stripe::ListCustomers {
137 email: Some(email),
138 ..Default::default()
139 },
140 )
141 .await?;
142
143 customers.data.first().cloned()
144 } else {
145 None
146 };
147
148 let customer_id = if let Some(existing_customer) = existing_customer {
149 existing_customer.id
150 } else {
151 let customer = Customer::create(
152 &self.client,
153 CreateCustomer {
154 email: email_address,
155 ..Default::default()
156 },
157 )
158 .await?;
159
160 customer.id
161 };
162
163 Ok(customer_id)
164 }
165
166 pub async fn subscribe_to_price(
167 &self,
168 subscription_id: &stripe::SubscriptionId,
169 price: &stripe::Price,
170 ) -> Result<()> {
171 let subscription =
172 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
173
174 if subscription_contains_price(&subscription, &price.id) {
175 return Ok(());
176 }
177
178 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
179
180 let price_per_unit = price.unit_amount.unwrap_or_default();
181 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
182
183 stripe::Subscription::update(
184 &self.client,
185 subscription_id,
186 stripe::UpdateSubscription {
187 items: Some(vec![stripe::UpdateSubscriptionItems {
188 price: Some(price.id.to_string()),
189 ..Default::default()
190 }]),
191 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
192 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
193 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
194 },
195 }),
196 ..Default::default()
197 },
198 )
199 .await?;
200
201 Ok(())
202 }
203
204 pub async fn bill_model_request_usage(
205 &self,
206 customer_id: &stripe::CustomerId,
207 event_name: &str,
208 requests: i32,
209 ) -> Result<()> {
210 let timestamp = Utc::now().timestamp();
211 let idempotency_key = Uuid::new_v4();
212
213 StripeMeterEvent::create(
214 &self.client,
215 StripeCreateMeterEventParams {
216 identifier: &format!("model_requests/{}", idempotency_key),
217 event_name,
218 payload: StripeCreateMeterEventPayload {
219 value: requests as u64,
220 stripe_customer_id: customer_id,
221 },
222 timestamp: Some(timestamp),
223 },
224 )
225 .await?;
226
227 Ok(())
228 }
229
230 pub async fn checkout_with_zed_pro(
231 &self,
232 customer_id: stripe::CustomerId,
233 github_login: &str,
234 success_url: &str,
235 ) -> Result<String> {
236 let zed_pro_price_id = self.zed_pro_price_id().await?;
237
238 let mut params = stripe::CreateCheckoutSession::new();
239 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
240 params.customer = Some(customer_id);
241 params.client_reference_id = Some(github_login);
242 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
243 price: Some(zed_pro_price_id.to_string()),
244 quantity: Some(1),
245 ..Default::default()
246 }]);
247 params.success_url = Some(success_url);
248
249 let session = stripe::CheckoutSession::create(&self.client, params).await?;
250 Ok(session.url.context("no checkout session URL")?)
251 }
252
253 pub async fn checkout_with_zed_pro_trial(
254 &self,
255 customer_id: stripe::CustomerId,
256 github_login: &str,
257 feature_flags: Vec<String>,
258 success_url: &str,
259 ) -> Result<String> {
260 let zed_pro_price_id = self.zed_pro_price_id().await?;
261
262 let eligible_for_extended_trial = feature_flags
263 .iter()
264 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
265
266 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
267
268 let mut subscription_metadata = std::collections::HashMap::new();
269 if eligible_for_extended_trial {
270 subscription_metadata.insert(
271 "promo_feature_flag".to_string(),
272 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
273 );
274 }
275
276 let mut params = stripe::CreateCheckoutSession::new();
277 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
278 trial_period_days: Some(trial_period_days),
279 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
280 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
281 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
282 }
283 }),
284 metadata: if !subscription_metadata.is_empty() {
285 Some(subscription_metadata)
286 } else {
287 None
288 },
289 ..Default::default()
290 });
291 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
292 params.payment_method_collection =
293 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
294 params.customer = Some(customer_id);
295 params.client_reference_id = Some(github_login);
296 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
297 price: Some(zed_pro_price_id.to_string()),
298 quantity: Some(1),
299 ..Default::default()
300 }]);
301 params.success_url = Some(success_url);
302
303 let session = stripe::CheckoutSession::create(&self.client, params).await?;
304 Ok(session.url.context("no checkout session URL")?)
305 }
306
307 pub async fn subscribe_to_zed_free(
308 &self,
309 customer_id: stripe::CustomerId,
310 ) -> Result<stripe::Subscription> {
311 let zed_free_price_id = self.zed_free_price_id().await?;
312
313 let mut params = stripe::CreateSubscription::new(customer_id);
314 params.items = Some(vec![stripe::CreateSubscriptionItems {
315 price: Some(zed_free_price_id.to_string()),
316 quantity: Some(1),
317 ..Default::default()
318 }]);
319
320 let subscription = stripe::Subscription::create(&self.client, params).await?;
321
322 Ok(subscription)
323 }
324
325 pub async fn checkout_with_zed_free(
326 &self,
327 customer_id: stripe::CustomerId,
328 github_login: &str,
329 success_url: &str,
330 ) -> Result<String> {
331 let zed_free_price_id = self.zed_free_price_id().await?;
332
333 let mut params = stripe::CreateCheckoutSession::new();
334 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
335 params.payment_method_collection =
336 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
337 params.customer = Some(customer_id);
338 params.client_reference_id = Some(github_login);
339 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
340 price: Some(zed_free_price_id.to_string()),
341 quantity: Some(1),
342 ..Default::default()
343 }]);
344 params.success_url = Some(success_url);
345
346 let session = stripe::CheckoutSession::create(&self.client, params).await?;
347 Ok(session.url.context("no checkout session URL")?)
348 }
349}
350
351#[derive(Clone, Deserialize)]
352struct StripeMeter {
353 id: String,
354 event_name: String,
355}
356
357impl StripeMeter {
358 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
359 #[derive(Serialize)]
360 struct Params {
361 #[serde(skip_serializing_if = "Option::is_none")]
362 limit: Option<u64>,
363 }
364
365 client.get_query("/billing/meters", Params { limit: Some(100) })
366 }
367}
368
369#[derive(Deserialize)]
370struct StripeMeterEvent {
371 identifier: String,
372}
373
374impl StripeMeterEvent {
375 pub async fn create(
376 client: &stripe::Client,
377 params: StripeCreateMeterEventParams<'_>,
378 ) -> Result<Self, stripe::StripeError> {
379 let identifier = params.identifier;
380 match client.post_form("/billing/meter_events", params).await {
381 Ok(event) => Ok(event),
382 Err(stripe::StripeError::Stripe(error)) => {
383 if error.http_status == 400
384 && error
385 .message
386 .as_ref()
387 .map_or(false, |message| message.contains(identifier))
388 {
389 Ok(Self {
390 identifier: identifier.to_string(),
391 })
392 } else {
393 Err(stripe::StripeError::Stripe(error))
394 }
395 }
396 Err(error) => Err(error),
397 }
398 }
399}
400
401#[derive(Serialize)]
402struct StripeCreateMeterEventParams<'a> {
403 identifier: &'a str,
404 event_name: &'a str,
405 payload: StripeCreateMeterEventPayload<'a>,
406 timestamp: Option<i64>,
407}
408
409#[derive(Serialize)]
410struct StripeCreateMeterEventPayload<'a> {
411 value: u64,
412 stripe_customer_id: &'a stripe::CustomerId,
413}
414
415fn subscription_contains_price(
416 subscription: &stripe::Subscription,
417 price_id: &stripe::PriceId,
418) -> bool {
419 subscription.items.data.iter().any(|item| {
420 item.price
421 .as_ref()
422 .map_or(false, |price| price.id == *price_id)
423 })
424}