fake_stripe_client.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Result, anyhow};
  4use async_trait::async_trait;
  5use chrono::{Duration, Utc};
  6use collections::HashMap;
  7use parking_lot::Mutex;
  8use uuid::Uuid;
  9
 10use crate::stripe_client::{
 11    CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
 12    StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
 13    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
 14    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
 15    StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
 16    StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
 17    StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, UpdateCustomerParams,
 18    UpdateSubscriptionParams,
 19};
 20
 21#[derive(Debug, Clone)]
 22pub struct StripeCreateMeterEventCall {
 23    pub identifier: Arc<str>,
 24    pub event_name: Arc<str>,
 25    pub value: u64,
 26    pub stripe_customer_id: StripeCustomerId,
 27    pub timestamp: Option<i64>,
 28}
 29
 30#[derive(Debug, Clone)]
 31pub struct StripeCreateCheckoutSessionCall {
 32    pub customer: Option<StripeCustomerId>,
 33    pub client_reference_id: Option<String>,
 34    pub mode: Option<StripeCheckoutSessionMode>,
 35    pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
 36    pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
 37    pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
 38    pub success_url: Option<String>,
 39    pub billing_address_collection: Option<StripeBillingAddressCollection>,
 40    pub customer_update: Option<StripeCustomerUpdate>,
 41}
 42
 43pub struct FakeStripeClient {
 44    pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
 45    pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
 46    pub update_subscription_calls:
 47        Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
 48    pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
 49    pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
 50    pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
 51    pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
 52}
 53
 54impl FakeStripeClient {
 55    pub fn new() -> Self {
 56        Self {
 57            customers: Arc::new(Mutex::new(HashMap::default())),
 58            subscriptions: Arc::new(Mutex::new(HashMap::default())),
 59            update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
 60            prices: Arc::new(Mutex::new(HashMap::default())),
 61            meters: Arc::new(Mutex::new(HashMap::default())),
 62            create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
 63            create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
 64        }
 65    }
 66}
 67
 68#[async_trait]
 69impl StripeClient for FakeStripeClient {
 70    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
 71        Ok(self
 72            .customers
 73            .lock()
 74            .values()
 75            .filter(|customer| customer.email.as_deref() == Some(email))
 76            .cloned()
 77            .collect())
 78    }
 79
 80    async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
 81        self.customers
 82            .lock()
 83            .get(customer_id)
 84            .cloned()
 85            .ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
 86    }
 87
 88    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
 89        let customer = StripeCustomer {
 90            id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
 91            email: params.email.map(|email| email.to_string()),
 92        };
 93
 94        self.customers
 95            .lock()
 96            .insert(customer.id.clone(), customer.clone());
 97
 98        Ok(customer)
 99    }
100
101    async fn update_customer(
102        &self,
103        customer_id: &StripeCustomerId,
104        params: UpdateCustomerParams<'_>,
105    ) -> Result<StripeCustomer> {
106        let mut customers = self.customers.lock();
107        if let Some(customer) = customers.get_mut(customer_id) {
108            if let Some(email) = params.email {
109                customer.email = Some(email.to_string());
110            }
111            Ok(customer.clone())
112        } else {
113            Err(anyhow!("no customer found for {customer_id:?}"))
114        }
115    }
116
117    async fn list_subscriptions_for_customer(
118        &self,
119        customer_id: &StripeCustomerId,
120    ) -> Result<Vec<StripeSubscription>> {
121        let subscriptions = self
122            .subscriptions
123            .lock()
124            .values()
125            .filter(|subscription| subscription.customer == *customer_id)
126            .cloned()
127            .collect();
128
129        Ok(subscriptions)
130    }
131
132    async fn get_subscription(
133        &self,
134        subscription_id: &StripeSubscriptionId,
135    ) -> Result<StripeSubscription> {
136        self.subscriptions
137            .lock()
138            .get(subscription_id)
139            .cloned()
140            .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
141    }
142
143    async fn create_subscription(
144        &self,
145        params: StripeCreateSubscriptionParams,
146    ) -> Result<StripeSubscription> {
147        let now = Utc::now();
148
149        let subscription = StripeSubscription {
150            id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
151            customer: params.customer,
152            status: stripe::SubscriptionStatus::Active,
153            current_period_start: now.timestamp(),
154            current_period_end: (now + Duration::days(30)).timestamp(),
155            items: params
156                .items
157                .into_iter()
158                .map(|item| StripeSubscriptionItem {
159                    id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
160                    price: item
161                        .price
162                        .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
163                })
164                .collect(),
165            cancel_at: None,
166            cancellation_details: None,
167        };
168
169        self.subscriptions
170            .lock()
171            .insert(subscription.id.clone(), subscription.clone());
172
173        Ok(subscription)
174    }
175
176    async fn update_subscription(
177        &self,
178        subscription_id: &StripeSubscriptionId,
179        params: UpdateSubscriptionParams,
180    ) -> Result<()> {
181        let subscription = self.get_subscription(subscription_id).await?;
182
183        self.update_subscription_calls
184            .lock()
185            .push((subscription.id, params));
186
187        Ok(())
188    }
189
190    async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
191        // TODO: Implement fake subscription cancellation.
192        let _ = subscription_id;
193
194        Ok(())
195    }
196
197    async fn list_prices(&self) -> Result<Vec<StripePrice>> {
198        let prices = self.prices.lock().values().cloned().collect();
199
200        Ok(prices)
201    }
202
203    async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
204        let meters = self.meters.lock().values().cloned().collect();
205
206        Ok(meters)
207    }
208
209    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
210        self.create_meter_event_calls
211            .lock()
212            .push(StripeCreateMeterEventCall {
213                identifier: params.identifier.into(),
214                event_name: params.event_name.into(),
215                value: params.payload.value,
216                stripe_customer_id: params.payload.stripe_customer_id.clone(),
217                timestamp: params.timestamp,
218            });
219
220        Ok(())
221    }
222
223    async fn create_checkout_session(
224        &self,
225        params: StripeCreateCheckoutSessionParams<'_>,
226    ) -> Result<StripeCheckoutSession> {
227        self.create_checkout_session_calls
228            .lock()
229            .push(StripeCreateCheckoutSessionCall {
230                customer: params.customer.cloned(),
231                client_reference_id: params.client_reference_id.map(|id| id.to_string()),
232                mode: params.mode,
233                line_items: params.line_items,
234                payment_method_collection: params.payment_method_collection,
235                subscription_data: params.subscription_data,
236                success_url: params.success_url.map(|url| url.to_string()),
237                billing_address_collection: params.billing_address_collection,
238                customer_update: params.customer_update,
239            });
240
241        Ok(StripeCheckoutSession {
242            url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
243        })
244    }
245}