1use std::sync::Arc;
2
3use anyhow::anyhow;
4use chrono::Utc;
5use collections::HashMap;
6use stripe::SubscriptionStatus;
7use tokio::sync::RwLock;
8use uuid::Uuid;
9
10use crate::Result;
11use crate::db::billing_subscription::SubscriptionKind;
12use crate::stripe_client::{
13 RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams,
14 StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
15 StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
16 StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
17 StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
18 UpdateSubscriptionParams,
19};
20
21pub struct StripeBilling {
22 state: RwLock<StripeBillingState>,
23 client: Arc<dyn StripeClient>,
24}
25
26#[derive(Default)]
27struct StripeBillingState {
28 prices_by_lookup_key: HashMap<String, StripePrice>,
29}
30
31impl StripeBilling {
32 pub fn new(client: Arc<stripe::Client>) -> Self {
33 Self {
34 client: Arc::new(RealStripeClient::new(client.clone())),
35 state: RwLock::default(),
36 }
37 }
38
39 #[cfg(test)]
40 pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
41 Self {
42 client,
43 state: RwLock::default(),
44 }
45 }
46
47 pub fn client(&self) -> &Arc<dyn StripeClient> {
48 &self.client
49 }
50
51 pub async fn initialize(&self) -> Result<()> {
52 log::info!("StripeBilling: initializing");
53
54 let mut state = self.state.write().await;
55
56 let prices = self.client.list_prices().await?;
57
58 for price in prices {
59 if let Some(lookup_key) = price.lookup_key.clone() {
60 state.prices_by_lookup_key.insert(lookup_key, price);
61 }
62 }
63
64 log::info!("StripeBilling: initialized");
65
66 Ok(())
67 }
68
69 pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
70 self.find_price_id_by_lookup_key("zed-pro").await
71 }
72
73 pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
74 self.find_price_id_by_lookup_key("zed-free").await
75 }
76
77 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
78 self.state
79 .read()
80 .await
81 .prices_by_lookup_key
82 .get(lookup_key)
83 .map(|price| price.id.clone())
84 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
85 }
86
87 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
88 self.state
89 .read()
90 .await
91 .prices_by_lookup_key
92 .get(lookup_key)
93 .cloned()
94 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
95 }
96
97 pub async fn determine_subscription_kind(
98 &self,
99 subscription: &StripeSubscription,
100 ) -> Option<SubscriptionKind> {
101 let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
102 let zed_free_price_id = self.zed_free_price_id().await.ok()?;
103
104 subscription.items.iter().find_map(|item| {
105 let price = item.price.as_ref()?;
106
107 if price.id == zed_pro_price_id {
108 Some(if subscription.status == SubscriptionStatus::Trialing {
109 SubscriptionKind::ZedProTrial
110 } else {
111 SubscriptionKind::ZedPro
112 })
113 } else if price.id == zed_free_price_id {
114 Some(SubscriptionKind::ZedFree)
115 } else {
116 None
117 }
118 })
119 }
120
121 /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
122 /// not already exist.
123 ///
124 /// Always returns a new Stripe customer if the email address is `None`.
125 pub async fn find_or_create_customer_by_email(
126 &self,
127 email_address: Option<&str>,
128 ) -> Result<StripeCustomerId> {
129 let existing_customer = if let Some(email) = email_address {
130 let customers = self.client.list_customers_by_email(email).await?;
131
132 customers.first().cloned()
133 } else {
134 None
135 };
136
137 let customer_id = if let Some(existing_customer) = existing_customer {
138 existing_customer.id
139 } else {
140 let customer = self
141 .client
142 .create_customer(crate::stripe_client::CreateCustomerParams {
143 email: email_address,
144 })
145 .await?;
146
147 customer.id
148 };
149
150 Ok(customer_id)
151 }
152
153 pub async fn subscribe_to_price(
154 &self,
155 subscription_id: &StripeSubscriptionId,
156 price: &StripePrice,
157 ) -> Result<()> {
158 let subscription = self.client.get_subscription(subscription_id).await?;
159
160 if subscription_contains_price(&subscription, &price.id) {
161 return Ok(());
162 }
163
164 const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
165
166 let price_per_unit = price.unit_amount.unwrap_or_default();
167 let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
168
169 self.client
170 .update_subscription(
171 subscription_id,
172 UpdateSubscriptionParams {
173 items: Some(vec![UpdateSubscriptionItems {
174 price: Some(price.id.clone()),
175 }]),
176 trial_settings: Some(StripeSubscriptionTrialSettings {
177 end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
178 missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
179 },
180 }),
181 },
182 )
183 .await?;
184
185 Ok(())
186 }
187
188 pub async fn bill_model_request_usage(
189 &self,
190 customer_id: &StripeCustomerId,
191 event_name: &str,
192 requests: i32,
193 ) -> Result<()> {
194 let timestamp = Utc::now().timestamp();
195 let idempotency_key = Uuid::new_v4();
196
197 self.client
198 .create_meter_event(StripeCreateMeterEventParams {
199 identifier: &format!("model_requests/{}", idempotency_key),
200 event_name,
201 payload: StripeCreateMeterEventPayload {
202 value: requests as u64,
203 stripe_customer_id: customer_id,
204 },
205 timestamp: Some(timestamp),
206 })
207 .await?;
208
209 Ok(())
210 }
211
212 pub async fn subscribe_to_zed_free(
213 &self,
214 customer_id: StripeCustomerId,
215 ) -> Result<StripeSubscription> {
216 let zed_free_price_id = self.zed_free_price_id().await?;
217
218 let existing_subscriptions = self
219 .client
220 .list_subscriptions_for_customer(&customer_id)
221 .await?;
222
223 let existing_active_subscription =
224 existing_subscriptions.into_iter().find(|subscription| {
225 subscription.status == SubscriptionStatus::Active
226 || subscription.status == SubscriptionStatus::Trialing
227 });
228 if let Some(subscription) = existing_active_subscription {
229 return Ok(subscription);
230 }
231
232 let params = StripeCreateSubscriptionParams {
233 customer: customer_id,
234 items: vec![StripeCreateSubscriptionItems {
235 price: Some(zed_free_price_id),
236 quantity: Some(1),
237 }],
238 automatic_tax: Some(StripeAutomaticTax { enabled: true }),
239 };
240
241 let subscription = self.client.create_subscription(params).await?;
242
243 Ok(subscription)
244 }
245}
246
247fn subscription_contains_price(
248 subscription: &StripeSubscription,
249 price_id: &StripePriceId,
250) -> bool {
251 subscription.items.iter().any(|item| {
252 item.price
253 .as_ref()
254 .map_or(false, |price| price.id == *price_id)
255 })
256}