1use std::sync::Arc;
2
3use crate::Result;
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
6use anyhow::{Context as _, anyhow};
7use chrono::Utc;
8use collections::HashMap;
9use serde::{Deserialize, Serialize};
10use stripe::{PriceId, SubscriptionStatus};
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct StripeBilling {
15 state: RwLock<StripeBillingState>,
16 client: Arc<stripe::Client>,
17}
18
19#[derive(Default)]
20struct StripeBillingState {
21 meters_by_event_name: HashMap<String, StripeMeter>,
22 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
23 prices_by_lookup_key: HashMap<String, stripe::Price>,
24}
25
26impl StripeBilling {
27 pub fn new(client: Arc<stripe::Client>) -> Self {
28 Self {
29 client,
30 state: RwLock::default(),
31 }
32 }
33
34 pub async fn initialize(&self) -> Result<()> {
35 log::info!("StripeBilling: initializing");
36
37 let mut state = self.state.write().await;
38
39 let (meters, prices) = futures::try_join!(
40 StripeMeter::list(&self.client),
41 stripe::Price::list(
42 &self.client,
43 &stripe::ListPrices {
44 limit: Some(100),
45 ..Default::default()
46 }
47 )
48 )?;
49
50 for meter in meters.data {
51 state
52 .meters_by_event_name
53 .insert(meter.event_name.clone(), meter);
54 }
55
56 for price in prices.data {
57 if let Some(lookup_key) = price.lookup_key.clone() {
58 state.prices_by_lookup_key.insert(lookup_key, price.clone());
59 }
60
61 if let Some(recurring) = price.recurring {
62 if let Some(meter) = recurring.meter {
63 state.price_ids_by_meter_id.insert(meter, price.id);
64 }
65 }
66 }
67
68 log::info!("StripeBilling: initialized");
69
70 Ok(())
71 }
72
73 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
74 self.find_price_id_by_lookup_key("zed-pro").await
75 }
76
77 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
78 self.find_price_id_by_lookup_key("zed-free").await
79 }
80
81 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
82 self.state
83 .read()
84 .await
85 .prices_by_lookup_key
86 .get(lookup_key)
87 .map(|price| price.id.clone())
88 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
89 }
90
91 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
92 self.state
93 .read()
94 .await
95 .prices_by_lookup_key
96 .get(lookup_key)
97 .cloned()
98 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
99 }
100
101 pub async fn determine_subscription_kind(
102 &self,
103 subscription: &stripe::Subscription,
104 ) -> Option<SubscriptionKind> {
105 let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
106 let zed_free_price_id = self.zed_free_price_id().await.ok()?;
107
108 subscription.items.data.iter().find_map(|item| {
109 let price = item.price.as_ref()?;
110
111 if price.id == zed_pro_price_id {
112 Some(if subscription.status == SubscriptionStatus::Trialing {
113 SubscriptionKind::ZedProTrial
114 } else {
115 SubscriptionKind::ZedPro
116 })
117 } else if price.id == zed_free_price_id {
118 Some(SubscriptionKind::ZedFree)
119 } else {
120 None
121 }
122 })
123 }
124
125 pub async fn subscribe_to_price(
126 &self,
127 subscription_id: &stripe::SubscriptionId,
128 price: &stripe::Price,
129 ) -> Result<()> {
130 let subscription =
131 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
132
133 if subscription_contains_price(&subscription, &price.id) {
134 return Ok(());
135 }
136
137 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
138
139 let price_per_unit = price.unit_amount.unwrap_or_default();
140 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
141
142 stripe::Subscription::update(
143 &self.client,
144 subscription_id,
145 stripe::UpdateSubscription {
146 items: Some(vec![stripe::UpdateSubscriptionItems {
147 price: Some(price.id.to_string()),
148 ..Default::default()
149 }]),
150 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
151 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
152 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
153 },
154 }),
155 ..Default::default()
156 },
157 )
158 .await?;
159
160 Ok(())
161 }
162
163 pub async fn bill_model_request_usage(
164 &self,
165 customer_id: &stripe::CustomerId,
166 event_name: &str,
167 requests: i32,
168 ) -> Result<()> {
169 let timestamp = Utc::now().timestamp();
170 let idempotency_key = Uuid::new_v4();
171
172 StripeMeterEvent::create(
173 &self.client,
174 StripeCreateMeterEventParams {
175 identifier: &format!("model_requests/{}", idempotency_key),
176 event_name,
177 payload: StripeCreateMeterEventPayload {
178 value: requests as u64,
179 stripe_customer_id: customer_id,
180 },
181 timestamp: Some(timestamp),
182 },
183 )
184 .await?;
185
186 Ok(())
187 }
188
189 pub async fn checkout_with_zed_pro(
190 &self,
191 customer_id: stripe::CustomerId,
192 github_login: &str,
193 success_url: &str,
194 ) -> Result<String> {
195 let zed_pro_price_id = self.zed_pro_price_id().await?;
196
197 let mut params = stripe::CreateCheckoutSession::new();
198 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
199 params.customer = Some(customer_id);
200 params.client_reference_id = Some(github_login);
201 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
202 price: Some(zed_pro_price_id.to_string()),
203 quantity: Some(1),
204 ..Default::default()
205 }]);
206 params.success_url = Some(success_url);
207
208 let session = stripe::CheckoutSession::create(&self.client, params).await?;
209 Ok(session.url.context("no checkout session URL")?)
210 }
211
212 pub async fn checkout_with_zed_pro_trial(
213 &self,
214 customer_id: stripe::CustomerId,
215 github_login: &str,
216 feature_flags: Vec<String>,
217 success_url: &str,
218 ) -> Result<String> {
219 let zed_pro_price_id = self.zed_pro_price_id().await?;
220
221 let eligible_for_extended_trial = feature_flags
222 .iter()
223 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
224
225 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
226
227 let mut subscription_metadata = std::collections::HashMap::new();
228 if eligible_for_extended_trial {
229 subscription_metadata.insert(
230 "promo_feature_flag".to_string(),
231 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
232 );
233 }
234
235 let mut params = stripe::CreateCheckoutSession::new();
236 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
237 trial_period_days: Some(trial_period_days),
238 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
239 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
240 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
241 }
242 }),
243 metadata: if !subscription_metadata.is_empty() {
244 Some(subscription_metadata)
245 } else {
246 None
247 },
248 ..Default::default()
249 });
250 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
251 params.payment_method_collection =
252 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
253 params.customer = Some(customer_id);
254 params.client_reference_id = Some(github_login);
255 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
256 price: Some(zed_pro_price_id.to_string()),
257 quantity: Some(1),
258 ..Default::default()
259 }]);
260 params.success_url = Some(success_url);
261
262 let session = stripe::CheckoutSession::create(&self.client, params).await?;
263 Ok(session.url.context("no checkout session URL")?)
264 }
265
266 pub async fn checkout_with_zed_free(
267 &self,
268 customer_id: stripe::CustomerId,
269 github_login: &str,
270 success_url: &str,
271 ) -> Result<String> {
272 let zed_free_price_id = self.zed_free_price_id().await?;
273
274 let mut params = stripe::CreateCheckoutSession::new();
275 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
276 params.payment_method_collection =
277 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
278 params.customer = Some(customer_id);
279 params.client_reference_id = Some(github_login);
280 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
281 price: Some(zed_free_price_id.to_string()),
282 quantity: Some(1),
283 ..Default::default()
284 }]);
285 params.success_url = Some(success_url);
286
287 let session = stripe::CheckoutSession::create(&self.client, params).await?;
288 Ok(session.url.context("no checkout session URL")?)
289 }
290}
291
292#[derive(Clone, Deserialize)]
293struct StripeMeter {
294 id: String,
295 event_name: String,
296}
297
298impl StripeMeter {
299 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
300 #[derive(Serialize)]
301 struct Params {
302 #[serde(skip_serializing_if = "Option::is_none")]
303 limit: Option<u64>,
304 }
305
306 client.get_query("/billing/meters", Params { limit: Some(100) })
307 }
308}
309
310#[derive(Deserialize)]
311struct StripeMeterEvent {
312 identifier: String,
313}
314
315impl StripeMeterEvent {
316 pub async fn create(
317 client: &stripe::Client,
318 params: StripeCreateMeterEventParams<'_>,
319 ) -> Result<Self, stripe::StripeError> {
320 let identifier = params.identifier;
321 match client.post_form("/billing/meter_events", params).await {
322 Ok(event) => Ok(event),
323 Err(stripe::StripeError::Stripe(error)) => {
324 if error.http_status == 400
325 && error
326 .message
327 .as_ref()
328 .map_or(false, |message| message.contains(identifier))
329 {
330 Ok(Self {
331 identifier: identifier.to_string(),
332 })
333 } else {
334 Err(stripe::StripeError::Stripe(error))
335 }
336 }
337 Err(error) => Err(error),
338 }
339 }
340}
341
342#[derive(Serialize)]
343struct StripeCreateMeterEventParams<'a> {
344 identifier: &'a str,
345 event_name: &'a str,
346 payload: StripeCreateMeterEventPayload<'a>,
347 timestamp: Option<i64>,
348}
349
350#[derive(Serialize)]
351struct StripeCreateMeterEventPayload<'a> {
352 value: u64,
353 stripe_customer_id: &'a stripe::CustomerId,
354}
355
356fn subscription_contains_price(
357 subscription: &stripe::Subscription,
358 price_id: &stripe::PriceId,
359) -> bool {
360 subscription.items.data.iter().any(|item| {
361 item.price
362 .as_ref()
363 .map_or(false, |price| price.id == *price_id)
364 })
365}