real_stripe_client.rs

  1use std::str::FromStr as _;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result, anyhow};
  5use async_trait::async_trait;
  6use serde::Serialize;
  7use stripe::{
  8    CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection,
  9    CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData,
 10    CreateCheckoutSessionSubscriptionDataTrialSettings,
 11    CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
 12    CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
 13    CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
 14    SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
 15    UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
 16    UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
 17};
 18
 19use crate::stripe_client::{
 20    CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
 21    StripeCheckoutSessionPaymentMethodCollection, StripeClient,
 22    StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
 23    StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
 24    StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
 25    StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
 26    StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
 27    StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
 28};
 29
 30pub struct RealStripeClient {
 31    client: Arc<stripe::Client>,
 32}
 33
 34impl RealStripeClient {
 35    pub fn new(client: Arc<stripe::Client>) -> Self {
 36        Self { client }
 37    }
 38}
 39
 40#[async_trait]
 41impl StripeClient for RealStripeClient {
 42    async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
 43        let response = Customer::list(
 44            &self.client,
 45            &ListCustomers {
 46                email: Some(email),
 47                ..Default::default()
 48            },
 49        )
 50        .await?;
 51
 52        Ok(response
 53            .data
 54            .into_iter()
 55            .map(StripeCustomer::from)
 56            .collect())
 57    }
 58
 59    async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
 60        let customer = Customer::create(
 61            &self.client,
 62            CreateCustomer {
 63                email: params.email,
 64                ..Default::default()
 65            },
 66        )
 67        .await?;
 68
 69        Ok(StripeCustomer::from(customer))
 70    }
 71
 72    async fn get_subscription(
 73        &self,
 74        subscription_id: &StripeSubscriptionId,
 75    ) -> Result<StripeSubscription> {
 76        let subscription_id = subscription_id.try_into()?;
 77
 78        let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
 79
 80        Ok(StripeSubscription::from(subscription))
 81    }
 82
 83    async fn update_subscription(
 84        &self,
 85        subscription_id: &StripeSubscriptionId,
 86        params: UpdateSubscriptionParams,
 87    ) -> Result<()> {
 88        let subscription_id = subscription_id.try_into()?;
 89
 90        stripe::Subscription::update(
 91            &self.client,
 92            &subscription_id,
 93            stripe::UpdateSubscription {
 94                items: params.items.map(|items| {
 95                    items
 96                        .into_iter()
 97                        .map(|item| UpdateSubscriptionItems {
 98                            price: item.price.map(|price| price.to_string()),
 99                            ..Default::default()
100                        })
101                        .collect()
102                }),
103                trial_settings: params.trial_settings.map(Into::into),
104                ..Default::default()
105            },
106        )
107        .await?;
108
109        Ok(())
110    }
111
112    async fn list_prices(&self) -> Result<Vec<StripePrice>> {
113        let response = stripe::Price::list(
114            &self.client,
115            &stripe::ListPrices {
116                limit: Some(100),
117                ..Default::default()
118            },
119        )
120        .await?;
121
122        Ok(response.data.into_iter().map(StripePrice::from).collect())
123    }
124
125    async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
126        #[derive(Serialize)]
127        struct Params {
128            #[serde(skip_serializing_if = "Option::is_none")]
129            limit: Option<u64>,
130        }
131
132        let response = self
133            .client
134            .get_query::<stripe::List<StripeMeter>, _>(
135                "/billing/meters",
136                Params { limit: Some(100) },
137            )
138            .await?;
139
140        Ok(response.data)
141    }
142
143    async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
144        let identifier = params.identifier;
145        match self.client.post_form("/billing/meter_events", params).await {
146            Ok(event) => Ok(event),
147            Err(stripe::StripeError::Stripe(error)) => {
148                if error.http_status == 400
149                    && error
150                        .message
151                        .as_ref()
152                        .map_or(false, |message| message.contains(identifier))
153                {
154                    Ok(())
155                } else {
156                    Err(anyhow!(stripe::StripeError::Stripe(error)))
157                }
158            }
159            Err(error) => Err(anyhow!(error)),
160        }
161    }
162
163    async fn create_checkout_session(
164        &self,
165        params: StripeCreateCheckoutSessionParams<'_>,
166    ) -> Result<StripeCheckoutSession> {
167        let params = params.try_into()?;
168        let session = CheckoutSession::create(&self.client, params).await?;
169
170        Ok(session.into())
171    }
172}
173
174impl From<CustomerId> for StripeCustomerId {
175    fn from(value: CustomerId) -> Self {
176        Self(value.as_str().into())
177    }
178}
179
180impl TryFrom<StripeCustomerId> for CustomerId {
181    type Error = anyhow::Error;
182
183    fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
184        Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
185    }
186}
187
188impl TryFrom<&StripeCustomerId> for CustomerId {
189    type Error = anyhow::Error;
190
191    fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
192        Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
193    }
194}
195
196impl From<Customer> for StripeCustomer {
197    fn from(value: Customer) -> Self {
198        StripeCustomer {
199            id: value.id.into(),
200            email: value.email,
201        }
202    }
203}
204
205impl From<SubscriptionId> for StripeSubscriptionId {
206    fn from(value: SubscriptionId) -> Self {
207        Self(value.as_str().into())
208    }
209}
210
211impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
212    type Error = anyhow::Error;
213
214    fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
215        Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
216    }
217}
218
219impl From<Subscription> for StripeSubscription {
220    fn from(value: Subscription) -> Self {
221        Self {
222            id: value.id.into(),
223            items: value.items.data.into_iter().map(Into::into).collect(),
224        }
225    }
226}
227
228impl From<SubscriptionItemId> for StripeSubscriptionItemId {
229    fn from(value: SubscriptionItemId) -> Self {
230        Self(value.as_str().into())
231    }
232}
233
234impl From<SubscriptionItem> for StripeSubscriptionItem {
235    fn from(value: SubscriptionItem) -> Self {
236        Self {
237            id: value.id.into(),
238            price: value.price.map(Into::into),
239        }
240    }
241}
242
243impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
244    fn from(value: StripeSubscriptionTrialSettings) -> Self {
245        Self {
246            end_behavior: value.end_behavior.into(),
247        }
248    }
249}
250
251impl From<StripeSubscriptionTrialSettingsEndBehavior>
252    for UpdateSubscriptionTrialSettingsEndBehavior
253{
254    fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
255        Self {
256            missing_payment_method: value.missing_payment_method.into(),
257        }
258    }
259}
260
261impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
262    for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
263{
264    fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
265        match value {
266            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
267            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
268                Self::CreateInvoice
269            }
270            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
271        }
272    }
273}
274
275impl From<PriceId> for StripePriceId {
276    fn from(value: PriceId) -> Self {
277        Self(value.as_str().into())
278    }
279}
280
281impl TryFrom<StripePriceId> for PriceId {
282    type Error = anyhow::Error;
283
284    fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
285        Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
286    }
287}
288
289impl From<Price> for StripePrice {
290    fn from(value: Price) -> Self {
291        Self {
292            id: value.id.into(),
293            unit_amount: value.unit_amount,
294            lookup_key: value.lookup_key,
295            recurring: value.recurring.map(StripePriceRecurring::from),
296        }
297    }
298}
299
300impl From<Recurring> for StripePriceRecurring {
301    fn from(value: Recurring) -> Self {
302        Self { meter: value.meter }
303    }
304}
305
306impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
307    type Error = anyhow::Error;
308
309    fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
310        Ok(Self {
311            customer: value
312                .customer
313                .map(|customer_id| customer_id.try_into())
314                .transpose()?,
315            client_reference_id: value.client_reference_id,
316            mode: value.mode.map(Into::into),
317            line_items: value
318                .line_items
319                .map(|line_items| line_items.into_iter().map(Into::into).collect()),
320            payment_method_collection: value.payment_method_collection.map(Into::into),
321            subscription_data: value.subscription_data.map(Into::into),
322            success_url: value.success_url,
323            ..Default::default()
324        })
325    }
326}
327
328impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
329    fn from(value: StripeCheckoutSessionMode) -> Self {
330        match value {
331            StripeCheckoutSessionMode::Payment => Self::Payment,
332            StripeCheckoutSessionMode::Setup => Self::Setup,
333            StripeCheckoutSessionMode::Subscription => Self::Subscription,
334        }
335    }
336}
337
338impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
339    fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
340        Self {
341            price: value.price,
342            quantity: value.quantity,
343            ..Default::default()
344        }
345    }
346}
347
348impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
349    fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
350        match value {
351            StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
352            StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
353        }
354    }
355}
356
357impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
358    fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
359        Self {
360            trial_period_days: value.trial_period_days,
361            trial_settings: value.trial_settings.map(Into::into),
362            metadata: value.metadata,
363            ..Default::default()
364        }
365    }
366}
367
368impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
369    fn from(value: StripeSubscriptionTrialSettings) -> Self {
370        Self {
371            end_behavior: value.end_behavior.into(),
372        }
373    }
374}
375
376impl From<StripeSubscriptionTrialSettingsEndBehavior>
377    for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
378{
379    fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
380        Self {
381            missing_payment_method: value.missing_payment_method.into(),
382        }
383    }
384}
385
386impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
387    for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
388{
389    fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
390        match value {
391            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
392            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
393                Self::CreateInvoice
394            }
395            StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
396        }
397    }
398}
399
400impl From<CheckoutSession> for StripeCheckoutSession {
401    fn from(value: CheckoutSession) -> Self {
402        Self { url: value.url }
403    }
404}