From 880b2e512b31158a43637a1bfe47cb4492884a5f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 2 Mar 2026 22:24:55 -0800 Subject: [PATCH] Remove Zeta1 logic, allow choosing zeta2 experiment (#50560) Release Notes: - N/A --- crates/edit_prediction/src/edit_prediction.rs | 105 +++++-- .../src/edit_prediction_tests.rs | 17 +- crates/edit_prediction/src/fim.rs | 6 +- crates/edit_prediction/src/mercury.rs | 34 ++- crates/edit_prediction/src/prediction.rs | 5 +- crates/edit_prediction/src/sweep_ai.rs | 14 +- crates/edit_prediction/src/zeta.rs | 260 ++++++++---------- .../edit_prediction_cli/src/format_prompt.rs | 6 +- .../edit_prediction_cli/src/load_project.rs | 7 +- crates/edit_prediction_cli/src/predict.rs | 4 +- .../edit_prediction_cli/src/pull_examples.rs | 9 +- .../src/reversal_tracking.rs | 14 +- .../src/edit_prediction_button.rs | 50 ++++ crates/settings_content/src/language.rs | 5 +- .../zed/src/zed/edit_prediction_registry.rs | 44 ++- crates/zeta_prompt/src/zeta_prompt.rs | 85 +++--- 16 files changed, 379 insertions(+), 286 deletions(-) diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index b25ccee37970f2dc0dfa8bcbec4b1cdcdfe6d506..e6e3a9abdf83deb785cd56d358b065973682b8cc 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -137,10 +137,13 @@ pub struct EditPredictionStore { user_store: Entity, llm_token: LlmApiToken, _llm_token_subscription: Subscription, + _fetch_experiments_task: Task<()>, projects: HashMap, update_required: bool, edit_prediction_model: EditPredictionModel, zeta2_raw_config: Option, + preferred_experiment: Option, + available_experiments: Vec, pub sweep_ai: SweepAi, pub mercury: Mercury, data_collection_choice: DataCollectionChoice, @@ -154,8 +157,7 @@ pub struct EditPredictionStore { #[derive(Copy, Clone, PartialEq, Eq)] pub enum EditPredictionModel { - Zeta1, - Zeta2, + Zeta, Fim { format: EditPredictionPromptFormat }, Sweep, Mercury, @@ -699,11 +701,23 @@ impl EditPredictionStore { }) .detach(); + let mut current_user = user_store.read(cx).watch_current_user(); + let fetch_experiments_task = cx.spawn(async move |this, cx| { + while current_user.borrow().is_none() { + current_user.next().await; + } + this.update(cx, |this, cx| { + this.refresh_available_experiments(cx); + }) + .log_err(); + }); + let this = Self { projects: HashMap::default(), client, user_store, llm_token, + _fetch_experiments_task: fetch_experiments_task, _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, |this, _listener, _event, cx| { @@ -717,8 +731,10 @@ impl EditPredictionStore { }, ), update_required: false, - edit_prediction_model: EditPredictionModel::Zeta2, + edit_prediction_model: EditPredictionModel::Zeta, zeta2_raw_config: Self::zeta2_raw_config_from_env(), + preferred_experiment: None, + available_experiments: Vec::new(), sweep_ai: SweepAi::new(cx), mercury: Mercury::new(cx), @@ -753,6 +769,60 @@ impl EditPredictionStore { self.zeta2_raw_config.as_ref() } + pub fn preferred_experiment(&self) -> Option<&str> { + self.preferred_experiment.as_deref() + } + + pub fn set_preferred_experiment(&mut self, experiment: Option) { + self.preferred_experiment = experiment; + } + + pub fn available_experiments(&self) -> &[String] { + &self.available_experiments + } + + pub fn refresh_available_experiments(&mut self, cx: &mut Context) { + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let app_version = AppVersion::global(cx); + cx.spawn(async move |this, cx| { + let experiments = cx + .background_spawn(async move { + let http_client = client.http_client(); + let token = llm_token.acquire(&client).await?; + let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?; + let request = http_client::Request::builder() + .method(Method::GET) + .uri(url.as_ref()) + .header("Authorization", format!("Bearer {}", token)) + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) + .body(Default::default())?; + let mut response = http_client.send(request).await?; + if response.status().is_success() { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + let experiments: Vec = serde_json::from_slice(&body)?; + Ok(experiments) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!( + "Failed to fetch experiments: {:?}\nBody: {}", + response.status(), + body + ); + } + }) + .await?; + this.update(cx, |this, cx| { + this.available_experiments = experiments; + cx.notify(); + })?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet { use ui::IconName; match self.edit_prediction_model { @@ -766,7 +836,7 @@ impl EditPredictionStore { EditPredictionModel::Mercury => { edit_prediction_types::EditPredictionIconSet::new(IconName::Inception) } - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + EditPredictionModel::Zeta => { edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict) .with_disabled(IconName::ZedPredictDisabled) .with_up(IconName::ZedPredictUp) @@ -895,10 +965,7 @@ impl EditPredictionStore { } pub fn usage(&self, cx: &App) -> Option { - if matches!( - self.edit_prediction_model, - EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1 - ) { + if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) { self.user_store.read(cx).edit_prediction_usage() } else { None @@ -1347,7 +1414,7 @@ impl EditPredictionStore { cx, ); } - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + EditPredictionModel::Zeta => { let is_cloud = !matches!( all_language_settings(None, cx).edit_predictions.provider, EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi @@ -1608,7 +1675,7 @@ impl EditPredictionStore { cx: &App, ) { match self.edit_prediction_model { - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + EditPredictionModel::Zeta => { let is_cloud = !matches!( all_language_settings(None, cx).edit_predictions.provider, EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi @@ -2103,10 +2170,7 @@ impl EditPredictionStore { let can_collect_data = !cfg!(test) && is_open_source && self.is_data_collection_enabled(cx) - && matches!( - self.edit_prediction_model, - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 - ); + && matches!(self.edit_prediction_model, EditPredictionModel::Zeta); let inputs = EditPredictionModelInput { project: project.clone(), @@ -2138,18 +2202,7 @@ impl EditPredictionStore { } let task = match self.edit_prediction_model { - EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta( - self, - inputs, - Some(zeta_prompt::EditPredictionModelKind::Zeta1), - cx, - ), - EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta( - self, - inputs, - Some(zeta_prompt::EditPredictionModelKind::Zeta2), - cx, - ), + EditPredictionModel::Zeta => zeta::request_prediction_with_zeta(self, inputs, cx), EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx), EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx), EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index beeb855c7b84bae53ea2f8f8bd6a117403e77db1..cc3bb84808981fd1430f9e71aa796e590cc78169 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1704,12 +1704,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { // Generate a model response that would apply the given diff to the active file. fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response { - let editable_range = request - .input - .excerpt_ranges - .as_ref() - .map(|r| zeta_prompt::excerpt_range_for_format(Default::default(), r).1) - .unwrap_or(request.input.editable_range_in_excerpt.clone()); + let editable_range = + zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1; let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string(); let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap(); @@ -1846,11 +1842,10 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { related_files: Default::default(), cursor_path: Path::new("").into(), cursor_excerpt: "".into(), - editable_range_in_excerpt: 0..0, cursor_offset_in_excerpt: 0, excerpt_start_row: None, - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: Default::default(), + experiment: None, in_open_source_repo: false, can_collect_data: false, }, @@ -2183,7 +2178,7 @@ async fn make_test_ep_store( let ep_store = cx.new(|cx| { let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx); - ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1); + ep_store.set_edit_prediction_model(EditPredictionModel::Zeta); let worktrees = project.read(cx).worktrees(cx).collect::>(); for worktree in worktrees { @@ -2282,7 +2277,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut cx.background_executor.run_until_parked(); let completion_task = ep_store.update(cx, |ep_store, cx| { - ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1); + ep_store.set_edit_prediction_model(EditPredictionModel::Zeta); ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx) }); diff --git a/crates/edit_prediction/src/fim.rs b/crates/edit_prediction/src/fim.rs index dda008133d3726f5e7ba32ec05c770878d16585f..66f2e58a3b01b4fbf49b11864db4daec6b4dc1c2 100644 --- a/crates/edit_prediction/src/fim.rs +++ b/crates/edit_prediction/src/fim.rs @@ -72,16 +72,14 @@ pub fn request_prediction( events, related_files: Vec::new(), cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start, - editable_range_in_excerpt: cursor_offset - excerpt_offset_range.start - ..cursor_offset - excerpt_offset_range.start, cursor_path: full_path.clone(), excerpt_start_row: Some(excerpt_range.start.row), cursor_excerpt: snapshot .text_for_range(excerpt_range) .collect::() .into(), - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: Default::default(), + experiment: None, in_open_source_repo: false, can_collect_data: false, }; diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index f3adba55e620e77ffd7bb12b0e950fd4d3f011fc..bf9b43d528db1717f54143e4805e41aefc81f64a 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -16,7 +16,7 @@ use release_channel::AppVersion; use serde::Serialize; use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant}; -use zeta_prompt::ZetaPromptInput; +use zeta_prompt::{ExcerptRanges, ZetaPromptInput}; const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"; const MAX_REWRITE_TOKENS: usize = 150; @@ -83,6 +83,12 @@ impl Mercury { let editable_offset_range = editable_range.to_offset(&snapshot); + let editable_range_in_excerpt = (editable_offset_range.start + - context_offset_range.start) + ..(editable_offset_range.end - context_offset_range.start); + let context_range_in_excerpt = + 0..(context_offset_range.end - context_offset_range.start); + let inputs = zeta_prompt::ZetaPromptInput { events, related_files, @@ -93,12 +99,17 @@ impl Mercury { .text_for_range(context_range) .collect::() .into(), - editable_range_in_excerpt: (editable_offset_range.start - - context_offset_range.start) - ..(editable_offset_range.end - context_offset_range.start), + experiment: None, excerpt_start_row: Some(context_start_row), - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: ExcerptRanges { + editable_150: editable_range_in_excerpt.clone(), + editable_180: editable_range_in_excerpt.clone(), + editable_350: editable_range_in_excerpt.clone(), + editable_150_context_350: context_range_in_excerpt.clone(), + editable_180_context_350: context_range_in_excerpt.clone(), + editable_350_context_150: context_range_in_excerpt.clone(), + ..Default::default() + }, in_open_source_repo: false, can_collect_data: false, }; @@ -273,19 +284,18 @@ fn build_prompt(inputs: &ZetaPromptInput) -> String { prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref()); prompt.push('\n'); - prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]); + let editable_range = &inputs.excerpt_ranges.editable_350; + prompt.push_str(&inputs.cursor_excerpt[0..editable_range.start]); push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| { prompt.push_str( - &inputs.cursor_excerpt - [inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt], + &inputs.cursor_excerpt[editable_range.start..inputs.cursor_offset_in_excerpt], ); prompt.push_str(CURSOR_TAG); prompt.push_str( - &inputs.cursor_excerpt - [inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end], + &inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..editable_range.end], ); }); - prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]); + prompt.push_str(&inputs.cursor_excerpt[editable_range.end..]); }, ); diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 9c17f29fe29bc711f6750cf6fe24586067bfc619..0dd33c03a95d77ec680d47d96daa8e6a44f51b62 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -160,10 +160,9 @@ mod tests { cursor_path: Path::new("path.txt").into(), cursor_offset_in_excerpt: 0, cursor_excerpt: "".into(), - editable_range_in_excerpt: 0..0, excerpt_start_row: None, - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: Default::default(), + experiment: None, in_open_source_repo: false, can_collect_data: false, }, diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index 5a9fcf0e6ce7bfa5476d6c48245068994178f7bc..d88a159a47aa7633a5b064e72a75dd61604710e1 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -215,12 +215,18 @@ impl SweepAi { related_files: inputs.related_files.clone(), cursor_path: full_path.clone(), cursor_excerpt: request_body.file_contents.clone().into(), - // we actually don't know - editable_range_in_excerpt: 0..inputs.snapshot.len(), cursor_offset_in_excerpt: request_body.cursor_position, excerpt_start_row: Some(0), - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: zeta_prompt::ExcerptRanges { + editable_150: 0..inputs.snapshot.len(), + editable_180: 0..inputs.snapshot.len(), + editable_350: 0..inputs.snapshot.len(), + editable_150_context_350: 0..inputs.snapshot.len(), + editable_180_context_350: 0..inputs.snapshot.len(), + editable_350_context_150: 0..inputs.snapshot.len(), + ..Default::default() + }, + experiment: None, in_open_source_repo: false, can_collect_data: false, }; diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index abcfeabec44b26405153c10c43e6c2739e5e802e..f6a786572736908556535b9131c1cf7814a6126f 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -13,14 +13,15 @@ use gpui::{App, AppContext as _, Task, http_client, prelude::*}; use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings}; use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff}; use release_channel::AppVersion; +use settings::EditPredictionPromptFormat; use text::{Anchor, Bias}; use std::env; use std::ops::Range; use std::{path::Path, sync::Arc, time::Instant}; use zeta_prompt::{ - CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output, - format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens, + CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill, + prompt_input_contains_special_tokens, zeta1::{self, EDITABLE_REGION_END_MARKER}, }; @@ -39,7 +40,6 @@ pub fn request_prediction_with_zeta( is_open_source, .. }: EditPredictionModelInput, - preferred_model: Option, cx: &mut Context, ) -> Task>> { let settings = &all_language_settings(None, cx).edit_predictions; @@ -55,6 +55,7 @@ pub fn request_prediction_with_zeta( let http_client = cx.http_client(); let buffer_snapshotted_at = Instant::now(); let raw_config = store.zeta2_raw_config().cloned(); + let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned()); let excerpt_path: Arc = snapshot .file() @@ -80,8 +81,7 @@ pub fn request_prediction_with_zeta( events, excerpt_path, cursor_offset, - zeta_version, - preferred_model, + preferred_experiment, is_open_source, can_collect_data, ); @@ -90,22 +90,8 @@ pub fn request_prediction_with_zeta( return Ok((None, None)); } - let is_zeta1 = preferred_model == Some(EditPredictionModelKind::Zeta1); - let excerpt_ranges = prompt_input - .excerpt_ranges - .as_ref() - .ok_or_else(|| anyhow::anyhow!("excerpt_ranges missing from prompt input"))?; - if let Some(debug_tx) = &debug_tx { - let prompt = if is_zeta1 { - zeta1::format_zeta1_from_input( - &prompt_input, - excerpt_ranges.editable_350.clone(), - excerpt_ranges.editable_350_context_150.clone(), - ) - } else { - format_zeta_prompt(&prompt_input, zeta_version) - }; + let prompt = format_zeta_prompt(&prompt_input, zeta_version); debug_tx .unbounded_send(DebugEvent::EditPredictionStarted( EditPredictionStartedDebugEvent { @@ -119,130 +105,133 @@ pub fn request_prediction_with_zeta( log::trace!("Sending edit prediction request"); - let (request_id, output_text, model_version, usage) = if let Some(custom_settings) = - &custom_server_settings - { - let max_tokens = custom_settings.max_output_tokens * 4; - - if is_zeta1 { - let ranges = excerpt_ranges; - let prompt = zeta1::format_zeta1_from_input( - &prompt_input, - ranges.editable_350.clone(), - ranges.editable_350_context_150.clone(), - ); - editable_range_in_excerpt = ranges.editable_350.clone(); - let stop_tokens = vec![ - EDITABLE_REGION_END_MARKER.to_string(), - format!("{EDITABLE_REGION_END_MARKER}\n"), - format!("{EDITABLE_REGION_END_MARKER}\n\n"), - format!("{EDITABLE_REGION_END_MARKER}\n\n\n"), - ]; - - let (response_text, request_id) = send_custom_server_request( - provider, - custom_settings, + let (request_id, output_text, model_version, usage) = + if let Some(custom_settings) = &custom_server_settings { + let max_tokens = custom_settings.max_output_tokens * 4; + + match custom_settings.prompt_format { + EditPredictionPromptFormat::Zeta => { + let ranges = &prompt_input.excerpt_ranges; + let prompt = zeta1::format_zeta1_from_input( + &prompt_input, + ranges.editable_350.clone(), + ranges.editable_350_context_150.clone(), + ); + editable_range_in_excerpt = ranges.editable_350.clone(); + let stop_tokens = vec![ + EDITABLE_REGION_END_MARKER.to_string(), + format!("{EDITABLE_REGION_END_MARKER}\n"), + format!("{EDITABLE_REGION_END_MARKER}\n\n"), + format!("{EDITABLE_REGION_END_MARKER}\n\n\n"), + ]; + + let (response_text, request_id) = send_custom_server_request( + provider, + custom_settings, + prompt, + max_tokens, + stop_tokens, + &http_client, + ) + .await?; + + let request_id = EditPredictionId(request_id.into()); + let output_text = zeta1::clean_zeta1_model_output(&response_text); + + (request_id, output_text, None, None) + } + EditPredictionPromptFormat::Zeta2 => { + let prompt = format_zeta_prompt(&prompt_input, zeta_version); + let prefill = get_prefill(&prompt_input, zeta_version); + let prompt = format!("{prompt}{prefill}"); + + editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format( + zeta_version, + &prompt_input.excerpt_ranges, + ) + .0; + + let (response_text, request_id) = send_custom_server_request( + provider, + custom_settings, + prompt, + max_tokens, + vec![], + &http_client, + ) + .await?; + + let request_id = EditPredictionId(request_id.into()); + let output_text = if response_text.is_empty() { + None + } else { + let output = format!("{prefill}{response_text}"); + Some(clean_zeta2_model_output(&output, zeta_version).to_string()) + }; + + (request_id, output_text, None, None) + } + _ => anyhow::bail!("unsupported prompt format"), + } + } else if let Some(config) = &raw_config { + let prompt = format_zeta_prompt(&prompt_input, config.format); + let prefill = get_prefill(&prompt_input, config.format); + let prompt = format!("{prompt}{prefill}"); + let request = RawCompletionRequest { + model: config.model_id.clone().unwrap_or_default(), prompt, - max_tokens, - stop_tokens, - &http_client, + temperature: None, + stop: vec![], + max_tokens: Some(2048), + environment: Some(config.format.to_string().to_lowercase()), + }; + + editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format( + config.format, + &prompt_input.excerpt_ranges, + ) + .1; + + let (mut response, usage) = EditPredictionStore::send_raw_llm_request( + request, + client, + None, + llm_token, + app_version, ) .await?; - let request_id = EditPredictionId(request_id.into()); - let output_text = zeta1::clean_zeta1_model_output(&response_text); + let request_id = EditPredictionId(response.id.clone().into()); + let output_text = response.choices.pop().map(|choice| { + let response = &choice.text; + let output = format!("{prefill}{response}"); + clean_zeta2_model_output(&output, config.format).to_string() + }); - (request_id, output_text, None, None) + (request_id, output_text, None, usage) } else { - let prompt = format_zeta_prompt(&prompt_input, zeta_version); - let prefill = get_prefill(&prompt_input, zeta_version); - let prompt = format!("{prompt}{prefill}"); - - editable_range_in_excerpt = prompt_input - .excerpt_ranges - .as_ref() - .map(|ranges| zeta_prompt::excerpt_range_for_format(zeta_version, ranges).0) - .unwrap_or(prompt_input.editable_range_in_excerpt.clone()); - - let (response_text, request_id) = send_custom_server_request( - provider, - custom_settings, - prompt, - max_tokens, - vec![], - &http_client, + // Use V3 endpoint - server handles model/version selection and suffix stripping + let (response, usage) = EditPredictionStore::send_v3_request( + prompt_input.clone(), + client, + llm_token, + app_version, + trigger, ) .await?; - let request_id = EditPredictionId(request_id.into()); - let output_text = if response_text.is_empty() { + let request_id = EditPredictionId(response.request_id.into()); + let output_text = if response.output.is_empty() { None } else { - let output = format!("{prefill}{response_text}"); - Some(clean_zeta2_model_output(&output, zeta_version).to_string()) + Some(response.output) }; + editable_range_in_excerpt = response.editable_range; + let model_version = response.model_version; - (request_id, output_text, None, None) - } - } else if let Some(config) = &raw_config { - let prompt = format_zeta_prompt(&prompt_input, config.format); - let prefill = get_prefill(&prompt_input, config.format); - let prompt = format!("{prompt}{prefill}"); - let request = RawCompletionRequest { - model: config.model_id.clone().unwrap_or_default(), - prompt, - temperature: None, - stop: vec![], - max_tokens: Some(2048), - environment: Some(config.format.to_string().to_lowercase()), + (request_id, output_text, model_version, usage) }; - editable_range_in_excerpt = prompt_input - .excerpt_ranges - .as_ref() - .map(|ranges| zeta_prompt::excerpt_range_for_format(config.format, ranges).1) - .unwrap_or(prompt_input.editable_range_in_excerpt.clone()); - - let (mut response, usage) = EditPredictionStore::send_raw_llm_request( - request, - client, - None, - llm_token, - app_version, - ) - .await?; - - let request_id = EditPredictionId(response.id.clone().into()); - let output_text = response.choices.pop().map(|choice| { - let response = &choice.text; - let output = format!("{prefill}{response}"); - clean_zeta2_model_output(&output, config.format).to_string() - }); - - (request_id, output_text, None, usage) - } else { - // Use V3 endpoint - server handles model/version selection and suffix stripping - let (response, usage) = EditPredictionStore::send_v3_request( - prompt_input.clone(), - client, - llm_token, - app_version, - trigger, - ) - .await?; - - let request_id = EditPredictionId(response.request_id.into()); - let output_text = if response.output.is_empty() { - None - } else { - Some(response.output) - }; - editable_range_in_excerpt = response.editable_range; - let model_version = response.model_version; - - (request_id, output_text, model_version, usage) - }; - let received_response_at = Instant::now(); log::trace!("Got edit prediction response"); @@ -373,8 +362,7 @@ pub fn zeta2_prompt_input( events: Vec>, excerpt_path: Arc, cursor_offset: usize, - zeta_format: ZetaFormat, - preferred_model: Option, + preferred_experiment: Option, is_open_source: bool, can_collect_data: bool, ) -> (Range, zeta_prompt::ZetaPromptInput) { @@ -392,11 +380,6 @@ pub fn zeta2_prompt_input( let full_context_start_offset = full_context_offset_range.start; let full_context_start_row = full_context.start.row; - let editable_offset_range = match preferred_model { - Some(EditPredictionModelKind::Zeta1) => excerpt_ranges.editable_350.clone(), - _ => zeta_prompt::excerpt_range_for_format(zeta_format, &excerpt_ranges).0, - }; - let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset; let prompt_input = zeta_prompt::ZetaPromptInput { @@ -405,13 +388,12 @@ pub fn zeta2_prompt_input( .text_for_range(full_context) .collect::() .into(), - editable_range_in_excerpt: editable_offset_range, cursor_offset_in_excerpt, excerpt_start_row: Some(full_context_start_row), events, related_files, - excerpt_ranges: Some(excerpt_ranges), - preferred_model, + excerpt_ranges, + experiment: preferred_experiment, in_open_source_repo: is_open_source, can_collect_data, }; diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 6cdfeef8f569df9277d3417c0134b2c7047bee30..ecacd963023d7d113ea5ad77b61fd1d88306fc95 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -36,12 +36,8 @@ pub async fn run_format_prompt( step_progress.set_substatus("formatting teacher prompt"); let zeta_format = ZetaFormat::default(); - let excerpt_ranges = prompt_inputs - .excerpt_ranges - .as_ref() - .context("prompt_inputs must have excerpt_ranges")?; let (editable_range, context_range) = - excerpt_range_for_format(zeta_format, excerpt_ranges); + excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges); let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range); example.prompt = Some(ExamplePrompt { diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 680af6f0168c766c6066a91a8f57fe4573b46403..dcf417c2e8cc70dfcaffdf4b96dbe3b17daa61d4 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -93,21 +93,18 @@ pub async fn run_load_project( let cursor_offset_in_excerpt = cursor_offset - full_context_offset_range.start; let excerpt_start_row = Some(full_context_point_range.start.row); - let editable_range_in_excerpt = excerpt_ranges.editable_350.clone(); - ( ZetaPromptInput { cursor_path: example.spec.cursor_path.clone(), cursor_excerpt, - editable_range_in_excerpt, cursor_offset_in_excerpt, excerpt_start_row, events, related_files: existing_related_files, - excerpt_ranges: Some(excerpt_ranges), - preferred_model: None, + excerpt_ranges, in_open_source_repo: false, can_collect_data: false, + experiment: None, }, language_name, ) diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index e02fcbdb425a62fb478b8be36fdd034eede27622..02ba24b8a4f2627b9542254e3d118981737f8318 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -110,8 +110,8 @@ pub async fn run_prediction( ep_store.update(&mut cx, |store, _cx| { let model = match provider { - PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, - PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2, + PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta, + PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta, PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher(..) diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index b53a3d5546e1a5697550ed24715f049c36c98178..2f371675b29015795beef550ce5e3956c63751f9 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -1115,11 +1115,8 @@ fn build_settled_example( requested_format: ZetaFormat, zed_version: Option, ) -> Example { - let requested_editable_range = input - .excerpt_ranges - .as_ref() - .map(|ranges| excerpt_range_for_format(requested_format, ranges).0) - .unwrap_or_else(|| input.editable_range_in_excerpt.clone()); + let requested_editable_range = + excerpt_range_for_format(requested_format, &input.excerpt_ranges).0; let base_cursor_excerpt = input.cursor_excerpt.to_string(); @@ -1268,7 +1265,7 @@ fn build_rejected_example( let rejected_patch = build_output_patch( &input.cursor_path, input.cursor_excerpt.as_ref(), - &input.editable_range_in_excerpt, + &input.excerpt_ranges.editable_350, &output, ); let mut example = build_example_from_snowflake( diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs index cc558939e9aecf826afce77d6205b0ff49ab87bc..2d578c8666f217365ed2ed24ff766ed6f19566d7 100644 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -655,6 +655,7 @@ mod tests { use super::*; use edit_prediction::udiff::apply_diff_to_string; use indoc::indoc; + use zeta_prompt::ExcerptRanges; fn make_test_prompt_inputs( content: &str, @@ -664,13 +665,20 @@ mod tests { ZetaPromptInput { cursor_path: Arc::from(Path::new("src/test.rs")), cursor_excerpt: content.into(), - editable_range_in_excerpt: 0..content.len(), cursor_offset_in_excerpt: 0, excerpt_start_row, events, related_files: Vec::new(), - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: ExcerptRanges { + editable_150: 0..content.len(), + editable_180: 0..content.len(), + editable_350: 0..content.len(), + editable_150_context_350: 0..content.len(), + editable_180_context_350: 0..content.len(), + editable_350_context_150: 0..content.len(), + ..Default::default() + }, + experiment: None, in_open_source_repo: false, can_collect_data: false, } diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index c1fcd78f3f0cee24e6e8d936bf6af56f8d1ebda0..6339c7d6cd9fa1cc40101cc1bf14650a6904b3c7 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -1194,6 +1194,56 @@ impl EditPredictionButton { menu = self.build_language_settings_menu(menu, window, cx); } menu = self.add_provider_switching_section(menu, provider, cx); + + if cx.is_staff() { + if let Some(store) = EditPredictionStore::try_global(cx) { + let store = store.read(cx); + let experiments = store.available_experiments().to_vec(); + let preferred = store.preferred_experiment().map(|s| s.to_owned()); + + let preferred_for_submenu = preferred.clone(); + menu = menu + .separator() + .submenu("Experiment", move |menu, _window, _cx| { + let mut menu = menu.toggleable_entry( + "Default", + preferred_for_submenu.is_none(), + IconPosition::Start, + None, + { + move |_window, cx| { + if let Some(store) = EditPredictionStore::try_global(cx) { + store.update(cx, |store, _cx| { + store.set_preferred_experiment(None); + }); + } + } + }, + ); + for experiment in &experiments { + let is_selected = preferred.as_deref() == Some(experiment.as_str()); + let experiment_name = experiment.clone(); + menu = menu.toggleable_entry( + experiment.clone(), + is_selected, + IconPosition::Start, + None, + move |_window, cx| { + if let Some(store) = EditPredictionStore::try_global(cx) { + store.update(cx, |store, _cx| { + store.set_preferred_experiment(Some( + experiment_name.clone(), + )); + }); + } + }, + ); + } + menu + }); + } + } + menu = menu.separator().item( ContextMenuEntry::new("Configure Providers") .icon(IconName::Settings) diff --git a/crates/settings_content/src/language.rs b/crates/settings_content/src/language.rs index db22f3a9e1448dbc529c133fb0195c422f02bc40..d429f53824fd0f4f0a5810bce01b05badcfb9a51 100644 --- a/crates/settings_content/src/language.rs +++ b/crates/settings_content/src/language.rs @@ -123,9 +123,7 @@ impl<'de> Deserialize<'de> for EditPredictionProvider { Content::Experimental(name) if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME => { - EditPredictionProvider::Experimental( - EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, - ) + EditPredictionProvider::Zed } Content::Experimental(name) => { return Err(D::Error::custom(format!( @@ -240,6 +238,7 @@ pub enum EditPredictionPromptFormat { #[default] Infer, Zeta, + Zeta2, CodeLlama, StarCoder, DeepseekCoder, diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 79b33093d86b306c3b0420f919bd555d9ea4ca7a..67b0d26c88cf0bd254a776834de09fb89d6ea195 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -15,6 +15,8 @@ use std::{cell::RefCell, rc::Rc, sync::Arc}; use ui::Window; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { + edit_prediction::EditPredictionStore::global(&client, &user_store, cx); + let editors: Rc, AnyWindowHandle>>> = Rc::default(); cx.observe_new({ let editors = editors.clone(); @@ -131,9 +133,9 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option None, EditPredictionProvider::Copilot => Some(EditPredictionProviderConfig::Copilot), - EditPredictionProvider::Zed => Some(EditPredictionProviderConfig::Zed( - EditPredictionModel::Zeta1, - )), + EditPredictionProvider::Zed => { + Some(EditPredictionProviderConfig::Zed(EditPredictionModel::Zeta)) + } EditPredictionProvider::Codestral => Some(EditPredictionProviderConfig::Codestral), EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi => { let custom_settings = if provider == EditPredictionProvider::Ollama { @@ -153,9 +155,7 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option Option() { - Some(EditPredictionProviderConfig::Zed( - EditPredictionModel::Zeta2, - )) + Some(EditPredictionProviderConfig::Zed(EditPredictionModel::Zeta)) } else { None } @@ -212,8 +210,7 @@ impl EditPredictionProviderConfig { EditPredictionProviderConfig::Copilot => "Copilot", EditPredictionProviderConfig::Codestral => "Codestral", EditPredictionProviderConfig::Zed(model) => match model { - EditPredictionModel::Zeta1 => "Zeta1", - EditPredictionModel::Zeta2 => "Zeta2", + EditPredictionModel::Zeta => "Zeta", EditPredictionModel::Fim { .. } => "FIM", EditPredictionModel::Sweep => "Sweep", EditPredictionModel::Mercury => "Mercury", @@ -311,26 +308,23 @@ fn assign_edit_prediction_provider( let ep_store = edit_prediction::EditPredictionStore::global(client, &user_store, cx); if let Some(project) = editor.project() { - let has_model = ep_store.update(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { ep_store.set_edit_prediction_model(model); if let Some(buffer) = &singleton_buffer { ep_store.register_buffer(buffer, project, cx); } - true }); - if has_model { - let provider = cx.new(|cx| { - ZedEditPredictionDelegate::new( - project.clone(), - singleton_buffer, - &client, - &user_store, - cx, - ) - }); - editor.set_edit_prediction_provider(Some(provider), window, cx); - } + let provider = cx.new(|cx| { + ZedEditPredictionDelegate::new( + project.clone(), + singleton_buffer, + &client, + &user_store, + cx, + ) + }); + editor.set_edit_prediction_provider(Some(provider), window, cx); } } } diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index abb1c3ddc74d58d0b300e5e64d77a60a48b83283..0cd37a455397334933dbfa2464c2dbcb72bba456 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -18,17 +18,10 @@ fn estimate_tokens(bytes: usize) -> usize { bytes / 3 } -/// The client's preferred edit prediction model. The server may override this. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum EditPredictionModelKind { - Zeta1, - Zeta2, -} - /// Pre-computed byte offset ranges within `cursor_excerpt` for different /// editable and context token budgets. Allows the server to select the /// appropriate ranges for whichever model it uses. -#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Hash, Serialize, Deserialize)] pub struct ExcerptRanges { /// Editable region computed with a 150-token budget. pub editable_150: Range, @@ -54,21 +47,16 @@ pub struct ExcerptRanges { pub struct ZetaPromptInput { pub cursor_path: Arc, pub cursor_excerpt: Arc, - pub editable_range_in_excerpt: Range, pub cursor_offset_in_excerpt: usize, #[serde(default, skip_serializing_if = "Option::is_none")] pub excerpt_start_row: Option, pub events: Vec>, pub related_files: Vec, - /// When set, the excerpt was computed with a larger budget (~512 tokens) - /// and these ranges let the server select model-appropriate subsets. - /// When absent, the excerpt IS the context region and - /// `editable_range_in_excerpt` is the only editable range. + /// These ranges let the server select model-appropriate subsets. + pub excerpt_ranges: ExcerptRanges, + /// The name of the edit prediction model experiment to use. #[serde(default, skip_serializing_if = "Option::is_none")] - pub excerpt_ranges: Option, - /// Client's preferred model. The server may override. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub preferred_model: Option, + pub experiment: Option, #[serde(default)] pub in_open_source_repo: bool, #[serde(default)] @@ -274,15 +262,7 @@ pub fn resolve_cursor_region( input: &ZetaPromptInput, format: ZetaFormat, ) -> (&str, Range, usize) { - let Some(ranges) = &input.excerpt_ranges else { - return ( - &input.cursor_excerpt, - input.editable_range_in_excerpt.clone(), - input.cursor_offset_in_excerpt, - ); - }; - - let (editable_range, context_range) = excerpt_range_for_format(format, ranges); + let (editable_range, context_range) = excerpt_range_for_format(format, &input.excerpt_ranges); let context_start = context_range.start; let context_text = &input.cursor_excerpt[context_range]; let adjusted_editable = @@ -1159,16 +1139,24 @@ mod tests { events: Vec, related_files: Vec, ) -> ZetaPromptInput { + let context_range = 0..cursor_excerpt.len(); ZetaPromptInput { cursor_path: Path::new("test.rs").into(), cursor_excerpt: cursor_excerpt.into(), - editable_range_in_excerpt: editable_range, cursor_offset_in_excerpt: cursor_offset, excerpt_start_row: None, events: events.into_iter().map(Arc::new).collect(), related_files, - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: ExcerptRanges { + editable_150: editable_range.clone(), + editable_180: editable_range.clone(), + editable_350: editable_range, + editable_150_context_350: context_range.clone(), + editable_180_context_350: context_range.clone(), + editable_350_context_150: context_range, + ..Default::default() + }, + experiment: None, in_open_source_repo: false, can_collect_data: false, } @@ -1752,13 +1740,20 @@ mod tests { let input = ZetaPromptInput { cursor_path: Path::new("src/main.rs").into(), cursor_excerpt: excerpt.into(), - editable_range_in_excerpt: 15..41, cursor_offset_in_excerpt: 30, excerpt_start_row: Some(0), events: vec![Arc::new(make_event("other.rs", "-old\n+new\n"))], related_files: vec![], - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: ExcerptRanges { + editable_150: 15..41, + editable_180: 15..41, + editable_350: 15..41, + editable_150_context_350: 0..excerpt.len(), + editable_180_context_350: 0..excerpt.len(), + editable_350_context_150: 0..excerpt.len(), + ..Default::default() + }, + experiment: None, in_open_source_repo: false, can_collect_data: false, }; @@ -1807,13 +1802,20 @@ mod tests { let input = ZetaPromptInput { cursor_path: Path::new("src/main.rs").into(), cursor_excerpt: excerpt.into(), - editable_range_in_excerpt: 0..28, cursor_offset_in_excerpt: 15, excerpt_start_row: Some(10), events: vec![], related_files: vec![], - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: ExcerptRanges { + editable_150: 0..28, + editable_180: 0..28, + editable_350: 0..28, + editable_150_context_350: 0..28, + editable_180_context_350: 0..28, + editable_350_context_150: 0..28, + ..Default::default() + }, + experiment: None, in_open_source_repo: false, can_collect_data: false, }; @@ -1857,13 +1859,20 @@ mod tests { let input = ZetaPromptInput { cursor_path: Path::new("test.rs").into(), cursor_excerpt: excerpt.into(), - editable_range_in_excerpt: editable_range.clone(), cursor_offset_in_excerpt: 25, excerpt_start_row: Some(0), events: vec![], related_files: vec![], - excerpt_ranges: None, - preferred_model: None, + excerpt_ranges: ExcerptRanges { + editable_150: editable_range.clone(), + editable_180: editable_range.clone(), + editable_350: editable_range.clone(), + editable_150_context_350: context_range.clone(), + editable_180_context_350: context_range.clone(), + editable_350_context_150: context_range.clone(), + ..Default::default() + }, + experiment: None, in_open_source_repo: false, can_collect_data: false, };