@@ -5,13 +5,14 @@ use std::time::Duration;
use anyhow::{anyhow, bail, Context};
use axum::{extract, routing::post, Extension, Json, Router};
use reqwest::StatusCode;
+use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize};
use stripe::{
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
- CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
+ CreateCustomer, Customer, CustomerId, EventId, EventObject, EventType, Expandable, ListEvents,
SubscriptionStatus,
};
use util::ResultExt;
@@ -19,7 +20,7 @@ use util::ResultExt;
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
- CreateBillingSubscriptionParams,
+ CreateBillingSubscriptionParams, UpdateBillingCustomerParams, UpdateBillingSubscriptionParams,
};
use crate::{AppState, Error, Result};
@@ -231,6 +232,7 @@ async fn poll_stripe_events(
) -> anyhow::Result<()> {
let event_types = [
EventType::CustomerCreated.to_string(),
+ EventType::CustomerUpdated.to_string(),
EventType::CustomerSubscriptionCreated.to_string(),
EventType::CustomerSubscriptionUpdated.to_string(),
EventType::CustomerSubscriptionPaused.to_string(),
@@ -255,7 +257,7 @@ async fn poll_stripe_events(
let events = stripe::Event::list(stripe_client, ¶ms).await?;
for event in events.data {
match event.type_ {
- EventType::CustomerCreated => {
+ EventType::CustomerCreated | EventType::CustomerUpdated => {
handle_customer_event(app, stripe_client, event)
.await
.log_err();
@@ -283,15 +285,59 @@ async fn poll_stripe_events(
async fn handle_customer_event(
app: &Arc<AppState>,
- stripe_client: &stripe::Client,
+ _stripe_client: &stripe::Client,
event: stripe::Event,
) -> anyhow::Result<()> {
let EventObject::Customer(customer) = event.data.object else {
bail!("unexpected event payload for {}", event.id);
};
- find_or_create_billing_customer(app, stripe_client, Expandable::Object(Box::new(customer)))
- .await?;
+ log::info!("handling Stripe {} event: {}", event.type_, event.id);
+
+ let Some(email) = customer.email else {
+ log::info!("Stripe customer has no email: skipping");
+ return Ok(());
+ };
+
+ let Some(user) = app.db.get_user_by_email(&email).await? else {
+ log::info!("no user found for email: skipping");
+ return Ok(());
+ };
+
+ if let Some(existing_customer) = app
+ .db
+ .get_billing_customer_by_stripe_customer_id(&customer.id)
+ .await?
+ {
+ if should_ignore_event(&event.id, existing_customer.last_stripe_event_id.as_deref()) {
+ log::info!(
+ "ignoring Stripe event {} based on last seen event ID",
+ event.id
+ );
+ return Ok(());
+ }
+
+ app.db
+ .update_billing_customer(
+ existing_customer.id,
+ &UpdateBillingCustomerParams {
+ // For now we just update the last event ID for the customer
+ // and leave the rest of the information as-is, as it is not
+ // likely to change.
+ last_stripe_event_id: ActiveValue::set(Some(event.id.to_string())),
+ ..Default::default()
+ },
+ )
+ .await?;
+ } else {
+ app.db
+ .create_billing_customer(&CreateBillingCustomerParams {
+ user_id: user.id,
+ stripe_customer_id: customer.id.to_string(),
+ last_stripe_event_id: Some(event.id.to_string()),
+ })
+ .await?;
+ }
Ok(())
}
@@ -305,18 +351,60 @@ async fn handle_customer_subscription_event(
bail!("unexpected event payload for {}", event.id);
};
- let billing_customer =
- find_or_create_billing_customer(app, stripe_client, subscription.customer)
- .await?
- .ok_or_else(|| anyhow!("billing customer not found"))?;
+ log::info!("handling Stripe {} event: {}", event.type_, event.id);
+
+ let billing_customer = find_or_create_billing_customer(
+ app,
+ stripe_client,
+ // Even though we're handling a subscription event, we can still set
+ // the ID as the last seen event ID on the customer in the event that
+ // we have to create it.
+ //
+ // This is done to avoid any potential rollback in the customer's values
+ // if we then see an older event that pertains to the customer.
+ &event.id,
+ subscription.customer,
+ )
+ .await?
+ .ok_or_else(|| anyhow!("billing customer not found"))?;
+
+ if let Some(existing_subscription) = app
+ .db
+ .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
+ .await?
+ {
+ if should_ignore_event(
+ &event.id,
+ existing_subscription.last_stripe_event_id.as_deref(),
+ ) {
+ log::info!(
+ "ignoring Stripe event {} based on last seen event ID",
+ event.id
+ );
+ return Ok(());
+ }
- app.db
- .upsert_billing_subscription_by_stripe_subscription_id(&CreateBillingSubscriptionParams {
- billing_customer_id: billing_customer.id,
- stripe_subscription_id: subscription.id.to_string(),
- stripe_subscription_status: subscription.status.into(),
- })
- .await?;
+ app.db
+ .update_billing_subscription(
+ existing_subscription.id,
+ &UpdateBillingSubscriptionParams {
+ billing_customer_id: ActiveValue::set(billing_customer.id),
+ stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
+ stripe_subscription_status: ActiveValue::set(subscription.status.into()),
+ last_stripe_event_id: ActiveValue::set(Some(event.id.to_string())),
+ },
+ )
+ .await?;
+ } else {
+ app.db
+ .create_billing_subscription(&CreateBillingSubscriptionParams {
+ billing_customer_id: billing_customer.id,
+ stripe_subscription_id: subscription.id.to_string(),
+ stripe_subscription_status: subscription.status.into(),
+ last_stripe_event_id: Some(event.id.to_string()),
+ })
+ .await?;
+ }
Ok(())
}
@@ -340,6 +428,7 @@ impl From<SubscriptionStatus> for StripeSubscriptionStatus {
async fn find_or_create_billing_customer(
app: &Arc<AppState>,
stripe_client: &stripe::Client,
+ event_id: &EventId,
customer_or_id: Expandable<Customer>,
) -> anyhow::Result<Option<billing_customer::Model>> {
let customer_id = match &customer_or_id {
@@ -377,8 +466,70 @@ async fn find_or_create_billing_customer(
.create_billing_customer(&CreateBillingCustomerParams {
user_id: user.id,
stripe_customer_id: customer.id.to_string(),
+ last_stripe_event_id: Some(event_id.to_string()),
})
.await?;
Ok(Some(billing_customer))
}
+
+/// Returns whether an [`Event`] should be ignored, based on its ID and the last
+/// seen event ID for this object.
+#[inline]
+fn should_ignore_event(event_id: &EventId, last_event_id: Option<&str>) -> bool {
+ !should_apply_event(event_id, last_event_id)
+}
+
+/// Returns whether an [`Event`] should be applied, based on its ID and the last
+/// seen event ID for this object.
+fn should_apply_event(event_id: &EventId, last_event_id: Option<&str>) -> bool {
+ let Some(last_event_id) = last_event_id else {
+ return true;
+ };
+
+ event_id.as_str() < last_event_id
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_should_apply_event() {
+ let subscription_created_event = EventId::from_str("evt_1Pi5s9RxOf7d5PNafuZSGsmh").unwrap();
+ let subscription_updated_event = EventId::from_str("evt_1Pi5s9RxOf7d5PNa5UZLSsto").unwrap();
+
+ assert_eq!(
+ should_apply_event(
+ &subscription_created_event,
+ Some(subscription_created_event.as_str())
+ ),
+ false,
+ "Events should not be applied when the IDs are the same."
+ );
+
+ assert_eq!(
+ should_apply_event(
+ &subscription_created_event,
+ Some(subscription_updated_event.as_str())
+ ),
+ false,
+ "Events should not be applied when the last event ID is newer than the event ID."
+ );
+
+ assert_eq!(
+ should_apply_event(&subscription_created_event, None),
+ true,
+ "Events should be applied when we don't have a last event ID."
+ );
+
+ assert_eq!(
+ should_apply_event(
+ &subscription_updated_event,
+ Some(subscription_created_event.as_str())
+ ),
+ true,
+ "Events should be applied when the event ID is newer than the last event ID."
+ );
+ }
+}
@@ -1,3 +1,5 @@
+use sea_orm::IntoActiveValue;
+
use crate::db::billing_subscription::StripeSubscriptionStatus;
use super::*;
@@ -7,6 +9,15 @@ pub struct CreateBillingSubscriptionParams {
pub billing_customer_id: BillingCustomerId,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
+ pub last_stripe_event_id: Option<String>,
+}
+
+#[derive(Debug, Default)]
+pub struct UpdateBillingSubscriptionParams {
+ pub billing_customer_id: ActiveValue<BillingCustomerId>,
+ pub stripe_subscription_id: ActiveValue<String>,
+ pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
+ pub last_stripe_event_id: ActiveValue<Option<String>>,
}
impl Database {
@@ -20,6 +31,7 @@ impl Database {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
+ last_stripe_event_id: params.last_stripe_event_id.clone().into_active_value(),
..Default::default()
})
.exec_without_returning(&*tx)
@@ -30,24 +42,22 @@ impl Database {
.await
}
- /// Upserts the billing subscription by its Stripe subscription ID.
- pub async fn upsert_billing_subscription_by_stripe_subscription_id(
+ /// Updates the specified billing subscription.
+ pub async fn update_billing_subscription(
&self,
- params: &CreateBillingSubscriptionParams,
+ id: BillingSubscriptionId,
+ params: &UpdateBillingSubscriptionParams,
) -> Result<()> {
self.transaction(|tx| async move {
- billing_subscription::Entity::insert(billing_subscription::ActiveModel {
- billing_customer_id: ActiveValue::set(params.billing_customer_id),
- stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
- stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
+ billing_subscription::Entity::update(billing_subscription::ActiveModel {
+ id: ActiveValue::set(id),
+ billing_customer_id: params.billing_customer_id.clone(),
+ stripe_subscription_id: params.stripe_subscription_id.clone(),
+ stripe_subscription_status: params.stripe_subscription_status.clone(),
+ last_stripe_event_id: params.last_stripe_event_id.clone(),
..Default::default()
})
- .on_conflict(
- OnConflict::columns([billing_subscription::Column::StripeSubscriptionId])
- .update_columns([billing_subscription::Column::StripeSubscriptionStatus])
- .to_owned(),
- )
- .exec_with_returning(&*tx)
+ .exec(&*tx)
.await?;
Ok(())
@@ -68,6 +78,22 @@ impl Database {
.await
}
+ /// Returns the billing subscription with the specified Stripe subscription ID.
+ pub async fn get_billing_subscription_by_stripe_subscription_id(
+ &self,
+ stripe_subscription_id: &str,
+ ) -> Result<Option<billing_subscription::Model>> {
+ self.transaction(|tx| async move {
+ Ok(billing_subscription::Entity::find()
+ .filter(
+ billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
+ )
+ .one(&*tx)
+ .await?)
+ })
+ .await
+ }
+
/// Returns all of the billing subscriptions for the user with the specified ID.
///
/// Note that this returns the subscriptions regardless of their status.
@@ -29,6 +29,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_active_user".into(),
+ last_stripe_event_id: None,
})
.await
.unwrap();
@@ -38,6 +39,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
billing_customer_id: customer.id,
stripe_subscription_id: "sub_active_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::Active,
+ last_stripe_event_id: None,
})
.await
.unwrap();
@@ -63,6 +65,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_past_due_user".into(),
+ last_stripe_event_id: None,
})
.await
.unwrap();
@@ -72,6 +75,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
billing_customer_id: customer.id,
stripe_subscription_id: "sub_past_due_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::PastDue,
+ last_stripe_event_id: None,
})
.await
.unwrap();