1use crate::db::billing_subscription::{
2 StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
3};
4
5use super::*;
6
7#[derive(Debug)]
8pub struct CreateBillingSubscriptionParams {
9 pub billing_customer_id: BillingCustomerId,
10 pub kind: Option<SubscriptionKind>,
11 pub stripe_subscription_id: String,
12 pub stripe_subscription_status: StripeSubscriptionStatus,
13 pub stripe_cancellation_reason: Option<StripeCancellationReason>,
14 pub stripe_current_period_start: Option<i64>,
15 pub stripe_current_period_end: Option<i64>,
16}
17
18#[derive(Debug, Default)]
19pub struct UpdateBillingSubscriptionParams {
20 pub billing_customer_id: ActiveValue<BillingCustomerId>,
21 pub kind: ActiveValue<Option<SubscriptionKind>>,
22 pub stripe_subscription_id: ActiveValue<String>,
23 pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
24 pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
25 pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
26 pub stripe_current_period_start: ActiveValue<Option<i64>>,
27 pub stripe_current_period_end: ActiveValue<Option<i64>>,
28}
29
30impl Database {
31 /// Creates a new billing subscription.
32 pub async fn create_billing_subscription(
33 &self,
34 params: &CreateBillingSubscriptionParams,
35 ) -> Result<()> {
36 self.transaction(|tx| async move {
37 billing_subscription::Entity::insert(billing_subscription::ActiveModel {
38 billing_customer_id: ActiveValue::set(params.billing_customer_id),
39 kind: ActiveValue::set(params.kind),
40 stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
41 stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
42 stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
43 stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
44 stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
45 ..Default::default()
46 })
47 .exec_without_returning(&*tx)
48 .await?;
49
50 Ok(())
51 })
52 .await
53 }
54
55 /// Updates the specified billing subscription.
56 pub async fn update_billing_subscription(
57 &self,
58 id: BillingSubscriptionId,
59 params: &UpdateBillingSubscriptionParams,
60 ) -> Result<()> {
61 self.transaction(|tx| async move {
62 billing_subscription::Entity::update(billing_subscription::ActiveModel {
63 id: ActiveValue::set(id),
64 billing_customer_id: params.billing_customer_id.clone(),
65 kind: params.kind.clone(),
66 stripe_subscription_id: params.stripe_subscription_id.clone(),
67 stripe_subscription_status: params.stripe_subscription_status.clone(),
68 stripe_cancel_at: params.stripe_cancel_at.clone(),
69 stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
70 stripe_current_period_start: params.stripe_current_period_start.clone(),
71 stripe_current_period_end: params.stripe_current_period_end.clone(),
72 created_at: ActiveValue::not_set(),
73 })
74 .exec(&*tx)
75 .await?;
76
77 Ok(())
78 })
79 .await
80 }
81
82 /// Returns the billing subscription with the specified ID.
83 pub async fn get_billing_subscription_by_id(
84 &self,
85 id: BillingSubscriptionId,
86 ) -> Result<Option<billing_subscription::Model>> {
87 self.transaction(|tx| async move {
88 Ok(billing_subscription::Entity::find_by_id(id)
89 .one(&*tx)
90 .await?)
91 })
92 .await
93 }
94
95 /// Returns the billing subscription with the specified Stripe subscription ID.
96 pub async fn get_billing_subscription_by_stripe_subscription_id(
97 &self,
98 stripe_subscription_id: &str,
99 ) -> Result<Option<billing_subscription::Model>> {
100 self.transaction(|tx| async move {
101 Ok(billing_subscription::Entity::find()
102 .filter(
103 billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
104 )
105 .one(&*tx)
106 .await?)
107 })
108 .await
109 }
110
111 pub async fn get_active_billing_subscription(
112 &self,
113 user_id: UserId,
114 ) -> Result<Option<billing_subscription::Model>> {
115 self.transaction(|tx| async move {
116 Ok(billing_subscription::Entity::find()
117 .inner_join(billing_customer::Entity)
118 .filter(billing_customer::Column::UserId.eq(user_id))
119 .filter(
120 Condition::all()
121 .add(
122 Condition::any()
123 .add(
124 billing_subscription::Column::StripeSubscriptionStatus
125 .eq(StripeSubscriptionStatus::Active),
126 )
127 .add(
128 billing_subscription::Column::StripeSubscriptionStatus
129 .eq(StripeSubscriptionStatus::Trialing),
130 ),
131 )
132 .add(billing_subscription::Column::Kind.is_not_null()),
133 )
134 .one(&*tx)
135 .await?)
136 })
137 .await
138 }
139
140 /// Returns all of the billing subscriptions for the user with the specified ID.
141 ///
142 /// Note that this returns the subscriptions regardless of their status.
143 /// If you're wanting to check if a use has an active billing subscription,
144 /// use `get_active_billing_subscriptions` instead.
145 pub async fn get_billing_subscriptions(
146 &self,
147 user_id: UserId,
148 ) -> Result<Vec<billing_subscription::Model>> {
149 self.transaction(|tx| async move {
150 let subscriptions = billing_subscription::Entity::find()
151 .inner_join(billing_customer::Entity)
152 .filter(billing_customer::Column::UserId.eq(user_id))
153 .order_by_asc(billing_subscription::Column::Id)
154 .all(&*tx)
155 .await?;
156
157 Ok(subscriptions)
158 })
159 .await
160 }
161
162 pub async fn get_active_billing_subscriptions(
163 &self,
164 user_ids: HashSet<UserId>,
165 ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
166 self.transaction(|tx| {
167 let user_ids = user_ids.clone();
168 async move {
169 let mut rows = billing_subscription::Entity::find()
170 .inner_join(billing_customer::Entity)
171 .select_also(billing_customer::Entity)
172 .filter(billing_customer::Column::UserId.is_in(user_ids))
173 .filter(
174 billing_subscription::Column::StripeSubscriptionStatus
175 .eq(StripeSubscriptionStatus::Active),
176 )
177 .filter(billing_subscription::Column::Kind.is_null())
178 .order_by_asc(billing_subscription::Column::Id)
179 .stream(&*tx)
180 .await?;
181
182 let mut subscriptions = HashMap::default();
183 while let Some(row) = rows.next().await {
184 if let (subscription, Some(customer)) = row? {
185 subscriptions.insert(customer.user_id, (customer, subscription));
186 }
187 }
188 Ok(subscriptions)
189 }
190 })
191 .await
192 }
193
194 pub async fn get_active_zed_pro_billing_subscriptions(
195 &self,
196 user_ids: HashSet<UserId>,
197 ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
198 self.transaction(|tx| {
199 let user_ids = user_ids.clone();
200 async move {
201 let mut rows = billing_subscription::Entity::find()
202 .inner_join(billing_customer::Entity)
203 .select_also(billing_customer::Entity)
204 .filter(billing_customer::Column::UserId.is_in(user_ids))
205 .filter(
206 billing_subscription::Column::StripeSubscriptionStatus
207 .eq(StripeSubscriptionStatus::Active),
208 )
209 .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
210 .order_by_asc(billing_subscription::Column::Id)
211 .stream(&*tx)
212 .await?;
213
214 let mut subscriptions = HashMap::default();
215 while let Some(row) = rows.next().await {
216 if let (subscription, Some(customer)) = row? {
217 subscriptions.insert(customer.user_id, (customer, subscription));
218 }
219 }
220 Ok(subscriptions)
221 }
222 })
223 .await
224 }
225
226 /// Returns whether the user has an active billing subscription.
227 pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
228 Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
229 }
230
231 /// Returns the count of the active billing subscriptions for the user with the specified ID.
232 pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
233 self.transaction(|tx| async move {
234 let count = billing_subscription::Entity::find()
235 .inner_join(billing_customer::Entity)
236 .filter(
237 billing_customer::Column::UserId.eq(user_id).and(
238 billing_subscription::Column::StripeSubscriptionStatus
239 .eq(StripeSubscriptionStatus::Active),
240 ),
241 )
242 .count(&*tx)
243 .await?;
244
245 Ok(count as usize)
246 })
247 .await
248 }
249}