1use std::sync::Arc;
2
3use anyhow::{Context as _, anyhow};
4use chrono::Utc;
5use collections::HashMap;
6use stripe::SubscriptionStatus;
7use tokio::sync::RwLock;
8use uuid::Uuid;
9
10use crate::Result;
11use crate::db::billing_subscription::SubscriptionKind;
12use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
13use crate::stripe_client::{
14 RealStripeClient, StripeAutomaticTax, StripeBillingAddressCollection,
15 StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
16 StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
17 StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
18 StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
19 StripeCustomerId, StripeCustomerUpdate, StripeCustomerUpdateAddress, StripeCustomerUpdateName,
20 StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
21 StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
22 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection,
23 UpdateSubscriptionItems, UpdateSubscriptionParams,
24};
25
26pub struct StripeBilling {
27 state: RwLock<StripeBillingState>,
28 client: Arc<dyn StripeClient>,
29}
30
31#[derive(Default)]
32struct StripeBillingState {
33 prices_by_lookup_key: HashMap<String, StripePrice>,
34}
35
36impl StripeBilling {
37 pub fn new(client: Arc<stripe::Client>) -> Self {
38 Self {
39 client: Arc::new(RealStripeClient::new(client.clone())),
40 state: RwLock::default(),
41 }
42 }
43
44 #[cfg(test)]
45 pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
46 Self {
47 client,
48 state: RwLock::default(),
49 }
50 }
51
52 pub fn client(&self) -> &Arc<dyn StripeClient> {
53 &self.client
54 }
55
56 pub async fn initialize(&self) -> Result<()> {
57 log::info!("StripeBilling: initializing");
58
59 let mut state = self.state.write().await;
60
61 let prices = self.client.list_prices().await?;
62
63 for price in prices {
64 if let Some(lookup_key) = price.lookup_key.clone() {
65 state.prices_by_lookup_key.insert(lookup_key, price);
66 }
67 }
68
69 log::info!("StripeBilling: initialized");
70
71 Ok(())
72 }
73
74 pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
75 self.find_price_id_by_lookup_key("zed-pro").await
76 }
77
78 pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
79 self.find_price_id_by_lookup_key("zed-free").await
80 }
81
82 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
83 self.state
84 .read()
85 .await
86 .prices_by_lookup_key
87 .get(lookup_key)
88 .map(|price| price.id.clone())
89 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
90 }
91
92 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
93 self.state
94 .read()
95 .await
96 .prices_by_lookup_key
97 .get(lookup_key)
98 .cloned()
99 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
100 }
101
102 pub async fn determine_subscription_kind(
103 &self,
104 subscription: &StripeSubscription,
105 ) -> Option<SubscriptionKind> {
106 let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
107 let zed_free_price_id = self.zed_free_price_id().await.ok()?;
108
109 subscription.items.iter().find_map(|item| {
110 let price = item.price.as_ref()?;
111
112 if price.id == zed_pro_price_id {
113 Some(if subscription.status == SubscriptionStatus::Trialing {
114 SubscriptionKind::ZedProTrial
115 } else {
116 SubscriptionKind::ZedPro
117 })
118 } else if price.id == zed_free_price_id {
119 Some(SubscriptionKind::ZedFree)
120 } else {
121 None
122 }
123 })
124 }
125
126 /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
127 /// not already exist.
128 ///
129 /// Always returns a new Stripe customer if the email address is `None`.
130 pub async fn find_or_create_customer_by_email(
131 &self,
132 email_address: Option<&str>,
133 ) -> Result<StripeCustomerId> {
134 let existing_customer = if let Some(email) = email_address {
135 let customers = self.client.list_customers_by_email(email).await?;
136
137 customers.first().cloned()
138 } else {
139 None
140 };
141
142 let customer_id = if let Some(existing_customer) = existing_customer {
143 existing_customer.id
144 } else {
145 let customer = self
146 .client
147 .create_customer(crate::stripe_client::CreateCustomerParams {
148 email: email_address,
149 })
150 .await?;
151
152 customer.id
153 };
154
155 Ok(customer_id)
156 }
157
158 pub async fn subscribe_to_price(
159 &self,
160 subscription_id: &StripeSubscriptionId,
161 price: &StripePrice,
162 ) -> Result<()> {
163 let subscription = self.client.get_subscription(subscription_id).await?;
164
165 if subscription_contains_price(&subscription, &price.id) {
166 return Ok(());
167 }
168
169 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
170
171 let price_per_unit = price.unit_amount.unwrap_or_default();
172 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
173
174 self.client
175 .update_subscription(
176 subscription_id,
177 UpdateSubscriptionParams {
178 items: Some(vec![UpdateSubscriptionItems {
179 price: Some(price.id.clone()),
180 }]),
181 trial_settings: Some(StripeSubscriptionTrialSettings {
182 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
183 missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
184 },
185 }),
186 },
187 )
188 .await?;
189
190 Ok(())
191 }
192
193 pub async fn bill_model_request_usage(
194 &self,
195 customer_id: &StripeCustomerId,
196 event_name: &str,
197 requests: i32,
198 ) -> Result<()> {
199 let timestamp = Utc::now().timestamp();
200 let idempotency_key = Uuid::new_v4();
201
202 self.client
203 .create_meter_event(StripeCreateMeterEventParams {
204 identifier: &format!("model_requests/{}", idempotency_key),
205 event_name,
206 payload: StripeCreateMeterEventPayload {
207 value: requests as u64,
208 stripe_customer_id: customer_id,
209 },
210 timestamp: Some(timestamp),
211 })
212 .await?;
213
214 Ok(())
215 }
216
217 pub async fn checkout_with_zed_pro(
218 &self,
219 customer_id: &StripeCustomerId,
220 github_login: &str,
221 success_url: &str,
222 ) -> Result<String> {
223 let zed_pro_price_id = self.zed_pro_price_id().await?;
224
225 let mut params = StripeCreateCheckoutSessionParams::default();
226 params.mode = Some(StripeCheckoutSessionMode::Subscription);
227 params.customer = Some(customer_id);
228 params.client_reference_id = Some(github_login);
229 params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
230 price: Some(zed_pro_price_id.to_string()),
231 quantity: Some(1),
232 }]);
233 params.success_url = Some(success_url);
234 params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
235 params.customer_update = Some(StripeCustomerUpdate {
236 address: Some(StripeCustomerUpdateAddress::Auto),
237 name: Some(StripeCustomerUpdateName::Auto),
238 shipping: None,
239 });
240 params.tax_id_collection = Some(StripeTaxIdCollection { enabled: true });
241
242 let session = self.client.create_checkout_session(params).await?;
243 Ok(session.url.context("no checkout session URL")?)
244 }
245
246 pub async fn checkout_with_zed_pro_trial(
247 &self,
248 customer_id: &StripeCustomerId,
249 github_login: &str,
250 feature_flags: Vec<String>,
251 success_url: &str,
252 ) -> Result<String> {
253 let zed_pro_price_id = self.zed_pro_price_id().await?;
254
255 let eligible_for_extended_trial = feature_flags
256 .iter()
257 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
258
259 let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
260
261 let mut subscription_metadata = std::collections::HashMap::new();
262 if eligible_for_extended_trial {
263 subscription_metadata.insert(
264 "promo_feature_flag".to_string(),
265 AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
266 );
267 }
268
269 let mut params = StripeCreateCheckoutSessionParams::default();
270 params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
271 trial_period_days: Some(trial_period_days),
272 trial_settings: Some(StripeSubscriptionTrialSettings {
273 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
274 missing_payment_method:
275 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
276 },
277 }),
278 metadata: if !subscription_metadata.is_empty() {
279 Some(subscription_metadata)
280 } else {
281 None
282 },
283 });
284 params.mode = Some(StripeCheckoutSessionMode::Subscription);
285 params.payment_method_collection =
286 Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
287 params.customer = Some(customer_id);
288 params.client_reference_id = Some(github_login);
289 params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
290 price: Some(zed_pro_price_id.to_string()),
291 quantity: Some(1),
292 }]);
293 params.success_url = Some(success_url);
294 params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
295 params.customer_update = Some(StripeCustomerUpdate {
296 address: Some(StripeCustomerUpdateAddress::Auto),
297 name: Some(StripeCustomerUpdateName::Auto),
298 shipping: None,
299 });
300 params.tax_id_collection = Some(StripeTaxIdCollection { enabled: true });
301
302 let session = self.client.create_checkout_session(params).await?;
303 Ok(session.url.context("no checkout session URL")?)
304 }
305
306 pub async fn subscribe_to_zed_free(
307 &self,
308 customer_id: StripeCustomerId,
309 ) -> Result<StripeSubscription> {
310 let zed_free_price_id = self.zed_free_price_id().await?;
311
312 let existing_subscriptions = self
313 .client
314 .list_subscriptions_for_customer(&customer_id)
315 .await?;
316
317 let existing_active_subscription =
318 existing_subscriptions.into_iter().find(|subscription| {
319 subscription.status == SubscriptionStatus::Active
320 || subscription.status == SubscriptionStatus::Trialing
321 });
322 if let Some(subscription) = existing_active_subscription {
323 return Ok(subscription);
324 }
325
326 let params = StripeCreateSubscriptionParams {
327 customer: customer_id,
328 items: vec![StripeCreateSubscriptionItems {
329 price: Some(zed_free_price_id),
330 quantity: Some(1),
331 }],
332 automatic_tax: Some(StripeAutomaticTax { enabled: true }),
333 };
334
335 let subscription = self.client.create_subscription(params).await?;
336
337 Ok(subscription)
338 }
339}
340
341fn subscription_contains_price(
342 subscription: &StripeSubscription,
343 price_id: &StripePriceId,
344) -> bool {
345 subscription.items.iter().any(|item| {
346 item.price
347 .as_ref()
348 .map_or(false, |price| price.id == *price_id)
349 })
350}