Detailed changes
@@ -2466,6 +2466,7 @@ dependencies = [
"hex",
"http_client",
"indoc",
+ "jsonwebtoken",
"language",
"language_model",
"live_kit_client",
@@ -2507,6 +2508,7 @@ dependencies = [
"telemetry_events",
"text",
"theme",
+ "thiserror",
"time",
"tokio",
"toml 0.8.16",
@@ -15,6 +15,7 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
BLOB_STORE_REGION = "the-region"
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
SEED_PATH = "crates/collab/seed.default.json"
+LLM_API_SECRET = "llm-secret"
# CLICKHOUSE_URL = ""
# CLICKHOUSE_USER = "default"
@@ -37,6 +37,7 @@ futures.workspace = true
google_ai.workspace = true
hex.workspace = true
http_client.workspace = true
+jsonwebtoken.workspace = true
live_kit_server.workspace = true
log.workspace = true
nanoid.workspace = true
@@ -61,6 +62,7 @@ subtle.workspace = true
rustc-demangle.workspace = true
telemetry_events.workspace = true
text.workspace = true
+thiserror.workspace = true
time.workspace = true
tokio.workspace = true
toml.workspace = true
@@ -81,14 +81,14 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
- Error::Http(
+ Error::http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?
.strip_prefix("token ")
.ok_or_else(|| {
- Error::Http(
+ Error::http(
StatusCode::BAD_REQUEST,
"invalid authorization header".to_string(),
)
@@ -97,7 +97,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
let state = req.extensions().get::<Arc<AppState>>().unwrap();
if token != state.config.api_token {
- Err(Error::Http(
+ Err(Error::http(
StatusCode::UNAUTHORIZED,
"invalid authorization token".to_string(),
))?
@@ -185,13 +185,13 @@ async fn create_access_token(
if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
impersonated_user_id = Some(impersonated_user.id);
} else {
- return Err(Error::Http(
+ return Err(Error::http(
StatusCode::UNPROCESSABLE_ENTITY,
format!("user {impersonate} does not exist"),
));
}
} else {
- return Err(Error::Http(
+ return Err(Error::http(
StatusCode::UNAUTHORIZED,
"you do not have permission to impersonate other users".to_string(),
));
@@ -120,7 +120,7 @@ async fn create_billing_subscription(
.zip(app.config.stripe_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
- Err(Error::Http(
+ Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
@@ -201,7 +201,7 @@ async fn manage_billing_subscription(
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
- Err(Error::Http(
+ Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
@@ -206,14 +206,14 @@ pub async fn post_hang(
body: Bytes,
) -> Result<()> {
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
- return Err(Error::Http(
+ return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
};
if checksum != expected {
- return Err(Error::Http(
+ return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid checksum".into(),
))?;
@@ -265,25 +265,25 @@ pub async fn post_panic(
body: Bytes,
) -> Result<()> {
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
- return Err(Error::Http(
+ return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
};
if checksum != expected {
- return Err(Error::Http(
+ return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid checksum".into(),
))?;
}
let report: telemetry_events::PanicRequest = serde_json::from_slice(&body)
- .map_err(|_| Error::Http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
+ .map_err(|_| Error::http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
let panic = report.panic;
if panic.os_name == "Linux" && panic.os_version == Some("1.0.0".to_string()) {
- return Err(Error::Http(
+ return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid os version".into(),
))?;
@@ -362,14 +362,14 @@ pub async fn post_events(
body: Bytes,
) -> Result<()> {
let Some(clickhouse_client) = app.clickhouse_client.clone() else {
- Err(Error::Http(
+ Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
- return Err(Error::Http(
+ return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
@@ -385,7 +385,7 @@ pub async fn post_events(
let mut to_upload = ToUpload::default();
let Some(last_event) = request_body.events.last() else {
- return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
+ return Err(Error::http(StatusCode::BAD_REQUEST, "no events".into()))?;
};
let country_code = country_code_header.map(|h| h.to_string());
@@ -185,7 +185,7 @@ async fn download_extension(
.clone()
.zip(app.config.blob_store_bucket.clone())
else {
- Err(Error::Http(
+ Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
@@ -202,7 +202,7 @@ async fn download_extension(
.await?;
if !version_exists {
- Err(Error::Http(
+ Err(Error::http(
StatusCode::NOT_FOUND,
"unknown extension version".into(),
))?;
@@ -33,7 +33,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
- Error::Http(
+ Error::http(
StatusCode::UNAUTHORIZED,
"missing authorization header".to_string(),
)
@@ -45,14 +45,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
let first = auth_header.next().unwrap_or("");
if first == "dev-server-token" {
let dev_server_token = auth_header.next().ok_or_else(|| {
- Error::Http(
+ Error::http(
StatusCode::BAD_REQUEST,
"missing dev-server-token token in authorization header".to_string(),
)
})?;
let dev_server = verify_dev_server_token(dev_server_token, &state.db)
.await
- .map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
+ .map_err(|e| Error::http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
req.extensions_mut()
.insert(Principal::DevServer(dev_server));
@@ -60,14 +60,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
}
let user_id = UserId(first.parse().map_err(|_| {
- Error::Http(
+ Error::http(
StatusCode::BAD_REQUEST,
"missing user id in authorization header".to_string(),
)
})?);
let access_token = auth_header.next().ok_or_else(|| {
- Error::Http(
+ Error::http(
StatusCode::BAD_REQUEST,
"missing access token in authorization header".to_string(),
)
@@ -111,7 +111,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
}
}
- Err(Error::Http(
+ Err(Error::http(
StatusCode::UNAUTHORIZED,
"invalid credentials".to_string(),
))
@@ -13,7 +13,10 @@ mod tests;
use anyhow::anyhow;
use aws_config::{BehaviorVersion, Region};
-use axum::{http::StatusCode, response::IntoResponse};
+use axum::{
+ http::{HeaderMap, StatusCode},
+ response::IntoResponse,
+};
use db::{ChannelId, Database};
use executor::Executor;
pub use rate_limiter::*;
@@ -24,7 +27,7 @@ use util::ResultExt;
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
- Http(StatusCode, String),
+ Http(StatusCode, String, HeaderMap),
Database(sea_orm::error::DbErr),
Internal(anyhow::Error),
Stripe(stripe::StripeError),
@@ -66,12 +69,18 @@ impl From<serde_json::Error> for Error {
}
}
+impl Error {
+ fn http(code: StatusCode, message: String) -> Self {
+ Self::Http(code, message, HeaderMap::default())
+ }
+}
+
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
match self {
- Error::Http(code, message) => {
+ Error::Http(code, message, headers) => {
log::error!("HTTP error {}: {}", code, &message);
- (code, message).into_response()
+ (code, headers, message).into_response()
}
Error::Database(error) => {
log::error!(
@@ -104,7 +113,7 @@ impl IntoResponse for Error {
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
- Error::Http(code, message) => (code, message).fmt(f),
+ Error::Http(code, message, _headers) => (code, message).fmt(f),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
@@ -115,7 +124,7 @@ impl std::fmt::Debug for Error {
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
- Error::Http(code, message) => write!(f, "{code}: {message}"),
+ Error::Http(code, message, _) => write!(f, "{code}: {message}"),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
@@ -141,6 +150,7 @@ pub struct Config {
pub live_kit_server: Option<String>,
pub live_kit_key: Option<String>,
pub live_kit_secret: Option<String>,
+ pub llm_api_secret: Option<String>,
pub rust_log: Option<String>,
pub log_json: Option<bool>,
pub blob_store_url: Option<String>,
@@ -1,16 +1,122 @@
+mod token;
+
+use crate::{executor::Executor, Config, Error, Result};
+use anyhow::Context as _;
+use axum::{
+ body::Body,
+ http::{self, HeaderName, HeaderValue, Request, StatusCode},
+ middleware::{self, Next},
+ response::{IntoResponse, Response},
+ routing::post,
+ Extension, Json, Router,
+};
+use futures::StreamExt as _;
+use http_client::IsahcHttpClient;
+use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc;
-use crate::{executor::Executor, Config, Result};
+pub use token::*;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
+ pub http_client: IsahcHttpClient,
}
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
- let this = Self { config, executor };
+ let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
+ let http_client = IsahcHttpClient::builder()
+ .default_header("User-Agent", user_agent)
+ .build()
+ .context("failed to construct http client")?;
+
+ let this = Self {
+ config,
+ executor,
+ http_client,
+ };
Ok(Arc::new(this))
}
}
+
+pub fn routes() -> Router<(), Body> {
+ Router::new()
+ .route("/completion", post(perform_completion))
+ .layer(middleware::from_fn(validate_api_token))
+}
+
+async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
+ let token = req
+ .headers()
+ .get(http::header::AUTHORIZATION)
+ .and_then(|header| header.to_str().ok())
+ .ok_or_else(|| {
+ Error::http(
+ StatusCode::BAD_REQUEST,
+ "missing authorization header".to_string(),
+ )
+ })?
+ .strip_prefix("Bearer ")
+ .ok_or_else(|| {
+ Error::http(
+ StatusCode::BAD_REQUEST,
+ "invalid authorization header".to_string(),
+ )
+ })?;
+
+ let state = req.extensions().get::<Arc<LlmState>>().unwrap();
+ match LlmTokenClaims::validate(&token, &state.config) {
+ Ok(claims) => {
+ req.extensions_mut().insert(claims);
+ Ok::<_, Error>(next.run(req).await.into_response())
+ }
+ Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
+ StatusCode::UNAUTHORIZED,
+ "unauthorized".to_string(),
+ [(
+ HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
+ HeaderValue::from_static("true"),
+ )]
+ .into_iter()
+ .collect(),
+ )),
+ Err(_err) => Err(Error::http(
+ StatusCode::UNAUTHORIZED,
+ "unauthorized".to_string(),
+ )),
+ }
+}
+
+async fn perform_completion(
+ Extension(state): Extension<Arc<LlmState>>,
+ Extension(_claims): Extension<LlmTokenClaims>,
+ Json(params): Json<PerformCompletionParams>,
+) -> Result<impl IntoResponse> {
+ let api_key = state
+ .config
+ .anthropic_api_key
+ .as_ref()
+ .context("no Anthropic AI API key configured on the server")?;
+ let chunks = anthropic::stream_completion(
+ &state.http_client,
+ anthropic::ANTHROPIC_API_URL,
+ api_key,
+ serde_json::from_str(¶ms.provider_request.get())?,
+ None,
+ )
+ .await?;
+
+ let stream = chunks.map(|event| {
+ let mut buffer = Vec::new();
+ event.map(|chunk| {
+ buffer.clear();
+ serde_json::to_writer(&mut buffer, &chunk).unwrap();
+ buffer.push(b'\n');
+ buffer
+ })
+ });
+
+ Ok(Response::new(Body::wrap_stream(stream)))
+}
@@ -0,0 +1,75 @@
+use crate::{db::UserId, Config};
+use anyhow::{anyhow, Result};
+use chrono::Utc;
+use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
+use serde::{Deserialize, Serialize};
+use std::time::Duration;
+use thiserror::Error;
+
+#[derive(Clone, Debug, Default, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct LlmTokenClaims {
+ pub iat: u64,
+ pub exp: u64,
+ pub jti: String,
+ pub user_id: u64,
+ pub plan: rpc::proto::Plan,
+}
+
+const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
+
+impl LlmTokenClaims {
+ pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
+ let secret = config
+ .llm_api_secret
+ .as_ref()
+ .ok_or_else(|| anyhow!("no LLM API secret"))?;
+
+ let now = Utc::now();
+ let claims = Self {
+ iat: now.timestamp() as u64,
+ exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
+ jti: uuid::Uuid::new_v4().to_string(),
+ user_id: user_id.to_proto(),
+ plan,
+ };
+
+ Ok(jsonwebtoken::encode(
+ &Header::default(),
+ &claims,
+ &EncodingKey::from_secret(secret.as_ref()),
+ )?)
+ }
+
+ pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
+ let secret = config
+ .llm_api_secret
+ .as_ref()
+ .ok_or_else(|| anyhow!("no LLM API secret"))?;
+
+ match jsonwebtoken::decode::<Self>(
+ token,
+ &DecodingKey::from_secret(secret.as_ref()),
+ &Validation::default(),
+ ) {
+ Ok(token) => Ok(token.claims),
+ Err(e) => {
+ if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
+ Err(ValidateLlmTokenError::Expired)
+ } else {
+ Err(ValidateLlmTokenError::JwtError(e))
+ }
+ }
+ }
+ }
+}
+
+#[derive(Error, Debug)]
+pub enum ValidateLlmTokenError {
+ #[error("access token is expired")]
+ Expired,
+ #[error("access token validation error: {0}")]
+ JwtError(#[from] jsonwebtoken::errors::Error),
+ #[error("{0}")]
+ Other(#[from] anyhow::Error),
+}
@@ -83,7 +83,9 @@ async fn main() -> Result<()> {
if mode.is_llm() {
let state = LlmState::new(config.clone(), Executor::Production).await?;
- app = app.layer(Extension(state.clone()));
+ app = app
+ .merge(collab::llm::routes())
+ .layer(Extension(state.clone()));
}
if mode.is_collab() || mode.is_api() {
@@ -1,6 +1,7 @@
mod connection_pool;
use crate::api::CloudflareIpCountryHeader;
+use crate::llm::LlmTokenClaims;
use crate::{
auth,
db::{
@@ -11,7 +12,7 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId,
},
executor::Executor,
- AppState, Config, Error, RateLimit, RateLimiter, Result,
+ AppState, Config, Error, RateLimit, Result,
};
use anyhow::{anyhow, bail, Context as _};
use async_tungstenite::tungstenite::{
@@ -149,10 +150,9 @@ struct Session {
db: Arc<tokio::sync::Mutex<DbHandle>>,
peer: Arc<Peer>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
- live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
+ app_state: Arc<AppState>,
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<IsahcHttpClient>,
- rate_limiter: Arc<RateLimiter>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
@@ -615,6 +615,7 @@ impl Server {
.add_message_handler(user_message_handler(unfollow))
.add_message_handler(user_message_handler(update_followers))
.add_request_handler(user_handler(get_private_user_info))
+ .add_request_handler(user_handler(get_llm_api_token))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
@@ -1046,9 +1047,8 @@ impl Server {
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
- live_kit_client: this.app_state.live_kit_client.clone(),
+ app_state: this.app_state.clone(),
http_client,
- rate_limiter: this.app_state.rate_limiter.clone(),
geoip_country_code,
_executor: executor.clone(),
supermaven_client,
@@ -1559,7 +1559,7 @@ async fn create_room(
let live_kit_room = nanoid::nanoid!(30);
let live_kit_connection_info = util::maybe!(async {
- let live_kit = session.live_kit_client.as_ref();
+ let live_kit = session.app_state.live_kit_client.as_ref();
let live_kit = live_kit?;
let user_id = session.user_id().to_string();
@@ -1630,25 +1630,26 @@ async fn join_room(
.trace_err();
}
- let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
- if let Some(token) = live_kit
- .room_token(
- &joined_room.room.live_kit_room,
- &session.user_id().to_string(),
- )
- .trace_err()
- {
- Some(proto::LiveKitConnectionInfo {
- server_url: live_kit.url().into(),
- token,
- can_publish: true,
- })
+ let live_kit_connection_info =
+ if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
+ if let Some(token) = live_kit
+ .room_token(
+ &joined_room.room.live_kit_room,
+ &session.user_id().to_string(),
+ )
+ .trace_err()
+ {
+ Some(proto::LiveKitConnectionInfo {
+ server_url: live_kit.url().into(),
+ token,
+ can_publish: true,
+ })
+ } else {
+ None
+ }
} else {
None
- }
- } else {
- None
- };
+ };
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room),
@@ -1877,7 +1878,7 @@ async fn set_room_participant_role(
(live_kit_room, can_publish)
};
- if let Some(live_kit) = session.live_kit_client.as_ref() {
+ if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
live_kit
.update_participant(
live_kit_room.clone(),
@@ -4048,35 +4049,40 @@ async fn join_channel_internal(
.join_channel(channel_id, session.user_id(), session.connection_id)
.await?;
- let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
- let (can_publish, token) = if role == ChannelRole::Guest {
- (
- false,
- live_kit
- .guest_token(
- &joined_room.room.live_kit_room,
- &session.user_id().to_string(),
+ let live_kit_connection_info =
+ session
+ .app_state
+ .live_kit_client
+ .as_ref()
+ .and_then(|live_kit| {
+ let (can_publish, token) = if role == ChannelRole::Guest {
+ (
+ false,
+ live_kit
+ .guest_token(
+ &joined_room.room.live_kit_room,
+ &session.user_id().to_string(),
+ )
+ .trace_err()?,
)
- .trace_err()?,
- )
- } else {
- (
- true,
- live_kit
- .room_token(
- &joined_room.room.live_kit_room,
- &session.user_id().to_string(),
+ } else {
+ (
+ true,
+ live_kit
+ .room_token(
+ &joined_room.room.live_kit_room,
+ &session.user_id().to_string(),
+ )
+ .trace_err()?,
)
- .trace_err()?,
- )
- };
+ };
- Some(LiveKitConnectionInfo {
- server_url: live_kit.url().into(),
- token,
- can_publish,
- })
- });
+ Some(LiveKitConnectionInfo {
+ server_url: live_kit.url().into(),
+ token,
+ can_publish,
+ })
+ });
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room.clone()),
@@ -4610,6 +4616,7 @@ async fn complete_with_language_model(
};
session
+ .app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@@ -4655,6 +4662,7 @@ async fn stream_complete_with_language_model(
};
session
+ .app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@@ -4766,6 +4774,7 @@ async fn count_language_model_tokens(
};
session
+ .app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@@ -4885,6 +4894,7 @@ async fn compute_embeddings(
};
session
+ .app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@@ -5143,6 +5153,24 @@ async fn get_private_user_info(
Ok(())
}
+async fn get_llm_api_token(
+ _request: proto::GetLlmToken,
+ response: Response<proto::GetLlmToken>,
+ session: UserSession,
+) -> Result<()> {
+ if !session.is_staff() {
+ Err(anyhow!("permission denied"))?
+ }
+
+ let token = LlmTokenClaims::create(
+ session.user_id(),
+ session.current_plan().await?,
+ &session.app_state.config,
+ )?;
+ response.send(proto::GetLlmTokenResponse { token })?;
+ Ok(())
+}
+
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
let message = match message {
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
@@ -5486,7 +5514,7 @@ async fn leave_room_for_session(session: &UserSession, connection_id: Connection
update_user_contacts(contact_user_id, &session).await?;
}
- if let Some(live_kit) = session.live_kit_client.as_ref() {
+ if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
live_kit
.remove_participant(live_kit_room.clone(), session.user_id().to_string())
.await
@@ -651,6 +651,7 @@ impl TestServer {
live_kit_server: None,
live_kit_key: None,
live_kit_secret: None,
+ llm_api_secret: None,
rust_log: None,
log_json: None,
zed_environment: "test".into(),
@@ -175,6 +175,22 @@ impl HttpClientWithUrl {
query,
)?)
}
+
+ /// Builds a Zed LLM URL using the given path.
+ pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
+ let base_url = self.base_url();
+ let base_api_url = match base_url.as_ref() {
+ "https://zed.dev" => "https://llm.zed.dev",
+ "https://staging.zed.dev" => "https://llm-staging.zed.dev",
+ "http://localhost:3000" => "http://localhost:8080",
+ other => other,
+ };
+
+ Ok(Url::parse_with_params(
+ &format!("{}{}", base_api_url, path),
+ query,
+ )?)
+ }
}
impl HttpClient for Arc<HttpClientWithUrl> {
@@ -5,13 +5,20 @@ use crate::{
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anyhow::{anyhow, Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
-use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt};
+use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
+use http_client::{HttpClient, Method};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
+use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
+use smol::{
+ io::BufReader,
+ lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
+};
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@@ -46,6 +53,7 @@ pub struct AvailableModel {
pub struct CloudLanguageModelProvider {
client: Arc<Client>,
+ llm_api_token: LlmApiToken,
state: gpui::Model<State>,
_maintain_client_status: Task<()>,
}
@@ -104,6 +112,7 @@ impl CloudLanguageModelProvider {
Self {
client,
state,
+ llm_api_token: LlmApiToken::default(),
_maintain_client_status: maintain_client_status,
}
}
@@ -181,6 +190,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
+ llm_api_token: self.llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
@@ -208,13 +218,27 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
}
+struct LlmServiceFeatureFlag;
+
+impl FeatureFlag for LlmServiceFeatureFlag {
+ const NAME: &'static str = "llm-service";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
+}
+
pub struct CloudLanguageModel {
id: LanguageModelId,
model: CloudModel,
+ llm_api_token: LlmApiToken,
client: Arc<Client>,
request_limiter: RateLimiter,
}
+#[derive(Clone, Default)]
+struct LlmApiToken(Arc<RwLock<Option<String>>>);
+
impl LanguageModel for CloudLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -279,25 +303,88 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
- _: &AsyncAppContext,
+ cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
- let client = self.client.clone();
let request = request.into_anthropic(model.id().into());
- let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let stream = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::Anthropic as i32,
- request,
- })
- .await?;
- Ok(anthropic::extract_text_from_events(
- stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
+ let client = self.client.clone();
+
+ if cx
+ .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+ .unwrap_or(false)
+ {
+ let http_client = self.client.http_client();
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let request = serde_json::to_string(&request)?;
+ let mut token = llm_api_token.acquire(&client).await?;
+ let mut did_retry = false;
+
+ let response = loop {
+ let request = http_client::Request::builder()
+ .method(Method::POST)
+ .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {token}"))
+ .body(
+ serde_json::to_string(&PerformCompletionParams {
+ provider_request: RawValue::from_string(request.clone())?,
+ })?
+ .into(),
+ )?;
+ let response = http_client.send(request).await?;
+ if response.status().is_success() {
+ break response;
+ } else if !did_retry
+ && response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ {
+ did_retry = true;
+ token = llm_api_token.refresh(&client).await?;
+ } else {
+ break Err(anyhow!(
+ "cloud language model completion failed with status {}",
+ response.status()
+ ))?;
+ }
+ };
+
+ let body = BufReader::new(response.into_body());
+
+ let stream =
+ futures::stream::try_unfold(body, move |mut body| async move {
+ let mut buffer = String::new();
+ match body.read_line(&mut buffer).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: anthropic::Event =
+ serde_json::from_str(&buffer)?;
+ Ok(Some((event, body)))
+ }
+ Err(e) => Err(e.into()),
+ }
+ });
+
+ Ok(anthropic::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ } else {
+ let future = self.request_limiter.stream(async move {
+ let request = serde_json::to_string(&request)?;
+ let stream = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::Anthropic as i32,
+ request,
+ })
+ .await?
+ .map(|event| Ok(serde_json::from_str(&event?.event)?));
+ Ok(anthropic::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
@@ -417,6 +504,30 @@ impl LanguageModel for CloudLanguageModel {
}
}
+impl LlmApiToken {
+ async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
+ let lock = self.0.upgradable_read().await;
+ if let Some(token) = lock.as_ref() {
+ Ok(token.to_string())
+ } else {
+ Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
+ }
+ }
+
+ async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
+ Self::fetch(self.0.write().await, &client).await
+ }
+
+ async fn fetch<'a>(
+ mut lock: RwLockWriteGuard<'a, Option<String>>,
+ client: &Arc<Client>,
+ ) -> Result<String> {
+ let response = client.request(proto::GetLlmToken {}).await?;
+ *lock = Some(response.token.clone());
+ Ok(response.token.clone())
+ }
+}
+
struct ConfigurationView {
state: gpui::Model<State>,
}
@@ -126,7 +126,7 @@ message Envelope {
Unfollow unfollow = 101;
GetPrivateUserInfo get_private_user_info = 102;
GetPrivateUserInfoResponse get_private_user_info_response = 103;
- UpdateUserPlan update_user_plan = 234; // current max
+ UpdateUserPlan update_user_plan = 234;
UpdateDiffBase update_diff_base = 104;
OnTypeFormatting on_type_formatting = 105;
@@ -270,6 +270,9 @@ message Envelope {
AddWorktree add_worktree = 222;
AddWorktreeResponse add_worktree_response = 223;
+
+ GetLlmToken get_llm_token = 235;
+ GetLlmTokenResponse get_llm_token_response = 236; // current max
}
reserved 158 to 161;
@@ -2425,6 +2428,12 @@ message SynchronizeContextsResponse {
repeated ContextVersion contexts = 1;
}
+message GetLlmToken {}
+
+message GetLlmTokenResponse {
+ string token = 1;
+}
+
// Remote FS
message AddWorktree {
@@ -259,6 +259,8 @@ messages!(
(GetTypeDefinitionResponse, Background),
(GetImplementation, Background),
(GetImplementationResponse, Background),
+ (GetLlmToken, Background),
+ (GetLlmTokenResponse, Background),
(GetUsers, Foreground),
(Hello, Foreground),
(IncomingCall, Foreground),
@@ -438,6 +440,7 @@ request_messages!(
(GetImplementation, GetImplementationResponse),
(GetDocumentHighlights, GetDocumentHighlightsResponse),
(GetHover, GetHoverResponse),
+ (GetLlmToken, GetLlmTokenResponse),
(GetNotifications, GetNotificationsResponse),
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse),
@@ -0,0 +1,8 @@
+use serde::{Deserialize, Serialize};
+
+pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
+
+#[derive(Serialize, Deserialize)]
+pub struct PerformCompletionParams {
+ pub provider_request: Box<serde_json::value::RawValue>,
+}
@@ -1,12 +1,14 @@
pub mod auth;
mod conn;
mod extension;
+mod llm;
mod notification;
mod peer;
pub mod proto;
pub use conn::Connection;
pub use extension::*;
+pub use llm::*;
pub use notification::*;
pub use peer::*;
pub use proto::{error::*, Receipt, TypedEnvelope};