1use std::sync::Arc;
2
3use anyhow::anyhow;
4use collections::HashMap;
5use stripe::SubscriptionStatus;
6use tokio::sync::RwLock;
7
8use crate::Result;
9use crate::stripe_client::{
10 RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems,
11 StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId,
12 StripeSubscription,
13};
14
15pub struct StripeBilling {
16 state: RwLock<StripeBillingState>,
17 client: Arc<dyn StripeClient>,
18}
19
20#[derive(Default)]
21struct StripeBillingState {
22 prices_by_lookup_key: HashMap<String, StripePrice>,
23}
24
25impl StripeBilling {
26 pub fn new(client: Arc<stripe::Client>) -> Self {
27 Self {
28 client: Arc::new(RealStripeClient::new(client.clone())),
29 state: RwLock::default(),
30 }
31 }
32
33 #[cfg(test)]
34 pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
35 Self {
36 client,
37 state: RwLock::default(),
38 }
39 }
40
41 pub fn client(&self) -> &Arc<dyn StripeClient> {
42 &self.client
43 }
44
45 pub async fn initialize(&self) -> Result<()> {
46 log::info!("StripeBilling: initializing");
47
48 let mut state = self.state.write().await;
49
50 let prices = self.client.list_prices().await?;
51
52 for price in prices {
53 if let Some(lookup_key) = price.lookup_key.clone() {
54 state.prices_by_lookup_key.insert(lookup_key, price);
55 }
56 }
57
58 log::info!("StripeBilling: initialized");
59
60 Ok(())
61 }
62
63 pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
64 self.find_price_id_by_lookup_key("zed-pro").await
65 }
66
67 pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
68 self.find_price_id_by_lookup_key("zed-free").await
69 }
70
71 pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
72 self.state
73 .read()
74 .await
75 .prices_by_lookup_key
76 .get(lookup_key)
77 .map(|price| price.id.clone())
78 .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
79 }
80
81 pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
82 self.state
83 .read()
84 .await
85 .prices_by_lookup_key
86 .get(lookup_key)
87 .cloned()
88 .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
89 }
90
91 /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
92 /// not already exist.
93 ///
94 /// Always returns a new Stripe customer if the email address is `None`.
95 pub async fn find_or_create_customer_by_email(
96 &self,
97 email_address: Option<&str>,
98 ) -> Result<StripeCustomerId> {
99 let existing_customer = if let Some(email) = email_address {
100 let customers = self.client.list_customers_by_email(email).await?;
101
102 customers.first().cloned()
103 } else {
104 None
105 };
106
107 let customer_id = if let Some(existing_customer) = existing_customer {
108 existing_customer.id
109 } else {
110 let customer = self
111 .client
112 .create_customer(crate::stripe_client::CreateCustomerParams {
113 email: email_address,
114 })
115 .await?;
116
117 customer.id
118 };
119
120 Ok(customer_id)
121 }
122
123 pub async fn subscribe_to_zed_free(
124 &self,
125 customer_id: StripeCustomerId,
126 ) -> Result<StripeSubscription> {
127 let zed_free_price_id = self.zed_free_price_id().await?;
128
129 let existing_subscriptions = self
130 .client
131 .list_subscriptions_for_customer(&customer_id)
132 .await?;
133
134 let existing_active_subscription =
135 existing_subscriptions.into_iter().find(|subscription| {
136 subscription.status == SubscriptionStatus::Active
137 || subscription.status == SubscriptionStatus::Trialing
138 });
139 if let Some(subscription) = existing_active_subscription {
140 return Ok(subscription);
141 }
142
143 let params = StripeCreateSubscriptionParams {
144 customer: customer_id,
145 items: vec![StripeCreateSubscriptionItems {
146 price: Some(zed_free_price_id),
147 quantity: Some(1),
148 }],
149 automatic_tax: Some(StripeAutomaticTax { enabled: true }),
150 };
151
152 let subscription = self.client.create_subscription(params).await?;
153
154 Ok(subscription)
155 }
156}