From c4306812117d9a5047b5b7194275e9644d920c23 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 6 Feb 2026 00:31:24 -0800 Subject: [PATCH] Do not pass zeta prompt format in production endpoint (#48541) This allows us to switch the prompt format without client-side changes. If we want to experiment with prompt formats or models other than the currently-deployed one, we can use the raw endpoint, and do prompt construction and output processing on the client. This also adds an optional environment parameter to the raw endpoint, so that we can use that endpoint in the new scheme where we're deploying to separate environments for different zeta prompt versions. Release Notes: - N/A --- Cargo.lock | 1 + .../cloud_llm_client/src/predict_edits_v3.rs | 6 +- crates/edit_prediction/src/edit_prediction.rs | 100 +++++++------- .../src/edit_prediction_tests.rs | 126 +++--------------- .../src/zed_edit_prediction_delegate.rs | 2 +- crates/edit_prediction/src/zeta1.rs | 19 +-- crates/edit_prediction/src/zeta2.rs | 72 +++++----- crates/edit_prediction_cli/Cargo.toml | 1 + .../edit_prediction_cli/src/format_prompt.rs | 8 +- crates/edit_prediction_cli/src/main.rs | 29 ++-- .../edit_prediction_cli/src/parse_output.rs | 32 ++--- crates/edit_prediction_cli/src/predict.rs | 16 ++- .../zed/src/zed/edit_prediction_registry.rs | 4 +- crates/zeta_prompt/src/zeta_prompt.rs | 51 ++++--- 14 files changed, 193 insertions(+), 274 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 71ddd769e0d406735195903131523bf475fb0572..3ea3eb05237ff2fd40ad615a1d64e6bc407e5f57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5398,6 +5398,7 @@ dependencies = [ "smol", "sqlez", "sqlez_macros", + "strum 0.27.2", "telemetry_events", "tempfile", "terminal_view", diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index c0bd3bc3204f06eb85fc1e6db16095a3d0af3f44..9e7772ab7450cb47785d034b39d9c7c642b931c2 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -11,16 +11,14 @@ pub struct RawCompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, pub stop: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + pub environment: Option, } #[derive(Debug, Serialize, Deserialize)] pub struct PredictEditsV3Request { #[serde(flatten)] pub input: zeta_prompt::ZetaPromptInput, - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, - #[serde(default)] - pub prompt_version: zeta_prompt::ZetaVersion, #[serde(default)] pub trigger: PredictEditsRequestTrigger, } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 065681adc87c3bd0f7a9a7ac94b985164b95fe9f..6385c6a2b6972740b19fc4009b8b3c241f3cef4e 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -36,18 +36,18 @@ use semver::Version; use serde::de::DeserializeOwned; use settings::{EditPredictionProvider, Settings as _, update_settings_file}; use std::collections::{VecDeque, hash_map}; +use std::env; use text::Edit; use workspace::Workspace; -use zeta_prompt::ZetaPromptInput; -use zeta_prompt::ZetaVersion; +use zeta_prompt::{ZetaFormat, ZetaPromptInput}; +use std::mem; use std::ops::Range; use std::path::Path; use std::rc::Rc; use std::str::FromStr as _; -use std::sync::{Arc, LazyLock}; +use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -use std::{env, mem}; use thiserror::Error; use util::{RangeExt as _, ResultExt as _}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; @@ -105,9 +105,6 @@ const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1); const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); -static EDIT_PREDICTIONS_MODEL_ID: LazyLock> = - LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok()); - pub struct Zeta2FeatureFlag; impl FeatureFlag for Zeta2FeatureFlag { @@ -133,6 +130,15 @@ struct EditPredictionStoreGlobal(Entity); impl Global for EditPredictionStoreGlobal {} +/// Configuration for using the raw Zeta2 endpoint. +/// When set, the client uses the raw endpoint and constructs the prompt itself. +/// The version is also used as the Baseten environment name (lowercased). +#[derive(Clone)] +pub struct Zeta2RawConfig { + pub model_id: Option, + pub format: ZetaFormat, +} + pub struct EditPredictionStore { client: Arc, user_store: Entity, @@ -141,6 +147,7 @@ pub struct EditPredictionStore { projects: HashMap, update_required: bool, edit_prediction_model: EditPredictionModel, + zeta2_raw_config: Option, pub sweep_ai: SweepAi, pub mercury: Mercury, pub ollama: Ollama, @@ -148,16 +155,13 @@ pub struct EditPredictionStore { reject_predictions_tx: mpsc::UnboundedSender, shown_predictions: VecDeque, rated_predictions: HashSet, - custom_predict_edits_url: Option>, } #[derive(Copy, Clone, Default, PartialEq, Eq)] pub enum EditPredictionModel { #[default] Zeta1, - Zeta2 { - version: ZetaVersion, - }, + Zeta2, Sweep, Mercury, Ollama, @@ -631,9 +635,8 @@ impl EditPredictionStore { }, ), update_required: false, - edit_prediction_model: EditPredictionModel::Zeta2 { - version: Default::default(), - }, + edit_prediction_model: EditPredictionModel::Zeta2, + zeta2_raw_config: Self::zeta2_raw_config_from_env(), sweep_ai: SweepAi::new(cx), mercury: Mercury::new(cx), ollama: Ollama::new(), @@ -642,24 +645,30 @@ impl EditPredictionStore { reject_predictions_tx: reject_tx, rated_predictions: Default::default(), shown_predictions: Default::default(), - custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") { - Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into), - Err(_) => None, - }, }; this } - #[cfg(test)] - pub fn set_custom_predict_edits_url(&mut self, url: Url) { - self.custom_predict_edits_url = Some(url.into()); + fn zeta2_raw_config_from_env() -> Option { + let version_str = env::var("ZED_ZETA_FORMAT").ok()?; + let format = ZetaFormat::parse(&version_str).ok()?; + let model_id = env::var("ZED_ZETA_MODEL").ok(); + Some(Zeta2RawConfig { model_id, format }) } pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) { self.edit_prediction_model = model; } + pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) { + self.zeta2_raw_config = Some(config); + } + + pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> { + self.zeta2_raw_config.as_ref() + } + pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet { use ui::IconName; match self.edit_prediction_model { @@ -673,7 +682,7 @@ impl EditPredictionStore { EditPredictionModel::Mercury => { edit_prediction_types::EditPredictionIconSet::new(IconName::Inception) } - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict) .with_disabled(IconName::ZedPredictDisabled) .with_up(IconName::ZedPredictUp) @@ -796,10 +805,7 @@ impl EditPredictionStore { } pub fn usage(&self, cx: &App) -> Option { - if matches!( - self.edit_prediction_model, - EditPredictionModel::Zeta2 { .. } - ) { + if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) { self.user_store.read(cx).edit_prediction_usage() } else { None @@ -1223,7 +1229,7 @@ impl EditPredictionStore { ); } EditPredictionModel::Ollama => {} - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { zeta2::edit_prediction_accepted(self, current_prediction, cx) } } @@ -1359,16 +1365,14 @@ impl EditPredictionStore { cx: &App, ) { match self.edit_prediction_model { - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { - if self.custom_predict_edits_url.is_none() { - self.reject_predictions_tx - .unbounded_send(EditPredictionRejection { - request_id: prediction_id.to_string(), - reason, - was_shown, - }) - .log_err(); - } + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + self.reject_predictions_tx + .unbounded_send(EditPredictionRejection { + request_id: prediction_id.to_string(), + reason, + was_shown, + }) + .log_err(); } EditPredictionModel::Sweep | EditPredictionModel::Ollama => {} EditPredictionModel::Mercury => { @@ -1805,24 +1809,16 @@ impl EditPredictionStore { .detach_and_log_err(cx); } } - let task = match self.edit_prediction_model { + let task = match &self.edit_prediction_model { EditPredictionModel::Zeta1 => { if should_send_testing_zeta2_request() { let mut zeta2_inputs = inputs.clone(); zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing; - zeta2::request_prediction_with_zeta2( - self, - zeta2_inputs, - Default::default(), - cx, - ) - .detach(); + zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach(); } zeta1::request_prediction_with_zeta1(self, inputs, cx) } - EditPredictionModel::Zeta2 { version } => { - zeta2::request_prediction_with_zeta2(self, inputs, version, cx) - } + EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx), EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx), EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx), @@ -1976,7 +1972,6 @@ impl EditPredictionStore { pub(crate) async fn send_v3_request( input: ZetaPromptInput, - prompt_version: ZetaVersion, client: Arc, llm_token: LlmApiToken, app_version: Version, @@ -1986,12 +1981,7 @@ impl EditPredictionStore { .http_client() .build_zed_llm_url("/predict_edits/v3", &[])?; - let request = PredictEditsV3Request { - input, - model: EDIT_PREDICTIONS_MODEL_ID.clone(), - prompt_version, - trigger, - }; + let request = PredictEditsV3Request { input, trigger }; Self::send_api_request( |builder| { diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 906f5a76d6232cbf00a01801f598e49b804305b3..dad91c7fed96eb8f8abcdcf2ced33029e92f1861 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1343,7 +1343,7 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi } fn prompt_from_request(request: &PredictEditsV3Request) -> String { - zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version) + zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default()) } struct RequestChannels { @@ -2073,6 +2073,20 @@ async fn make_test_ep_store( ) .unwrap()) } + (&Method::POST, "/predict_edits/v3") => { + next_request_id += 1; + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsV3Response { + request_id: format!("request-{next_request_id}"), + output: "hello world".to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()) + } _ => Ok(http_client::Response::builder() .status(404) .body("Not Found".into()) @@ -2200,116 +2214,6 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut ); } -#[gpui::test] -async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/project", - serde_json::json!({ - "main.rs": "fn main() {\n \n}\n" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - - let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let predict_called_clone = predict_called.clone(); - - let http_client = FakeHttpClient::create({ - move |req| { - let uri = req.uri().path().to_string(); - let predict_called = predict_called_clone.clone(); - async move { - if uri.contains("predict") { - predict_called.store(true, std::sync::atomic::Ordering::SeqCst); - Ok(gpui::http_client::Response::builder() - .body( - serde_json::to_string(&open_ai::Response { - id: "test-123".to_string(), - object: "chat.completion".to_string(), - created: 0, - model: "test".to_string(), - usage: open_ai::Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - choices: vec![open_ai::Choice { - index: 0, - message: open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain( - indoc! {" - ```main.rs - <|start_of_file|> - <|editable_region_start|> - fn main() { - println!(\"Hello, world!\"); - } - <|editable_region_end|> - ``` - "} - .to_string(), - )), - tool_calls: vec![], - }, - finish_reason: Some("stop".to_string()), - }], - }) - .unwrap() - .into(), - ) - .unwrap()) - } else { - Ok(gpui::http_client::Response::builder() - .status(401) - .body("Unauthorized".into()) - .unwrap()) - } - } - } - }); - - let client = - cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); - cx.update(|cx| { - language_model::RefreshLlmTokenListener::register(client.clone(), cx); - }); - - let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx)); - - let buffer = project - .update(cx, |project, cx| { - let path = project - .find_project_path(path!("/project/main.rs"), cx) - .unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - - let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4))); - ep_store.update(cx, |ep_store, cx| { - ep_store.register_buffer(&buffer, &project, cx) - }); - cx.background_executor.run_until_parked(); - - let completion_task = ep_store.update(cx, |ep_store, cx| { - ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap()); - ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1); - ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx) - }); - - let _ = completion_task.await; - - assert!( - predict_called.load(std::sync::atomic::Ordering::SeqCst), - "With custom URL, predict endpoint should be called even without authentication" - ); -} - #[gpui::test] fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) { let buffer = cx.new(|cx| { diff --git a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs index a92c1450ce7105fdea51d8c610254b9570a6e0d7..3f517a5caaa5bb5964403bd22fa89182a03e4363 100644 --- a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -70,7 +70,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { .with_down(IconName::SweepAiDown) .with_error(IconName::SweepAiError), EditPredictionModel::Mercury => EditPredictionIconSet::new(IconName::Inception), - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { EditPredictionIconSet::new(IconName::ZedPredict) .with_disabled(IconName::ZedPredictDisabled) .with_up(IconName::ZedPredictUp) diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index c7b093edec197f6e161e755f8fd3c429528badfc..74c6cb568b26f19de21447d520d9569afa58b432 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -81,17 +81,12 @@ pub(crate) fn request_prediction_with_zeta1( cx, ); - let (uri, require_auth) = match &store.custom_predict_edits_url { - Some(custom_url) => (custom_url.clone(), false), - None => { - match client - .http_client() - .build_zed_llm_url("/predict_edits/v2", &[]) - { - Ok(url) => (url.into(), true), - Err(err) => return Task::ready(Err(err)), - } - } + let uri = match client + .http_client() + .build_zed_llm_url("/predict_edits/v2", &[]) + { + Ok(url) => Arc::from(url), + Err(err) => return Task::ready(Err(err)), }; cx.spawn(async move |this, cx| { @@ -127,7 +122,7 @@ pub(crate) fn request_prediction_with_zeta1( client, llm_token, app_version, - require_auth, + true, ) .await; diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 3d2ee06bc55d9c5b01268cc33b1c7404009f910d..2a3efa5c803aee1ed53572c506d238317fc9842a 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,9 +1,8 @@ use crate::prediction::EditPredictionResult; use crate::zeta1::compute_edits_and_cursor_position; use crate::{ - CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, - EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent, - EditPredictionStore, + CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, + EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, }; use anyhow::{Result, anyhow}; use cloud_llm_client::predict_edits_v3::RawCompletionRequest; @@ -14,17 +13,16 @@ use release_channel::AppVersion; use std::env; use std::{path::Path, sync::Arc, time::Instant}; -use zeta_prompt::format_zeta_prompt; -use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers}; +use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt}; pub const MAX_CONTEXT_TOKENS: usize = 350; -pub fn max_editable_tokens(version: ZetaVersion) -> usize { - match version { - ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150, - ZetaVersion::V0114180EditableRegion => 180, - ZetaVersion::V0120GitMergeMarkers => 180, - ZetaVersion::V0131GitMergeMarkersPrefix => 180, +pub fn max_editable_tokens(format: ZetaFormat) -> usize { + match format { + ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150, + ZetaFormat::V0114180EditableRegion => 180, + ZetaFormat::V0120GitMergeMarkers => 180, + ZetaFormat::V0131GitMergeMarkersPrefix => 180, } } @@ -40,11 +38,10 @@ pub fn request_prediction_with_zeta2( trigger, .. }: EditPredictionModelInput, - zeta_version: ZetaVersion, cx: &mut Context, ) -> Task>> { let buffer_snapshotted_at = Instant::now(); - let custom_url = store.custom_predict_edits_url.clone(); + let raw_config = store.zeta2_raw_config().cloned(); let Some(excerpt_path) = snapshot .file() @@ -59,6 +56,11 @@ pub fn request_prediction_with_zeta2( let request_task = cx.background_spawn({ async move { + let zeta_version = raw_config + .as_ref() + .map(|config| config.format) + .unwrap_or(ZetaFormat::default()); + let cursor_offset = position.to_offset(&snapshot); let (editable_offset_range, prompt_input) = zeta2_prompt_input( &snapshot, @@ -84,33 +86,36 @@ pub fn request_prediction_with_zeta2( log::trace!("Sending edit prediction request"); - let (request_id, output_text, usage) = if let Some(custom_url) = custom_url { - // Use raw endpoint with custom URL - let prompt = format_zeta_prompt(&prompt_input, zeta_version); + let (request_id, output_text, usage) = if let Some(config) = &raw_config { + let prompt = format_zeta_prompt(&prompt_input, config.format); let request = RawCompletionRequest { - model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(), + model: config.model_id.clone().unwrap_or_default(), prompt, temperature: None, stop: vec![], max_tokens: Some(2048), + environment: Some(config.format.to_string().to_lowercase()), }; let (mut response, usage) = EditPredictionStore::send_raw_llm_request( request, client, - Some(custom_url), + None, llm_token, app_version, ) .await?; let request_id = EditPredictionId(response.id.clone().into()); - let output_text = response.choices.pop().map(|choice| choice.text); + let output_text = response.choices.pop().map(|choice| { + clean_zeta2_model_output(&choice.text, config.format).to_string() + }); + (request_id, output_text, usage) } else { + // Use V3 endpoint - server handles model/version selection and suffix stripping let (response, usage) = EditPredictionStore::send_v3_request( prompt_input.clone(), - zeta_version, client, llm_token, app_version, @@ -135,6 +140,13 @@ pub fn request_prediction_with_zeta2( return Ok((Some((request_id, None)), usage)); }; + // Client-side cursor marker processing (applies to both raw and v3 responses) + let cursor_offset_in_output = output_text.find(CURSOR_MARKER); + if let Some(offset) = cursor_offset_in_output { + log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}"); + output_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); + } + if let Some(debug_tx) = &debug_tx { debug_tx .unbounded_send(DebugEvent::EditPredictionFinished( @@ -147,20 +159,6 @@ pub fn request_prediction_with_zeta2( .ok(); } - let cursor_offset_in_output = output_text.find(CURSOR_MARKER); - if let Some(offset) = cursor_offset_in_output { - log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}"); - output_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); - } - - if zeta_version == ZetaVersion::V0120GitMergeMarkers { - if let Some(stripped) = - output_text.strip_suffix(v0120_git_merge_markers::END_MARKER) - { - output_text = stripped.to_string(); - } - } - let mut old_text = snapshot .text_for_range(editable_offset_range.clone()) .collect::(); @@ -242,7 +240,7 @@ pub fn zeta2_prompt_input( events: Vec>, excerpt_path: Arc, cursor_offset: usize, - zeta_version: ZetaVersion, + zeta_format: ZetaFormat, ) -> (std::ops::Range, zeta_prompt::ZetaPromptInput) { let cursor_point = cursor_offset.to_point(snapshot); @@ -250,7 +248,7 @@ pub fn zeta2_prompt_input( crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( cursor_point, snapshot, - max_editable_tokens(zeta_version), + max_editable_tokens(zeta_format), MAX_CONTEXT_TOKENS, ); @@ -288,7 +286,7 @@ pub(crate) fn edit_prediction_accepted( cx: &App, ) { let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok(); - if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() { + if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() { return; } diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 20451bdc7a7e2e96a9a1b48ed32180250f64b6b6..df5d742a3bb90d7e20a9cf6b4bd84314b14e206e 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -50,6 +50,7 @@ settings.workspace = true shellexpand.workspace = true smol.workspace = true sqlez.workspace = true +strum.workspace = true sqlez_macros.workspace = true terminal_view.workspace = true util.workspace = true diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 44df820d7611f6fa62aec66259bb203e09de428a..c0f078ed9af489c358695db80136dec854b0f532 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -12,7 +12,7 @@ use language::{Buffer, OffsetRangeExt, Point}; use similar::DiffableStr; use std::sync::Arc; use std::{fmt::Write as _, ops::Range}; -use zeta_prompt::ZetaVersion; +use zeta_prompt::ZetaFormat; use zeta_prompt::format_zeta_prompt; pub async fn run_format_prompt( @@ -54,7 +54,7 @@ pub async fn run_format_prompt( let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( cursor_point, &snapshot, - edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()), + edit_prediction::zeta2::max_editable_tokens(ZetaFormat::default()), edit_prediction::zeta2::MAX_CONTEXT_TOKENS, ); let editable_range = editable_range.to_offset(&snapshot); @@ -126,7 +126,7 @@ pub fn zeta2_output_for_patch( input: &zeta_prompt::ZetaPromptInput, patch: &str, cursor_offset: Option, - version: ZetaVersion, + version: ZetaFormat, ) -> Result { let mut old_editable_region = input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string(); @@ -155,7 +155,7 @@ pub fn zeta2_output_for_patch( } match version { - ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => { + ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => { if !result.ends_with('\n') { result.push('\n'); } diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 95e8332b44741ca7bdfb173282508f960d8d0303..0ade5e3f3fd30ef8139bc90e1c96e7b325395d5c 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -31,7 +31,7 @@ use edit_prediction::EditPredictionStore; use futures::channel::mpsc; use futures::{SinkExt as _, StreamExt as _}; use gpui::{AppContext as _, Application, BackgroundExecutor, Task}; -use zeta_prompt::ZetaVersion; +use zeta_prompt::ZetaFormat; use reqwest_client::ReqwestClient; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -207,6 +207,8 @@ enum Command { Qa(qa::QaArgs), /// Repair predictions that received poor QA scores by generating improved predictions Repair(repair::RepairArgs), + /// Print all valid zeta formats (lowercase, one per line) + PrintZetaFormats, } impl Display for Command { @@ -249,6 +251,9 @@ impl Display for Command { Command::Repair(_) => { write!(f, "repair") } + Command::PrintZetaFormats => { + write!(f, "print-zeta-formats") + } } } } @@ -321,7 +326,7 @@ enum PredictionProvider { Sweep, Mercury, Zeta1, - Zeta2(ZetaVersion), + Zeta2(ZetaFormat), Teacher(TeacherBackend), TeacherNonBatching(TeacherBackend), Repair, @@ -329,7 +334,7 @@ enum PredictionProvider { impl Default for PredictionProvider { fn default() -> Self { - PredictionProvider::Zeta2(ZetaVersion::default()) + PredictionProvider::Zeta2(ZetaFormat::default()) } } @@ -339,7 +344,7 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::Sweep => write!(f, "sweep"), PredictionProvider::Mercury => write!(f, "mercury"), PredictionProvider::Zeta1 => write!(f, "zeta1"), - PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"), + PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"), PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"), PredictionProvider::TeacherNonBatching(backend) => { write!(f, "teacher-non-batching:{backend}") @@ -361,8 +366,8 @@ impl std::str::FromStr for PredictionProvider { "mercury" => Ok(PredictionProvider::Mercury), "zeta1" => Ok(PredictionProvider::Zeta1), "zeta2" => { - let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default(); - Ok(PredictionProvider::Zeta2(version)) + let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default(); + Ok(PredictionProvider::Zeta2(format)) } "teacher" => { let backend = arg @@ -385,7 +390,7 @@ impl std::str::FromStr for PredictionProvider { For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\ For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\ Available zeta versions:\n{}", - ZetaVersion::options_as_string() + ZetaFormat::options_as_string() ) } } @@ -719,6 +724,13 @@ fn main() { std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap(); return; } + Command::PrintZetaFormats => { + use strum::IntoEnumIterator as _; + for format in ZetaFormat::iter() { + println!("{}", format.to_string().to_lowercase()); + } + return; + } Command::Synthesize(synth_args) => { let Some(output_dir) = args.output else { panic!("output dir is required"); @@ -953,7 +965,8 @@ fn main() { | Command::Split(_) | Command::TruncatePatch(_) | Command::FilterLanguages(_) - | Command::ImportBatch(_) => { + | Command::ImportBatch(_) + | Command::PrintZetaFormats => { unreachable!() } } diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index a6e795c27d8b1d9352cb4a7c2accdc995aa429df..e45060924d07a992ec2e563e5b16c3f85938ee2d 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -5,7 +5,7 @@ use crate::{ repair, }; use anyhow::{Context as _, Result}; -use zeta_prompt::{CURSOR_MARKER, ZetaVersion}; +use zeta_prompt::{CURSOR_MARKER, ZetaFormat}; pub fn run_parse_output(example: &mut Example) -> Result<()> { example @@ -49,13 +49,13 @@ pub fn parse_prediction_output( } } -fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result { - let (current_marker, end_marker) = match version { - ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"), - ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => { +fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result { + let (current_marker, end_marker) = match format { + ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"), + ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => { ("<|fim_middle|>current\n", "<|fim_suffix|>") } - ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => ( + ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => ( zeta_prompt::v0120_git_merge_markers::START_MARKER, zeta_prompt::v0120_git_merge_markers::SEPARATOR, ), @@ -82,7 +82,7 @@ fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result Result<(String, Option)> { let prompt = &example.prompt.as_ref().context("prompt required")?.input; let prompt_inputs = example @@ -90,7 +90,7 @@ fn parse_zeta2_output( .as_ref() .context("prompt_inputs required")?; - let old_text = extract_zeta2_current_region(prompt, version)?; + let old_text = extract_zeta2_current_region(prompt, format)?; let mut new_text = actual_output.to_string(); let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) { @@ -100,11 +100,11 @@ fn parse_zeta2_output( None }; - let suffix = match version { - ZetaVersion::V0131GitMergeMarkersPrefix => { + let suffix = match format { + ZetaFormat::V0131GitMergeMarkersPrefix => { zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER } - ZetaVersion::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER, + ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER, _ => "", }; if !suffix.is_empty() { @@ -184,7 +184,7 @@ mod tests { <|fim_middle|>updated "}; - let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap(); + let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } @@ -201,7 +201,7 @@ mod tests { <|fim_middle|>updated "}; - let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap(); + let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } @@ -218,7 +218,7 @@ mod tests { <|fim_middle|>updated "}; - let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap(); + let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } @@ -236,7 +236,7 @@ mod tests { "}; let region = - extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap(); + extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } @@ -254,7 +254,7 @@ mod tests { "}; let region = - extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap(); + extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 63be8e8b70dfbc5204ab530d839df4e9cdc34e41..5979439a2a7f3a66bfe94881bd04b9d948fe3c7e 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -11,7 +11,7 @@ use crate::{ retrieve_context::run_context_retrieval, }; use anyhow::Context as _; -use edit_prediction::{DebugEvent, EditPredictionStore}; +use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig}; use futures::{FutureExt as _, StreamExt as _, future::Shared}; use gpui::{AppContext as _, AsyncApp, Task}; use std::{ @@ -21,6 +21,7 @@ use std::{ atomic::{AtomicUsize, Ordering::SeqCst}, }, }; +use zeta_prompt::ZetaFormat; static ANTHROPIC_CLIENT: OnceLock = OnceLock::new(); static OPENAI_CLIENT: OnceLock = OnceLock::new(); @@ -103,9 +104,7 @@ pub async fn run_prediction( ep_store.update(&mut cx, |store, _cx| { let model = match provider { PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, - PredictionProvider::Zeta2(version) => { - edit_prediction::EditPredictionModel::Zeta2 { version } - } + PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2, PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher(..) @@ -115,6 +114,15 @@ pub async fn run_prediction( } }; store.set_edit_prediction_model(model); + + // If user specified a non-default Zeta2 version, configure raw endpoint. + // ZED_ZETA_MODEL env var is optional. + if let PredictionProvider::Zeta2(format) = provider { + if format != ZetaFormat::default() { + let model_id = std::env::var("ZED_ZETA_MODEL").ok(); + store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format }); + } + } }); step_progress.set_substatus("configuring model"); let state = example.state.as_ref().context("state must be set")?; diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 2347e27ccaa9ee5a94a9db4d262607ce126c3e57..3e3ed33fd6de460eca0a6e16e58751ff03118166 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -217,9 +217,7 @@ fn assign_edit_prediction_provider( if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME && cx.has_flag::() => { - edit_prediction::EditPredictionModel::Zeta2 { - version: Default::default(), - } + edit_prediction::EditPredictionModel::Zeta2 } EditPredictionProvider::Zed if user_store.read(cx).current_user().is_some() => diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 8799d680287aefbcd8a4740eb3b558f9cd62ccb8..73fdecbd134f2346f22304ae84c76ab53c1636c4 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -39,7 +39,7 @@ pub struct ZetaPromptInput { Deserialize, )] #[allow(non_camel_case_types)] -pub enum ZetaVersion { +pub enum ZetaFormat { V0112MiddleAtEnd, V0113Ordered, #[default] @@ -48,28 +48,28 @@ pub enum ZetaVersion { V0131GitMergeMarkersPrefix, } -impl std::fmt::Display for ZetaVersion { +impl std::fmt::Display for ZetaFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", <&'static str>::from(self)) } } -impl ZetaVersion { - pub fn parse(version_string: &str) -> Result { - let mut results = ZetaVersion::iter().filter(|version| { +impl ZetaFormat { + pub fn parse(format_name: &str) -> Result { + let mut results = ZetaFormat::iter().filter(|version| { <&'static str>::from(version) .to_lowercase() - .contains(&version_string.to_lowercase()) + .contains(&format_name.to_lowercase()) }); let Some(result) = results.next() else { anyhow::bail!( - "`{version_string}` did not match any of:\n{}", + "`{format_name}` did not match any of:\n{}", Self::options_as_string() ); }; if results.next().is_some() { anyhow::bail!( - "`{version_string}` matched more than one of:\n{}", + "`{format_name}` matched more than one of:\n{}", Self::options_as_string() ); } @@ -77,8 +77,8 @@ impl ZetaVersion { } pub fn options_as_string() -> String { - ZetaVersion::iter() - .map(|version| format!("- {}\n", <&'static str>::from(version))) + ZetaFormat::iter() + .map(|format| format!("- {}\n", <&'static str>::from(format))) .collect::>() .concat() } @@ -137,27 +137,40 @@ pub struct RelatedExcerpt { pub text: Arc, } -pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String { - format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS) +pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String { + format_zeta_prompt_with_budget(input, format, MAX_PROMPT_TOKENS) +} + +/// Post-processes model output for the given zeta format by stripping format-specific suffixes. +pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str { + match format { + ZetaFormat::V0120GitMergeMarkers => output + .strip_suffix(v0120_git_merge_markers::END_MARKER) + .unwrap_or(output), + ZetaFormat::V0131GitMergeMarkersPrefix => output + .strip_suffix(v0131_git_merge_markers_prefix::END_MARKER) + .unwrap_or(output), + _ => output, + } } fn format_zeta_prompt_with_budget( input: &ZetaPromptInput, - version: ZetaVersion, + format: ZetaFormat, max_tokens: usize, ) -> String { let mut cursor_section = String::new(); - match version { - ZetaVersion::V0112MiddleAtEnd => { + match format { + ZetaFormat::V0112MiddleAtEnd => { v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input); } - ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => { + ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => { v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input) } - ZetaVersion::V0120GitMergeMarkers => { + ZetaFormat::V0120GitMergeMarkers => { v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input) } - ZetaVersion::V0131GitMergeMarkersPrefix => { + ZetaFormat::V0131GitMergeMarkersPrefix => { v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input) } } @@ -563,7 +576,7 @@ mod tests { } fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String { - format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens) + format_zeta_prompt_with_budget(input, ZetaFormat::V0114180EditableRegion, max_tokens) } #[test]