1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use chrono::{Duration, Utc};
6use collections::HashMap;
7use parking_lot::Mutex;
8use uuid::Uuid;
9
10use crate::stripe_client::{
11 CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
12 StripeCheckoutSessionPaymentMethodCollection, StripeClient,
13 StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
14 StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
15 StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
16 StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
17 StripeSubscriptionItemId, UpdateCustomerParams, UpdateSubscriptionParams,
18};
19
20#[derive(Debug, Clone)]
21pub struct StripeCreateMeterEventCall {
22 pub identifier: Arc<str>,
23 pub event_name: Arc<str>,
24 pub value: u64,
25 pub stripe_customer_id: StripeCustomerId,
26 pub timestamp: Option<i64>,
27}
28
29#[derive(Debug, Clone)]
30pub struct StripeCreateCheckoutSessionCall {
31 pub customer: Option<StripeCustomerId>,
32 pub client_reference_id: Option<String>,
33 pub mode: Option<StripeCheckoutSessionMode>,
34 pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
35 pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
36 pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
37 pub success_url: Option<String>,
38}
39
40pub struct FakeStripeClient {
41 pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
42 pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
43 pub update_subscription_calls:
44 Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
45 pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
46 pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
47 pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
48 pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
49}
50
51impl FakeStripeClient {
52 pub fn new() -> Self {
53 Self {
54 customers: Arc::new(Mutex::new(HashMap::default())),
55 subscriptions: Arc::new(Mutex::new(HashMap::default())),
56 update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
57 prices: Arc::new(Mutex::new(HashMap::default())),
58 meters: Arc::new(Mutex::new(HashMap::default())),
59 create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
60 create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
61 }
62 }
63}
64
65#[async_trait]
66impl StripeClient for FakeStripeClient {
67 async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
68 Ok(self
69 .customers
70 .lock()
71 .values()
72 .filter(|customer| customer.email.as_deref() == Some(email))
73 .cloned()
74 .collect())
75 }
76
77 async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
78 self.customers
79 .lock()
80 .get(customer_id)
81 .cloned()
82 .ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
83 }
84
85 async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
86 let customer = StripeCustomer {
87 id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
88 email: params.email.map(|email| email.to_string()),
89 };
90
91 self.customers
92 .lock()
93 .insert(customer.id.clone(), customer.clone());
94
95 Ok(customer)
96 }
97
98 async fn update_customer(
99 &self,
100 customer_id: &StripeCustomerId,
101 params: UpdateCustomerParams<'_>,
102 ) -> Result<StripeCustomer> {
103 let mut customers = self.customers.lock();
104 if let Some(customer) = customers.get_mut(customer_id) {
105 if let Some(email) = params.email {
106 customer.email = Some(email.to_string());
107 }
108 Ok(customer.clone())
109 } else {
110 Err(anyhow!("no customer found for {customer_id:?}"))
111 }
112 }
113
114 async fn list_subscriptions_for_customer(
115 &self,
116 customer_id: &StripeCustomerId,
117 ) -> Result<Vec<StripeSubscription>> {
118 let subscriptions = self
119 .subscriptions
120 .lock()
121 .values()
122 .filter(|subscription| subscription.customer == *customer_id)
123 .cloned()
124 .collect();
125
126 Ok(subscriptions)
127 }
128
129 async fn get_subscription(
130 &self,
131 subscription_id: &StripeSubscriptionId,
132 ) -> Result<StripeSubscription> {
133 self.subscriptions
134 .lock()
135 .get(subscription_id)
136 .cloned()
137 .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
138 }
139
140 async fn create_subscription(
141 &self,
142 params: StripeCreateSubscriptionParams,
143 ) -> Result<StripeSubscription> {
144 let now = Utc::now();
145
146 let subscription = StripeSubscription {
147 id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
148 customer: params.customer,
149 status: stripe::SubscriptionStatus::Active,
150 current_period_start: now.timestamp(),
151 current_period_end: (now + Duration::days(30)).timestamp(),
152 items: params
153 .items
154 .into_iter()
155 .map(|item| StripeSubscriptionItem {
156 id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
157 price: item
158 .price
159 .and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
160 })
161 .collect(),
162 cancel_at: None,
163 cancellation_details: None,
164 };
165
166 self.subscriptions
167 .lock()
168 .insert(subscription.id.clone(), subscription.clone());
169
170 Ok(subscription)
171 }
172
173 async fn update_subscription(
174 &self,
175 subscription_id: &StripeSubscriptionId,
176 params: UpdateSubscriptionParams,
177 ) -> Result<()> {
178 let subscription = self.get_subscription(subscription_id).await?;
179
180 self.update_subscription_calls
181 .lock()
182 .push((subscription.id, params));
183
184 Ok(())
185 }
186
187 async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
188 // TODO: Implement fake subscription cancellation.
189 let _ = subscription_id;
190
191 Ok(())
192 }
193
194 async fn list_prices(&self) -> Result<Vec<StripePrice>> {
195 let prices = self.prices.lock().values().cloned().collect();
196
197 Ok(prices)
198 }
199
200 async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
201 let meters = self.meters.lock().values().cloned().collect();
202
203 Ok(meters)
204 }
205
206 async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
207 self.create_meter_event_calls
208 .lock()
209 .push(StripeCreateMeterEventCall {
210 identifier: params.identifier.into(),
211 event_name: params.event_name.into(),
212 value: params.payload.value,
213 stripe_customer_id: params.payload.stripe_customer_id.clone(),
214 timestamp: params.timestamp,
215 });
216
217 Ok(())
218 }
219
220 async fn create_checkout_session(
221 &self,
222 params: StripeCreateCheckoutSessionParams<'_>,
223 ) -> Result<StripeCheckoutSession> {
224 self.create_checkout_session_calls
225 .lock()
226 .push(StripeCreateCheckoutSessionCall {
227 customer: params.customer.cloned(),
228 client_reference_id: params.client_reference_id.map(|id| id.to_string()),
229 mode: params.mode,
230 line_items: params.line_items,
231 payment_method_collection: params.payment_method_collection,
232 subscription_data: params.subscription_data,
233 success_url: params.success_url.map(|url| url.to_string()),
234 });
235
236 Ok(StripeCheckoutSession {
237 url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
238 })
239 }
240}