Enable Zeta edit predictions with custom URL without authentication (#43236)

Dominic Burkart and Agus Zubiaga created

Enables using Zeta edit predictions with a custom
`ZED_PREDICT_EDITS_URL` without requiring authentication to Zed servers.
This is useful for:

- Development and testing workflows
- Self-hosted Zeta instances
- Custom AI model endpoints

Prior context on this usage of `ZED_PREDICT_EDITS_URL`:
https://github.com/zed-industries/zed/pull/30418

Release Notes:
- Improved self-hosted zeta UX. Users no longer have to log into Zed to
use custom or self-hosted zeta backends.

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/edit_prediction/src/edit_prediction.rs       |  97 +++++--
crates/edit_prediction/src/edit_prediction_tests.rs | 168 +++++++++++++++
crates/edit_prediction/src/zeta1.rs                 |  26 +
3 files changed, 249 insertions(+), 42 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -19,6 +19,7 @@ use futures::{
     select_biased,
 };
 use gpui::BackgroundExecutor;
+use gpui::http_client::Url;
 use gpui::{
     App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
     http_client::{self, AsyncBody, Method},
@@ -127,15 +128,6 @@ static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
     }
     .to_string()
 });
-static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
-    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
-        if *USE_OLLAMA {
-            Some("http://localhost:11434/v1/chat/completions".into())
-        } else {
-            None
-        }
-    })
-});
 
 pub struct Zeta2FeatureFlag;
 
@@ -170,6 +162,7 @@ pub struct EditPredictionStore {
     reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
     shown_predictions: VecDeque<EditPrediction>,
     rated_predictions: HashSet<EditPredictionId>,
+    custom_predict_edits_url: Option<Arc<Url>>,
 }
 
 #[derive(Copy, Clone, Default, PartialEq, Eq)]
@@ -568,6 +561,20 @@ impl EditPredictionStore {
             reject_predictions_tx: reject_tx,
             rated_predictions: Default::default(),
             shown_predictions: Default::default(),
+            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
+                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
+                Err(_) => {
+                    if *USE_OLLAMA {
+                        Some(
+                            Url::parse("http://localhost:11434/v1/chat/completions")
+                                .unwrap()
+                                .into(),
+                        )
+                    } else {
+                        None
+                    }
+                }
+            },
         };
 
         this.configure_context_retrieval(cx);
@@ -586,6 +593,11 @@ impl EditPredictionStore {
         this
     }
 
+    #[cfg(test)]
+    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
+        self.custom_predict_edits_url = Some(url.into());
+    }
+
     pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
         self.edit_prediction_model = model;
     }
@@ -1015,8 +1027,13 @@ impl EditPredictionStore {
     }
 
     fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+        let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
         match self.edit_prediction_model {
-            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
+            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+                if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
+                    return;
+                }
+            }
             EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
         }
 
