collab: Add `POST /users/:id/update_plan` endpoint (#34953)

Marshall Bowers created

This PR adds a new `POST /users/:id/update_plan` endpoint to Collab to
allow Cloud to push down plan updates to users.

Release Notes:

- N/A

Change summary

crates/collab/src/api.rs         | 79 ++++++++++++++++++++++++++++++++++
crates/collab/src/api/billing.rs |  2 
crates/collab/src/rpc.rs         | 30 +++++++++---
3 files changed, 101 insertions(+), 10 deletions(-)

Detailed changes

crates/collab/src/api.rs 🔗

@@ -11,7 +11,9 @@ use crate::{
     db::{User, UserId},
     rpc,
 };
+use ::rpc::proto;
 use anyhow::Context as _;
+use axum::extract;
 use axum::{
     Extension, Json, Router,
     body::Body,
@@ -23,6 +25,7 @@ use axum::{
     routing::{get, post},
 };
 use axum_extra::response::ErasedJson;
+use chrono::{DateTime, Utc};
 use serde::{Deserialize, Serialize};
 use std::sync::{Arc, OnceLock};
 use tower::ServiceBuilder;
@@ -101,6 +104,7 @@ pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
         .route("/users/look_up", get(look_up_user))
         .route("/users/:id/access_tokens", post(create_access_token))
         .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
+        .route("/users/:id/update_plan", post(update_plan))
         .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
         .merge(billing::router())
         .merge(contributors::router())
@@ -347,3 +351,78 @@ async fn refresh_llm_tokens(
 
     Ok(Json(RefreshLlmTokensResponse {}))
 }
+
+#[derive(Debug, Serialize, Deserialize)]
+struct UpdatePlanBody {
+    pub plan: zed_llm_client::Plan,
+    pub subscription_period: SubscriptionPeriod,
+    pub usage: zed_llm_client::CurrentUsage,
+    pub trial_started_at: Option<DateTime<Utc>>,
+    pub is_usage_based_billing_enabled: bool,
+    pub is_account_too_young: bool,
+    pub has_overdue_invoices: bool,
+}
+
+#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
+struct SubscriptionPeriod {
+    pub started_at: DateTime<Utc>,
+    pub ended_at: DateTime<Utc>,
+}
+
+#[derive(Serialize)]
+struct UpdatePlanResponse {}
+
+async fn update_plan(
+    Path(user_id): Path<UserId>,
+    Extension(rpc_server): Extension<Arc<rpc::Server>>,
+    extract::Json(body): extract::Json<UpdatePlanBody>,
+) -> Result<Json<UpdatePlanResponse>> {
+    let plan = match body.plan {
+        zed_llm_client::Plan::ZedFree => proto::Plan::Free,
+        zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
+        zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
+    };
+
+    let update_user_plan = proto::UpdateUserPlan {
+        plan: plan.into(),
+        trial_started_at: body
+            .trial_started_at
+            .map(|trial_started_at| trial_started_at.timestamp() as u64),
+        is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
+        usage: Some(proto::SubscriptionUsage {
+            model_requests_usage_amount: body.usage.model_requests.used,
+            model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
+            edit_predictions_usage_amount: body.usage.edit_predictions.used,
+            edit_predictions_usage_limit: Some(usage_limit_to_proto(
+                body.usage.edit_predictions.limit,
+            )),
+        }),
+        subscription_period: Some(proto::SubscriptionPeriod {
+            started_at: body.subscription_period.started_at.timestamp() as u64,
+            ended_at: body.subscription_period.ended_at.timestamp() as u64,
+        }),
+        account_too_young: Some(body.is_account_too_young),
+        has_overdue_invoices: Some(body.has_overdue_invoices),
+    };
+
+    rpc_server
+        .update_plan_for_user(user_id, update_user_plan)
+        .await?;
+
+    Ok(Json(UpdatePlanResponse {}))
+}
+
+fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
+    proto::UsageLimit {
+        variant: Some(match limit {
+            zed_llm_client::UsageLimit::Limited(limit) => {
+                proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
+                    limit: limit as u32,
+                })
+            }
+            zed_llm_client::UsageLimit::Unlimited => {
+                proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
+            }
+        }),
+    }
+}

crates/collab/src/api/billing.rs 🔗

@@ -785,7 +785,7 @@ async fn handle_customer_subscription_event(
 
     // When the user's subscription changes, push down any changes to their plan.
     rpc_server
-        .update_plan_for_user(billing_customer.user_id)
+        .update_plan_for_user_legacy(billing_customer.user_id)
         .await
         .trace_err();
 

crates/collab/src/rpc.rs 🔗

@@ -1002,7 +1002,26 @@ impl Server {
         Ok(())
     }
 
-    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
+    pub async fn update_plan_for_user(
+        self: &Arc<Self>,
+        user_id: UserId,
+        update_user_plan: proto::UpdateUserPlan,
+    ) -> Result<()> {
+        let pool = self.connection_pool.lock();
+        for connection_id in pool.user_connection_ids(user_id) {
+            self.peer
+                .send(connection_id, update_user_plan.clone())
+                .trace_err();
+        }
+
+        Ok(())
+    }
+
+    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
+    /// message on the Collab server.
+    ///
+    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
+    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
         let user = self
             .app_state
             .db
@@ -1018,14 +1037,7 @@ impl Server {
         )
         .await?;
 
-        let pool = self.connection_pool.lock();
-        for connection_id in pool.user_connection_ids(user_id) {
-            self.peer
-                .send(connection_id, update_user_plan.clone())
-                .trace_err();
-        }
-
-        Ok(())
+        self.update_plan_for_user(user_id, update_user_plan).await
     }
 
     pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {