1use crate::db::billing_subscription::{
2 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
3};
4
5use super::*;
6
7#[derive(Debug)]
8pub struct CreateBillingSubscriptionParams {
9 pub billing_customer_id: BillingCustomerId,
10 pub kind: Option<SubscriptionKind>,
11 pub stripe_subscription_id: String,
12 pub stripe_subscription_status: StripeSubscriptionStatus,
13 pub stripe_cancellation_reason: Option<StripeCancellationReason>,
14 pub stripe_current_period_start: Option<i64>,
15 pub stripe_current_period_end: Option<i64>,
16}
17
18#[derive(Debug, Default)]
19pub struct UpdateBillingSubscriptionParams {
20 pub billing_customer_id: ActiveValue<BillingCustomerId>,
21 pub kind: ActiveValue<Option<SubscriptionKind>>,
22 pub stripe_subscription_id: ActiveValue<String>,
23 pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
24 pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
25 pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
26 pub stripe_current_period_start: ActiveValue<Option<i64>>,
27 pub stripe_current_period_end: ActiveValue<Option<i64>>,
28}
29
30impl Database {
31 /// Creates a new billing subscription.
32 pub async fn create_billing_subscription(
33 &self,
34 params: &CreateBillingSubscriptionParams,
35 ) -> Result<()> {
36 self.transaction(|tx| async move {
37 billing_subscription::Entity::insert(billing_subscription::ActiveModel {
38 billing_customer_id: ActiveValue::set(params.billing_customer_id),
39 kind: ActiveValue::set(params.kind),
40 stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
41 stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
42 stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
43 stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
44 stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
45 ..Default::default()
46 })
47 .exec_without_returning(&*tx)
48 .await?;
49
50 Ok(())
51 })
52 .await
53 }
54
55 /// Updates the specified billing subscription.
56 pub async fn update_billing_subscription(
57 &self,
58 id: BillingSubscriptionId,
59 params: &UpdateBillingSubscriptionParams,
60 ) -> Result<()> {
61 self.transaction(|tx| async move {
62 billing_subscription::Entity::update(billing_subscription::ActiveModel {
63 id: ActiveValue::set(id),
64 billing_customer_id: params.billing_customer_id.clone(),
65 stripe_subscription_id: params.stripe_subscription_id.clone(),
66 stripe_subscription_status: params.stripe_subscription_status.clone(),
67 stripe_cancel_at: params.stripe_cancel_at.clone(),
68 stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
69 ..Default::default()
70 })
71 .exec(&*tx)
72 .await?;
73
74 Ok(())
75 })
76 .await
77 }
78
79 /// Returns the billing subscription with the specified ID.
80 pub async fn get_billing_subscription_by_id(
81 &self,
82 id: BillingSubscriptionId,
83 ) -> Result<Option<billing_subscription::Model>> {
84 self.transaction(|tx| async move {
85 Ok(billing_subscription::Entity::find_by_id(id)
86 .one(&*tx)
87 .await?)
88 })
89 .await
90 }
91
92 /// Returns the billing subscription with the specified Stripe subscription ID.
93 pub async fn get_billing_subscription_by_stripe_subscription_id(
94 &self,
95 stripe_subscription_id: &str,
96 ) -> Result<Option<billing_subscription::Model>> {
97 self.transaction(|tx| async move {
98 Ok(billing_subscription::Entity::find()
99 .filter(
100 billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
101 )
102 .one(&*tx)
103 .await?)
104 })
105 .await
106 }
107
108 /// Returns all of the billing subscriptions for the user with the specified ID.
109 ///
110 /// Note that this returns the subscriptions regardless of their status.
111 /// If you're wanting to check if a use has an active billing subscription,
112 /// use `get_active_billing_subscriptions` instead.
113 pub async fn get_billing_subscriptions(
114 &self,
115 user_id: UserId,
116 ) -> Result<Vec<billing_subscription::Model>> {
117 self.transaction(|tx| async move {
118 let subscriptions = billing_subscription::Entity::find()
119 .inner_join(billing_customer::Entity)
120 .filter(billing_customer::Column::UserId.eq(user_id))
121 .order_by_asc(billing_subscription::Column::Id)
122 .all(&*tx)
123 .await?;
124
125 Ok(subscriptions)
126 })
127 .await
128 }
129
130 pub async fn get_active_billing_subscriptions(
131 &self,
132 user_ids: HashSet<UserId>,
133 ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
134 self.transaction(|tx| {
135 let user_ids = user_ids.clone();
136 async move {
137 let mut rows = billing_subscription::Entity::find()
138 .inner_join(billing_customer::Entity)
139 .select_also(billing_customer::Entity)
140 .filter(billing_customer::Column::UserId.is_in(user_ids))
141 .filter(
142 billing_subscription::Column::StripeSubscriptionStatus
143 .eq(StripeSubscriptionStatus::Active),
144 )
145 .order_by_asc(billing_subscription::Column::Id)
146 .stream(&*tx)
147 .await?;
148
149 let mut subscriptions = HashMap::default();
150 while let Some(row) = rows.next().await {
151 if let (subscription, Some(customer)) = row? {
152 subscriptions.insert(customer.user_id, (customer, subscription));
153 }
154 }
155 Ok(subscriptions)
156 }
157 })
158 .await
159 }
160
161 /// Returns whether the user has an active billing subscription.
162 pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
163 Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
164 }
165
166 /// Returns the count of the active billing subscriptions for the user with the specified ID.
167 pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
168 self.transaction(|tx| async move {
169 let count = billing_subscription::Entity::find()
170 .inner_join(billing_customer::Entity)
171 .filter(
172 billing_customer::Column::UserId.eq(user_id).and(
173 billing_subscription::Column::StripeSubscriptionStatus
174 .eq(StripeSubscriptionStatus::Active),
175 ),
176 )
177 .count(&*tx)
178 .await?;
179
180 Ok(count as usize)
181 })
182 .await
183 }
184}