1use crate::db::billing_subscription::{
2 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
3};
4
5use super::*;
6
7#[derive(Debug)]
8pub struct CreateBillingSubscriptionParams {
9 pub billing_customer_id: BillingCustomerId,
10 pub kind: Option<SubscriptionKind>,
11 pub stripe_subscription_id: String,
12 pub stripe_subscription_status: StripeSubscriptionStatus,
13 pub stripe_cancellation_reason: Option<StripeCancellationReason>,
14 pub stripe_current_period_start: Option<i64>,
15 pub stripe_current_period_end: Option<i64>,
16}
17
18#[derive(Debug, Default)]
19pub struct UpdateBillingSubscriptionParams {
20 pub billing_customer_id: ActiveValue<BillingCustomerId>,
21 pub kind: ActiveValue<Option<SubscriptionKind>>,
22 pub stripe_subscription_id: ActiveValue<String>,
23 pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
24 pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
25 pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
26 pub stripe_current_period_start: ActiveValue<Option<i64>>,
27 pub stripe_current_period_end: ActiveValue<Option<i64>>,
28}
29
30impl Database {
31 /// Creates a new billing subscription.
32 pub async fn create_billing_subscription(
33 &self,
34 params: &CreateBillingSubscriptionParams,
35 ) -> Result<billing_subscription::Model> {
36 self.transaction(|tx| async move {
37 let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
38 billing_customer_id: ActiveValue::set(params.billing_customer_id),
39 kind: ActiveValue::set(params.kind),
40 stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
41 stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
42 stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
43 stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
44 stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
45 ..Default::default()
46 })
47 .exec(&*tx)
48 .await?
49 .last_insert_id;
50
51 Ok(billing_subscription::Entity::find_by_id(id)
52 .one(&*tx)
53 .await?
54 .ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?)
55 })
56 .await
57 }
58
59 /// Updates the specified billing subscription.
60 pub async fn update_billing_subscription(
61 &self,
62 id: BillingSubscriptionId,
63 params: &UpdateBillingSubscriptionParams,
64 ) -> Result<()> {
65 self.transaction(|tx| async move {
66 billing_subscription::Entity::update(billing_subscription::ActiveModel {
67 id: ActiveValue::set(id),
68 billing_customer_id: params.billing_customer_id.clone(),
69 kind: params.kind.clone(),
70 stripe_subscription_id: params.stripe_subscription_id.clone(),
71 stripe_subscription_status: params.stripe_subscription_status.clone(),
72 stripe_cancel_at: params.stripe_cancel_at.clone(),
73 stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
74 stripe_current_period_start: params.stripe_current_period_start.clone(),
75 stripe_current_period_end: params.stripe_current_period_end.clone(),
76 created_at: ActiveValue::not_set(),
77 })
78 .exec(&*tx)
79 .await?;
80
81 Ok(())
82 })
83 .await
84 }
85
86 /// Returns the billing subscription with the specified ID.
87 pub async fn get_billing_subscription_by_id(
88 &self,
89 id: BillingSubscriptionId,
90 ) -> Result<Option<billing_subscription::Model>> {
91 self.transaction(|tx| async move {
92 Ok(billing_subscription::Entity::find_by_id(id)
93 .one(&*tx)
94 .await?)
95 })
96 .await
97 }
98
99 /// Returns the billing subscription with the specified Stripe subscription ID.
100 pub async fn get_billing_subscription_by_stripe_subscription_id(
101 &self,
102 stripe_subscription_id: &str,
103 ) -> Result<Option<billing_subscription::Model>> {
104 self.transaction(|tx| async move {
105 Ok(billing_subscription::Entity::find()
106 .filter(
107 billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
108 )
109 .one(&*tx)
110 .await?)
111 })
112 .await
113 }
114
115 pub async fn get_active_billing_subscription(
116 &self,
117 user_id: UserId,
118 ) -> Result<Option<billing_subscription::Model>> {
119 self.transaction(|tx| async move {
120 Ok(billing_subscription::Entity::find()
121 .inner_join(billing_customer::Entity)
122 .filter(billing_customer::Column::UserId.eq(user_id))
123 .filter(
124 Condition::all()
125 .add(
126 Condition::any()
127 .add(
128 billing_subscription::Column::StripeSubscriptionStatus
129 .eq(StripeSubscriptionStatus::Active),
130 )
131 .add(
132 billing_subscription::Column::StripeSubscriptionStatus
133 .eq(StripeSubscriptionStatus::Trialing),
134 ),
135 )
136 .add(billing_subscription::Column::Kind.is_not_null()),
137 )
138 .one(&*tx)
139 .await?)
140 })
141 .await
142 }
143
144 /// Returns all of the billing subscriptions for the user with the specified ID.
145 ///
146 /// Note that this returns the subscriptions regardless of their status.
147 /// If you're wanting to check if a use has an active billing subscription,
148 /// use `get_active_billing_subscriptions` instead.
149 pub async fn get_billing_subscriptions(
150 &self,
151 user_id: UserId,
152 ) -> Result<Vec<billing_subscription::Model>> {
153 self.transaction(|tx| async move {
154 let subscriptions = billing_subscription::Entity::find()
155 .inner_join(billing_customer::Entity)
156 .filter(billing_customer::Column::UserId.eq(user_id))
157 .order_by_asc(billing_subscription::Column::Id)
158 .all(&*tx)
159 .await?;
160
161 Ok(subscriptions)
162 })
163 .await
164 }
165
166 pub async fn get_active_billing_subscriptions(
167 &self,
168 user_ids: HashSet<UserId>,
169 ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
170 self.transaction(|tx| {
171 let user_ids = user_ids.clone();
172 async move {
173 let mut rows = billing_subscription::Entity::find()
174 .inner_join(billing_customer::Entity)
175 .select_also(billing_customer::Entity)
176 .filter(billing_customer::Column::UserId.is_in(user_ids))
177 .filter(
178 billing_subscription::Column::StripeSubscriptionStatus
179 .eq(StripeSubscriptionStatus::Active),
180 )
181 .filter(billing_subscription::Column::Kind.is_null())
182 .order_by_asc(billing_subscription::Column::Id)
183 .stream(&*tx)
184 .await?;
185
186 let mut subscriptions = HashMap::default();
187 while let Some(row) = rows.next().await {
188 if let (subscription, Some(customer)) = row? {
189 subscriptions.insert(customer.user_id, (customer, subscription));
190 }
191 }
192 Ok(subscriptions)
193 }
194 })
195 .await
196 }
197
198 pub async fn get_active_zed_pro_billing_subscriptions(
199 &self,
200 user_ids: HashSet<UserId>,
201 ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
202 self.transaction(|tx| {
203 let user_ids = user_ids.clone();
204 async move {
205 let mut rows = billing_subscription::Entity::find()
206 .inner_join(billing_customer::Entity)
207 .select_also(billing_customer::Entity)
208 .filter(billing_customer::Column::UserId.is_in(user_ids))
209 .filter(
210 billing_subscription::Column::StripeSubscriptionStatus
211 .eq(StripeSubscriptionStatus::Active),
212 )
213 .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
214 .order_by_asc(billing_subscription::Column::Id)
215 .stream(&*tx)
216 .await?;
217
218 let mut subscriptions = HashMap::default();
219 while let Some(row) = rows.next().await {
220 if let (subscription, Some(customer)) = row? {
221 subscriptions.insert(customer.user_id, (customer, subscription));
222 }
223 }
224 Ok(subscriptions)
225 }
226 })
227 .await
228 }
229
230 /// Returns whether the user has an active billing subscription.
231 pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
232 Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
233 }
234
235 /// Returns the count of the active billing subscriptions for the user with the specified ID.
236 pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
237 self.transaction(|tx| async move {
238 let count = billing_subscription::Entity::find()
239 .inner_join(billing_customer::Entity)
240 .filter(
241 billing_customer::Column::UserId.eq(user_id).and(
242 billing_subscription::Column::StripeSubscriptionStatus
243 .eq(StripeSubscriptionStatus::Active)
244 .or(billing_subscription::Column::StripeSubscriptionStatus
245 .eq(StripeSubscriptionStatus::Trialing)),
246 ),
247 )
248 .count(&*tx)
249 .await?;
250
251 Ok(count as usize)
252 })
253 .await
254 }
255}