fake_stripe_client.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Result, anyhow};
  4use async_trait::async_trait;
  5use chrono::{Duration, Utc};
  6use collections::HashMap;
  7use parking_lot::Mutex;
  8use uuid::Uuid;
  9
 10use crate::stripe_client::{
 11    CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
 12    StripeCheckoutSessionPaymentMethodCollection, StripeClient,
 13    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
 14    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
 15    StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
 16    StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
 17    StripeSubscriptionItemId, UpdateCustomerParams, UpdateSubscriptionParams,
 18};
 19
 20#[derive(Debug, Clone)]
 21pub struct StripeCreateMeterEventCall {
 22    pub identifier: Arc<str>,
 23    pub event_name: Arc<str>,
 24    pub value: u64,
 25    pub stripe_customer_id: StripeCustomerId,
 26    pub timestamp: Option<i64>,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub struct StripeCreateCheckoutSessionCall {
 31    pub customer: Option<StripeCustomerId>,
 32    pub client_reference_id: Option<String>,
 33    pub mode: Option<StripeCheckoutSessionMode>,
 34    pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
 35    pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
 36    pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
 37    pub success_url: Option<String>,
 38}
 39
 40pub struct FakeStripeClient {
 41    pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
 42    pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
 43    pub update_subscription_calls:
 44        Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
 45    pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
 46    pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
 47    pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
 48    pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
 49}
 50
 51impl FakeStripeClient {
 52    pub fn new() -> Self {
 53        Self {
 54            customers: Arc::new(Mutex::new(HashMap::default())),
 55            subscriptions: Arc::new(Mutex::new(HashMap::default())),
 56            update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
 57            prices: Arc::new(Mutex::new(HashMap::default())),
 58            meters: Arc::new(Mutex::new(HashMap::default())),
 59            create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
 60            create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
 61        }
 62    }
 63}
 64
 65#[async_trait]
 66impl StripeClient for FakeStripeClient {
 67    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
 68        Ok(self
 69            .customers
 70            .lock()
 71            .values()
 72            .filter(|customer| customer.email.as_deref() == Some(email))
 73            .cloned()
 74            .collect())
 75    }
 76
 77    async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
 78        self.customers
 79            .lock()
 80            .get(customer_id)
 81            .cloned()
 82            .ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
 83    }
 84
 85    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
 86        let customer = StripeCustomer {
 87            id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
 88            email: params.email.map(|email| email.to_string()),
 89        };
 90
 91        self.customers
 92            .lock()
 93            .insert(customer.id.clone(), customer.clone());
 94
 95        Ok(customer)
 96    }
 97
 98    async fn update_customer(
 99        &self,
100        customer_id: &StripeCustomerId,
101        params: UpdateCustomerParams<'_>,
102    ) -> Result<StripeCustomer> {
103        let mut customers = self.customers.lock();
104        if let Some(customer) = customers.get_mut(customer_id) {
105            if let Some(email) = params.email {
106                customer.email = Some(email.to_string());
107            }
108            Ok(customer.clone())
109        } else {
110            Err(anyhow!("no customer found for {customer_id:?}"))
111        }
112    }
113
114    async fn list_subscriptions_for_customer(
115        &self,
116        customer_id: &StripeCustomerId,
117    ) -> Result<Vec<StripeSubscription>> {
118        let subscriptions = self
119            .subscriptions
120            .lock()
121            .values()
122            .filter(|subscription| subscription.customer == *customer_id)
123            .cloned()
124            .collect();
125
126        Ok(subscriptions)
127    }
128
129    async fn get_subscription(
130        &self,
131        subscription_id: &StripeSubscriptionId,
132    ) -> Result<StripeSubscription> {
133        self.subscriptions
134            .lock()
135            .get(subscription_id)
136            .cloned()
137            .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
138    }
139
140    async fn create_subscription(
141        &self,
142        params: StripeCreateSubscriptionParams,
143    ) -> Result<StripeSubscription> {
144        let now = Utc::now();
145
146        let subscription = StripeSubscription {
147            id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
148            customer: params.customer,
149            status: stripe::SubscriptionStatus::Active,
150            current_period_start: now.timestamp(),
151            current_period_end: (now + Duration::days(30)).timestamp(),
152            items: params
153                .items
154                .into_iter()
155                .map(|item| StripeSubscriptionItem {
156                    id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
157                    price: item
158                        .price
159                        .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
160                })
161                .collect(),
162            cancel_at: None,
163            cancellation_details: None,
164        };
165
166        self.subscriptions
167            .lock()
168            .insert(subscription.id.clone(), subscription.clone());
169
170        Ok(subscription)
171    }
172
173    async fn update_subscription(
174        &self,
175        subscription_id: &StripeSubscriptionId,
176        params: UpdateSubscriptionParams,
177    ) -> Result<()> {
178        let subscription = self.get_subscription(subscription_id).await?;
179
180        self.update_subscription_calls
181            .lock()
182            .push((subscription.id, params));
183
184        Ok(())
185    }
186
187    async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
188        // TODO: Implement fake subscription cancellation.
189        let _ = subscription_id;
190
191        Ok(())
192    }
193
194    async fn list_prices(&self) -> Result<Vec<StripePrice>> {
195        let prices = self.prices.lock().values().cloned().collect();
196
197        Ok(prices)
198    }
199
200    async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
201        let meters = self.meters.lock().values().cloned().collect();
202
203        Ok(meters)
204    }
205
206    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
207        self.create_meter_event_calls
208            .lock()
209            .push(StripeCreateMeterEventCall {
210                identifier: params.identifier.into(),
211                event_name: params.event_name.into(),
212                value: params.payload.value,
213                stripe_customer_id: params.payload.stripe_customer_id.clone(),
214                timestamp: params.timestamp,
215            });
216
217        Ok(())
218    }
219
220    async fn create_checkout_session(
221        &self,
222        params: StripeCreateCheckoutSessionParams<'_>,
223    ) -> Result<StripeCheckoutSession> {
224        self.create_checkout_session_calls
225            .lock()
226            .push(StripeCreateCheckoutSessionCall {
227                customer: params.customer.cloned(),
228                client_reference_id: params.client_reference_id.map(|id| id.to_string()),
229                mode: params.mode,
230                line_items: params.line_items,
231                payment_method_collection: params.payment_method_collection,
232                subscription_data: params.subscription_data,
233                success_url: params.success_url.map(|url| url.to_string()),
234            });
235
236        Ok(StripeCheckoutSession {
237            url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
238        })
239    }
240}