diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index ac8bdd462a9c4754ef42a6afa41f1bef8b5bbe6a..d5b3af394ea96b4269c13ab3259dd5a99ec6cf49 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -12,6 +12,9 @@ use uuid::Uuid; /// The name of the header used to indicate which version of Zed the client is running. pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version"; +/// The name of the header used to indicate which edit prediction experiment should be used. +pub const PREFERRED_EXPERIMENT_HEADER_NAME: &str = "x-zed-preferred-experiment"; + /// The name of the header used to indicate when a request failed due to an /// expired LLM token. /// diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 07ec5366db8c2d5f84c53f8ccfe44f84e393df6c..5efa626f20ba33f48f16427372487f46011a6e80 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -9,7 +9,8 @@ use cloud_llm_client::predict_edits_v3::{ use cloud_llm_client::{ EditPredictionRejectReason, EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME, + PREFERRED_EXPERIMENT_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, + ZED_VERSION_HEADER_NAME, }; use collections::{HashMap, HashSet}; use copilot::{Copilot, Reinstall, SignIn, SignOut}; @@ -2586,6 +2587,7 @@ impl EditPredictionStore { pub(crate) async fn send_v3_request( input: ZetaPromptInput, + preferred_experiment: Option, client: Arc, llm_token: LlmApiToken, organization_id: Option, @@ -2604,11 +2606,16 @@ impl EditPredictionStore { Self::send_api_request( |builder| { - let req = builder + let builder = builder .uri(url.as_ref()) .header("Content-Encoding", "zstd") - .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref()) - .body(compressed.clone().into()); + .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref()); + let builder = if let Some(preferred_experiment) = preferred_experiment.as_deref() { + builder.header(PREFERRED_EXPERIMENT_HEADER_NAME, preferred_experiment) + } else { + builder + }; + let req = builder.body(compressed.clone().into()); Ok(req?) }, client, diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 54dabf93f8da290d76c13222ae5a110e80d0b388..7a0c5f57992e1996be715c8220aaeb4e398f5601 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -2481,7 +2481,6 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { excerpt_start_row: None, excerpt_ranges: Default::default(), syntax_ranges: None, - experiment: None, in_open_source_repo: false, can_collect_data: false, repo_url: None, diff --git a/crates/edit_prediction/src/fim.rs b/crates/edit_prediction/src/fim.rs index 46586eb3796026c764ff8659734c564e368681b9..44a5b2541fbd4c65b15f929ee9fb5a5bd7fe929b 100644 --- a/crates/edit_prediction/src/fim.rs +++ b/crates/edit_prediction/src/fim.rs @@ -87,7 +87,6 @@ pub fn request_prediction( cursor_excerpt, excerpt_ranges: Default::default(), syntax_ranges: None, - experiment: None, in_open_source_repo: false, can_collect_data: false, repo_url: None, diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index 155fd449904687081da0a9eae3d4731863f02254..8e9dfa6cee34d2b746db0578ebefa73fb6223ff5 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -109,7 +109,6 @@ impl Mercury { - excerpt_offset_range.start, cursor_path: full_path.clone(), cursor_excerpt, - experiment: None, excerpt_start_row: Some(excerpt_point_range.start.row), excerpt_ranges, syntax_ranges: Some(syntax_ranges), diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index ef2bf2deafb7309f4871a921061ab114fa280e2f..eb45832d4cccad52ae950c6d6ff685092270aaa0 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -155,7 +155,6 @@ mod tests { excerpt_start_row: None, excerpt_ranges: Default::default(), syntax_ranges: None, - experiment: None, in_open_source_repo: false, can_collect_data: false, repo_url: None, diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 1674de5c0a71cf9a63d2e1fc55a58645b9a9314a..60157781eb769553d8f4c266df3e05d75d335ec3 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -120,7 +120,6 @@ pub fn request_prediction_with_zeta( diagnostic_search_range, excerpt_path, cursor_offset, - preferred_experiment, is_open_source, can_collect_data, repo_url, @@ -274,6 +273,7 @@ pub fn request_prediction_with_zeta( // Use V3 endpoint - server handles model/version selection and suffix stripping let (response, usage) = EditPredictionStore::send_v3_request( prompt_input.clone(), + preferred_experiment.clone(), client, llm_token, organization_id, @@ -535,7 +535,6 @@ pub fn zeta2_prompt_input( diagnostic_search_range: Range, excerpt_path: Arc, cursor_offset: usize, - preferred_experiment: Option, is_open_source: bool, can_collect_data: bool, repo_url: Option, @@ -567,7 +566,6 @@ pub fn zeta2_prompt_input( active_buffer_diagnostics, excerpt_ranges, syntax_ranges: Some(syntax_ranges), - experiment: preferred_experiment, in_open_source_repo: is_open_source, can_collect_data, repo_url, diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index d9138482767b2c49bb21bf7ed7c349ec6c9af3ff..f370f013ff061ef572c61605eebcfbf59158c4aa 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -108,7 +108,6 @@ pub async fn run_load_project( syntax_ranges: Some(syntax_ranges), in_open_source_repo: false, can_collect_data: false, - experiment: None, repo_url: None, }, language_name, diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 49b86404a8ad49c27e29bb2b887fb3fc8171c35c..3fa12a7a789b196b0219fadaec24f38b42a5b259 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -51,9 +51,6 @@ pub struct ZetaPromptInput { /// instead of `excerpt_ranges`. #[serde(default, skip_serializing_if = "Option::is_none")] pub syntax_ranges: Option>>, - /// The name of the edit prediction model experiment to use. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub experiment: Option, #[serde(default)] pub in_open_source_repo: bool, #[serde(default)] @@ -4503,7 +4500,6 @@ mod tests { ..Default::default() }, syntax_ranges: None, - experiment: None, in_open_source_repo: false, can_collect_data: false, repo_url: None, @@ -4534,7 +4530,6 @@ mod tests { ..Default::default() }, syntax_ranges: None, - experiment: None, in_open_source_repo: false, can_collect_data: false, repo_url: None, @@ -5163,7 +5158,6 @@ mod tests { ..Default::default() }, syntax_ranges: None, - experiment: None, in_open_source_repo: false, can_collect_data: false, repo_url: None, @@ -5228,7 +5222,6 @@ mod tests { ..Default::default() }, syntax_ranges: None, - experiment: None, in_open_source_repo: false, can_collect_data: false, repo_url: None, @@ -5288,7 +5281,6 @@ mod tests { ..Default::default() }, syntax_ranges: None, - experiment: None, in_open_source_repo: false, can_collect_data: false, repo_url: None,