fake_stripe_client.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Result, anyhow};
  4use async_trait::async_trait;
  5use chrono::{Duration, Utc};
  6use collections::HashMap;
  7use parking_lot::Mutex;
  8use uuid::Uuid;
  9
 10use crate::stripe_client::{
 11    CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
 12    StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
 13    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
 14    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
 15    StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
 16    StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
 17    StripeSubscriptionItemId, UpdateCustomerParams, UpdateSubscriptionParams,
 18};
 19
 20#[derive(Debug, Clone)]
 21pub struct StripeCreateMeterEventCall {
 22    pub identifier: Arc<str>,
 23    pub event_name: Arc<str>,
 24    pub value: u64,
 25    pub stripe_customer_id: StripeCustomerId,
 26    pub timestamp: Option<i64>,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub struct StripeCreateCheckoutSessionCall {
 31    pub customer: Option<StripeCustomerId>,
 32    pub client_reference_id: Option<String>,
 33    pub mode: Option<StripeCheckoutSessionMode>,
 34    pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
 35    pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
 36    pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
 37    pub success_url: Option<String>,
 38    pub billing_address_collection: Option<StripeBillingAddressCollection>,
 39}
 40
 41pub struct FakeStripeClient {
 42    pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
 43    pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
 44    pub update_subscription_calls:
 45        Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
 46    pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
 47    pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
 48    pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
 49    pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
 50}
 51
 52impl FakeStripeClient {
 53    pub fn new() -> Self {
 54        Self {
 55            customers: Arc::new(Mutex::new(HashMap::default())),
 56            subscriptions: Arc::new(Mutex::new(HashMap::default())),
 57            update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
 58            prices: Arc::new(Mutex::new(HashMap::default())),
 59            meters: Arc::new(Mutex::new(HashMap::default())),
 60            create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
 61            create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
 62        }
 63    }
 64}
 65
 66#[async_trait]
 67impl StripeClient for FakeStripeClient {
 68    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
 69        Ok(self
 70            .customers
 71            .lock()
 72            .values()
 73            .filter(|customer| customer.email.as_deref() == Some(email))
 74            .cloned()
 75            .collect())
 76    }
 77
 78    async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
 79        self.customers
 80            .lock()
 81            .get(customer_id)
 82            .cloned()
 83            .ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
 84    }
 85
 86    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
 87        let customer = StripeCustomer {
 88            id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
 89            email: params.email.map(|email| email.to_string()),
 90        };
 91
 92        self.customers
 93            .lock()
 94            .insert(customer.id.clone(), customer.clone());
 95
 96        Ok(customer)
 97    }
 98
 99    async fn update_customer(
100        &self,
101        customer_id: &StripeCustomerId,
102        params: UpdateCustomerParams<'_>,
103    ) -> Result<StripeCustomer> {
104        let mut customers = self.customers.lock();
105        if let Some(customer) = customers.get_mut(customer_id) {
106            if let Some(email) = params.email {
107                customer.email = Some(email.to_string());
108            }
109            Ok(customer.clone())
110        } else {
111            Err(anyhow!("no customer found for {customer_id:?}"))
112        }
113    }
114
115    async fn list_subscriptions_for_customer(
116        &self,
117        customer_id: &StripeCustomerId,
118    ) -> Result<Vec<StripeSubscription>> {
119        let subscriptions = self
120            .subscriptions
121            .lock()
122            .values()
123            .filter(|subscription| subscription.customer == *customer_id)
124            .cloned()
125            .collect();
126
127        Ok(subscriptions)
128    }
129
130    async fn get_subscription(
131        &self,
132        subscription_id: &StripeSubscriptionId,
133    ) -> Result<StripeSubscription> {
134        self.subscriptions
135            .lock()
136            .get(subscription_id)
137            .cloned()
138            .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
139    }
140
141    async fn create_subscription(
142        &self,
143        params: StripeCreateSubscriptionParams,
144    ) -> Result<StripeSubscription> {
145        let now = Utc::now();
146
147        let subscription = StripeSubscription {
148            id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
149            customer: params.customer,
150            status: stripe::SubscriptionStatus::Active,
151            current_period_start: now.timestamp(),
152            current_period_end: (now + Duration::days(30)).timestamp(),
153            items: params
154                .items
155                .into_iter()
156                .map(|item| StripeSubscriptionItem {
157                    id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
158                    price: item
159                        .price
160                        .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
161                })
162                .collect(),
163            cancel_at: None,
164            cancellation_details: None,
165        };
166
167        self.subscriptions
168            .lock()
169            .insert(subscription.id.clone(), subscription.clone());
170
171        Ok(subscription)
172    }
173
174    async fn update_subscription(
175        &self,
176        subscription_id: &StripeSubscriptionId,
177        params: UpdateSubscriptionParams,
178    ) -> Result<()> {
179        let subscription = self.get_subscription(subscription_id).await?;
180
181        self.update_subscription_calls
182            .lock()
183            .push((subscription.id, params));
184
185        Ok(())
186    }
187
188    async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
189        // TODO: Implement fake subscription cancellation.
190        let _ = subscription_id;
191
192        Ok(())
193    }
194
195    async fn list_prices(&self) -> Result<Vec<StripePrice>> {
196        let prices = self.prices.lock().values().cloned().collect();
197
198        Ok(prices)
199    }
200
201    async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
202        let meters = self.meters.lock().values().cloned().collect();
203
204        Ok(meters)
205    }
206
207    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
208        self.create_meter_event_calls
209            .lock()
210            .push(StripeCreateMeterEventCall {
211                identifier: params.identifier.into(),
212                event_name: params.event_name.into(),
213                value: params.payload.value,
214                stripe_customer_id: params.payload.stripe_customer_id.clone(),
215                timestamp: params.timestamp,
216            });
217
218        Ok(())
219    }
220
221    async fn create_checkout_session(
222        &self,
223        params: StripeCreateCheckoutSessionParams<'_>,
224    ) -> Result<StripeCheckoutSession> {
225        self.create_checkout_session_calls
226            .lock()
227            .push(StripeCreateCheckoutSessionCall {
228                customer: params.customer.cloned(),
229                client_reference_id: params.client_reference_id.map(|id| id.to_string()),
230                mode: params.mode,
231                line_items: params.line_items,
232                payment_method_collection: params.payment_method_collection,
233                subscription_data: params.subscription_data,
234                success_url: params.success_url.map(|url| url.to_string()),
235                billing_address_collection: params.billing_address_collection,
236            });
237
238        Ok(StripeCheckoutSession {
239            url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
240        })
241    }
242}