1use std::sync::Arc;
2
3use crate::db::billing_subscription::StripeSubscriptionStatus;
4use crate::db::tests::new_test_user;
5use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
6use crate::test_both_dbs;
7
8use super::Database;
9
10test_both_dbs!(
11 test_get_active_billing_subscriptions,
12 test_get_active_billing_subscriptions_postgres,
13 test_get_active_billing_subscriptions_sqlite
14);
15
16async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
17 // A user with no subscription has no active billing subscriptions.
18 {
19 let user_id = new_test_user(db, "no-subscription-user@example.com").await;
20 let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
21
22 assert_eq!(subscriptions.len(), 0);
23 }
24
25 // A user with an active subscription has one active billing subscription.
26 {
27 let user_id = new_test_user(db, "active-user@example.com").await;
28 let customer = db
29 .create_billing_customer(&CreateBillingCustomerParams {
30 user_id,
31 stripe_customer_id: "cus_active_user".into(),
32 last_stripe_event_id: None,
33 })
34 .await
35 .unwrap();
36 assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
37
38 db.create_billing_subscription(&CreateBillingSubscriptionParams {
39 billing_customer_id: customer.id,
40 stripe_subscription_id: "sub_active_user".into(),
41 stripe_subscription_status: StripeSubscriptionStatus::Active,
42 last_stripe_event_id: None,
43 })
44 .await
45 .unwrap();
46
47 let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
48 assert_eq!(subscriptions.len(), 1);
49
50 let subscription = &subscriptions[0];
51 assert_eq!(
52 subscription.stripe_subscription_id,
53 "sub_active_user".to_string()
54 );
55 assert_eq!(
56 subscription.stripe_subscription_status,
57 StripeSubscriptionStatus::Active
58 );
59 }
60
61 // A user with a past-due subscription has no active billing subscriptions.
62 {
63 let user_id = new_test_user(db, "past-due-user@example.com").await;
64 let customer = db
65 .create_billing_customer(&CreateBillingCustomerParams {
66 user_id,
67 stripe_customer_id: "cus_past_due_user".into(),
68 last_stripe_event_id: None,
69 })
70 .await
71 .unwrap();
72 assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
73
74 db.create_billing_subscription(&CreateBillingSubscriptionParams {
75 billing_customer_id: customer.id,
76 stripe_subscription_id: "sub_past_due_user".into(),
77 stripe_subscription_status: StripeSubscriptionStatus::PastDue,
78 last_stripe_event_id: None,
79 })
80 .await
81 .unwrap();
82
83 let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
84 assert_eq!(subscriptions.len(), 0);
85 }
86}