1use anyhow::Context as _;
2
3use crate::db::billing_subscription::{
4 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
5};
6
7use super::*;
8
9#[derive(Debug)]
10pub struct CreateBillingSubscriptionParams {
11 pub billing_customer_id: BillingCustomerId,
12 pub kind: Option<SubscriptionKind>,
13 pub stripe_subscription_id: String,
14 pub stripe_subscription_status: StripeSubscriptionStatus,
15 pub stripe_cancellation_reason: Option<StripeCancellationReason>,
16 pub stripe_current_period_start: Option<i64>,
17 pub stripe_current_period_end: Option<i64>,
18}
19
20#[derive(Debug, Default)]
21pub struct UpdateBillingSubscriptionParams {
22 pub billing_customer_id: ActiveValue<BillingCustomerId>,
23 pub kind: ActiveValue<Option<SubscriptionKind>>,
24 pub stripe_subscription_id: ActiveValue<String>,
25 pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
26 pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
27 pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
28 pub stripe_current_period_start: ActiveValue<Option<i64>>,
29 pub stripe_current_period_end: ActiveValue<Option<i64>>,
30}
31
32impl Database {
33 /// Creates a new billing subscription.
34 pub async fn create_billing_subscription(
35 &self,
36 params: &CreateBillingSubscriptionParams,
37 ) -> Result<billing_subscription::Model> {
38 self.transaction(|tx| async move {
39 let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
40 billing_customer_id: ActiveValue::set(params.billing_customer_id),
41 kind: ActiveValue::set(params.kind),
42 stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
43 stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
44 stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
45 stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
46 stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
47 ..Default::default()
48 })
49 .exec(&*tx)
50 .await?
51 .last_insert_id;
52
53 Ok(billing_subscription::Entity::find_by_id(id)
54 .one(&*tx)
55 .await?
56 .context("failed to retrieve inserted billing subscription")?)
57 })
58 .await
59 }
60
61 /// Updates the specified billing subscription.
62 pub async fn update_billing_subscription(
63 &self,
64 id: BillingSubscriptionId,
65 params: &UpdateBillingSubscriptionParams,
66 ) -> Result<()> {
67 self.transaction(|tx| async move {
68 billing_subscription::Entity::update(billing_subscription::ActiveModel {
69 id: ActiveValue::set(id),
70 billing_customer_id: params.billing_customer_id.clone(),
71 kind: params.kind.clone(),
72 stripe_subscription_id: params.stripe_subscription_id.clone(),
73 stripe_subscription_status: params.stripe_subscription_status.clone(),
74 stripe_cancel_at: params.stripe_cancel_at.clone(),
75 stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
76 stripe_current_period_start: params.stripe_current_period_start.clone(),
77 stripe_current_period_end: params.stripe_current_period_end.clone(),
78 created_at: ActiveValue::not_set(),
79 })
80 .exec(&*tx)
81 .await?;
82
83 Ok(())
84 })
85 .await
86 }
87
88 /// Returns the billing subscription with the specified ID.
89 pub async fn get_billing_subscription_by_id(
90 &self,
91 id: BillingSubscriptionId,
92 ) -> Result<Option<billing_subscription::Model>> {
93 self.transaction(|tx| async move {
94 Ok(billing_subscription::Entity::find_by_id(id)
95 .one(&*tx)
96 .await?)
97 })
98 .await
99 }
100
101 /// Returns the billing subscription with the specified Stripe subscription ID.
102 pub async fn get_billing_subscription_by_stripe_subscription_id(
103 &self,
104 stripe_subscription_id: &str,
105 ) -> Result<Option<billing_subscription::Model>> {
106 self.transaction(|tx| async move {
107 Ok(billing_subscription::Entity::find()
108 .filter(
109 billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
110 )
111 .one(&*tx)
112 .await?)
113 })
114 .await
115 }
116
117 pub async fn get_active_billing_subscription(
118 &self,
119 user_id: UserId,
120 ) -> Result<Option<billing_subscription::Model>> {
121 self.transaction(|tx| async move {
122 Ok(billing_subscription::Entity::find()
123 .inner_join(billing_customer::Entity)
124 .filter(billing_customer::Column::UserId.eq(user_id))
125 .filter(
126 Condition::all()
127 .add(
128 Condition::any()
129 .add(
130 billing_subscription::Column::StripeSubscriptionStatus
131 .eq(StripeSubscriptionStatus::Active),
132 )
133 .add(
134 billing_subscription::Column::StripeSubscriptionStatus
135 .eq(StripeSubscriptionStatus::Trialing),
136 ),
137 )
138 .add(billing_subscription::Column::Kind.is_not_null()),
139 )
140 .one(&*tx)
141 .await?)
142 })
143 .await
144 }
145
146 /// Returns all of the billing subscriptions for the user with the specified ID.
147 ///
148 /// Note that this returns the subscriptions regardless of their status.
149 /// If you're wanting to check if a use has an active billing subscription,
150 /// use `get_active_billing_subscriptions` instead.
151 pub async fn get_billing_subscriptions(
152 &self,
153 user_id: UserId,
154 ) -> Result<Vec<billing_subscription::Model>> {
155 self.transaction(|tx| async move {
156 let subscriptions = billing_subscription::Entity::find()
157 .inner_join(billing_customer::Entity)
158 .filter(billing_customer::Column::UserId.eq(user_id))
159 .order_by_asc(billing_subscription::Column::Id)
160 .all(&*tx)
161 .await?;
162
163 Ok(subscriptions)
164 })
165 .await
166 }
167
168 pub async fn get_active_billing_subscriptions(
169 &self,
170 user_ids: HashSet<UserId>,
171 ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
172 self.transaction(|tx| {
173 let user_ids = user_ids.clone();
174 async move {
175 let mut rows = billing_subscription::Entity::find()
176 .inner_join(billing_customer::Entity)
177 .select_also(billing_customer::Entity)
178 .filter(billing_customer::Column::UserId.is_in(user_ids))
179 .filter(
180 billing_subscription::Column::StripeSubscriptionStatus
181 .eq(StripeSubscriptionStatus::Active),
182 )
183 .filter(billing_subscription::Column::Kind.is_null())
184 .order_by_asc(billing_subscription::Column::Id)
185 .stream(&*tx)
186 .await?;
187
188 let mut subscriptions = HashMap::default();
189 while let Some(row) = rows.next().await {
190 if let (subscription, Some(customer)) = row? {
191 subscriptions.insert(customer.user_id, (customer, subscription));
192 }
193 }
194 Ok(subscriptions)
195 }
196 })
197 .await
198 }
199
200 pub async fn get_active_zed_pro_billing_subscriptions(
201 &self,
202 ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
203 self.transaction(|tx| async move {
204 let mut rows = billing_subscription::Entity::find()
205 .inner_join(billing_customer::Entity)
206 .select_also(billing_customer::Entity)
207 .filter(
208 billing_subscription::Column::StripeSubscriptionStatus
209 .eq(StripeSubscriptionStatus::Active),
210 )
211 .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
212 .order_by_asc(billing_subscription::Column::Id)
213 .stream(&*tx)
214 .await?;
215
216 let mut subscriptions = HashMap::default();
217 while let Some(row) = rows.next().await {
218 if let (subscription, Some(customer)) = row? {
219 subscriptions.insert(customer.user_id, (customer, subscription));
220 }
221 }
222 Ok(subscriptions)
223 })
224 .await
225 }
226
227 pub async fn get_active_zed_pro_billing_subscriptions_for_users(
228 &self,
229 user_ids: HashSet<UserId>,
230 ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
231 self.transaction(|tx| {
232 let user_ids = user_ids.clone();
233 async move {
234 let mut rows = billing_subscription::Entity::find()
235 .inner_join(billing_customer::Entity)
236 .select_also(billing_customer::Entity)
237 .filter(billing_customer::Column::UserId.is_in(user_ids))
238 .filter(
239 billing_subscription::Column::StripeSubscriptionStatus
240 .eq(StripeSubscriptionStatus::Active),
241 )
242 .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
243 .order_by_asc(billing_subscription::Column::Id)
244 .stream(&*tx)
245 .await?;
246
247 let mut subscriptions = HashMap::default();
248 while let Some(row) = rows.next().await {
249 if let (subscription, Some(customer)) = row? {
250 subscriptions.insert(customer.user_id, (customer, subscription));
251 }
252 }
253 Ok(subscriptions)
254 }
255 })
256 .await
257 }
258
259 /// Returns whether the user has an active billing subscription.
260 pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
261 Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
262 }
263
264 /// Returns the count of the active billing subscriptions for the user with the specified ID.
265 pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
266 self.transaction(|tx| async move {
267 let count = billing_subscription::Entity::find()
268 .inner_join(billing_customer::Entity)
269 .filter(
270 billing_customer::Column::UserId.eq(user_id).and(
271 billing_subscription::Column::StripeSubscriptionStatus
272 .eq(StripeSubscriptionStatus::Active)
273 .or(billing_subscription::Column::StripeSubscriptionStatus
274 .eq(StripeSubscriptionStatus::Trialing)),
275 ),
276 )
277 .count(&*tx)
278 .await?;
279
280 Ok(count as usize)
281 })
282 .await
283 }
284}