billing_subscriptions.rs

  1use crate::db::billing_subscription::{
  2    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  3};
  4
  5use super::*;
  6
  7#[derive(Debug)]
  8pub struct CreateBillingSubscriptionParams {
  9    pub billing_customer_id: BillingCustomerId,
 10    pub kind: Option<SubscriptionKind>,
 11    pub stripe_subscription_id: String,
 12    pub stripe_subscription_status: StripeSubscriptionStatus,
 13    pub stripe_cancellation_reason: Option<StripeCancellationReason>,
 14    pub stripe_current_period_start: Option<i64>,
 15    pub stripe_current_period_end: Option<i64>,
 16}
 17
 18#[derive(Debug, Default)]
 19pub struct UpdateBillingSubscriptionParams {
 20    pub billing_customer_id: ActiveValue<BillingCustomerId>,
 21    pub kind: ActiveValue<Option<SubscriptionKind>>,
 22    pub stripe_subscription_id: ActiveValue<String>,
 23    pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
 24    pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
 25    pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
 26    pub stripe_current_period_start: ActiveValue<Option<i64>>,
 27    pub stripe_current_period_end: ActiveValue<Option<i64>>,
 28}
 29
 30impl Database {
 31    /// Creates a new billing subscription.
 32    pub async fn create_billing_subscription(
 33        &self,
 34        params: &CreateBillingSubscriptionParams,
 35    ) -> Result<billing_subscription::Model> {
 36        self.transaction(|tx| async move {
 37            let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
 38                billing_customer_id: ActiveValue::set(params.billing_customer_id),
 39                kind: ActiveValue::set(params.kind),
 40                stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
 41                stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
 42                stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
 43                stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
 44                stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
 45                ..Default::default()
 46            })
 47            .exec(&*tx)
 48            .await?
 49            .last_insert_id;
 50
 51            Ok(billing_subscription::Entity::find_by_id(id)
 52                .one(&*tx)
 53                .await?
 54                .ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?)
 55        })
 56        .await
 57    }
 58
 59    /// Updates the specified billing subscription.
 60    pub async fn update_billing_subscription(
 61        &self,
 62        id: BillingSubscriptionId,
 63        params: &UpdateBillingSubscriptionParams,
 64    ) -> Result<()> {
 65        self.transaction(|tx| async move {
 66            billing_subscription::Entity::update(billing_subscription::ActiveModel {
 67                id: ActiveValue::set(id),
 68                billing_customer_id: params.billing_customer_id.clone(),
 69                kind: params.kind.clone(),
 70                stripe_subscription_id: params.stripe_subscription_id.clone(),
 71                stripe_subscription_status: params.stripe_subscription_status.clone(),
 72                stripe_cancel_at: params.stripe_cancel_at.clone(),
 73                stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
 74                stripe_current_period_start: params.stripe_current_period_start.clone(),
 75                stripe_current_period_end: params.stripe_current_period_end.clone(),
 76                created_at: ActiveValue::not_set(),
 77            })
 78            .exec(&*tx)
 79            .await?;
 80
 81            Ok(())
 82        })
 83        .await
 84    }
 85
 86    /// Returns the billing subscription with the specified ID.
 87    pub async fn get_billing_subscription_by_id(
 88        &self,
 89        id: BillingSubscriptionId,
 90    ) -> Result<Option<billing_subscription::Model>> {
 91        self.transaction(|tx| async move {
 92            Ok(billing_subscription::Entity::find_by_id(id)
 93                .one(&*tx)
 94                .await?)
 95        })
 96        .await
 97    }
 98
 99    /// Returns the billing subscription with the specified Stripe subscription ID.
