stripe_billing.rs

  1use std::sync::Arc;
  2
  3use crate::Result;
  4use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  5use anyhow::{Context as _, anyhow};
  6use chrono::Utc;
  7use collections::HashMap;
  8use serde::{Deserialize, Serialize};
  9use stripe::PriceId;
 10use tokio::sync::RwLock;
 11use uuid::Uuid;
 12
 13pub struct StripeBilling {
 14    state: RwLock<StripeBillingState>,
 15    client: Arc<stripe::Client>,
 16}
 17
 18#[derive(Default)]
 19struct StripeBillingState {
 20    meters_by_event_name: HashMap<String, StripeMeter>,
 21    price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
 22    prices_by_lookup_key: HashMap<String, stripe::Price>,
 23}
 24
 25impl StripeBilling {
 26    pub fn new(client: Arc<stripe::Client>) -> Self {
 27        Self {
 28            client,
 29            state: RwLock::default(),
 30        }
 31    }
 32
 33    pub async fn initialize(&self) -> Result<()> {
 34        log::info!("StripeBilling: initializing");
 35
 36        let mut state = self.state.write().await;
 37
 38        let (meters, prices) = futures::try_join!(
 39            StripeMeter::list(&self.client),
 40            stripe::Price::list(
 41                &self.client,
 42                &stripe::ListPrices {
 43                    limit: Some(100),
 44                    ..Default::default()
 45                }
 46            )
 47        )?;
 48
 49        for meter in meters.data {
 50            state
 51                .meters_by_event_name
 52                .insert(meter.event_name.clone(), meter);
 53        }
 54
 55        for price in prices.data {
 56            if let Some(lookup_key) = price.lookup_key.clone() {
 57                state.prices_by_lookup_key.insert(lookup_key, price.clone());
 58            }
 59
 60            if let Some(recurring) = price.recurring {
 61                if let Some(meter) = recurring.meter {
 62                    state.price_ids_by_meter_id.insert(meter, price.id);
 63                }
 64            }
 65        }
 66
 67        log::info!("StripeBilling: initialized");
 68
 69        Ok(())
 70    }
 71
 72    pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
 73        self.find_price_id_by_lookup_key("zed-pro").await
 74    }
 75
 76    pub async fn zed_free_price_id(&self) -> Result<PriceId> {
 77        self.find_price_id_by_lookup_key("zed-free").await
 78    }
 79
 80    pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
 81        self.state
 82            .read()
 83            .await
 84            .prices_by_lookup_key
 85            .get(lookup_key)
 86            .map(|price| price.id.clone())
 87            .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
 88    }
 89
 90    pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
 91        self.state
 92            .read()
 93            .await
 94            .prices_by_lookup_key
 95            .get(lookup_key)
 96            .cloned()
 97            .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
 98    }
 99
