1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use chrono::{Duration, Utc};
6use collections::HashMap;
7use parking_lot::Mutex;
8use uuid::Uuid;
9
10use crate::stripe_client::{
11 CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
12 StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
13 StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
14 StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
15 StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
16 StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
17 StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection,
18 UpdateCustomerParams, UpdateSubscriptionParams,
19};
20
21#[derive(Debug, Clone)]
22pub struct StripeCreateMeterEventCall {
23 pub identifier: Arc<str>,
24 pub event_name: Arc<str>,
25 pub value: u64,
26 pub stripe_customer_id: StripeCustomerId,
27 pub timestamp: Option<i64>,
28}
29
30#[derive(Debug, Clone)]
31pub struct StripeCreateCheckoutSessionCall {
32 pub customer: Option<StripeCustomerId>,
33 pub client_reference_id: Option<String>,
34 pub mode: Option<StripeCheckoutSessionMode>,
35 pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
36 pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
37 pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
38 pub success_url: Option<String>,
39 pub billing_address_collection: Option<StripeBillingAddressCollection>,
40 pub customer_update: Option<StripeCustomerUpdate>,
41 pub tax_id_collection: Option<StripeTaxIdCollection>,
42}
43
44pub struct FakeStripeClient {
45 pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
46 pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
47 pub update_subscription_calls:
48 Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
49 pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
50 pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
51 pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
52 pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
53}
54
55impl FakeStripeClient {
56 pub fn new() -> Self {
57 Self {
58 customers: Arc::new(Mutex::new(HashMap::default())),
59 subscriptions: Arc::new(Mutex::new(HashMap::default())),
60 update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
61 prices: Arc::new(Mutex::new(HashMap::default())),
62 meters: Arc::new(Mutex::new(HashMap::default())),
63 create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
64 create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
65 }
66 }
67}
68
69#[async_trait]
70impl StripeClient for FakeStripeClient {
71 async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
72 Ok(self
73 .customers
74 .lock()
75 .values()
76 .filter(|customer| customer.email.as_deref() == Some(email))
77 .cloned()
78 .collect())
79 }
80
81 async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
82 self.customers
83 .lock()
84 .get(customer_id)
85 .cloned()
86 .ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
87 }
88
89 async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
90 let customer = StripeCustomer {
91 id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
92 email: params.email.map(|email| email.to_string()),
93 };
94
95 self.customers
96 .lock()
97 .insert(customer.id.clone(), customer.clone());
98
99 Ok(customer)
100 }
101
102 async fn update_customer(
103 &self,
104 customer_id: &StripeCustomerId,
105 params: UpdateCustomerParams<'_>,
106 ) -> Result<StripeCustomer> {
107 let mut customers = self.customers.lock();
108 if let Some(customer) = customers.get_mut(customer_id) {
109 if let Some(email) = params.email {
110 customer.email = Some(email.to_string());
111 }
112 Ok(customer.clone())
113 } else {
114 Err(anyhow!("no customer found for {customer_id:?}"))
115 }
116 }
117
118 async fn list_subscriptions_for_customer(
119 &self,
120 customer_id: &StripeCustomerId,
121 ) -> Result<Vec<StripeSubscription>> {
122 let subscriptions = self
123 .subscriptions
124 .lock()
125 .values()
126 .filter(|subscription| subscription.customer == *customer_id)
127 .cloned()
128 .collect();
129
130 Ok(subscriptions)
131 }
132
133 async fn get_subscription(
134 &self,
135 subscription_id: &StripeSubscriptionId,
136 ) -> Result<StripeSubscription> {
137 self.subscriptions
138 .lock()
139 .get(subscription_id)
140 .cloned()
141 .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
142 }
143
144 async fn create_subscription(
145 &self,
146 params: StripeCreateSubscriptionParams,
147 ) -> Result<StripeSubscription> {
148 let now = Utc::now();
149
150 let subscription = StripeSubscription {
151 id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
152 customer: params.customer,
153 status: stripe::SubscriptionStatus::Active,
154 current_period_start: now.timestamp(),
155 current_period_end: (now + Duration::days(30)).timestamp(),
156 items: params
157 .items
158 .into_iter()
159 .map(|item| StripeSubscriptionItem {
160 id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
161 price: item
162 .price
163 .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
164 })
165 .collect(),
166 cancel_at: None,
167 cancellation_details: None,
168 };
169
170 self.subscriptions
171 .lock()
172 .insert(subscription.id.clone(), subscription.clone());
173
174 Ok(subscription)
175 }
176
177 async fn update_subscription(
178 &self,
179 subscription_id: &StripeSubscriptionId,
180 params: UpdateSubscriptionParams,
181 ) -> Result<()> {
182 let subscription = self.get_subscription(subscription_id).await?;
183
184 self.update_subscription_calls
185 .lock()
186 .push((subscription.id, params));
187
188 Ok(())
189 }
190
191 async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
192 // TODO: Implement fake subscription cancellation.
193 let _ = subscription_id;
194
195 Ok(())
196 }
197
198 async fn list_prices(&self) -> Result<Vec<StripePrice>> {
199 let prices = self.prices.lock().values().cloned().collect();
200
201 Ok(prices)
202 }
203
204 async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
205 let meters = self.meters.lock().values().cloned().collect();
206
207 Ok(meters)
208 }
209
210 async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
211 self.create_meter_event_calls
212 .lock()
213 .push(StripeCreateMeterEventCall {
214 identifier: params.identifier.into(),
215 event_name: params.event_name.into(),
216 value: params.payload.value,
217 stripe_customer_id: params.payload.stripe_customer_id.clone(),
218 timestamp: params.timestamp,
219 });
220
221 Ok(())
222 }
223
224 async fn create_checkout_session(
225 &self,
226 params: StripeCreateCheckoutSessionParams<'_>,
227 ) -> Result<StripeCheckoutSession> {
228 self.create_checkout_session_calls
229 .lock()
230 .push(StripeCreateCheckoutSessionCall {
231 customer: params.customer.cloned(),
232 client_reference_id: params.client_reference_id.map(|id| id.to_string()),
233 mode: params.mode,
234 line_items: params.line_items,
235 payment_method_collection: params.payment_method_collection,
236 subscription_data: params.subscription_data,
237 success_url: params.success_url.map(|url| url.to_string()),
238 billing_address_collection: params.billing_address_collection,
239 customer_update: params.customer_update,
240 tax_id_collection: params.tax_id_collection,
241 });
242
243 Ok(StripeCheckoutSession {
244 url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
245 })
246 }
247}