1use std::sync::Arc;
2
3use crate::db::billing_subscription::StripeSubscriptionStatus;
4use crate::db::tests::new_test_user;
5use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
6use crate::test_both_dbs;
7
8use super::Database;
9
10test_both_dbs!(
11 test_get_active_billing_subscriptions,
12 test_get_active_billing_subscriptions_postgres,
13 test_get_active_billing_subscriptions_sqlite
14);
15
16async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
17 // A user with no subscription has no active billing subscriptions.
18 {
19 let user_id = new_test_user(db, "no-subscription-user@example.com").await;
20 let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
21
22 assert_eq!(subscriptions.len(), 0);
23 }
24
25 // A user with an active subscription has one active billing subscription.
26 {
27 let user_id = new_test_user(db, "active-user@example.com").await;
28 let customer = db
29 .create_billing_customer(&CreateBillingCustomerParams {
30 user_id,
31 stripe_customer_id: "cus_active_user".into(),
32 })
33 .await
34 .unwrap();
35 assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
36
37 db.create_billing_subscription(&CreateBillingSubscriptionParams {
38 billing_customer_id: customer.id,
39 stripe_subscription_id: "sub_active_user".into(),
40 stripe_subscription_status: StripeSubscriptionStatus::Active,
41 })
42 .await
43 .unwrap();
44
45 let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
46 assert_eq!(subscriptions.len(), 1);
47
48 let subscription = &subscriptions[0];
49 assert_eq!(
50 subscription.stripe_subscription_id,
51 "sub_active_user".to_string()
52 );
53 assert_eq!(
54 subscription.stripe_subscription_status,
55 StripeSubscriptionStatus::Active
56 );
57 }
58
59 // A user with a past-due subscription has no active billing subscriptions.
60 {
61 let user_id = new_test_user(db, "past-due-user@example.com").await;
62 let customer = db
63 .create_billing_customer(&CreateBillingCustomerParams {
64 user_id,
65 stripe_customer_id: "cus_past_due_user".into(),
66 })
67 .await
68 .unwrap();
69 assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
70
71 db.create_billing_subscription(&CreateBillingSubscriptionParams {
72 billing_customer_id: customer.id,
73 stripe_subscription_id: "sub_past_due_user".into(),
74 stripe_subscription_status: StripeSubscriptionStatus::PastDue,
75 })
76 .await
77 .unwrap();
78
79 let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
80 assert_eq!(subscriptions.len(), 0);
81 }
82}