billing_subscriptions.rs

  1use anyhow::Context as _;
  2
  3use crate::db::billing_subscription::{
  4    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  5};
  6
  7use super::*;
  8
  9#[derive(Debug)]
 10pub struct CreateBillingSubscriptionParams {
 11    pub billing_customer_id: BillingCustomerId,
 12    pub kind: Option<SubscriptionKind>,
 13    pub stripe_subscription_id: String,
 14    pub stripe_subscription_status: StripeSubscriptionStatus,
 15    pub stripe_cancellation_reason: Option<StripeCancellationReason>,
 16    pub stripe_current_period_start: Option<i64>,
 17    pub stripe_current_period_end: Option<i64>,
 18}
 19
 20#[derive(Debug, Default)]
 21pub struct UpdateBillingSubscriptionParams {
 22    pub billing_customer_id: ActiveValue<BillingCustomerId>,
 23    pub kind: ActiveValue<Option<SubscriptionKind>>,
 24    pub stripe_subscription_id: ActiveValue<String>,
 25    pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
 26    pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
 27    pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
 28    pub stripe_current_period_start: ActiveValue<Option<i64>>,
 29    pub stripe_current_period_end: ActiveValue<Option<i64>>,
 30}
 31
 32impl Database {
 33    /// Creates a new billing subscription.
 34    pub async fn create_billing_subscription(
 35        &self,
 36        params: &CreateBillingSubscriptionParams,
 37    ) -> Result<billing_subscription::Model> {
 38        self.transaction(|tx| async move {
 39            let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
 40                billing_customer_id: ActiveValue::set(params.billing_customer_id),
 41                kind: ActiveValue::set(params.kind),
 42                stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
 43                stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
 44                stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
 45                stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
 46                stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
 47                ..Default::default()
 48            })
 49            .exec(&*tx)
 50            .await?
 51            .last_insert_id;
 52
 53            Ok(billing_subscription::Entity::find_by_id(id)
 54                .one(&*tx)
 55                .await?
 56                .context("failed to retrieve inserted billing subscription")?)
 57        })
 58        .await
 59    }
 60
 61    /// Updates the specified billing subscription.
 62    pub async fn update_billing_subscription(
 63        &self,
 64        id: BillingSubscriptionId,
 65        params: &UpdateBillingSubscriptionParams,
 66    ) -> Result<()> {
 67        self.transaction(|tx| async move {
 68            billing_subscription::Entity::update(billing_subscription::ActiveModel {
 69                id: ActiveValue::set(id),
 70                billing_customer_id: params.billing_customer_id.clone(),
 71                kind: params.kind.clone(),
 72                stripe_subscription_id: params.stripe_subscription_id.clone(),
 73                stripe_subscription_status: params.stripe_subscription_status.clone(),
 74                stripe_cancel_at: params.stripe_cancel_at.clone(),
 75                stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
 76                stripe_current_period_start: params.stripe_current_period_start.clone(),
 77                stripe_current_period_end: params.stripe_current_period_end.clone(),
 78                created_at: ActiveValue::not_set(),
 79            })
 80            .exec(&*tx)
 81            .await?;
 82
 83            Ok(())
 84        })
 85        .await
 86    }
 87
 88    /// Returns the billing subscription with the specified ID.
 89    pub async fn get_billing_subscription_by_id(
 90        &self,
 91        id: BillingSubscriptionId,
 92    ) -> Result<Option<billing_subscription::Model>> {
 93        self.transaction(|tx| async move {
 94            Ok(billing_subscription::Entity::find_by_id(id)
 95                .one(&*tx)
 96                .await?)
 97        })
 98        .await
 99    }