@@ -1036,12 +1053,15 @@ impl EditPredictionStore {
         let llm_token = self.llm_token.clone();
         let app_version = AppVersion::global(cx);
         cx.spawn(async move |this, cx| {
-            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
-                http_client::Url::parse(&predict_edits_url)?
+            let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
+                (http_client::Url::parse(&accept_edits_url)?, false)
             } else {
-                client
-                    .http_client()
-                    .build_zed_llm_url("/predict_edits/accept", &[])?
+                (
+                    client
+                        .http_client()
+                        .build_zed_llm_url("/predict_edits/accept", &[])?,
+                    true,
+                )
             };
 
             let response = cx
@@ -1058,6 +1078,7 @@ impl EditPredictionStore {
                     client,
                     llm_token,
                     app_version,
+                    require_auth,
                 ))
                 .await;
 
@@ -1116,6 +1137,7 @@ impl EditPredictionStore {
                 client.clone(),
                 llm_token.clone(),
                 app_version.clone(),
+                true,
             )
             .await;
 
@@ -1161,7 +1183,11 @@ impl EditPredictionStore {
         was_shown: bool,
     ) {
         match self.edit_prediction_model {
-            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
+            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+                if self.custom_predict_edits_url.is_some() {
+                    return;
+                }
+            }
             EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
         }
 
@@ -1671,13 +1697,9 @@ impl EditPredictionStore {
         #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
         #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
     ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
-        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
-            http_client::Url::parse(&predict_edits_url)?
-        } else {
-            client
-                .http_client()
-                .build_zed_llm_url("/predict_edits/raw", &[])?
-        };
+        let url = client
+            .http_client()
+            .build_zed_llm_url("/predict_edits/raw", &[])?;
 
         #[cfg(feature = "cli-support")]
         let cache_key = if let Some(cache) = eval_cache {
@@ -1710,6 +1732,7 @@ impl EditPredictionStore {
             client,
             llm_token,
             app_version,
+            true,
         )
         .await?;
 
@@ -1770,23 +1793,34 @@ impl EditPredictionStore {
         client: Arc<Client>,
         llm_token: LlmApiToken,
         app_version: Version,
+        require_auth: bool,
     ) -> Result<(Res, Option<EditPredictionUsage>)>
     where
         Res: DeserializeOwned,
     {
         let http_client = client.http_client();
-        let mut token = llm_token.acquire(&client).await?;
+
+        let mut token = if require_auth {
+            Some(llm_token.acquire(&client).await?)
+        } else {
+            llm_token.acquire(&client).await.ok()
+        };
         let mut did_retry = false;
 
         loop {
             let request_builder = http_client::Request::builder().method(Method::POST);
 
-            let request = build(
-                request_builder
-                    .header("Content-Type", "application/json")
-                    .header("Authorization", format!("Bearer {}", token))
-                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
-            )?;
+            let mut request_builder = request_builder
+                .header("Content-Type", "application/json")
+                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
+
+            // Only add Authorization header if we have a token
+            if let Some(ref token_value) = token {
+                request_builder =
+                    request_builder.header("Authorization", format!("Bearer {}", token_value));
+            }
+
+            let request = build(request_builder)?;
 
             let mut response = http_client.send(request).await?;
 
@@ -1810,13 +1844,14 @@ impl EditPredictionStore {
                 response.body_mut().read_to_end(&mut body).await?;
                 return Ok((serde_json::from_slice(&body)?, usage));
             } else if !did_retry
+                && token.is_some()
                 && response
                     .headers()
                     .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
                     .is_some()
             {
                 did_retry = true;
-                token = llm_token.refresh(&client).await?;
+                token = Some(llm_token.refresh(&client).await?);
             } else {
                 let mut body = String::new();
                 response.body_mut().read_to_string(&mut body).await?;

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1914,6 +1914,174 @@ fn from_completion_edits(
         .collect()
 }
 
+#[gpui::test]
+async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
+    init_test(cx);
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/project",
+        serde_json::json!({
+            "main.rs": "fn main() {\n    \n}\n"
+        }),
+    )
+    .await;
+
+    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+
+    let http_client = FakeHttpClient::create(|_req| async move {
+        Ok(gpui::http_client::Response::builder()
+            .status(401)
+            .body("Unauthorized".into())
+            .unwrap())
+    });
+
+    let client =
+        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+    cx.update(|cx| {
+        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
+    });
+
+    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
+
+    let buffer = project
+        .update(cx, |project, cx| {
+            let path = project
+                .find_project_path(path!("/project/main.rs"), cx)
+                .unwrap();
+            project.open_buffer(path, cx)
+        })
+        .await
+        .unwrap();
+
+    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.register_buffer(&buffer, &project, cx)
+    });
+    cx.background_executor.run_until_parked();
+
+    let completion_task = ep_store.update(cx, |ep_store, cx| {
+        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
+        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
+    });
+
+    let result = completion_task.await;
+    assert!(
+        result.is_err(),
+        "Without authentication and without custom URL, prediction should fail"
+    );
+}
+
+#[gpui::test]
+async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
+    init_test(cx);
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/project",
+        serde_json::json!({
+            "main.rs": "fn main() {\n    \n}\n"
+        }),
+    )
+    .await;
+
+    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+
+    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
+    let predict_called_clone = predict_called.clone();
+
+    let http_client = FakeHttpClient::create({
+        move |req| {
+            let uri = req.uri().path().to_string();
+            let predict_called = predict_called_clone.clone();
+            async move {
+                if uri.contains("predict") {
+                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
+                    Ok(gpui::http_client::Response::builder()
+                        .body(
+                            serde_json::to_string(&open_ai::Response {
+                                id: "test-123".to_string(),
+                                object: "chat.completion".to_string(),
+                                created: 0,
+                                model: "test".to_string(),
+                                usage: open_ai::Usage {
+                                    prompt_tokens: 0,
+                                    completion_tokens: 0,
+                                    total_tokens: 0,
+                                },
+                                choices: vec![open_ai::Choice {
+                                    index: 0,
+                                    message: open_ai::RequestMessage::Assistant {
+                                        content: Some(open_ai::MessageContent::Plain(
+                                            indoc! {"
+                                                ```main.rs
+                                                <|start_of_file|>
+                                                <|editable_region_start|>
+                                                fn main() {
+                                                    println!(\"Hello, world!\");
+                                                }
+                                                <|editable_region_end|>
+                                                ```
+                                            "}
+                                            .to_string(),
+                                        )),
+                                        tool_calls: vec![],
+                                    },
+                                    finish_reason: Some("stop".to_string()),
+                                }],
+                            })
+                            .unwrap()
+                            .into(),
+                        )
+                        .unwrap())
+                } else {
+                    Ok(gpui::http_client::Response::builder()
+                        .status(401)
+                        .body("Unauthorized".into())
+                        .unwrap())
+                }
+            }
+        }
+    });
+
+    let client =
+        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+    cx.update(|cx| {
+        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
+    });
+
+    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
+
+    let buffer = project
+        .update(cx, |project, cx| {
+            let path = project
+                .find_project_path(path!("/project/main.rs"), cx)
+                .unwrap();
+            project.open_buffer(path, cx)
+        })
+        .await
+        .unwrap();
+
+    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.register_buffer(&buffer, &project, cx)
+    });
+    cx.background_executor.run_until_parked();
+
+    let completion_task = ep_store.update(cx, |ep_store, cx| {
+        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
+        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
+        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
+    });
+
+    let _ = completion_task.await;
+
+    assert!(
+        predict_called.load(std::sync::atomic::Ordering::SeqCst),
+        "With custom URL, predict endpoint should be called even without authentication"
+    );
+}
+
 #[ctor::ctor]
 fn init_logger() {
     zlog::init_test();

crates/edit_prediction/src/zeta1.rs 🔗

@@ -78,6 +78,19 @@ pub(crate) fn request_prediction_with_zeta1(
         cx,
     );
 
+    let (uri, require_auth) = match &store.custom_predict_edits_url {
+        Some(custom_url) => (custom_url.clone(), false),
+        None => {
+            match client
+                .http_client()
+                .build_zed_llm_url("/predict_edits/v2", &[])
+            {
+                Ok(url) => (url.into(), true),
+                Err(err) => return Task::ready(Err(err)),
+            }
+        }
+    };
+
     cx.spawn(async move |this, cx| {
         let GatherContextOutput {
             mut body,
@@ -102,25 +115,16 @@ pub(crate) fn request_prediction_with_zeta1(
             body.input_excerpt
         );
 
-        let http_client = client.http_client();
-
         let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
             |request| {
-                let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
-                    predict_edits_url
-                } else {
-                    http_client
-                        .build_zed_llm_url("/predict_edits/v2", &[])?
-                        .as_str()
-                        .into()
-                };
                 Ok(request
-                    .uri(uri)
+                    .uri(uri.as_str())
                     .body(serde_json::to_string(&body)?.into())?)
             },
             client,
             llm_token,
             app_version,
+            require_auth,
         )
         .await;