billing_subscriptions.rs

  1use crate::db::billing_subscription::{
  2    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  3};
  4
  5use super::*;
  6
  7#[derive(Debug)]
  8pub struct CreateBillingSubscriptionParams {
  9    pub billing_customer_id: BillingCustomerId,
 10    pub kind: Option<SubscriptionKind>,
 11    pub stripe_subscription_id: String,
 12    pub stripe_subscription_status: StripeSubscriptionStatus,
 13    pub stripe_cancellation_reason: Option<StripeCancellationReason>,
 14    pub stripe_current_period_start: Option<i64>,
 15    pub stripe_current_period_end: Option<i64>,
 16}
 17
 18#[derive(Debug, Default)]
 19pub struct UpdateBillingSubscriptionParams {
 20    pub billing_customer_id: ActiveValue<BillingCustomerId>,
 21    pub kind: ActiveValue<Option<SubscriptionKind>>,
 22    pub stripe_subscription_id: ActiveValue<String>,
 23    pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
 24    pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
 25    pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
 26    pub stripe_current_period_start: ActiveValue<Option<i64>>,
 27    pub stripe_current_period_end: ActiveValue<Option<i64>>,
 28}
 29
 30impl Database {
 31    /// Creates a new billing subscription.
 32    pub async fn create_billing_subscription(
 33        &self,
 34        params: &CreateBillingSubscriptionParams,
 35    ) -> Result<()> {
 36        self.transaction(|tx| async move {
 37            billing_subscription::Entity::insert(billing_subscription::ActiveModel {
 38                billing_customer_id: ActiveValue::set(params.billing_customer_id),
 39                kind: ActiveValue::set(params.kind),
 40                stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
 41                stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
 42                stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
 43                stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
 44                stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
 45                ..Default::default()
 46            })
 47            .exec_without_returning(&*tx)
 48            .await?;
 49
 50            Ok(())
 51        })
 52        .await
 53    }
 54
 55    /// Updates the specified billing subscription.
 56    pub async fn update_billing_subscription(
 57        &self,
 58        id: BillingSubscriptionId,
 59        params: &UpdateBillingSubscriptionParams,
 60    ) -> Result<()> {
 61        self.transaction(|tx| async move {
 62            billing_subscription::Entity::update(billing_subscription::ActiveModel {
 63                id: ActiveValue::set(id),
 64                billing_customer_id: params.billing_customer_id.clone(),
 65                kind: params.kind.clone(),
 66                stripe_subscription_id: params.stripe_subscription_id.clone(),
 67                stripe_subscription_status: params.stripe_subscription_status.clone(),
 68                stripe_cancel_at: params.stripe_cancel_at.clone(),
 69                stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
 70                stripe_current_period_start: params.stripe_current_period_start.clone(),
 71                stripe_current_period_end: params.stripe_current_period_end.clone(),
 72                created_at: ActiveValue::not_set(),
 73            })
 74            .exec(&*tx)
 75            .await?;
 76
 77            Ok(())
 78        })
 79        .await
 80    }
 81
 82    /// Returns the billing subscription with the specified ID.
 83    pub async fn get_billing_subscription_by_id(
 84        &self,
 85        id: BillingSubscriptionId,
 86    ) -> Result<Option<billing_subscription::Model>> {
 87        self.transaction(|tx| async move {
 88            Ok(billing_subscription::Entity::find_by_id(id)
 89                .one(&*tx)
 90                .await?)
 91        })
 92        .await
 93    }
 94
 95    /// Returns the billing subscription with the specified Stripe subscription ID.
 96    pub async fn get_billing_subscription_by_stripe_subscription_id(
 97        &self,
 98        stripe_subscription_id: &str,
 99    ) -> Result<Option<billing_subscription::Model>> {
100        self.transaction(|tx| async move {
101            Ok(billing_subscription::Entity::find()
102                .filter(
103                    billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
104                )
105                .one(&*tx)
106                .await?)
107        })
108        .await
109    }
110
111    pub async fn get_active_billing_subscription(
112        &self,
113        user_id: UserId,
114    ) -> Result<Option<billing_subscription::Model>> {
115        self.transaction(|tx| async move {
116            Ok(billing_subscription::Entity::find()
117                .inner_join(billing_customer::Entity)
118                .filter(billing_customer::Column::UserId.eq(user_id))
119                .filter(
120                    Condition::all()
121                        .add(
122                            Condition::any()
123                                .add(
124                                    billing_subscription::Column::StripeSubscriptionStatus
125                                        .eq(StripeSubscriptionStatus::Active),
126                                )
127                                .add(
128                                    billing_subscription::Column::StripeSubscriptionStatus
129                                        .eq(StripeSubscriptionStatus::Trialing),
130                                ),
131                        )
132                        .add(billing_subscription::Column::Kind.is_not_null()),
133                )
134                .one(&*tx)
135                .await?)
136        })
137        .await
138    }
139
140    /// Returns all of the billing subscriptions for the user with the specified ID.
141    ///
142    /// Note that this returns the subscriptions regardless of their status.
143    /// If you're wanting to check if a use has an active billing subscription,
144    /// use `get_active_billing_subscriptions` instead.
145    pub async fn get_billing_subscriptions(
146        &self,
147        user_id: UserId,
148    ) -> Result<Vec<billing_subscription::Model>> {
149        self.transaction(|tx| async move {
150            let subscriptions = billing_subscription::Entity::find()
151                .inner_join(billing_customer::Entity)
152                .filter(billing_customer::Column::UserId.eq(user_id))
153                .order_by_asc(billing_subscription::Column::Id)
154                .all(&*tx)
155                .await?;
156
157            Ok(subscriptions)
158        })
159        .await
160    }
161
162    pub async fn get_active_billing_subscriptions(
163        &self,
164        user_ids: HashSet<UserId>,
165    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
166        self.transaction(|tx| {
167            let user_ids = user_ids.clone();
168            async move {
169                let mut rows = billing_subscription::Entity::find()
170                    .inner_join(billing_customer::Entity)
171                    .select_also(billing_customer::Entity)
172                    .filter(billing_customer::Column::UserId.is_in(user_ids))
173                    .filter(
174                        billing_subscription::Column::StripeSubscriptionStatus
175                            .eq(StripeSubscriptionStatus::Active),
176                    )
177                    .filter(billing_subscription::Column::Kind.is_null())
178                    .order_by_asc(billing_subscription::Column::Id)
179                    .stream(&*tx)
180                    .await?;
181
182                let mut subscriptions = HashMap::default();
183                while let Some(row) = rows.next().await {
184                    if let (subscription, Some(customer)) = row? {
185                        subscriptions.insert(customer.user_id, (customer, subscription));
186                    }
187                }
188                Ok(subscriptions)
189            }
190        })
191        .await
192    }
193
194    pub async fn get_active_zed_pro_billing_subscriptions(
195        &self,
196        user_ids: HashSet<UserId>,
197    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
198        self.transaction(|tx| {
199            let user_ids = user_ids.clone();
200            async move {
201                let mut rows = billing_subscription::Entity::find()
202                    .inner_join(billing_customer::Entity)
203                    .select_also(billing_customer::Entity)
204                    .filter(billing_customer::Column::UserId.is_in(user_ids))
205                    .filter(
206                        billing_subscription::Column::StripeSubscriptionStatus
207                            .eq(StripeSubscriptionStatus::Active),
208                    )
209                    .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
210                    .order_by_asc(billing_subscription::Column::Id)
211                    .stream(&*tx)
212                    .await?;
213
214                let mut subscriptions = HashMap::default();
215                while let Some(row) = rows.next().await {
216                    if let (subscription, Some(customer)) = row? {
217                        subscriptions.insert(customer.user_id, (customer, subscription));
218                    }
219                }
220                Ok(subscriptions)
221            }
222        })
223        .await
224    }
225
226    /// Returns whether the user has an active billing subscription.
227    pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
228        Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
229    }
230
231    /// Returns the count of the active billing subscriptions for the user with the specified ID.
232    pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
233        self.transaction(|tx| async move {
234            let count = billing_subscription::Entity::find()
235                .inner_join(billing_customer::Entity)
236                .filter(
237                    billing_customer::Column::UserId.eq(user_id).and(
238                        billing_subscription::Column::StripeSubscriptionStatus
239                            .eq(StripeSubscriptionStatus::Active)
240                            .or(billing_subscription::Column::StripeSubscriptionStatus
241                                .eq(StripeSubscriptionStatus::Trialing)),
242                    ),
243                )
244                .count(&*tx)
245                .await?;
246
247            Ok(count as usize)
248        })
249        .await
250    }
251}