fake_stripe_client.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Result, anyhow};
  4use async_trait::async_trait;
  5use chrono::{Duration, Utc};
  6use collections::HashMap;
  7use parking_lot::Mutex;
  8use uuid::Uuid;
  9
 10use crate::stripe_client::{
 11    CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
 12    StripeCheckoutSessionPaymentMethodCollection, StripeClient,
 13    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
 14    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
 15    StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
 16    StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
 17    StripeSubscriptionItemId, UpdateSubscriptionParams,
 18};
 19
 20#[derive(Debug, Clone)]
 21pub struct StripeCreateMeterEventCall {
 22    pub identifier: Arc<str>,
 23    pub event_name: Arc<str>,
 24    pub value: u64,
 25    pub stripe_customer_id: StripeCustomerId,
 26    pub timestamp: Option<i64>,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub struct StripeCreateCheckoutSessionCall {
 31    pub customer: Option<StripeCustomerId>,
 32    pub client_reference_id: Option<String>,
 33    pub mode: Option<StripeCheckoutSessionMode>,
 34    pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
 35    pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
 36    pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
 37    pub success_url: Option<String>,
 38}
 39
 40pub struct FakeStripeClient {
 41    pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
 42    pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
 43    pub update_subscription_calls:
 44        Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
 45    pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
 46    pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
 47    pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
 48    pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
 49}
 50
 51impl FakeStripeClient {
 52    pub fn new() -> Self {
 53        Self {
 54            customers: Arc::new(Mutex::new(HashMap::default())),
 55            subscriptions: Arc::new(Mutex::new(HashMap::default())),
 56            update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
 57            prices: Arc::new(Mutex::new(HashMap::default())),
 58            meters: Arc::new(Mutex::new(HashMap::default())),
 59            create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
 60            create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
 61        }
 62    }
 63}
 64
 65#[async_trait]
 66impl StripeClient for FakeStripeClient {
 67    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
 68        Ok(self
 69            .customers
 70            .lock()
 71            .values()
 72            .filter(|customer| customer.email.as_deref() == Some(email))
 73            .cloned()
 74            .collect())
 75    }
 76
 77    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
 78        let customer = StripeCustomer {
 79            id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
 80            email: params.email.map(|email| email.to_string()),
 81        };
 82
 83        self.customers
 84            .lock()
 85            .insert(customer.id.clone(), customer.clone());
 86
 87        Ok(customer)
 88    }
 89
 90    async fn list_subscriptions_for_customer(
 91        &self,
 92        customer_id: &StripeCustomerId,
 93    ) -> Result<Vec<StripeSubscription>> {
 94        let subscriptions = self
 95            .subscriptions
 96            .lock()
 97            .values()
 98            .filter(|subscription| subscription.customer == *customer_id)
 99            .cloned()
100            .collect();
101
102        Ok(subscriptions)
103    }
104
105    async fn get_subscription(
106        &self,
107        subscription_id: &StripeSubscriptionId,
108    ) -> Result<StripeSubscription> {
109        self.subscriptions
110            .lock()
111            .get(subscription_id)
112            .cloned()
113            .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
114    }
115
116    async fn create_subscription(
117        &self,
118        params: StripeCreateSubscriptionParams,
119    ) -> Result<StripeSubscription> {
120        let now = Utc::now();
121
122        let subscription = StripeSubscription {
123            id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
124            customer: params.customer,
125            status: stripe::SubscriptionStatus::Active,
126            current_period_start: now.timestamp(),
127            current_period_end: (now + Duration::days(30)).timestamp(),
128            items: params
129                .items
130                .into_iter()
131                .map(|item| StripeSubscriptionItem {
132                    id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
133                    price: item
134                        .price
135                        .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
136                })
137                .collect(),
138        };
139
140        self.subscriptions
141            .lock()
142            .insert(subscription.id.clone(), subscription.clone());
143
144        Ok(subscription)
145    }
146
147    async fn update_subscription(
148        &self,
149        subscription_id: &StripeSubscriptionId,
150        params: UpdateSubscriptionParams,
151    ) -> Result<()> {
152        let subscription = self.get_subscription(subscription_id).await?;
153
154        self.update_subscription_calls
155            .lock()
156            .push((subscription.id, params));
157
158        Ok(())
159    }
160
161    async fn list_prices(&self) -> Result<Vec<StripePrice>> {
162        let prices = self.prices.lock().values().cloned().collect();
163
164        Ok(prices)
165    }
166
167    async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
168        let meters = self.meters.lock().values().cloned().collect();
169
170        Ok(meters)
171    }
172
173    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
174        self.create_meter_event_calls
175            .lock()
176            .push(StripeCreateMeterEventCall {
177                identifier: params.identifier.into(),
178                event_name: params.event_name.into(),
179                value: params.payload.value,
180                stripe_customer_id: params.payload.stripe_customer_id.clone(),
181                timestamp: params.timestamp,
182            });
183
184        Ok(())
185    }
186
187    async fn create_checkout_session(
188        &self,
189        params: StripeCreateCheckoutSessionParams<'_>,
190    ) -> Result<StripeCheckoutSession> {
191        self.create_checkout_session_calls
192            .lock()
193            .push(StripeCreateCheckoutSessionCall {
194                customer: params.customer.cloned(),
195                client_reference_id: params.client_reference_id.map(|id| id.to_string()),
196                mode: params.mode,
197                line_items: params.line_items,
198                payment_method_collection: params.payment_method_collection,
199                subscription_data: params.subscription_data,
200                success_url: params.success_url.map(|url| url.to_string()),
201            });
202
203        Ok(StripeCheckoutSession {
204            url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
205        })
206    }
207}