fake_stripe_client.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Result, anyhow};
  4use async_trait::async_trait;
  5use chrono::{Duration, Utc};
  6use collections::HashMap;
  7use parking_lot::Mutex;
  8use uuid::Uuid;
  9
 10use crate::stripe_client::{
 11    CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
 12    StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
 13    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
 14    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
 15    StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
 16    StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
 17    StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection,
 18    UpdateCustomerParams, UpdateSubscriptionParams,
 19};
 20
 21#[derive(Debug, Clone)]
 22pub struct StripeCreateMeterEventCall {
 23    pub identifier: Arc<str>,
 24    pub event_name: Arc<str>,
 25    pub value: u64,
 26    pub stripe_customer_id: StripeCustomerId,
 27    pub timestamp: Option<i64>,
 28}
 29
 30#[derive(Debug, Clone)]
 31pub struct StripeCreateCheckoutSessionCall {
 32    pub customer: Option<StripeCustomerId>,
 33    pub client_reference_id: Option<String>,
 34    pub mode: Option<StripeCheckoutSessionMode>,
 35    pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
 36    pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
 37    pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
 38    pub success_url: Option<String>,
 39    pub billing_address_collection: Option<StripeBillingAddressCollection>,
 40    pub customer_update: Option<StripeCustomerUpdate>,
 41    pub tax_id_collection: Option<StripeTaxIdCollection>,
 42}
 43
 44pub struct FakeStripeClient {
 45    pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
 46    pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
 47    pub update_subscription_calls:
 48        Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
 49    pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
 50    pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
 51    pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
 52    pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
 53}
 54
 55impl FakeStripeClient {
 56    pub fn new() -> Self {
 57        Self {
 58            customers: Arc::new(Mutex::new(HashMap::default())),
 59            subscriptions: Arc::new(Mutex::new(HashMap::default())),
 60            update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
 61            prices: Arc::new(Mutex::new(HashMap::default())),
 62            meters: Arc::new(Mutex::new(HashMap::default())),
 63            create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
 64            create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
 65        }
 66    }
 67}
 68
 69#[async_trait]
 70impl StripeClient for FakeStripeClient {
 71    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
 72        Ok(self
 73            .customers
 74            .lock()
 75            .values()
 76            .filter(|customer| customer.email.as_deref() == Some(email))
 77            .cloned()
 78            .collect())
 79    }
 80
 81    async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
 82        self.customers
 83            .lock()
 84            .get(customer_id)
 85            .cloned()
 86            .ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
 87    }
 88
 89    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
 90        let customer = StripeCustomer {
 91            id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
 92            email: params.email.map(|email| email.to_string()),
 93        };
 94
 95        self.customers
 96            .lock()
 97            .insert(customer.id.clone(), customer.clone());
 98
 99        Ok(customer)
100    }
101
102    async fn update_customer(
103        &self,
104        customer_id: &StripeCustomerId,
105        params: UpdateCustomerParams<'_>,
106    ) -> Result<StripeCustomer> {
107        let mut customers = self.customers.lock();
108        if let Some(customer) = customers.get_mut(customer_id) {
109            if let Some(email) = params.email {
110                customer.email = Some(email.to_string());
111            }
112            Ok(customer.clone())
113        } else {
114            Err(anyhow!("no customer found for {customer_id:?}"))
115        }
116    }
117
118    async fn list_subscriptions_for_customer(
119        &self,
120        customer_id: &StripeCustomerId,
121    ) -> Result<Vec<StripeSubscription>> {
122        let subscriptions = self
123            .subscriptions
124            .lock()
125            .values()
126            .filter(|subscription| subscription.customer == *customer_id)
127            .cloned()
128            .collect();
129
130        Ok(subscriptions)
131    }
132
133    async fn get_subscription(
134        &self,
135        subscription_id: &StripeSubscriptionId,
136    ) -> Result<StripeSubscription> {
137        self.subscriptions
138            .lock()
139            .get(subscription_id)
140            .cloned()
141            .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
142    }
143
144    async fn create_subscription(
145        &self,
146        params: StripeCreateSubscriptionParams,
147    ) -> Result<StripeSubscription> {
148        let now = Utc::now();
149
150        let subscription = StripeSubscription {
151            id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
152            customer: params.customer,
153            status: stripe::SubscriptionStatus::Active,
154            current_period_start: now.timestamp(),
155            current_period_end: (now + Duration::days(30)).timestamp(),
156            items: params
157                .items
158                .into_iter()
159                .map(|item| StripeSubscriptionItem {
160                    id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
161                    price: item
162                        .price
163                        .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
164                })
165                .collect(),
166            cancel_at: None,
167            cancellation_details: None,
168        };
169
170        self.subscriptions
171            .lock()
172            .insert(subscription.id.clone(), subscription.clone());
173
174        Ok(subscription)
175    }
176
177    async fn update_subscription(
178        &self,
179        subscription_id: &StripeSubscriptionId,
180        params: UpdateSubscriptionParams,
181    ) -> Result<()> {
182        let subscription = self.get_subscription(subscription_id).await?;
183
184        self.update_subscription_calls
185            .lock()
186            .push((subscription.id, params));
187
188        Ok(())
189    }
190
191    async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
192        // TODO: Implement fake subscription cancellation.
193        let _ = subscription_id;
194
195        Ok(())
196    }
197
198    async fn list_prices(&self) -> Result<Vec<StripePrice>> {
199        let prices = self.prices.lock().values().cloned().collect();
200
201        Ok(prices)
202    }
203
204    async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
205        let meters = self.meters.lock().values().cloned().collect();
206
207        Ok(meters)
208    }
209
210    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
211        self.create_meter_event_calls
212            .lock()
213            .push(StripeCreateMeterEventCall {
214                identifier: params.identifier.into(),
215                event_name: params.event_name.into(),
216                value: params.payload.value,
217                stripe_customer_id: params.payload.stripe_customer_id.clone(),
218                timestamp: params.timestamp,
219            });
220
221        Ok(())
222    }
223
224    async fn create_checkout_session(
225        &self,
226        params: StripeCreateCheckoutSessionParams<'_>,
227    ) -> Result<StripeCheckoutSession> {
228        self.create_checkout_session_calls
229            .lock()
230            .push(StripeCreateCheckoutSessionCall {
231                customer: params.customer.cloned(),
232                client_reference_id: params.client_reference_id.map(|id| id.to_string()),
233                mode: params.mode,
234                line_items: params.line_items,
235                payment_method_collection: params.payment_method_collection,
236                subscription_data: params.subscription_data,
237                success_url: params.success_url.map(|url| url.to_string()),
238                billing_address_collection: params.billing_address_collection,
239                customer_update: params.customer_update,
240                tax_id_collection: params.tax_id_collection,
241            });
242
243        Ok(StripeCheckoutSession {
244            url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
245        })
246    }
247}