@@ -11,7 +11,9 @@ use crate::{
db::{User, UserId},
rpc,
};
+use ::rpc::proto;
use anyhow::Context as _;
+use axum::extract;
use axum::{
Extension, Json, Router,
body::Body,
@@ -23,6 +25,7 @@ use axum::{
routing::{get, post},
};
use axum_extra::response::ErasedJson;
+use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::sync::{Arc, OnceLock};
use tower::ServiceBuilder;
@@ -101,6 +104,7 @@ pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
.route("/users/look_up", get(look_up_user))
.route("/users/:id/access_tokens", post(create_access_token))
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
+ .route("/users/:id/update_plan", post(update_plan))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
.merge(billing::router())
.merge(contributors::router())
@@ -347,3 +351,78 @@ async fn refresh_llm_tokens(
Ok(Json(RefreshLlmTokensResponse {}))
}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct UpdatePlanBody {
+ pub plan: zed_llm_client::Plan,
+ pub subscription_period: SubscriptionPeriod,
+ pub usage: zed_llm_client::CurrentUsage,
+ pub trial_started_at: Option<DateTime<Utc>>,
+ pub is_usage_based_billing_enabled: bool,
+ pub is_account_too_young: bool,
+ pub has_overdue_invoices: bool,
+}
+
+#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
+struct SubscriptionPeriod {
+ pub started_at: DateTime<Utc>,
+ pub ended_at: DateTime<Utc>,
+}
+
+#[derive(Serialize)]
+struct UpdatePlanResponse {}
+
+async fn update_plan(
+ Path(user_id): Path<UserId>,
+ Extension(rpc_server): Extension<Arc<rpc::Server>>,
+ extract::Json(body): extract::Json<UpdatePlanBody>,
+) -> Result<Json<UpdatePlanResponse>> {
+ let plan = match body.plan {
+ zed_llm_client::Plan::ZedFree => proto::Plan::Free,
+ zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
+ zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
+ };
+
+ let update_user_plan = proto::UpdateUserPlan {
+ plan: plan.into(),
+ trial_started_at: body
+ .trial_started_at
+ .map(|trial_started_at| trial_started_at.timestamp() as u64),
+ is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
+ usage: Some(proto::SubscriptionUsage {
+ model_requests_usage_amount: body.usage.model_requests.used,
+ model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
+ edit_predictions_usage_amount: body.usage.edit_predictions.used,
+ edit_predictions_usage_limit: Some(usage_limit_to_proto(
+ body.usage.edit_predictions.limit,
+ )),
+ }),
+ subscription_period: Some(proto::SubscriptionPeriod {
+ started_at: body.subscription_period.started_at.timestamp() as u64,
+ ended_at: body.subscription_period.ended_at.timestamp() as u64,
+ }),
+ account_too_young: Some(body.is_account_too_young),
+ has_overdue_invoices: Some(body.has_overdue_invoices),
+ };
+
+ rpc_server
+ .update_plan_for_user(user_id, update_user_plan)
+ .await?;
+
+ Ok(Json(UpdatePlanResponse {}))
+}
+
+fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
+ proto::UsageLimit {
+ variant: Some(match limit {
+ zed_llm_client::UsageLimit::Limited(limit) => {
+ proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
+ limit: limit as u32,
+ })
+ }
+ zed_llm_client::UsageLimit::Unlimited => {
+ proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
+ }
+ }),
+ }
+}
@@ -1002,7 +1002,26 @@ impl Server {
Ok(())
}
- pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
+ pub async fn update_plan_for_user(
+ self: &Arc<Self>,
+ user_id: UserId,
+ update_user_plan: proto::UpdateUserPlan,
+ ) -> Result<()> {
+ let pool = self.connection_pool.lock();
+ for connection_id in pool.user_connection_ids(user_id) {
+ self.peer
+ .send(connection_id, update_user_plan.clone())
+ .trace_err();
+ }
+
+ Ok(())
+ }
+
+ /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
+ /// message on the Collab server.
+ ///
+ /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
+ pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
let user = self
.app_state
.db
@@ -1018,14 +1037,7 @@ impl Server {
)
.await?;
- let pool = self.connection_pool.lock();
- for connection_id in pool.user_connection_ids(user_id) {
- self.peer
- .send(connection_id, update_user_plan.clone())
- .trace_err();
- }
-
- Ok(())
+ self.update_plan_for_user(user_id, update_user_plan).await
}
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {