From 0d891bd3e5fef696e920e72f48049383113ca055 Mon Sep 17 00:00:00 2001 From: Dominic Burkart Date: Mon, 15 Dec 2025 18:59:40 +0100 Subject: [PATCH] Enable Zeta edit predictions with custom URL without authentication (#43236) Enables using Zeta edit predictions with a custom `ZED_PREDICT_EDITS_URL` without requiring authentication to Zed servers. This is useful for: - Development and testing workflows - Self-hosted Zeta instances - Custom AI model endpoints Prior context on this usage of `ZED_PREDICT_EDITS_URL`: https://github.com/zed-industries/zed/pull/30418 Release Notes: - Improved self-hosted zeta UX. Users no longer have to log into Zed to use custom or self-hosted zeta backends. --------- Co-authored-by: Agus Zubiaga --- crates/edit_prediction/src/edit_prediction.rs | 97 ++++++---- .../src/edit_prediction_tests.rs | 168 ++++++++++++++++++ crates/edit_prediction/src/zeta1.rs | 26 +-- 3 files changed, 249 insertions(+), 42 deletions(-) diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index ff15d04cc1c0f8e7bbeb7f2a29b520a8ec32097a..f5ea7590fcba97ee916af985824e21cdf4ea725f 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -19,6 +19,7 @@ use futures::{ select_biased, }; use gpui::BackgroundExecutor; +use gpui::http_client::Url; use gpui::{ App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, http_client::{self, AsyncBody, Method}, @@ -127,15 +128,6 @@ static EDIT_PREDICTIONS_MODEL_ID: LazyLock = LazyLock::new(|| { } .to_string() }); -static PREDICT_EDITS_URL: LazyLock> = LazyLock::new(|| { - env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| { - if *USE_OLLAMA { - Some("http://localhost:11434/v1/chat/completions".into()) - } else { - None - } - }) -}); pub struct Zeta2FeatureFlag; @@ -170,6 +162,7 @@ pub struct EditPredictionStore { reject_predictions_tx: mpsc::UnboundedSender, shown_predictions: VecDeque, rated_predictions: HashSet, + custom_predict_edits_url: Option>, } #[derive(Copy, Clone, Default, PartialEq, Eq)] @@ -568,6 +561,20 @@ impl EditPredictionStore { reject_predictions_tx: reject_tx, rated_predictions: Default::default(), shown_predictions: Default::default(), + custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") { + Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into), + Err(_) => { + if *USE_OLLAMA { + Some( + Url::parse("http://localhost:11434/v1/chat/completions") + .unwrap() + .into(), + ) + } else { + None + } + } + }, }; this.configure_context_retrieval(cx); @@ -586,6 +593,11 @@ impl EditPredictionStore { this } + #[cfg(test)] + pub fn set_custom_predict_edits_url(&mut self, url: Url) { + self.custom_predict_edits_url = Some(url.into()); + } + pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) { self.edit_prediction_model = model; } @@ -1015,8 +1027,13 @@ impl EditPredictionStore { } fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { + let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok(); match self.edit_prediction_model { - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() { + return; + } + } EditPredictionModel::Sweep | EditPredictionModel::Mercury => return, } @@ -1036,12 +1053,15 @@ impl EditPredictionStore { let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); cx.spawn(async move |this, cx| { - let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") { - http_client::Url::parse(&predict_edits_url)? + let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url { + (http_client::Url::parse(&accept_edits_url)?, false) } else { - client - .http_client() - .build_zed_llm_url("/predict_edits/accept", &[])? + ( + client + .http_client() + .build_zed_llm_url("/predict_edits/accept", &[])?, + true, + ) }; let response = cx @@ -1058,6 +1078,7 @@ impl EditPredictionStore { client, llm_token, app_version, + require_auth, )) .await; @@ -1116,6 +1137,7 @@ impl EditPredictionStore { client.clone(), llm_token.clone(), app_version.clone(), + true, ) .await; @@ -1161,7 +1183,11 @@ impl EditPredictionStore { was_shown: bool, ) { match self.edit_prediction_model { - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + if self.custom_predict_edits_url.is_some() { + return; + } + } EditPredictionModel::Sweep | EditPredictionModel::Mercury => return, } @@ -1671,13 +1697,9 @@ impl EditPredictionStore { #[cfg(feature = "cli-support")] eval_cache: Option>, #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind, ) -> Result<(open_ai::Response, Option)> { - let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { - http_client::Url::parse(&predict_edits_url)? - } else { - client - .http_client() - .build_zed_llm_url("/predict_edits/raw", &[])? - }; + let url = client + .http_client() + .build_zed_llm_url("/predict_edits/raw", &[])?; #[cfg(feature = "cli-support")] let cache_key = if let Some(cache) = eval_cache { @@ -1710,6 +1732,7 @@ impl EditPredictionStore { client, llm_token, app_version, + true, ) .await?; @@ -1770,23 +1793,34 @@ impl EditPredictionStore { client: Arc, llm_token: LlmApiToken, app_version: Version, + require_auth: bool, ) -> Result<(Res, Option)> where Res: DeserializeOwned, { let http_client = client.http_client(); - let mut token = llm_token.acquire(&client).await?; + + let mut token = if require_auth { + Some(llm_token.acquire(&client).await?) + } else { + llm_token.acquire(&client).await.ok() + }; let mut did_retry = false; loop { let request_builder = http_client::Request::builder().method(Method::POST); - let request = build( - request_builder - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)) - .header(ZED_VERSION_HEADER_NAME, app_version.to_string()), - )?; + let mut request_builder = request_builder + .header("Content-Type", "application/json") + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()); + + // Only add Authorization header if we have a token + if let Some(ref token_value) = token { + request_builder = + request_builder.header("Authorization", format!("Bearer {}", token_value)); + } + + let request = build(request_builder)?; let mut response = http_client.send(request).await?; @@ -1810,13 +1844,14 @@ impl EditPredictionStore { response.body_mut().read_to_end(&mut body).await?; return Ok((serde_json::from_slice(&body)?, usage)); } else if !did_retry + && token.is_some() && response .headers() .get(EXPIRED_LLM_TOKEN_HEADER_NAME) .is_some() { did_retry = true; - token = llm_token.refresh(&client).await?; + token = Some(llm_token.refresh(&client).await?); } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 5067aa0050d7a0831ca7668d17188fa6d41637b9..eee3f1f79e93b60ee3ea7c80bd987af22d613833 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1914,6 +1914,174 @@ fn from_completion_edits( .collect() } +#[gpui::test] +async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + serde_json::json!({ + "main.rs": "fn main() {\n \n}\n" + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + let http_client = FakeHttpClient::create(|_req| async move { + Ok(gpui::http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()) + }); + + let client = + cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); + cx.update(|cx| { + language_model::RefreshLlmTokenListener::register(client.clone(), cx); + }); + + let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx)); + + let buffer = project + .update(cx, |project, cx| { + let path = project + .find_project_path(path!("/project/main.rs"), cx) + .unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4))); + ep_store.update(cx, |ep_store, cx| { + ep_store.register_buffer(&buffer, &project, cx) + }); + cx.background_executor.run_until_parked(); + + let completion_task = ep_store.update(cx, |ep_store, cx| { + ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1); + ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx) + }); + + let result = completion_task.await; + assert!( + result.is_err(), + "Without authentication and without custom URL, prediction should fail" + ); +} + +#[gpui::test] +async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + serde_json::json!({ + "main.rs": "fn main() {\n \n}\n" + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let predict_called_clone = predict_called.clone(); + + let http_client = FakeHttpClient::create({ + move |req| { + let uri = req.uri().path().to_string(); + let predict_called = predict_called_clone.clone(); + async move { + if uri.contains("predict") { + predict_called.store(true, std::sync::atomic::Ordering::SeqCst); + Ok(gpui::http_client::Response::builder() + .body( + serde_json::to_string(&open_ai::Response { + id: "test-123".to_string(), + object: "chat.completion".to_string(), + created: 0, + model: "test".to_string(), + usage: open_ai::Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + choices: vec![open_ai::Choice { + index: 0, + message: open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain( + indoc! {" + ```main.rs + <|start_of_file|> + <|editable_region_start|> + fn main() { + println!(\"Hello, world!\"); + } + <|editable_region_end|> + ``` + "} + .to_string(), + )), + tool_calls: vec![], + }, + finish_reason: Some("stop".to_string()), + }], + }) + .unwrap() + .into(), + ) + .unwrap()) + } else { + Ok(gpui::http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()) + } + } + } + }); + + let client = + cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); + cx.update(|cx| { + language_model::RefreshLlmTokenListener::register(client.clone(), cx); + }); + + let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx)); + + let buffer = project + .update(cx, |project, cx| { + let path = project + .find_project_path(path!("/project/main.rs"), cx) + .unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4))); + ep_store.update(cx, |ep_store, cx| { + ep_store.register_buffer(&buffer, &project, cx) + }); + cx.background_executor.run_until_parked(); + + let completion_task = ep_store.update(cx, |ep_store, cx| { + ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap()); + ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1); + ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx) + }); + + let _ = completion_task.await; + + assert!( + predict_called.load(std::sync::atomic::Ordering::SeqCst), + "With custom URL, predict endpoint should be called even without authentication" + ); +} + #[ctor::ctor] fn init_logger() { zlog::init_test(); diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index ed531749cb39d10d71d18947990dd1972f23a986..01c26573307e66cd6ca3bf8ab748ba8d082ea688 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -78,6 +78,19 @@ pub(crate) fn request_prediction_with_zeta1( cx, ); + let (uri, require_auth) = match &store.custom_predict_edits_url { + Some(custom_url) => (custom_url.clone(), false), + None => { + match client + .http_client() + .build_zed_llm_url("/predict_edits/v2", &[]) + { + Ok(url) => (url.into(), true), + Err(err) => return Task::ready(Err(err)), + } + } + }; + cx.spawn(async move |this, cx| { let GatherContextOutput { mut body, @@ -102,25 +115,16 @@ pub(crate) fn request_prediction_with_zeta1( body.input_excerpt ); - let http_client = client.http_client(); - let response = EditPredictionStore::send_api_request::( |request| { - let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { - predict_edits_url - } else { - http_client - .build_zed_llm_url("/predict_edits/v2", &[])? - .as_str() - .into() - }; Ok(request - .uri(uri) + .uri(uri.as_str()) .body(serde_json::to_string(&body)?.into())?) }, client, llm_token, app_version, + require_auth, ) .await;