100    pub async fn get_billing_subscription_by_stripe_subscription_id(
101        &self,
102        stripe_subscription_id: &str,
103    ) -> Result<Option<billing_subscription::Model>> {
104        self.transaction(|tx| async move {
105            Ok(billing_subscription::Entity::find()
106                .filter(
107                    billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
108                )
109                .one(&*tx)
110                .await?)
111        })
112        .await
113    }
114
115    pub async fn get_active_billing_subscription(
116        &self,
117        user_id: UserId,
118    ) -> Result<Option<billing_subscription::Model>> {
119        self.transaction(|tx| async move {
120            Ok(billing_subscription::Entity::find()
121                .inner_join(billing_customer::Entity)
122                .filter(billing_customer::Column::UserId.eq(user_id))
123                .filter(
124                    Condition::all()
125                        .add(
126                            Condition::any()
127                                .add(
128                                    billing_subscription::Column::StripeSubscriptionStatus
129                                        .eq(StripeSubscriptionStatus::Active),
130                                )
131                                .add(
132                                    billing_subscription::Column::StripeSubscriptionStatus
133                                        .eq(StripeSubscriptionStatus::Trialing),
134                                ),
135                        )
136                        .add(billing_subscription::Column::Kind.is_not_null()),
137                )
138                .one(&*tx)
139                .await?)
140        })
141        .await
142    }
143
144    /// Returns all of the billing subscriptions for the user with the specified ID.
145    ///
146    /// Note that this returns the subscriptions regardless of their status.
147    /// If you're wanting to check if a use has an active billing subscription,
148    /// use `get_active_billing_subscriptions` instead.
149    pub async fn get_billing_subscriptions(
150        &self,
151        user_id: UserId,
152    ) -> Result<Vec<billing_subscription::Model>> {
153        self.transaction(|tx| async move {
154            let subscriptions = billing_subscription::Entity::find()
155                .inner_join(billing_customer::Entity)
156                .filter(billing_customer::Column::UserId.eq(user_id))
157                .order_by_asc(billing_subscription::Column::Id)
158                .all(&*tx)
159                .await?;
160
161            Ok(subscriptions)
162        })
163        .await
164    }
165
166    pub async fn get_active_billing_subscriptions(
167        &self,
168        user_ids: HashSet<UserId>,
169    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
170        self.transaction(|tx| {
171            let user_ids = user_ids.clone();
172            async move {
173                let mut rows = billing_subscription::Entity::find()
174                    .inner_join(billing_customer::Entity)
175                    .select_also(billing_customer::Entity)
176                    .filter(billing_customer::Column::UserId.is_in(user_ids))
177                    .filter(
178                        billing_subscription::Column::StripeSubscriptionStatus
179                            .eq(StripeSubscriptionStatus::Active),
180                    )
181                    .filter(billing_subscription::Column::Kind.is_null())
182                    .order_by_asc(billing_subscription::Column::Id)
183                    .stream(&*tx)
184                    .await?;
185
186                let mut subscriptions = HashMap::default();
187                while let Some(row) = rows.next().await {
188                    if let (subscription, Some(customer)) = row? {
189                        subscriptions.insert(customer.user_id, (customer, subscription));
190                    }
191                }
192                Ok(subscriptions)
193            }
194        })
195        .await
196    }
197
198    pub async fn get_active_zed_pro_billing_subscriptions(
199        &self,
200        user_ids: HashSet<UserId>,
201    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
202        self.transaction(|tx| {
203            let user_ids = user_ids.clone();
204            async move {
205                let mut rows = billing_subscription::Entity::find()
206                    .inner_join(billing_customer::Entity)
207                    .select_also(billing_customer::Entity)
208                    .filter(billing_customer::Column::UserId.is_in(user_ids))
209                    .filter(
210                        billing_subscription::Column::StripeSubscriptionStatus
211                            .eq(StripeSubscriptionStatus::Active),
212                    )
213                    .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
214                    .order_by_asc(billing_subscription::Column::Id)
215                    .stream(&*tx)
216                    .await?;
217
218                let mut subscriptions = HashMap::default();
219                while let Some(row) = rows.next().await {
220                    if let (subscription, Some(customer)) = row? {
221                        subscriptions.insert(customer.user_id, (customer, subscription));
222                    }
223                }
224                Ok(subscriptions)
225            }
226        })
227        .await
228    }
229
230    /// Returns whether the user has an active billing subscription.
231    pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
232        Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
233    }
234
235    /// Returns the count of the active billing subscriptions for the user with the specified ID.
236    pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
237        self.transaction(|tx| async move {
238            let count = billing_subscription::Entity::find()
239                .inner_join(billing_customer::Entity)
240                .filter(
241                    billing_customer::Column::UserId.eq(user_id).and(
242                        billing_subscription::Column::StripeSubscriptionStatus
243                            .eq(StripeSubscriptionStatus::Active)
244                            .or(billing_subscription::Column::StripeSubscriptionStatus
245                                .eq(StripeSubscriptionStatus::Trialing)),
246                    ),
247                )
248                .count(&*tx)
249                .await?;
250
251            Ok(count as usize)
252        })
253        .await
254    }
255}