Detailed changes
@@ -420,12 +420,20 @@ CREATE TABLE dev_server_projects (
CREATE TABLE IF NOT EXISTS billing_subscriptions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
- user_id INTEGER NOT NULL REFERENCES users(id),
- stripe_customer_id TEXT NOT NULL,
+ billing_customer_id INTEGER NOT NULL REFERENCES billing_customers(id),
stripe_subscription_id TEXT NOT NULL,
stripe_subscription_status TEXT NOT NULL
);
-CREATE INDEX "ix_billing_subscriptions_on_user_id" ON billing_subscriptions (user_id);
-CREATE INDEX "ix_billing_subscriptions_on_stripe_customer_id" ON billing_subscriptions (stripe_customer_id);
+CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);
CREATE UNIQUE INDEX "uix_billing_subscriptions_on_stripe_subscription_id" ON billing_subscriptions (stripe_subscription_id);
+
+CREATE TABLE IF NOT EXISTS billing_customers (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ user_id INTEGER NOT NULL REFERENCES users(id),
+ stripe_customer_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
+CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id);
@@ -0,0 +1,18 @@
+CREATE TABLE IF NOT EXISTS billing_customers (
+ id SERIAL PRIMARY KEY,
+ created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT now(),
+ user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ stripe_customer_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
+CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id);
+
+-- Make `billing_subscriptions` reference `billing_customers` instead of having its
+-- own `user_id` and `stripe_customer_id`.
+DROP INDEX IF EXISTS "ix_billing_subscriptions_on_user_id";
+DROP INDEX IF EXISTS "ix_billing_subscriptions_on_stripe_customer_id";
+ALTER TABLE billing_subscriptions DROP COLUMN user_id;
+ALTER TABLE billing_subscriptions DROP COLUMN stripe_customer_id;
+ALTER TABLE billing_subscriptions ADD COLUMN billing_customer_id INTEGER NOT NULL REFERENCES billing_customers (id) ON DELETE CASCADE;
+CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);
@@ -3,7 +3,6 @@ use std::sync::Arc;
use anyhow::{anyhow, Context};
use axum::{extract, routing::post, Extension, Json, Router};
-use collections::HashSet;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use stripe::{
@@ -11,7 +10,7 @@ use stripe::{
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
- CustomerId,
+ CreateCustomer, Customer, CustomerId,
};
use crate::db::BillingSubscriptionId;
@@ -59,28 +58,27 @@ async fn create_billing_subscription(
))?
};
- let existing_customer_id = {
- let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
- let distinct_customer_ids = existing_subscriptions
- .iter()
- .map(|subscription| subscription.stripe_customer_id.as_str())
- .collect::<HashSet<_>>();
- // Sanity: Make sure we can determine a single Stripe customer ID for the user.
- if distinct_customer_ids.len() > 1 {
- Err(anyhow!("user has multiple existing customer IDs"))?;
- }
+ let customer_id =
+ if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
+ CustomerId::from_str(&existing_customer.stripe_customer_id)
+ .context("failed to parse customer ID")?
+ } else {
+ let customer = Customer::create(
+ &stripe_client,
+ CreateCustomer {
+ email: user.email_address.as_deref(),
+ ..Default::default()
+ },
+ )
+ .await?;
- distinct_customer_ids
- .into_iter()
- .next()
- .map(|id| CustomerId::from_str(id).context("failed to parse customer ID"))
- .transpose()
- }?;
+ customer.id
+ };
let checkout_session = {
let mut params = CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
- params.customer = existing_customer_id;
+ params.customer = Some(customer_id);
params.client_reference_id = Some(user.github_login.as_str());
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
price: Some(stripe_price_id.to_string()),
@@ -140,6 +138,14 @@ async fn manage_billing_subscription(
))?
};
+ let customer = app
+ .db
+ .get_billing_customer_by_user_id(user.id)
+ .await?
+ .ok_or_else(|| anyhow!("billing customer not found"))?;
+ let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
+ .context("failed to parse customer ID")?;
+
let subscription = if let Some(subscription_id) = body.subscription_id {
app.db
.get_billing_subscription_by_id(subscription_id)
@@ -158,9 +164,6 @@ async fn manage_billing_subscription(
.ok_or_else(|| anyhow!("user has no active subscriptions"))?
};
- let customer_id = CustomerId::from_str(&subscription.stripe_customer_id)
- .context("failed to parse customer ID")?;
-
let flow = match body.intent {
ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
@@ -45,6 +45,7 @@ use tokio::sync::{Mutex, OwnedMutexGuard};
pub use tests::TestDb;
pub use ids::*;
+pub use queries::billing_customers::CreateBillingCustomerParams;
pub use queries::billing_subscriptions::CreateBillingSubscriptionParams;
pub use queries::contributors::ContributorSelector;
pub use sea_orm::ConnectOptions;
@@ -68,6 +68,7 @@ macro_rules! id_type {
}
id_type!(AccessTokenId);
+id_type!(BillingCustomerId);
id_type!(BillingSubscriptionId);
id_type!(BufferId);
id_type!(ChannelBufferCollaboratorId);
@@ -1,6 +1,7 @@
use super::*;
pub mod access_tokens;
+pub mod billing_customers;
pub mod billing_subscriptions;
pub mod buffers;
pub mod channels;
@@ -0,0 +1,42 @@
+use super::*;
+
+#[derive(Debug)]
+pub struct CreateBillingCustomerParams {
+ pub user_id: UserId,
+ pub stripe_customer_id: String,
+}
+
+impl Database {
+ /// Creates a new billing customer.
+ pub async fn create_billing_customer(
+ &self,
+ params: &CreateBillingCustomerParams,
+ ) -> Result<billing_customer::Model> {
+ self.transaction(|tx| async move {
+ let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
+ user_id: ActiveValue::set(params.user_id),
+ stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
+ ..Default::default()
+ })
+ .exec_with_returning(&*tx)
+ .await?;
+
+ Ok(customer)
+ })
+ .await
+ }
+
+ /// Returns the billing customer for the user with the specified ID.
+ pub async fn get_billing_customer_by_user_id(
+ &self,
+ user_id: UserId,
+ ) -> Result<Option<billing_customer::Model>> {
+ self.transaction(|tx| async move {
+ Ok(billing_customer::Entity::find()
+ .filter(billing_customer::Column::UserId.eq(user_id))
+ .one(&*tx)
+ .await?)
+ })
+ .await
+ }
+}
@@ -4,8 +4,7 @@ use super::*;
#[derive(Debug)]
pub struct CreateBillingSubscriptionParams {
- pub user_id: UserId,
- pub stripe_customer_id: String,
+ pub billing_customer_id: BillingCustomerId,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
}
@@ -18,8 +17,7 @@ impl Database {
) -> Result<()> {
self.transaction(|tx| async move {
billing_subscription::Entity::insert(billing_subscription::ActiveModel {
- user_id: ActiveValue::set(params.user_id),
- stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
+ billing_customer_id: ActiveValue::set(params.billing_customer_id),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
..Default::default()
@@ -56,7 +54,8 @@ impl Database {
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
- .filter(billing_subscription::Column::UserId.eq(user_id))
+ .inner_join(billing_customer::Entity)
+ .filter(billing_customer::Column::UserId.eq(user_id))
.order_by_asc(billing_subscription::Column::Id)
.all(&*tx)
.await?;
@@ -73,8 +72,9 @@ impl Database {
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
+ .inner_join(billing_customer::Entity)
.filter(
- billing_subscription::Column::UserId.eq(user_id).and(
+ billing_customer::Column::UserId.eq(user_id).and(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
),
@@ -1,4 +1,5 @@
pub mod access_token;
+pub mod billing_customer;
pub mod billing_subscription;
pub mod buffer;
pub mod buffer_operation;
@@ -0,0 +1,39 @@
+use crate::db::{BillingCustomerId, UserId};
+use sea_orm::entity::prelude::*;
+
+/// A billing customer.
+#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "billing_customers")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: BillingCustomerId,
+ pub user_id: UserId,
+ pub stripe_customer_id: String,
+ pub created_at: DateTime,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::user::Entity",
+ from = "Column::UserId",
+ to = "super::user::Column::Id"
+ )]
+ User,
+ #[sea_orm(has_many = "super::billing_subscription::Entity")]
+ BillingSubscription,
+}
+
+impl Related<super::user::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::User.def()
+ }
+}
+
+impl Related<super::billing_subscription::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::BillingSubscription.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -1,4 +1,4 @@
-use crate::db::{BillingSubscriptionId, UserId};
+use crate::db::{BillingCustomerId, BillingSubscriptionId};
use sea_orm::entity::prelude::*;
/// A billing subscription.
@@ -7,8 +7,7 @@ use sea_orm::entity::prelude::*;
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingSubscriptionId,
- pub user_id: UserId,
- pub stripe_customer_id: String,
+ pub billing_customer_id: BillingCustomerId,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
pub created_at: DateTime,
@@ -17,16 +16,16 @@ pub struct Model {
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
- belongs_to = "super::user::Entity",
- from = "Column::UserId",
- to = "super::user::Column::Id"
+ belongs_to = "super::billing_customer::Entity",
+ from = "Column::BillingCustomerId",
+ to = "super::billing_customer::Column::Id"
)]
- User,
+ BillingCustomer,
}
-impl Related<super::user::Entity> for Entity {
+impl Related<super::billing_customer::Entity> for Entity {
fn to() -> RelationDef {
- Relation::User.def()
+ Relation::BillingCustomer.def()
}
}
@@ -24,8 +24,8 @@ pub struct Model {
pub enum Relation {
#[sea_orm(has_many = "super::access_token::Entity")]
AccessToken,
- #[sea_orm(has_many = "super::billing_subscription::Entity")]
- BillingSubscription,
+ #[sea_orm(has_one = "super::billing_customer::Entity")]
+ BillingCustomer,
#[sea_orm(has_one = "super::room_participant::Entity")]
RoomParticipant,
#[sea_orm(has_many = "super::project::Entity")]
@@ -44,6 +44,12 @@ impl Related<super::access_token::Entity> for Entity {
}
}
+impl Related<super::billing_customer::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::BillingCustomer.def()
+ }
+}
+
impl Related<super::room_participant::Entity> for Entity {
fn to() -> RelationDef {
Relation::RoomParticipant.def()
@@ -2,7 +2,7 @@ use std::sync::Arc;
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::tests::new_test_user;
-use crate::db::CreateBillingSubscriptionParams;
+use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
use crate::test_both_dbs;
use super::Database;
@@ -25,9 +25,17 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
// A user with an active subscription has one active billing subscription.
{
let user_id = new_test_user(db, "active-user@example.com").await;
+ let customer = db
+ .create_billing_customer(&CreateBillingCustomerParams {
+ user_id,
+ stripe_customer_id: "cus_active_user".into(),
+ })
+ .await
+ .unwrap();
+ assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
+
db.create_billing_subscription(&CreateBillingSubscriptionParams {
- user_id,
- stripe_customer_id: "cus_active_user".into(),
+ billing_customer_id: customer.id,
stripe_subscription_id: "sub_active_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::Active,
})
@@ -38,10 +46,6 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
assert_eq!(subscriptions.len(), 1);
let subscription = &subscriptions[0];
- assert_eq!(
- subscription.stripe_customer_id,
- "cus_active_user".to_string()
- );
assert_eq!(
subscription.stripe_subscription_id,
"sub_active_user".to_string()
@@ -55,9 +59,17 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
// A user with a past-due subscription has no active billing subscriptions.
{
let user_id = new_test_user(db, "past-due-user@example.com").await;
+ let customer = db
+ .create_billing_customer(&CreateBillingCustomerParams {
+ user_id,
+ stripe_customer_id: "cus_past_due_user".into(),
+ })
+ .await
+ .unwrap();
+ assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
+
db.create_billing_subscription(&CreateBillingSubscriptionParams {
- user_id,
- stripe_customer_id: "cus_past_due_user".into(),
+ billing_customer_id: customer.id,
stripe_subscription_id: "sub_past_due_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::PastDue,
})