100
101    /// Returns the billing subscription with the specified Stripe subscription ID.
102    pub async fn get_billing_subscription_by_stripe_subscription_id(
103        &self,
104        stripe_subscription_id: &str,
105    ) -> Result<Option<billing_subscription::Model>> {
106        self.transaction(|tx| async move {
107            Ok(billing_subscription::Entity::find()
108                .filter(
109                    billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
110                )
111                .one(&*tx)
112                .await?)
113        })
114        .await
115    }
116
117    pub async fn get_active_billing_subscription(
118        &self,
119        user_id: UserId,
120    ) -> Result<Option<billing_subscription::Model>> {
121        self.transaction(|tx| async move {
122            Ok(billing_subscription::Entity::find()
123                .inner_join(billing_customer::Entity)
124                .filter(billing_customer::Column::UserId.eq(user_id))
125                .filter(
126                    Condition::all()
127                        .add(
128                            Condition::any()
129                                .add(
130                                    billing_subscription::Column::StripeSubscriptionStatus
131                                        .eq(StripeSubscriptionStatus::Active),
132                                )
133                                .add(
134                                    billing_subscription::Column::StripeSubscriptionStatus
135                                        .eq(StripeSubscriptionStatus::Trialing),
136                                ),
137                        )
138                        .add(billing_subscription::Column::Kind.is_not_null()),
139                )
140                .one(&*tx)
141                .await?)
142        })
143        .await
144    }
145
146    /// Returns all of the billing subscriptions for the user with the specified ID.
147    ///
148    /// Note that this returns the subscriptions regardless of their status.
149    /// If you're wanting to check if a use has an active billing subscription,
150    /// use `get_active_billing_subscriptions` instead.
151    pub async fn get_billing_subscriptions(
152        &self,
153        user_id: UserId,
154    ) -> Result<Vec<billing_subscription::Model>> {
155        self.transaction(|tx| async move {
156            let subscriptions = billing_subscription::Entity::find()
157                .inner_join(billing_customer::Entity)
158                .filter(billing_customer::Column::UserId.eq(user_id))
159                .order_by_asc(billing_subscription::Column::Id)
160                .all(&*tx)
161                .await?;
162
163            Ok(subscriptions)
164        })
165        .await
166    }
167
168    pub async fn get_active_billing_subscriptions(
169        &self,
170        user_ids: HashSet<UserId>,
171    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
172        self.transaction(|tx| {
173            let user_ids = user_ids.clone();
174            async move {
175                let mut rows = billing_subscription::Entity::find()
176                    .inner_join(billing_customer::Entity)
177                    .select_also(billing_customer::Entity)
178                    .filter(billing_customer::Column::UserId.is_in(user_ids))
179                    .filter(
180                        billing_subscription::Column::StripeSubscriptionStatus
181                            .eq(StripeSubscriptionStatus::Active),
182                    )
183                    .filter(billing_subscription::Column::Kind.is_null())
184                    .order_by_asc(billing_subscription::Column::Id)
185                    .stream(&*tx)
186                    .await?;
187
188                let mut subscriptions = HashMap::default();
189                while let Some(row) = rows.next().await {
190                    if let (subscription, Some(customer)) = row? {
191                        subscriptions.insert(customer.user_id, (customer, subscription));
192                    }
193                }
194                Ok(subscriptions)
195            }
196        })
197        .await
198    }
199
200    pub async fn get_active_zed_pro_billing_subscriptions(
201        &self,
202    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
203        self.transaction(|tx| async move {
204            let mut rows = billing_subscription::Entity::find()
205                .inner_join(billing_customer::Entity)
206                .select_also(billing_customer::Entity)
207                .filter(
208                    billing_subscription::Column::StripeSubscriptionStatus
209                        .eq(StripeSubscriptionStatus::Active),
210                )
211                .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
212                .order_by_asc(billing_subscription::Column::Id)
213                .stream(&*tx)
214                .await?;
215
216            let mut subscriptions = HashMap::default();
217            while let Some(row) = rows.next().await {
218                if let (subscription, Some(customer)) = row? {
219                    subscriptions.insert(customer.user_id, (customer, subscription));
220                }
221            }
222            Ok(subscriptions)
223        })
224        .await
225    }
226
227    pub async fn get_active_zed_pro_billing_subscriptions_for_users(
228        &self,
229        user_ids: HashSet<UserId>,
230    ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
231        self.transaction(|tx| {
232            let user_ids = user_ids.clone();
233            async move {
234                let mut rows = billing_subscription::Entity::find()
235                    .inner_join(billing_customer::Entity)
236                    .select_also(billing_customer::Entity)
237                    .filter(billing_customer::Column::UserId.is_in(user_ids))
238                    .filter(
239                        billing_subscription::Column::StripeSubscriptionStatus
240                            .eq(StripeSubscriptionStatus::Active),
241                    )
242                    .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
243                    .order_by_asc(billing_subscription::Column::Id)
244                    .stream(&*tx)
245                    .await?;
246
247                let mut subscriptions = HashMap::default();
248                while let Some(row) = rows.next().await {
249                    if let (subscription, Some(customer)) = row? {
250                        subscriptions.insert(customer.user_id, (customer, subscription));
251                    }
252                }
253                Ok(subscriptions)
254            }
255        })
256        .await
257    }
258
259    /// Returns whether the user has an active billing subscription.
260    pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
261        Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
262    }
263
264    /// Returns the count of the active billing subscriptions for the user with the specified ID.
265    pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
266        self.transaction(|tx| async move {
267            let count = billing_subscription::Entity::find()
268                .inner_join(billing_customer::Entity)
269                .filter(
270                    billing_customer::Column::UserId.eq(user_id).and(
271                        billing_subscription::Column::StripeSubscriptionStatus
272                            .eq(StripeSubscriptionStatus::Active)
273                            .or(billing_subscription::Column::StripeSubscriptionStatus
274                                .eq(StripeSubscriptionStatus::Trialing)),
275                    ),
276                )
277                .count(&*tx)
278                .await?;
279
280            Ok(count as usize)
281        })
282        .await
283    }
284}