Introduce a separate backend service for LLM calls (#15831)

Max Brunsfeld , Marshall , and Marshall Bowers created

This PR introduces a separate backend service for making LLM calls.

It exposes an HTTP interface that can be called by Zed clients. To call
these endpoints, the client must provide a `Bearer` token. These tokens
are issued/refreshed by the collab service over RPC.

We're adding this in a backwards-compatible way. Right now the access
tokens can only be minted for Zed staff, and calling this separate LLM
service is behind the `llm-service` feature flag (which is not
automatically enabled for Zed staff).

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>

Change summary

Cargo.lock                                  |   2 
crates/collab/.env.toml                     |   1 
crates/collab/Cargo.toml                    |   2 
crates/collab/src/api.rs                    |  10 
crates/collab/src/api/billing.rs            |   4 
crates/collab/src/api/events.rs             |  18 +-
crates/collab/src/api/extensions.rs         |   4 
crates/collab/src/auth.rs                   |  12 
crates/collab/src/lib.rs                    |  22 ++
crates/collab/src/llm.rs                    | 110 +++++++++++++++++
crates/collab/src/llm/token.rs              |  75 +++++++++++
crates/collab/src/main.rs                   |   4 
crates/collab/src/rpc.rs                    | 130 ++++++++++++--------
crates/collab/src/tests/test_server.rs      |   1 
crates/http_client/src/http_client.rs       |  16 ++
crates/language_model/src/provider/cloud.rs | 145 ++++++++++++++++++++--
crates/proto/proto/zed.proto                |  11 +
crates/proto/src/proto.rs                   |   3 
crates/rpc/src/llm.rs                       |   8 +
crates/rpc/src/rpc.rs                       |   2 
20 files changed, 478 insertions(+), 102 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2466,6 +2466,7 @@ dependencies = [
  "hex",
  "http_client",
  "indoc",
+ "jsonwebtoken",
  "language",
  "language_model",
  "live_kit_client",
@@ -2507,6 +2508,7 @@ dependencies = [
  "telemetry_events",
  "text",
  "theme",
+ "thiserror",
  "time",
  "tokio",
  "toml 0.8.16",

crates/collab/.env.toml 🔗

@@ -15,6 +15,7 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
 BLOB_STORE_REGION = "the-region"
 ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
 SEED_PATH = "crates/collab/seed.default.json"
+LLM_API_SECRET = "llm-secret"
 
 # CLICKHOUSE_URL = ""
 # CLICKHOUSE_USER = "default"

crates/collab/Cargo.toml 🔗

@@ -37,6 +37,7 @@ futures.workspace = true
 google_ai.workspace = true
 hex.workspace = true
 http_client.workspace = true
+jsonwebtoken.workspace = true
 live_kit_server.workspace = true
 log.workspace = true
 nanoid.workspace = true
@@ -61,6 +62,7 @@ subtle.workspace = true
 rustc-demangle.workspace = true
 telemetry_events.workspace = true
 text.workspace = true
+thiserror.workspace = true
 time.workspace = true
 tokio.workspace = true
 toml.workspace = true

crates/collab/src/api.rs 🔗

@@ -81,14 +81,14 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
         .get(http::header::AUTHORIZATION)
         .and_then(|header| header.to_str().ok())
         .ok_or_else(|| {
-            Error::Http(
+            Error::http(
                 StatusCode::BAD_REQUEST,
                 "missing authorization header".to_string(),
             )
         })?
         .strip_prefix("token ")
         .ok_or_else(|| {
-            Error::Http(
+            Error::http(
                 StatusCode::BAD_REQUEST,
                 "invalid authorization header".to_string(),
             )
@@ -97,7 +97,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
     let state = req.extensions().get::<Arc<AppState>>().unwrap();
 
     if token != state.config.api_token {
-        Err(Error::Http(
+        Err(Error::http(
             StatusCode::UNAUTHORIZED,
             "invalid authorization token".to_string(),
         ))?
@@ -185,13 +185,13 @@ async fn create_access_token(
             if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
                 impersonated_user_id = Some(impersonated_user.id);
             } else {
-                return Err(Error::Http(
+                return Err(Error::http(
                     StatusCode::UNPROCESSABLE_ENTITY,
                     format!("user {impersonate} does not exist"),
                 ));
             }
         } else {
-            return Err(Error::Http(
+            return Err(Error::http(
                 StatusCode::UNAUTHORIZED,
                 "you do not have permission to impersonate other users".to_string(),
             ));

crates/collab/src/api/billing.rs 🔗

@@ -120,7 +120,7 @@ async fn create_billing_subscription(
         .zip(app.config.stripe_price_id.clone())
     else {
         log::error!("failed to retrieve Stripe client or price ID");
-        Err(Error::Http(
+        Err(Error::http(
             StatusCode::NOT_IMPLEMENTED,
             "not supported".into(),
         ))?
@@ -201,7 +201,7 @@ async fn manage_billing_subscription(
 
     let Some(stripe_client) = app.stripe_client.clone() else {
         log::error!("failed to retrieve Stripe client");
-        Err(Error::Http(
+        Err(Error::http(
             StatusCode::NOT_IMPLEMENTED,
             "not supported".into(),
         ))?

crates/collab/src/api/events.rs 🔗

@@ -206,14 +206,14 @@ pub async fn post_hang(
     body: Bytes,
 ) -> Result<()> {
     let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
-        return Err(Error::Http(
+        return Err(Error::http(
             StatusCode::INTERNAL_SERVER_ERROR,
             "events not enabled".into(),
         ))?;
     };
 
     if checksum != expected {
-        return Err(Error::Http(
+        return Err(Error::http(
             StatusCode::BAD_REQUEST,
             "invalid checksum".into(),
         ))?;
@@ -265,25 +265,25 @@ pub async fn post_panic(
     body: Bytes,
 ) -> Result<()> {
     let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
-        return Err(Error::Http(
+        return Err(Error::http(
             StatusCode::INTERNAL_SERVER_ERROR,
             "events not enabled".into(),
         ))?;
     };
 
     if checksum != expected {
-        return Err(Error::Http(
+        return Err(Error::http(
             StatusCode::BAD_REQUEST,
             "invalid checksum".into(),
         ))?;
     }
 
     let report: telemetry_events::PanicRequest = serde_json::from_slice(&body)
-        .map_err(|_| Error::Http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
+        .map_err(|_| Error::http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
     let panic = report.panic;
 
     if panic.os_name == "Linux" && panic.os_version == Some("1.0.0".to_string()) {
-        return Err(Error::Http(
+        return Err(Error::http(
             StatusCode::BAD_REQUEST,
             "invalid os version".into(),
         ))?;
@@ -362,14 +362,14 @@ pub async fn post_events(
     body: Bytes,
 ) -> Result<()> {
     let Some(clickhouse_client) = app.clickhouse_client.clone() else {
-        Err(Error::Http(
+        Err(Error::http(
             StatusCode::NOT_IMPLEMENTED,
             "not supported".into(),
         ))?
     };
 
     let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
-        return Err(Error::Http(
+        return Err(Error::http(
             StatusCode::INTERNAL_SERVER_ERROR,
             "events not enabled".into(),
         ))?;
@@ -385,7 +385,7 @@ pub async fn post_events(
 
     let mut to_upload = ToUpload::default();
     let Some(last_event) = request_body.events.last() else {
-        return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
+        return Err(Error::http(StatusCode::BAD_REQUEST, "no events".into()))?;
     };
     let country_code = country_code_header.map(|h| h.to_string());
 

crates/collab/src/api/extensions.rs 🔗

@@ -185,7 +185,7 @@ async fn download_extension(
         .clone()
         .zip(app.config.blob_store_bucket.clone())
     else {
-        Err(Error::Http(
+        Err(Error::http(
             StatusCode::NOT_IMPLEMENTED,
             "not supported".into(),
         ))?
@@ -202,7 +202,7 @@ async fn download_extension(
         .await?;
 
     if !version_exists {
-        Err(Error::Http(
+        Err(Error::http(
             StatusCode::NOT_FOUND,
             "unknown extension version".into(),
         ))?;

crates/collab/src/auth.rs 🔗

@@ -33,7 +33,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
         .get(http::header::AUTHORIZATION)
         .and_then(|header| header.to_str().ok())
         .ok_or_else(|| {
-            Error::Http(
+            Error::http(
                 StatusCode::UNAUTHORIZED,
                 "missing authorization header".to_string(),
             )
@@ -45,14 +45,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
     let first = auth_header.next().unwrap_or("");
     if first == "dev-server-token" {
         let dev_server_token = auth_header.next().ok_or_else(|| {
-            Error::Http(
+            Error::http(
                 StatusCode::BAD_REQUEST,
                 "missing dev-server-token token in authorization header".to_string(),
             )
         })?;
         let dev_server = verify_dev_server_token(dev_server_token, &state.db)
             .await
-            .map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
+            .map_err(|e| Error::http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
 
         req.extensions_mut()
             .insert(Principal::DevServer(dev_server));
@@ -60,14 +60,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
     }
 
     let user_id = UserId(first.parse().map_err(|_| {
-        Error::Http(
+        Error::http(
             StatusCode::BAD_REQUEST,
             "missing user id in authorization header".to_string(),
         )
     })?);
 
     let access_token = auth_header.next().ok_or_else(|| {
-        Error::Http(
+        Error::http(
             StatusCode::BAD_REQUEST,
             "missing access token in authorization header".to_string(),
         )
@@ -111,7 +111,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
         }
     }
 
-    Err(Error::Http(
+    Err(Error::http(
         StatusCode::UNAUTHORIZED,
         "invalid credentials".to_string(),
     ))

crates/collab/src/lib.rs 🔗

@@ -13,7 +13,10 @@ mod tests;
 
 use anyhow::anyhow;
 use aws_config::{BehaviorVersion, Region};
-use axum::{http::StatusCode, response::IntoResponse};
+use axum::{
+    http::{HeaderMap, StatusCode},
+    response::IntoResponse,
+};
 use db::{ChannelId, Database};
 use executor::Executor;
 pub use rate_limiter::*;
@@ -24,7 +27,7 @@ use util::ResultExt;
 pub type Result<T, E = Error> = std::result::Result<T, E>;
 
 pub enum Error {
-    Http(StatusCode, String),
+    Http(StatusCode, String, HeaderMap),
     Database(sea_orm::error::DbErr),
     Internal(anyhow::Error),
     Stripe(stripe::StripeError),
@@ -66,12 +69,18 @@ impl From<serde_json::Error> for Error {
     }
 }
 
+impl Error {
+    fn http(code: StatusCode, message: String) -> Self {
+        Self::Http(code, message, HeaderMap::default())
+    }
+}
+
 impl IntoResponse for Error {
     fn into_response(self) -> axum::response::Response {
         match self {
-            Error::Http(code, message) => {
+            Error::Http(code, message, headers) => {
                 log::error!("HTTP error {}: {}", code, &message);
-                (code, message).into_response()
+                (code, headers, message).into_response()
             }
             Error::Database(error) => {
                 log::error!(
@@ -104,7 +113,7 @@ impl IntoResponse for Error {
 impl std::fmt::Debug for Error {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         match self {
-            Error::Http(code, message) => (code, message).fmt(f),
+            Error::Http(code, message, _headers) => (code, message).fmt(f),
             Error::Database(error) => error.fmt(f),
             Error::Internal(error) => error.fmt(f),
             Error::Stripe(error) => error.fmt(f),
@@ -115,7 +124,7 @@ impl std::fmt::Debug for Error {
 impl std::fmt::Display for Error {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         match self {
-            Error::Http(code, message) => write!(f, "{code}: {message}"),
+            Error::Http(code, message, _) => write!(f, "{code}: {message}"),
             Error::Database(error) => error.fmt(f),
             Error::Internal(error) => error.fmt(f),
             Error::Stripe(error) => error.fmt(f),
@@ -141,6 +150,7 @@ pub struct Config {
     pub live_kit_server: Option<String>,
     pub live_kit_key: Option<String>,
     pub live_kit_secret: Option<String>,
+    pub llm_api_secret: Option<String>,
     pub rust_log: Option<String>,
     pub log_json: Option<bool>,
     pub blob_store_url: Option<String>,

crates/collab/src/llm.rs 🔗

@@ -1,16 +1,122 @@
+mod token;
+
+use crate::{executor::Executor, Config, Error, Result};
+use anyhow::Context as _;
+use axum::{
+    body::Body,
+    http::{self, HeaderName, HeaderValue, Request, StatusCode},
+    middleware::{self, Next},
+    response::{IntoResponse, Response},
+    routing::post,
+    Extension, Json, Router,
+};
+use futures::StreamExt as _;
+use http_client::IsahcHttpClient;
+use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use std::sync::Arc;
 
-use crate::{executor::Executor, Config, Result};
+pub use token::*;
 
 pub struct LlmState {
     pub config: Config,
     pub executor: Executor,
+    pub http_client: IsahcHttpClient,
 }
 
 impl LlmState {
     pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
-        let this = Self { config, executor };
+        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
+        let http_client = IsahcHttpClient::builder()
+            .default_header("User-Agent", user_agent)
+            .build()
+            .context("failed to construct http client")?;
+
+        let this = Self {
+            config,
+            executor,
+            http_client,
+        };
 
         Ok(Arc::new(this))
     }
 }
+
+pub fn routes() -> Router<(), Body> {
+    Router::new()
+        .route("/completion", post(perform_completion))
+        .layer(middleware::from_fn(validate_api_token))
+}
+
+async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
+    let token = req
+        .headers()
+        .get(http::header::AUTHORIZATION)
+        .and_then(|header| header.to_str().ok())
+        .ok_or_else(|| {
+            Error::http(
+                StatusCode::BAD_REQUEST,
+                "missing authorization header".to_string(),
+            )
+        })?
+        .strip_prefix("Bearer ")
+        .ok_or_else(|| {
+            Error::http(
+                StatusCode::BAD_REQUEST,
+                "invalid authorization header".to_string(),
+            )
+        })?;
+
+    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
+    match LlmTokenClaims::validate(&token, &state.config) {
+        Ok(claims) => {
+            req.extensions_mut().insert(claims);
+            Ok::<_, Error>(next.run(req).await.into_response())
+        }
+        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
+            StatusCode::UNAUTHORIZED,
+            "unauthorized".to_string(),
+            [(
+                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
+                HeaderValue::from_static("true"),
+            )]
+            .into_iter()
+            .collect(),
+        )),
+        Err(_err) => Err(Error::http(
+            StatusCode::UNAUTHORIZED,
+            "unauthorized".to_string(),
+        )),
+    }
+}
+
+async fn perform_completion(
+    Extension(state): Extension<Arc<LlmState>>,
+    Extension(_claims): Extension<LlmTokenClaims>,
+    Json(params): Json<PerformCompletionParams>,
+) -> Result<impl IntoResponse> {
+    let api_key = state
+        .config
+        .anthropic_api_key
+        .as_ref()
+        .context("no Anthropic AI API key configured on the server")?;
+    let chunks = anthropic::stream_completion(
+        &state.http_client,
+        anthropic::ANTHROPIC_API_URL,
+        api_key,
+        serde_json::from_str(&params.provider_request.get())?,
+        None,
+    )
+    .await?;
+
+    let stream = chunks.map(|event| {
+        let mut buffer = Vec::new();
+        event.map(|chunk| {
+            buffer.clear();
+            serde_json::to_writer(&mut buffer, &chunk).unwrap();
+            buffer.push(b'\n');
+            buffer
+        })
+    });
+
+    Ok(Response::new(Body::wrap_stream(stream)))
+}

crates/collab/src/llm/token.rs 🔗

@@ -0,0 +1,75 @@
+use crate::{db::UserId, Config};
+use anyhow::{anyhow, Result};
+use chrono::Utc;
+use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
+use serde::{Deserialize, Serialize};
+use std::time::Duration;
+use thiserror::Error;
+
+#[derive(Clone, Debug, Default, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct LlmTokenClaims {
+    pub iat: u64,
+    pub exp: u64,
+    pub jti: String,
+    pub user_id: u64,
+    pub plan: rpc::proto::Plan,
+}
+
+const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
+
+impl LlmTokenClaims {
+    pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
+        let secret = config
+            .llm_api_secret
+            .as_ref()
+            .ok_or_else(|| anyhow!("no LLM API secret"))?;
+
+        let now = Utc::now();
+        let claims = Self {
+            iat: now.timestamp() as u64,
+            exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
+            jti: uuid::Uuid::new_v4().to_string(),
+            user_id: user_id.to_proto(),
+            plan,
+        };
+
+        Ok(jsonwebtoken::encode(
+            &Header::default(),
+            &claims,
+            &EncodingKey::from_secret(secret.as_ref()),
+        )?)
+    }
+
+    pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
+        let secret = config
+            .llm_api_secret
+            .as_ref()
+            .ok_or_else(|| anyhow!("no LLM API secret"))?;
+
+        match jsonwebtoken::decode::<Self>(
+            token,
+            &DecodingKey::from_secret(secret.as_ref()),
+            &Validation::default(),
+        ) {
+            Ok(token) => Ok(token.claims),
+            Err(e) => {
+                if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
+                    Err(ValidateLlmTokenError::Expired)
+                } else {
+                    Err(ValidateLlmTokenError::JwtError(e))
+                }
+            }
+        }
+    }
+}
+
+#[derive(Error, Debug)]
+pub enum ValidateLlmTokenError {
+    #[error("access token is expired")]
+    Expired,
+    #[error("access token validation error: {0}")]
+    JwtError(#[from] jsonwebtoken::errors::Error),
+    #[error("{0}")]
+    Other(#[from] anyhow::Error),
+}

crates/collab/src/main.rs 🔗

@@ -83,7 +83,9 @@ async fn main() -> Result<()> {
             if mode.is_llm() {
                 let state = LlmState::new(config.clone(), Executor::Production).await?;
 
-                app = app.layer(Extension(state.clone()));
+                app = app
+                    .merge(collab::llm::routes())
+                    .layer(Extension(state.clone()));
             }
 
             if mode.is_collab() || mode.is_api() {

crates/collab/src/rpc.rs 🔗

@@ -1,6 +1,7 @@
 mod connection_pool;
 
 use crate::api::CloudflareIpCountryHeader;
+use crate::llm::LlmTokenClaims;
 use crate::{
     auth,
     db::{
@@ -11,7 +12,7 @@ use crate::{
         ServerId, UpdatedChannelMessage, User, UserId,
     },
     executor::Executor,
-    AppState, Config, Error, RateLimit, RateLimiter, Result,
+    AppState, Config, Error, RateLimit, Result,
 };
 use anyhow::{anyhow, bail, Context as _};
 use async_tungstenite::tungstenite::{
@@ -149,10 +150,9 @@ struct Session {
     db: Arc<tokio::sync::Mutex<DbHandle>>,
     peer: Arc<Peer>,
     connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
-    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
+    app_state: Arc<AppState>,
     supermaven_client: Option<Arc<SupermavenAdminApi>>,
     http_client: Arc<IsahcHttpClient>,
-    rate_limiter: Arc<RateLimiter>,
     /// The GeoIP country code for the user.
     #[allow(unused)]
     geoip_country_code: Option<String>,
@@ -615,6 +615,7 @@ impl Server {
             .add_message_handler(user_message_handler(unfollow))
             .add_message_handler(user_message_handler(update_followers))
             .add_request_handler(user_handler(get_private_user_info))
+            .add_request_handler(user_handler(get_llm_api_token))
             .add_message_handler(user_message_handler(acknowledge_channel_message))
             .add_message_handler(user_message_handler(acknowledge_buffer_version))
             .add_request_handler(user_handler(get_supermaven_api_key))
@@ -1046,9 +1047,8 @@ impl Server {
                 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
                 peer: this.peer.clone(),
                 connection_pool: this.connection_pool.clone(),
-                live_kit_client: this.app_state.live_kit_client.clone(),
+                app_state: this.app_state.clone(),
                 http_client,
-                rate_limiter: this.app_state.rate_limiter.clone(),
                 geoip_country_code,
                 _executor: executor.clone(),
                 supermaven_client,
@@ -1559,7 +1559,7 @@ async fn create_room(
     let live_kit_room = nanoid::nanoid!(30);
 
     let live_kit_connection_info = util::maybe!(async {
-        let live_kit = session.live_kit_client.as_ref();
+        let live_kit = session.app_state.live_kit_client.as_ref();
         let live_kit = live_kit?;
         let user_id = session.user_id().to_string();
 
@@ -1630,25 +1630,26 @@ async fn join_room(
             .trace_err();
     }
 
-    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
-        if let Some(token) = live_kit
-            .room_token(
-                &joined_room.room.live_kit_room,
-                &session.user_id().to_string(),
-            )
-            .trace_err()
-        {
-            Some(proto::LiveKitConnectionInfo {
-                server_url: live_kit.url().into(),
-                token,
-                can_publish: true,
-            })
+    let live_kit_connection_info =
+        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
+            if let Some(token) = live_kit
+                .room_token(
+                    &joined_room.room.live_kit_room,
+                    &session.user_id().to_string(),
+                )
+                .trace_err()
+            {
+                Some(proto::LiveKitConnectionInfo {
+                    server_url: live_kit.url().into(),
+                    token,
+                    can_publish: true,
+                })
+            } else {
+                None
+            }
         } else {
             None
-        }
-    } else {
-        None
-    };
+        };
 
     response.send(proto::JoinRoomResponse {
         room: Some(joined_room.room),
@@ -1877,7 +1878,7 @@ async fn set_room_participant_role(
         (live_kit_room, can_publish)
     };
 
-    if let Some(live_kit) = session.live_kit_client.as_ref() {
+    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
         live_kit
             .update_participant(
                 live_kit_room.clone(),
@@ -4048,35 +4049,40 @@ async fn join_channel_internal(
             .join_channel(channel_id, session.user_id(), session.connection_id)
             .await?;
 
-        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
-            let (can_publish, token) = if role == ChannelRole::Guest {
-                (
-                    false,
-                    live_kit
-                        .guest_token(
-                            &joined_room.room.live_kit_room,
-                            &session.user_id().to_string(),
+        let live_kit_connection_info =
+            session
+                .app_state
+                .live_kit_client
+                .as_ref()
+                .and_then(|live_kit| {
+                    let (can_publish, token) = if role == ChannelRole::Guest {
+                        (
+                            false,
+                            live_kit
+                                .guest_token(
+                                    &joined_room.room.live_kit_room,
+                                    &session.user_id().to_string(),
+                                )
+                                .trace_err()?,
                         )
-                        .trace_err()?,
-                )
-            } else {
-                (
-                    true,
-                    live_kit
-                        .room_token(
-                            &joined_room.room.live_kit_room,
-                            &session.user_id().to_string(),
+                    } else {
+                        (
+                            true,
+                            live_kit
+                                .room_token(
+                                    &joined_room.room.live_kit_room,
+                                    &session.user_id().to_string(),
+                                )
+                                .trace_err()?,
                         )
-                        .trace_err()?,
-                )
-            };
+                    };
 
-            Some(LiveKitConnectionInfo {
-                server_url: live_kit.url().into(),
-                token,
-                can_publish,
-            })
-        });
+                    Some(LiveKitConnectionInfo {
+                        server_url: live_kit.url().into(),
+                        token,
+                        can_publish,
+                    })
+                });
 
         response.send(proto::JoinRoomResponse {
             room: Some(joined_room.room.clone()),
@@ -4610,6 +4616,7 @@ async fn complete_with_language_model(
     };
 
     session
+        .app_state
         .rate_limiter
         .check(&*rate_limit, session.user_id())
         .await?;
@@ -4655,6 +4662,7 @@ async fn stream_complete_with_language_model(
     };
 
     session
+        .app_state
         .rate_limiter
         .check(&*rate_limit, session.user_id())
         .await?;
@@ -4766,6 +4774,7 @@ async fn count_language_model_tokens(
     };
 
     session
+        .app_state
         .rate_limiter
         .check(&*rate_limit, session.user_id())
         .await?;
@@ -4885,6 +4894,7 @@ async fn compute_embeddings(
     };
 
     session
+        .app_state
         .rate_limiter
         .check(&*rate_limit, session.user_id())
         .await?;
@@ -5143,6 +5153,24 @@ async fn get_private_user_info(
     Ok(())
 }
 
+async fn get_llm_api_token(
+    _request: proto::GetLlmToken,
+    response: Response<proto::GetLlmToken>,
+    session: UserSession,
+) -> Result<()> {
+    if !session.is_staff() {
+        Err(anyhow!("permission denied"))?
+    }
+
+    let token = LlmTokenClaims::create(
+        session.user_id(),
+        session.current_plan().await?,
+        &session.app_state.config,
+    )?;
+    response.send(proto::GetLlmTokenResponse { token })?;
+    Ok(())
+}
+
 fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
     let message = match message {
         TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
@@ -5486,7 +5514,7 @@ async fn leave_room_for_session(session: &UserSession, connection_id: Connection
         update_user_contacts(contact_user_id, &session).await?;
     }
 
-    if let Some(live_kit) = session.live_kit_client.as_ref() {
+    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
         live_kit
             .remove_participant(live_kit_room.clone(), session.user_id().to_string())
             .await

crates/collab/src/tests/test_server.rs 🔗

@@ -651,6 +651,7 @@ impl TestServer {
                 live_kit_server: None,
                 live_kit_key: None,
                 live_kit_secret: None,
+                llm_api_secret: None,
                 rust_log: None,
                 log_json: None,
                 zed_environment: "test".into(),

crates/http_client/src/http_client.rs 🔗

@@ -175,6 +175,22 @@ impl HttpClientWithUrl {
             query,
         )?)
     }
+
+    /// Builds a Zed LLM URL using the given path.
+    pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
+        let base_url = self.base_url();
+        let base_api_url = match base_url.as_ref() {
+            "https://zed.dev" => "https://llm.zed.dev",
+            "https://staging.zed.dev" => "https://llm-staging.zed.dev",
+            "http://localhost:3000" => "http://localhost:8080",
+            other => other,
+        };
+
+        Ok(Url::parse_with_params(
+            &format!("{}{}", base_api_url, path),
+            query,
+        )?)
+    }
 }
 
 impl HttpClient for Arc<HttpClientWithUrl> {

crates/language_model/src/provider/cloud.rs 🔗

@@ -5,13 +5,20 @@ use crate::{
     LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
 };
 use anyhow::{anyhow, Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use collections::BTreeMap;
-use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt};
+use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
 use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
+use http_client::{HttpClient, Method};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
+use serde_json::value::RawValue;
 use settings::{Settings, SettingsStore};
+use smol::{
+    io::BufReader,
+    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
+};
 use std::{future, sync::Arc};
 use strum::IntoEnumIterator;
 use ui::prelude::*;
@@ -46,6 +53,7 @@ pub struct AvailableModel {
 
 pub struct CloudLanguageModelProvider {
     client: Arc<Client>,
+    llm_api_token: LlmApiToken,
     state: gpui::Model<State>,
     _maintain_client_status: Task<()>,
 }
@@ -104,6 +112,7 @@ impl CloudLanguageModelProvider {
         Self {
             client,
             state,
+            llm_api_token: LlmApiToken::default(),
             _maintain_client_status: maintain_client_status,
         }
     }
@@ -181,6 +190,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
                 Arc::new(CloudLanguageModel {
                     id: LanguageModelId::from(model.id().to_string()),
                     model,
+                    llm_api_token: self.llm_api_token.clone(),
                     client: self.client.clone(),
                     request_limiter: RateLimiter::new(4),
                 }) as Arc<dyn LanguageModel>
@@ -208,13 +218,27 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
     }
 }
 
+struct LlmServiceFeatureFlag;
+
+impl FeatureFlag for LlmServiceFeatureFlag {
+    const NAME: &'static str = "llm-service";
+
+    fn enabled_for_staff() -> bool {
+        false
+    }
+}
+
 pub struct CloudLanguageModel {
     id: LanguageModelId,
     model: CloudModel,
+    llm_api_token: LlmApiToken,
     client: Arc<Client>,
     request_limiter: RateLimiter,
 }
 
+#[derive(Clone, Default)]
+struct LlmApiToken(Arc<RwLock<Option<String>>>);
+
 impl LanguageModel for CloudLanguageModel {
     fn id(&self) -> LanguageModelId {
         self.id.clone()
@@ -279,25 +303,88 @@ impl LanguageModel for CloudLanguageModel {
     fn stream_completion(
         &self,
         request: LanguageModelRequest,
-        _: &AsyncAppContext,
+        cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
         match &self.model {
             CloudModel::Anthropic(model) => {
-                let client = self.client.clone();
                 let request = request.into_anthropic(model.id().into());
-                let future = self.request_limiter.stream(async move {
-                    let request = serde_json::to_string(&request)?;
-                    let stream = client
-                        .request_stream(proto::StreamCompleteWithLanguageModel {
-                            provider: proto::LanguageModelProvider::Anthropic as i32,
-                            request,
-                        })
-                        .await?;
-                    Ok(anthropic::extract_text_from_events(
-                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
-                    ))
-                });
-                async move { Ok(future.await?.boxed()) }.boxed()
+                let client = self.client.clone();
+
+                if cx
+                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+                    .unwrap_or(false)
+                {
+                    let http_client = self.client.http_client();
+                    let llm_api_token = self.llm_api_token.clone();
+                    let future = self.request_limiter.stream(async move {
+                        let request = serde_json::to_string(&request)?;
+                        let mut token = llm_api_token.acquire(&client).await?;
+                        let mut did_retry = false;
+
+                        let response = loop {
+                            let request = http_client::Request::builder()
+                                .method(Method::POST)
+                                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
+                                .header("Content-Type", "application/json")
+                                .header("Authorization", format!("Bearer {token}"))
+                                .body(
+                                    serde_json::to_string(&PerformCompletionParams {
+                                        provider_request: RawValue::from_string(request.clone())?,
+                                    })?
+                                    .into(),
+                                )?;
+                            let response = http_client.send(request).await?;
+                            if response.status().is_success() {
+                                break response;
+                            } else if !did_retry
+                                && response
+                                    .headers()
+                                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+                                    .is_some()
+                            {
+                                did_retry = true;
+                                token = llm_api_token.refresh(&client).await?;
+                            } else {
+                                break Err(anyhow!(
+                                    "cloud language model completion failed with status {}",
+                                    response.status()
+                                ))?;
+                            }
+                        };
+
+                        let body = BufReader::new(response.into_body());
+
+                        let stream =
+                            futures::stream::try_unfold(body, move |mut body| async move {
+                                let mut buffer = String::new();
+                                match body.read_line(&mut buffer).await {
+                                    Ok(0) => Ok(None),
+                                    Ok(_) => {
+                                        let event: anthropic::Event =
+                                            serde_json::from_str(&buffer)?;
+                                        Ok(Some((event, body)))
+                                    }
+                                    Err(e) => Err(e.into()),
+                                }
+                            });
+
+                        Ok(anthropic::extract_text_from_events(stream))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                } else {
+                    let future = self.request_limiter.stream(async move {
+                        let request = serde_json::to_string(&request)?;
+                        let stream = client
+                            .request_stream(proto::StreamCompleteWithLanguageModel {
+                                provider: proto::LanguageModelProvider::Anthropic as i32,
+                                request,
+                            })
+                            .await?
+                            .map(|event| Ok(serde_json::from_str(&event?.event)?));
+                        Ok(anthropic::extract_text_from_events(stream))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                }
             }
             CloudModel::OpenAi(model) => {
                 let client = self.client.clone();
@@ -417,6 +504,30 @@ impl LanguageModel for CloudLanguageModel {
     }
 }
 
+impl LlmApiToken {
+    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
+        let lock = self.0.upgradable_read().await;
+        if let Some(token) = lock.as_ref() {
+            Ok(token.to_string())
+        } else {
+            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
+        }
+    }
+
+    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
+        Self::fetch(self.0.write().await, &client).await
+    }
+
+    async fn fetch<'a>(
+        mut lock: RwLockWriteGuard<'a, Option<String>>,
+        client: &Arc<Client>,
+    ) -> Result<String> {
+        let response = client.request(proto::GetLlmToken {}).await?;
+        *lock = Some(response.token.clone());
+        Ok(response.token.clone())
+    }
+}
+
 struct ConfigurationView {
     state: gpui::Model<State>,
 }

crates/proto/proto/zed.proto 🔗

@@ -126,7 +126,7 @@ message Envelope {
         Unfollow unfollow = 101;
         GetPrivateUserInfo get_private_user_info = 102;
         GetPrivateUserInfoResponse get_private_user_info_response = 103;
-        UpdateUserPlan update_user_plan = 234; // current max
+        UpdateUserPlan update_user_plan = 234;
         UpdateDiffBase update_diff_base = 104;
 
         OnTypeFormatting on_type_formatting = 105;
@@ -270,6 +270,9 @@ message Envelope {
 
         AddWorktree add_worktree = 222;
         AddWorktreeResponse add_worktree_response = 223;
+
+        GetLlmToken get_llm_token = 235;
+        GetLlmTokenResponse get_llm_token_response = 236; // current max
     }
 
     reserved 158 to 161;
@@ -2425,6 +2428,12 @@ message SynchronizeContextsResponse {
     repeated ContextVersion contexts = 1;
 }
 
+message GetLlmToken {}
+
+message GetLlmTokenResponse {
+    string token = 1;
+}
+
 // Remote FS
 
 message AddWorktree {

crates/proto/src/proto.rs 🔗

@@ -259,6 +259,8 @@ messages!(
     (GetTypeDefinitionResponse, Background),
     (GetImplementation, Background),
     (GetImplementationResponse, Background),
+    (GetLlmToken, Background),
+    (GetLlmTokenResponse, Background),
     (GetUsers, Foreground),
     (Hello, Foreground),
     (IncomingCall, Foreground),
@@ -438,6 +440,7 @@ request_messages!(
     (GetImplementation, GetImplementationResponse),
     (GetDocumentHighlights, GetDocumentHighlightsResponse),
     (GetHover, GetHoverResponse),
+    (GetLlmToken, GetLlmTokenResponse),
     (GetNotifications, GetNotificationsResponse),
     (GetPrivateUserInfo, GetPrivateUserInfoResponse),
     (GetProjectSymbols, GetProjectSymbolsResponse),

crates/rpc/src/llm.rs 🔗

@@ -0,0 +1,8 @@
+use serde::{Deserialize, Serialize};
+
+pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
+
+#[derive(Serialize, Deserialize)]
+pub struct PerformCompletionParams {
+    pub provider_request: Box<serde_json::value::RawValue>,
+}

crates/rpc/src/rpc.rs 🔗

@@ -1,12 +1,14 @@
 pub mod auth;
 mod conn;
 mod extension;
+mod llm;
 mod notification;
 mod peer;
 pub mod proto;
 
 pub use conn::Connection;
 pub use extension::*;
+pub use llm::*;
 pub use notification::*;
 pub use peer::*;
 pub use proto::{error::*, Receipt, TypedEnvelope};