1use std::sync::Arc;
2
3use crate::Result;
4use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
5use anyhow::{Context as _, anyhow};
6use chrono::Utc;
7use collections::HashMap;
8use serde::{Deserialize, Serialize};
9use stripe::PriceId;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13pub struct StripeBilling {
14 state: RwLock<StripeBillingState>,
15 client: Arc<stripe::Client>,
16}
17
18#[derive(Default)]
19struct StripeBillingState {
20 meters_by_event_name: HashMap<String, StripeMeter>,
21 price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
22 prices_by_lookup_key: HashMap<String, stripe::Price>,
23}
24
25impl StripeBilling {
26 pub fn new(client: Arc<stripe::Client>) -> Self {
27 Self {
28 client,
29 state: RwLock::default(),
30 }
31 }
32
33 pub async fn initialize(&self) -> Result<()> {
34 log::info!("StripeBilling: initializing");
35
36 let mut state = self.state.write().await;
37
38 let (meters, prices) = futures::try_join!(
39 StripeMeter::list(&self.client),
40 stripe::Price::list(
41 &self.client,
42 &stripe::ListPrices {
43 limit: Some(100),
44 ..Default::default()
45 }
46 )
47 )?;
48
49 for meter in meters.data {
50 state
51 .meters_by_event_name
52 .insert(meter.event_name.clone(), meter);
53 }
54
55 for price in prices.data {
56 if let Some(lookup_key) = price.lookup_key.clone() {
57 state.prices_by_lookup_key.insert(lookup_key, price.clone());
58 }
59
60 if let Some(recurring) = price.recurring {
61 if let Some(meter) = recurring.meter {
62 state.price_ids_by_meter_id.insert(meter, price.id);
63 }
64 }
65 }
66
67 log::info!("StripeBilling: initialized");
68
69 Ok(())
70 }
71
72 pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
73 self.find_price_id_by_lookup_key("zed-pro").await
74 }
75
76 pub async fn zed_free_price_id(&self) -> Result<PriceId> {
77 self.find_price_id_by_lookup_key("zed-free").await
78 }
79
80 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
81 self.state
82 .read()
83 .await
84 .prices_by_lookup_key
85 .get(lookup_key)
86 .map(|price| price.id.clone())
87 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
88 }
89
90 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
91 self.state
92 .read()
93 .await
94 .prices_by_lookup_key
95 .get(lookup_key)
96 .cloned()
97 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
98 }
99
100 pub async fn subscribe_to_price(
101 &self,
102 subscription_id: &stripe::SubscriptionId,
103 price: &stripe::Price,
104 ) -> Result<()> {
105 let subscription =
106 stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
107
108 if subscription_contains_price(&subscription, &price.id) {
109 return Ok(());
110 }
111
112 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
113
114 let price_per_unit = price.unit_amount.unwrap_or_default();
115 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
116
117 stripe::Subscription::update(
118 &self.client,
119 subscription_id,
120 stripe::UpdateSubscription {
121 items: Some(vec![stripe::UpdateSubscriptionItems {
122 price: Some(price.id.to_string()),
123 ..Default::default()
124 }]),
125 trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
126 end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
127 missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
128 },
129 }),
130 ..Default::default()
131 },
132 )
133 .await?;
134
135 Ok(())
136 }
137
138 pub async fn bill_model_request_usage(
139 &self,
140 customer_id: &stripe::CustomerId,
141 event_name: &str,
142 requests: i32,
143 ) -> Result<()> {
144 let timestamp = Utc::now().timestamp();
145 let idempotency_key = Uuid::new_v4();
146
147 StripeMeterEvent::create(
148 &self.client,
149 StripeCreateMeterEventParams {
150 identifier: &format!("model_requests/{}", idempotency_key),
151 event_name,
152 payload: StripeCreateMeterEventPayload {
153 value: requests as u64,
154 stripe_customer_id: customer_id,
155 },
156 timestamp: Some(timestamp),
157 },
158 )
159 .await?;
160
161 Ok(())
162 }
163
164 pub async fn checkout_with_zed_pro(
165 &self,
166 customer_id: stripe::CustomerId,
167 github_login: &str,
168 success_url: &str,
169 ) -> Result<String> {
170 let zed_pro_price_id = self.zed_pro_price_id().await?;
171
172 let mut params = stripe::CreateCheckoutSession::new();
173 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
174 params.customer = Some(customer_id);
175 params.client_reference_id = Some(github_login);
176 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
177 price: Some(zed_pro_price_id.to_string()),
178 quantity: Some(1),
179 ..Default::default()
180 }]);
181 params.success_url = Some(success_url);
182
183 let session = stripe::CheckoutSession::create(&self.client, params).await?;
184 Ok(session.url.context("no checkout session URL")?)
185 }
186
187 pub async fn checkout_with_zed_pro_trial(
188 &self,
189 customer_id: stripe::CustomerId,
190 github_login: &str,
191 feature_flags: Vec<String>,
192 success_url: &str,
193 ) -> Result<String> {
194 let zed_pro_price_id = self.zed_pro_price_id().await?;
195
196 let eligible_for_extended_trial = feature_flags
197 .iter()
198 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
199
200 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
201
202 let mut subscription_metadata = std::collections::HashMap::new();
203 if eligible_for_extended_trial {
204 subscription_metadata.insert(
205 "promo_feature_flag".to_string(),
206 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
207 );
208 }
209
210 let mut params = stripe::CreateCheckoutSession::new();
211 params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
212 trial_period_days: Some(trial_period_days),
213 trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
214 end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
215 missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
216 }
217 }),
218 metadata: if !subscription_metadata.is_empty() {
219 Some(subscription_metadata)
220 } else {
221 None
222 },
223 ..Default::default()
224 });
225 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
226 params.payment_method_collection =
227 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
228 params.customer = Some(customer_id);
229 params.client_reference_id = Some(github_login);
230 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
231 price: Some(zed_pro_price_id.to_string()),
232 quantity: Some(1),
233 ..Default::default()
234 }]);
235 params.success_url = Some(success_url);
236
237 let session = stripe::CheckoutSession::create(&self.client, params).await?;
238 Ok(session.url.context("no checkout session URL")?)
239 }
240
241 pub async fn checkout_with_zed_free(
242 &self,
243 customer_id: stripe::CustomerId,
244 github_login: &str,
245 success_url: &str,
246 ) -> Result<String> {
247 let zed_free_price_id = self.zed_free_price_id().await?;
248
249 let mut params = stripe::CreateCheckoutSession::new();
250 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
251 params.payment_method_collection =
252 Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
253 params.customer = Some(customer_id);
254 params.client_reference_id = Some(github_login);
255 params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
256 price: Some(zed_free_price_id.to_string()),
257 quantity: Some(1),
258 ..Default::default()
259 }]);
260 params.success_url = Some(success_url);
261
262 let session = stripe::CheckoutSession::create(&self.client, params).await?;
263 Ok(session.url.context("no checkout session URL")?)
264 }
265}
266
267#[derive(Clone, Deserialize)]
268struct StripeMeter {
269 id: String,
270 event_name: String,
271}
272
273impl StripeMeter {
274 pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
275 #[derive(Serialize)]
276 struct Params {
277 #[serde(skip_serializing_if = "Option::is_none")]
278 limit: Option<u64>,
279 }
280
281 client.get_query("/billing/meters", Params { limit: Some(100) })
282 }
283}
284
285#[derive(Deserialize)]
286struct StripeMeterEvent {
287 identifier: String,
288}
289
290impl StripeMeterEvent {
291 pub async fn create(
292 client: &stripe::Client,
293 params: StripeCreateMeterEventParams<'_>,
294 ) -> Result<Self, stripe::StripeError> {
295 let identifier = params.identifier;
296 match client.post_form("/billing/meter_events", params).await {
297 Ok(event) => Ok(event),
298 Err(stripe::StripeError::Stripe(error)) => {
299 if error.http_status == 400
300 && error
301 .message
302 .as_ref()
303 .map_or(false, |message| message.contains(identifier))
304 {
305 Ok(Self {
306 identifier: identifier.to_string(),
307 })
308 } else {
309 Err(stripe::StripeError::Stripe(error))
310 }
311 }
312 Err(error) => Err(error),
313 }
314 }
315}
316
317#[derive(Serialize)]
318struct StripeCreateMeterEventParams<'a> {
319 identifier: &'a str,
320 event_name: &'a str,
321 payload: StripeCreateMeterEventPayload<'a>,
322 timestamp: Option<i64>,
323}
324
325#[derive(Serialize)]
326struct StripeCreateMeterEventPayload<'a> {
327 value: u64,
328 stripe_customer_id: &'a stripe::CustomerId,
329}
330
331fn subscription_contains_price(
332 subscription: &stripe::Subscription,
333 price_id: &stripe::PriceId,
334) -> bool {
335 subscription.items.data.iter().any(|item| {
336 item.price
337 .as_ref()
338 .map_or(false, |price| price.id == *price_id)
339 })
340}