100    pub async fn subscribe_to_price(
101        &self,
102        subscription_id: &stripe::SubscriptionId,
103        price: &stripe::Price,
104    ) -> Result<()> {
105        let subscription =
106            stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
107
108        if subscription_contains_price(&subscription, &price.id) {
109            return Ok(());
110        }
111
112        const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
113
114        let price_per_unit = price.unit_amount.unwrap_or_default();
115        let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
116
117        stripe::Subscription::update(
118            &self.client,
119            subscription_id,
120            stripe::UpdateSubscription {
121                items: Some(vec![stripe::UpdateSubscriptionItems {
122                    price: Some(price.id.to_string()),
123                    ..Default::default()
124                }]),
125                trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
126                    end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
127                        missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
128                    },
129                }),
130                ..Default::default()
131            },
132        )
133        .await?;
134
135        Ok(())
136    }
137
138    pub async fn bill_model_request_usage(
139        &self,
140        customer_id: &stripe::CustomerId,
141        event_name: &str,
142        requests: i32,
143    ) -> Result<()> {
144        let timestamp = Utc::now().timestamp();
145        let idempotency_key = Uuid::new_v4();
146
147        StripeMeterEvent::create(
148            &self.client,
149            StripeCreateMeterEventParams {
150                identifier: &format!("model_requests/{}", idempotency_key),
151                event_name,
152                payload: StripeCreateMeterEventPayload {
153                    value: requests as u64,
154                    stripe_customer_id: customer_id,
155                },
156                timestamp: Some(timestamp),
157            },
158        )
159        .await?;
160
161        Ok(())
162    }
163
164    pub async fn checkout_with_zed_pro(
165        &self,
166        customer_id: stripe::CustomerId,
167        github_login: &str,
168        success_url: &str,
169    ) -> Result<String> {
170        let zed_pro_price_id = self.zed_pro_price_id().await?;
171
172        let mut params = stripe::CreateCheckoutSession::new();
173        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
174        params.customer = Some(customer_id);
175        params.client_reference_id = Some(github_login);
176        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
177            price: Some(zed_pro_price_id.to_string()),
178            quantity: Some(1),
179            ..Default::default()
180        }]);
181        params.success_url = Some(success_url);
182
183        let session = stripe::CheckoutSession::create(&self.client, params).await?;
184        Ok(session.url.context("no checkout session URL")?)
185    }
186
187    pub async fn checkout_with_zed_pro_trial(
188        &self,
189        customer_id: stripe::CustomerId,
190        github_login: &str,
191        feature_flags: Vec<String>,
192        success_url: &str,
193    ) -> Result<String> {
194        let zed_pro_price_id = self.zed_pro_price_id().await?;
195
196        let eligible_for_extended_trial = feature_flags
197            .iter()
198            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
199
200        let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
201
202        let mut subscription_metadata = std::collections::HashMap::new();
203        if eligible_for_extended_trial {
204            subscription_metadata.insert(
205                "promo_feature_flag".to_string(),
206                AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
207            );
208        }
209
210        let mut params = stripe::CreateCheckoutSession::new();
211        params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
212            trial_period_days: Some(trial_period_days),
213            trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
214                end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
215                    missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
216                }
217            }),
218            metadata: if !subscription_metadata.is_empty() {
219                Some(subscription_metadata)
220            } else {
221                None
222            },
223            ..Default::default()
224        });
225        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
226        params.payment_method_collection =
227            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
228        params.customer = Some(customer_id);
229        params.client_reference_id = Some(github_login);
230        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
231            price: Some(zed_pro_price_id.to_string()),
232            quantity: Some(1),
233            ..Default::default()
234        }]);
235        params.success_url = Some(success_url);
236
237        let session = stripe::CheckoutSession::create(&self.client, params).await?;
238        Ok(session.url.context("no checkout session URL")?)
239    }
240
241    pub async fn checkout_with_zed_free(
242        &self,
243        customer_id: stripe::CustomerId,
244        github_login: &str,
245        success_url: &str,
246    ) -> Result<String> {
247        let zed_free_price_id = self.zed_free_price_id().await?;
248
249        let mut params = stripe::CreateCheckoutSession::new();
250        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
251        params.payment_method_collection =
252            Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
253        params.customer = Some(customer_id);
254        params.client_reference_id = Some(github_login);
255        params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
256            price: Some(zed_free_price_id.to_string()),
257            quantity: Some(1),
258            ..Default::default()
259        }]);
260        params.success_url = Some(success_url);
261
262        let session = stripe::CheckoutSession::create(&self.client, params).await?;
263        Ok(session.url.context("no checkout session URL")?)
264    }
265}
266
267#[derive(Clone, Deserialize)]
268struct StripeMeter {
269    id: String,
270    event_name: String,
271}
272
273impl StripeMeter {
274    pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
275        #[derive(Serialize)]
276        struct Params {
277            #[serde(skip_serializing_if = "Option::is_none")]
278            limit: Option<u64>,
279        }
280
281        client.get_query("/billing/meters", Params { limit: Some(100) })
282    }
283}
284
285#[derive(Deserialize)]
286struct StripeMeterEvent {
287    identifier: String,
288}
289
290impl StripeMeterEvent {
291    pub async fn create(
292        client: &stripe::Client,
293        params: StripeCreateMeterEventParams<'_>,
294    ) -> Result<Self, stripe::StripeError> {
295        let identifier = params.identifier;
296        match client.post_form("/billing/meter_events", params).await {
297            Ok(event) => Ok(event),
298            Err(stripe::StripeError::Stripe(error)) => {
299                if error.http_status == 400
300                    && error
301                        .message
302                        .as_ref()
303                        .map_or(false, |message| message.contains(identifier))
304                {
305                    Ok(Self {
306                        identifier: identifier.to_string(),
307                    })
308                } else {
309                    Err(stripe::StripeError::Stripe(error))
310                }
311            }
312            Err(error) => Err(error),
313        }
314    }
315}
316
317#[derive(Serialize)]
318struct StripeCreateMeterEventParams<'a> {
319    identifier: &'a str,
320    event_name: &'a str,
321    payload: StripeCreateMeterEventPayload<'a>,
322    timestamp: Option<i64>,
323}
324
325#[derive(Serialize)]
326struct StripeCreateMeterEventPayload<'a> {
327    value: u64,
328    stripe_customer_id: &'a stripe::CustomerId,
329}
330
331fn subscription_contains_price(
332    subscription: &stripe::Subscription,
333    price_id: &stripe::PriceId,
334) -> bool {
335    subscription.items.data.iter().any(|item| {
336        item.price
337            .as_ref()
338            .map_or(false, |price| price.id == *price_id)
339    })
340}