1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use chrono::{Duration, Utc};
6use collections::HashMap;
7use parking_lot::Mutex;
8use uuid::Uuid;
9
10use crate::stripe_client::{
11 CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
12 StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
13 StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
14 StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
15 StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
16 StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
17 StripeSubscriptionItemId, UpdateCustomerParams, UpdateSubscriptionParams,
18};
19
20#[derive(Debug, Clone)]
21pub struct StripeCreateMeterEventCall {
22 pub identifier: Arc<str>,
23 pub event_name: Arc<str>,
24 pub value: u64,
25 pub stripe_customer_id: StripeCustomerId,
26 pub timestamp: Option<i64>,
27}
28
29#[derive(Debug, Clone)]
30pub struct StripeCreateCheckoutSessionCall {
31 pub customer: Option<StripeCustomerId>,
32 pub client_reference_id: Option<String>,
33 pub mode: Option<StripeCheckoutSessionMode>,
34 pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
35 pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
36 pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
37 pub success_url: Option<String>,
38 pub billing_address_collection: Option<StripeBillingAddressCollection>,
39}
40
41pub struct FakeStripeClient {
42 pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
43 pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
44 pub update_subscription_calls:
45 Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
46 pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
47 pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
48 pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
49 pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
50}
51
52impl FakeStripeClient {
53 pub fn new() -> Self {
54 Self {
55 customers: Arc::new(Mutex::new(HashMap::default())),
56 subscriptions: Arc::new(Mutex::new(HashMap::default())),
57 update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
58 prices: Arc::new(Mutex::new(HashMap::default())),
59 meters: Arc::new(Mutex::new(HashMap::default())),
60 create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
61 create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
62 }
63 }
64}
65
66#[async_trait]
67impl StripeClient for FakeStripeClient {
68 async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
69 Ok(self
70 .customers
71 .lock()
72 .values()
73 .filter(|customer| customer.email.as_deref() == Some(email))
74 .cloned()
75 .collect())
76 }
77
78 async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
79 self.customers
80 .lock()
81 .get(customer_id)
82 .cloned()
83 .ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
84 }
85
86 async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
87 let customer = StripeCustomer {
88 id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
89 email: params.email.map(|email| email.to_string()),
90 };
91
92 self.customers
93 .lock()
94 .insert(customer.id.clone(), customer.clone());
95
96 Ok(customer)
97 }
98
99 async fn update_customer(
100 &self,
101 customer_id: &StripeCustomerId,
102 params: UpdateCustomerParams<'_>,
103 ) -> Result<StripeCustomer> {
104 let mut customers = self.customers.lock();
105 if let Some(customer) = customers.get_mut(customer_id) {
106 if let Some(email) = params.email {
107 customer.email = Some(email.to_string());
108 }
109 Ok(customer.clone())
110 } else {
111 Err(anyhow!("no customer found for {customer_id:?}"))
112 }
113 }
114
115 async fn list_subscriptions_for_customer(
116 &self,
117 customer_id: &StripeCustomerId,
118 ) -> Result<Vec<StripeSubscription>> {
119 let subscriptions = self
120 .subscriptions
121 .lock()
122 .values()
123 .filter(|subscription| subscription.customer == *customer_id)
124 .cloned()
125 .collect();
126
127 Ok(subscriptions)
128 }
129
130 async fn get_subscription(
131 &self,
132 subscription_id: &StripeSubscriptionId,
133 ) -> Result<StripeSubscription> {
134 self.subscriptions
135 .lock()
136 .get(subscription_id)
137 .cloned()
138 .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
139 }
140
141 async fn create_subscription(
142 &self,
143 params: StripeCreateSubscriptionParams,
144 ) -> Result<StripeSubscription> {
145 let now = Utc::now();
146
147 let subscription = StripeSubscription {
148 id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
149 customer: params.customer,
150 status: stripe::SubscriptionStatus::Active,
151 current_period_start: now.timestamp(),
152 current_period_end: (now + Duration::days(30)).timestamp(),
153 items: params
154 .items
155 .into_iter()
156 .map(|item| StripeSubscriptionItem {
157 id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
158 price: item
159 .price
160 .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
161 })
162 .collect(),
163 cancel_at: None,
164 cancellation_details: None,
165 };
166
167 self.subscriptions
168 .lock()
169 .insert(subscription.id.clone(), subscription.clone());
170
171 Ok(subscription)
172 }
173
174 async fn update_subscription(
175 &self,
176 subscription_id: &StripeSubscriptionId,
177 params: UpdateSubscriptionParams,
178 ) -> Result<()> {
179 let subscription = self.get_subscription(subscription_id).await?;
180
181 self.update_subscription_calls
182 .lock()
183 .push((subscription.id, params));
184
185 Ok(())
186 }
187
188 async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
189 // TODO: Implement fake subscription cancellation.
190 let _ = subscription_id;
191
192 Ok(())
193 }
194
195 async fn list_prices(&self) -> Result<Vec<StripePrice>> {
196 let prices = self.prices.lock().values().cloned().collect();
197
198 Ok(prices)
199 }
200
201 async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
202 let meters = self.meters.lock().values().cloned().collect();
203
204 Ok(meters)
205 }
206
207 async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
208 self.create_meter_event_calls
209 .lock()
210 .push(StripeCreateMeterEventCall {
211 identifier: params.identifier.into(),
212 event_name: params.event_name.into(),
213 value: params.payload.value,
214 stripe_customer_id: params.payload.stripe_customer_id.clone(),
215 timestamp: params.timestamp,
216 });
217
218 Ok(())
219 }
220
221 async fn create_checkout_session(
222 &self,
223 params: StripeCreateCheckoutSessionParams<'_>,
224 ) -> Result<StripeCheckoutSession> {
225 self.create_checkout_session_calls
226 .lock()
227 .push(StripeCreateCheckoutSessionCall {
228 customer: params.customer.cloned(),
229 client_reference_id: params.client_reference_id.map(|id| id.to_string()),
230 mode: params.mode,
231 line_items: params.line_items,
232 payment_method_collection: params.payment_method_collection,
233 subscription_data: params.subscription_data,
234 success_url: params.success_url.map(|url| url.to_string()),
235 billing_address_collection: params.billing_address_collection,
236 });
237
238 Ok(StripeCheckoutSession {
239 url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
240 })
241 }
242}