diff --git a/Cargo.lock b/Cargo.lock index 71ddd769e0d406735195903131523bf475fb0572..3ea3eb05237ff2fd40ad615a1d64e6bc407e5f57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5398,6 +5398,7 @@ dependencies = [ "smol", "sqlez", "sqlez_macros", + "strum 0.27.2", "telemetry_events", "tempfile", "terminal_view", diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index c0bd3bc3204f06eb85fc1e6db16095a3d0af3f44..9e7772ab7450cb47785d034b39d9c7c642b931c2 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -11,16 +11,14 @@ pub struct RawCompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, pub stop: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + pub environment: Option, } #[derive(Debug, Serialize, Deserialize)] pub struct PredictEditsV3Request { #[serde(flatten)] pub input: zeta_prompt::ZetaPromptInput, - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, - #[serde(default)] - pub prompt_version: zeta_prompt::ZetaVersion, #[serde(default)] pub trigger: PredictEditsRequestTrigger, } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 065681adc87c3bd0f7a9a7ac94b985164b95fe9f..6385c6a2b6972740b19fc4009b8b3c241f3cef4e 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -36,18 +36,18 @@ use semver::Version; use serde::de::DeserializeOwned; use settings::{EditPredictionProvider, Settings as _, update_settings_file}; use std::collections::{VecDeque, hash_map}; +use std::env; use text::Edit; use workspace::Workspace; -use zeta_prompt::ZetaPromptInput; -use zeta_prompt::ZetaVersion; +use zeta_prompt::{ZetaFormat, ZetaPromptInput}; +use std::mem; use std::ops::Range; use std::path::Path; use std::rc::Rc; use std::str::FromStr as _; -use std::sync::{Arc, LazyLock}; +use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -use std::{env, mem}; use thiserror::Error; use util::{RangeExt as _, ResultExt as _}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; @@ -105,9 +105,6 @@ const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1); const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); -static EDIT_PREDICTIONS_MODEL_ID: LazyLock> = - LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok()); - pub struct Zeta2FeatureFlag; impl FeatureFlag for Zeta2FeatureFlag { @@ -133,6 +130,15 @@ struct EditPredictionStoreGlobal(Entity); impl Global for EditPredictionStoreGlobal {} +/// Configuration for using the raw Zeta2 endpoint. +/// When set, the client uses the raw endpoint and constructs the prompt itself. +/// The version is also used as the Baseten environment name (lowercased). +#[derive(Clone)] +pub struct Zeta2RawConfig { + pub model_id: Option, + pub format: ZetaFormat, +} + pub struct EditPredictionStore { client: Arc, user_store: Entity, @@ -141,6 +147,7 @@ pub struct EditPredictionStore { projects: HashMap, update_required: bool, edit_prediction_model: EditPredictionModel, + zeta2_raw_config: Option, pub sweep_ai: SweepAi, pub mercury: Mercury, pub ollama: Ollama, @@ -148,16 +155,13 @@ pub struct EditPredictionStore { reject_predictions_tx: mpsc::UnboundedSender, shown_predictions: VecDeque, rated_predictions: HashSet, - custom_predict_edits_url: Option>, } #[derive(Copy, Clone, Default, PartialEq, Eq)] pub enum EditPredictionModel { #[default] Zeta1, - Zeta2 { - version: ZetaVersion, - }, + Zeta2, Sweep, Mercury, Ollama, @@ -631,9 +635,8 @@ impl EditPredictionStore { }, ), update_required: false, - edit_prediction_model: EditPredictionModel::Zeta2 { - version: Default::default(), - }, + edit_prediction_model: EditPredictionModel::Zeta2, + zeta2_raw_config: Self::zeta2_raw_config_from_env(), sweep_ai: SweepAi::new(cx), mercury: Mercury::new(cx), ollama: Ollama::new(), @@ -642,24 +645,30 @@ impl EditPredictionStore { reject_predictions_tx: reject_tx, rated_predictions: Default::default(), shown_predictions: Default::default(), - custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") { - Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into), - Err(_) => None, - }, }; this } - #[cfg(test)] - pub fn set_custom_predict_edits_url(&mut self, url: Url) { - self.custom_predict_edits_url = Some(url.into()); + fn zeta2_raw_config_from_env() -> Option { + let version_str = env::var("ZED_ZETA_FORMAT").ok()?; + let format = ZetaFormat::parse(&version_str).ok()?; + let model_id = env::var("ZED_ZETA_MODEL").ok(); + Some(Zeta2RawConfig { model_id, format }) } pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) { self.edit_prediction_model = model; } + pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) { + self.zeta2_raw_config = Some(config); + } + + pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> { + self.zeta2_raw_config.as_ref() + } + pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet { use ui::IconName; match self.edit_prediction_model { @@ -673,7 +682,7 @@ impl EditPredictionStore { EditPredictionModel::Mercury => { edit_prediction_types::EditPredictionIconSet::new(IconName::Inception) } - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict) .with_disabled(IconName::ZedPredictDisabled) .with_up(IconName::ZedPredictUp) @@ -796,10 +805,7 @@ impl EditPredictionStore { } pub fn usage(&self, cx: &App) -> Option { - if matches!( - self.edit_prediction_model, - EditPredictionModel::Zeta2 { .. } - ) { + if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) { self.user_store.read(cx).edit_prediction_usage() } else { None @@ -1223,7 +1229,7 @@ impl EditPredictionStore { ); } EditPredictionModel::Ollama => {} - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { zeta2::edit_prediction_accepted(self, current_prediction, cx) } } @@ -1359,16 +1365,14 @@ impl EditPredictionStore { cx: &App, ) { match self.edit_prediction_model { - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { - if self.custom_predict_edits_url.is_none() { - self.reject_predictions_tx - .unbounded_send(EditPredictionRejection { - request_id: prediction_id.to_string(), - reason, - was_shown, - }) - .log_err(); - } + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { + self.reject_predictions_tx + .unbounded_send(EditPredictionRejection { + request_id: prediction_id.to_string(), + reason, + was_shown, + }) + .log_err(); } EditPredictionModel::Sweep | EditPredictionModel::Ollama => {} EditPredictionModel::Mercury => { @@ -1805,24 +1809,16 @@ impl EditPredictionStore { .detach_and_log_err(cx); } } - let task = match self.edit_prediction_model { + let task = match &self.edit_prediction_model { EditPredictionModel::Zeta1 => { if should_send_testing_zeta2_request() { let mut zeta2_inputs = inputs.clone(); zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing; - zeta2::request_prediction_with_zeta2( - self, - zeta2_inputs, - Default::default(), - cx, - ) - .detach(); + zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach(); } zeta1::request_prediction_with_zeta1(self, inputs, cx) } - EditPredictionModel::Zeta2 { version } => { - zeta2::request_prediction_with_zeta2(self, inputs, version, cx) - } + EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx), EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx), EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx), @@ -1976,7 +1972,6 @@ impl EditPredictionStore { pub(crate) async fn send_v3_request( input: ZetaPromptInput, - prompt_version: ZetaVersion, client: Arc, llm_token: LlmApiToken, app_version: Version, @@ -1986,12 +1981,7 @@ impl EditPredictionStore { .http_client() .build_zed_llm_url("/predict_edits/v3", &[])?; - let request = PredictEditsV3Request { - input, - model: EDIT_PREDICTIONS_MODEL_ID.clone(), - prompt_version, - trigger, - }; + let request = PredictEditsV3Request { input, trigger }; Self::send_api_request( |builder| { diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 906f5a76d6232cbf00a01801f598e49b804305b3..dad91c7fed96eb8f8abcdcf2ced33029e92f1861 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1343,7 +1343,7 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi } fn prompt_from_request(request: &PredictEditsV3Request) -> String { - zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version) + zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default()) } struct RequestChannels { @@ -2073,6 +2073,20 @@ async fn make_test_ep_store( ) .unwrap()) } + (&Method::POST, "/predict_edits/v3") => { + next_request_id += 1; + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsV3Response { + request_id: format!("request-{next_request_id}"), + output: "hello world".to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()) + } _ => Ok(http_client::Response::builder() .status(404) .body("Not Found".into()) @@ -2200,116 +2214,6 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut ); } -#[gpui::test] -async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/project", - serde_json::json!({ - "main.rs": "fn main() {\n \n}\n" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - - let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let predict_called_clone = predict_called.clone(); - - let http_client = FakeHttpClient::create({ - move |req| { - let uri = req.uri().path().to_string(); - let predict_called = predict_called_clone.clone(); - async move { - if uri.contains("predict") { - predict_called.store(true, std::sync::atomic::Ordering::SeqCst); - Ok(gpui::http_client::Response::builder() - .body( - serde_json::to_string(&open_ai::Response { - id: "test-123".to_string(), - object: "chat.completion".to_string(), - created: 0, - model: "test".to_string(), - usage: open_ai::Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - choices: vec![open_ai::Choice { - index: 0, - message: open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain( - indoc! {" - ```main.rs - <|start_of_file|> - <|editable_region_start|> - fn main() { - println!(\"Hello, world!\"); - } - <|editable_region_end|> - ``` - "} - .to_string(), - )), - tool_calls: vec![], - }, - finish_reason: Some("stop".to_string()), - }], - }) - .unwrap() - .into(), - ) - .unwrap()) - } else { - Ok(gpui::http_client::Response::builder() - .status(401) - .body("Unauthorized".into()) - .unwrap()) - } - } - } - }); - - let client = - cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); - cx.update(|cx| { - language_model::RefreshLlmTokenListener::register(client.clone(), cx); - }); - - let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx)); - - let buffer = project - .update(cx, |project, cx| { - let path = project - .find_project_path(path!("/project/main.rs"), cx) - .unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - - let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4))); - ep_store.update(cx, |ep_store, cx| { - ep_store.register_buffer(&buffer, &project, cx) - }); - cx.background_executor.run_until_parked(); - - let completion_task = ep_store.update(cx, |ep_store, cx| { - ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap()); - ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1); - ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx) - }); - - let _ = completion_task.await; - - assert!( - predict_called.load(std::sync::atomic::Ordering::SeqCst), - "With custom URL, predict endpoint should be called even without authentication" - ); -} - #[gpui::test] fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) { let buffer = cx.new(|cx| { diff --git a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs index a92c1450ce7105fdea51d8c610254b9570a6e0d7..3f517a5caaa5bb5964403bd22fa89182a03e4363 100644 --- a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -70,7 +70,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { .with_down(IconName::SweepAiDown) .with_error(IconName::SweepAiError), EditPredictionModel::Mercury => EditPredictionIconSet::new(IconName::Inception), - EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => { EditPredictionIconSet::new(IconName::ZedPredict) .with_disabled(IconName::ZedPredictDisabled) .with_up(IconName::ZedPredictUp) diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index c7b093edec197f6e161e755f8fd3c429528badfc..74c6cb568b26f19de21447d520d9569afa58b432 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -81,17 +81,12 @@ pub(crate) fn request_prediction_with_zeta1( cx, ); - let (uri, require_auth) = match &store.custom_predict_edits_url { - Some(custom_url) => (custom_url.clone(), false), - None => { - match client - .http_client() - .build_zed_llm_url("/predict_edits/v2", &[]) - { - Ok(url) => (url.into(), true), - Err(err) => return Task::ready(Err(err)), - } - } + let uri = match client + .http_client() + .build_zed_llm_url("/predict_edits/v2", &[]) + { + Ok(url) => Arc::from(url), + Err(err) => return Task::ready(Err(err)), }; cx.spawn(async move |this, cx| { @@ -127,7 +122,7 @@ pub(crate) fn request_prediction_with_zeta1( client, llm_token, app_version, - require_auth, + true, ) .await; diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 3d2ee06bc55d9c5b01268cc33b1c7404009f910d..2a3efa5c803aee1ed53572c506d238317fc9842a 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,9 +1,8 @@ use crate::prediction::EditPredictionResult; use crate::zeta1::compute_edits_and_cursor_position; use crate::{ - CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, - EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent, - EditPredictionStore, + CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, + EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, }; use anyhow::{Result, anyhow}; use cloud_llm_client::predict_edits_v3::RawCompletionRequest; @@ -14,17 +13,16 @@ use release_channel::AppVersion; use std::env; use std::{path::Path, sync::Arc, time::Instant}; -use zeta_prompt::format_zeta_prompt; -use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers}; +use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt}; pub const MAX_CONTEXT_TOKENS: usize = 350; -pub fn max_editable_tokens(version: ZetaVersion) -> usize { - match version { - ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150, - ZetaVersion::V0114180EditableRegion => 180, - ZetaVersion::V0120GitMergeMarkers => 180, - ZetaVersion::V0131GitMergeMarkersPrefix => 180, +pub fn max_editable_tokens(format: ZetaFormat) -> usize { + match format { + ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150, + ZetaFormat::V0114180EditableRegion => 180, + ZetaFormat::V0120GitMergeMarkers => 180, + ZetaFormat::V0131GitMergeMarkersPrefix => 180, } } @@ -40,11 +38,10 @@ pub fn request_prediction_with_zeta2( trigger, .. }: EditPredictionModelInput, - zeta_version: ZetaVersion, cx: &mut Context, ) -> Task>> { let buffer_snapshotted_at = Instant::now(); - let custom_url = store.custom_predict_edits_url.clone(); + let raw_config = store.zeta2_raw_config().cloned(); let Some(excerpt_path) = snapshot .file() @@ -59,6 +56,11 @@ pub fn request_prediction_with_zeta2( let request_task = cx.background_spawn({ async move { + let zeta_version = raw_config + .as_ref() + .map(|config| config.format) + .unwrap_or(ZetaFormat::default()); + let cursor_offset = position.to_offset(&snapshot); let (editable_offset_range, prompt_input) = zeta2_prompt_input( &snapshot, @@ -84,33 +86,36 @@ pub fn request_prediction_with_zeta2( log::trace!("Sending edit prediction request"); - let (request_id, output_text, usage) = if let Some(custom_url) = custom_url { - // Use raw endpoint with custom URL - let prompt = format_zeta_prompt(&prompt_input, zeta_version); + let (request_id, output_text, usage) = if let Some(config) = &raw_config { + let prompt = format_zeta_prompt(&prompt_input, config.format); let request = RawCompletionRequest { - model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(), + model: config.model_id.clone().unwrap_or_default(), prompt, temperature: None, stop: vec![], max_tokens: Some(2048), + environment: Some(config.format.to_string().to_lowercase()), }; let (mut response, usage) = EditPredictionStore::send_raw_llm_request( request, client, - Some(custom_url), + None, llm_token, app_version, ) .await?; let request_id = EditPredictionId(response.id.clone().into()); - let output_text = response.choices.pop().map(|choice| choice.text); + let output_text = response.choices.pop().map(|choice| { + clean_zeta2_model_output(&choice.text, config.format).to_string() + }); + (request_id, output_text, usage) } else { + // Use V3 endpoint - server handles model/version selection and suffix stripping let (response, usage) = EditPredictionStore::send_v3_request( prompt_input.clone(), - zeta_version, client, llm_token, app_version, @@ -135,6 +140,13 @@ pub fn request_prediction_with_zeta2( return Ok((Some((request_id, None)), usage)); }; + // Client-side cursor marker processing (applies to both raw and v3 responses) + let cursor_offset_in_output = output_text.find(CURSOR_MARKER); + if let Some(offset) = cursor_offset_in_output { + log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}"); + output_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); + } + if let Some(debug_tx) = &debug_tx { debug_tx .unbounded_send(DebugEvent::EditPredictionFinished( @@ -147,20 +159,6 @@ pub fn request_prediction_with_zeta2( .ok(); } - let cursor_offset_in_output = output_text.find(CURSOR_MARKER); - if let Some(offset) = cursor_offset_in_output { - log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}"); - output_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); - } - - if zeta_version == ZetaVersion::V0120GitMergeMarkers { - if let Some(stripped) = - output_text.strip_suffix(v0120_git_merge_markers::END_MARKER) - { - output_text = stripped.to_string(); - } - } - let mut old_text = snapshot .text_for_range(editable_offset_range.clone()) .collect::(); @@ -242,7 +240,7 @@ pub fn zeta2_prompt_input( events: Vec>, excerpt_path: Arc, cursor_offset: usize, - zeta_version: ZetaVersion, + zeta_format: ZetaFormat, ) -> (std::ops::Range, zeta_prompt::ZetaPromptInput) { let cursor_point = cursor_offset.to_point(snapshot); @@ -250,7 +248,7 @@ pub fn zeta2_prompt_input( crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( cursor_point, snapshot, - max_editable_tokens(zeta_version), + max_editable_tokens(zeta_format), MAX_CONTEXT_TOKENS, ); @@ -288,7 +286,7 @@ pub(crate) fn edit_prediction_accepted( cx: &App, ) { let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok(); - if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() { + if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() { return; } diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 20451bdc7a7e2e96a9a1b48ed32180250f64b6b6..df5d742a3bb90d7e20a9cf6b4bd84314b14e206e 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -50,6 +50,7 @@ settings.workspace = true shellexpand.workspace = true smol.workspace = true sqlez.workspace = true +strum.workspace = true sqlez_macros.workspace = true terminal_view.workspace = true util.workspace = true diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 44df820d7611f6fa62aec66259bb203e09de428a..c0f078ed9af489c358695db80136dec854b0f532 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -12,7 +12,7 @@ use language::{Buffer, OffsetRangeExt, Point}; use similar::DiffableStr; use std::sync::Arc; use std::{fmt::Write as _, ops::Range}; -use zeta_prompt::ZetaVersion; +use zeta_prompt::ZetaFormat; use zeta_prompt::format_zeta_prompt; pub async fn run_format_prompt( @@ -54,7 +54,7 @@ pub async fn run_format_prompt( let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( cursor_point, &snapshot, - edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()), + edit_prediction::zeta2::max_editable_tokens(ZetaFormat::default()), edit_prediction::zeta2::MAX_CONTEXT_TOKENS, ); let editable_range = editable_range.to_offset(&snapshot); @@ -126,7 +126,7 @@ pub fn zeta2_output_for_patch( input: &zeta_prompt::ZetaPromptInput, patch: &str, cursor_offset: Option, - version: ZetaVersion, + version: ZetaFormat, ) -> Result { let mut old_editable_region = input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string(); @@ -155,7 +155,7 @@ pub fn zeta2_output_for_patch( } match version { - ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => { + ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => { if !result.ends_with('\n') { result.push('\n'); } diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 95e8332b44741ca7bdfb173282508f960d8d0303..0ade5e3f3fd30ef8139bc90e1c96e7b325395d5c 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -31,7 +31,7 @@ use edit_prediction::EditPredictionStore; use futures::channel::mpsc; use futures::{SinkExt as _, StreamExt as _}; use gpui::{AppContext as _, Application, BackgroundExecutor, Task}; -use zeta_prompt::ZetaVersion; +use zeta_prompt::ZetaFormat; use reqwest_client::ReqwestClient; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -207,6 +207,8 @@ enum Command { Qa(qa::QaArgs), /// Repair predictions that received poor QA scores by generating improved predictions Repair(repair::RepairArgs), + /// Print all valid zeta formats (lowercase, one per line) + PrintZetaFormats, } impl Display for Command { @@ -249,6 +251,9 @@ impl Display for Command { Command::Repair(_) => { write!(f, "repair") } + Command::PrintZetaFormats => { + write!(f, "print-zeta-formats") + } } } } @@ -321,7 +326,7 @@ enum PredictionProvider { Sweep, Mercury, Zeta1, - Zeta2(ZetaVersion), + Zeta2(ZetaFormat), Teacher(TeacherBackend), TeacherNonBatching(TeacherBackend), Repair, @@ -329,7 +334,7 @@ enum PredictionProvider { impl Default for PredictionProvider { fn default() -> Self { - PredictionProvider::Zeta2(ZetaVersion::default()) + PredictionProvider::Zeta2(ZetaFormat::default()) } } @@ -339,7 +344,7 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::Sweep => write!(f, "sweep"), PredictionProvider::Mercury => write!(f, "mercury"), PredictionProvider::Zeta1 => write!(f, "zeta1"), - PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"), + PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"), PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"), PredictionProvider::TeacherNonBatching(backend) => { write!(f, "teacher-non-batching:{backend}") @@ -361,8 +366,8 @@ impl std::str::FromStr for PredictionProvider { "mercury" => Ok(PredictionProvider::Mercury), "zeta1" => Ok(PredictionProvider::Zeta1), "zeta2" => { - let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default(); - Ok(PredictionProvider::Zeta2(version)) + let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default(); + Ok(PredictionProvider::Zeta2(format)) } "teacher" => { let backend = arg @@ -385,7 +390,7 @@ impl std::str::FromStr for PredictionProvider { For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\ For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\ Available zeta versions:\n{}", - ZetaVersion::options_as_string() + ZetaFormat::options_as_string() ) } } @@ -719,6 +724,13 @@ fn main() { std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap(); return; } + Command::PrintZetaFormats => { + use strum::IntoEnumIterator as _; + for format in ZetaFormat::iter() { + println!("{}", format.to_string().to_lowercase()); + } + return; + } Command::Synthesize(synth_args) => { let Some(output_dir) = args.output else { panic!("output dir is required"); @@ -953,7 +965,8 @@ fn main() { | Command::Split(_) | Command::TruncatePatch(_) | Command::FilterLanguages(_) - | Command::ImportBatch(_) => { + | Command::ImportBatch(_) + | Command::PrintZetaFormats => { unreachable!() } } diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index a6e795c27d8b1d9352cb4a7c2accdc995aa429df..e45060924d07a992ec2e563e5b16c3f85938ee2d 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -5,7 +5,7 @@ use crate::{ repair, }; use anyhow::{Context as _, Result}; -use zeta_prompt::{CURSOR_MARKER, ZetaVersion}; +use zeta_prompt::{CURSOR_MARKER, ZetaFormat}; pub fn run_parse_output(example: &mut Example) -> Result<()> { example @@ -49,13 +49,13 @@ pub fn parse_prediction_output( } } -fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result { - let (current_marker, end_marker) = match version { - ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"), - ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => { +fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result { + let (current_marker, end_marker) = match format { + ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"), + ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => { ("<|fim_middle|>current\n", "<|fim_suffix|>") } - ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => ( + ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => ( zeta_prompt::v0120_git_merge_markers::START_MARKER, zeta_prompt::v0120_git_merge_markers::SEPARATOR, ), @@ -82,7 +82,7 @@ fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result Result<(String, Option)> { let prompt = &example.prompt.as_ref().context("prompt required")?.input; let prompt_inputs = example @@ -90,7 +90,7 @@ fn parse_zeta2_output( .as_ref() .context("prompt_inputs required")?; - let old_text = extract_zeta2_current_region(prompt, version)?; + let old_text = extract_zeta2_current_region(prompt, format)?; let mut new_text = actual_output.to_string(); let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) { @@ -100,11 +100,11 @@ fn parse_zeta2_output( None }; - let suffix = match version { - ZetaVersion::V0131GitMergeMarkersPrefix => { + let suffix = match format { + ZetaFormat::V0131GitMergeMarkersPrefix => { zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER } - ZetaVersion::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER, + ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER, _ => "", }; if !suffix.is_empty() { @@ -184,7 +184,7 @@ mod tests { <|fim_middle|>updated "}; - let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap(); + let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } @@ -201,7 +201,7 @@ mod tests { <|fim_middle|>updated "}; - let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap(); + let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } @@ -218,7 +218,7 @@ mod tests { <|fim_middle|>updated "}; - let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap(); + let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } @@ -236,7 +236,7 @@ mod tests { "}; let region = - extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap(); + extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } @@ -254,7 +254,7 @@ mod tests { "}; let region = - extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap(); + extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap(); assert_eq!(region, "println!(\"hello\");\n"); } } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 63be8e8b70dfbc5204ab530d839df4e9cdc34e41..5979439a2a7f3a66bfe94881bd04b9d948fe3c7e 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -11,7 +11,7 @@ use crate::{ retrieve_context::run_context_retrieval, }; use anyhow::Context as _; -use edit_prediction::{DebugEvent, EditPredictionStore}; +use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig}; use futures::{FutureExt as _, StreamExt as _, future::Shared}; use gpui::{AppContext as _, AsyncApp, Task}; use std::{ @@ -21,6 +21,7 @@ use std::{ atomic::{AtomicUsize, Ordering::SeqCst}, }, }; +use zeta_prompt::ZetaFormat; static ANTHROPIC_CLIENT: OnceLock = OnceLock::new(); static OPENAI_CLIENT: OnceLock = OnceLock::new(); @@ -103,9 +104,7 @@ pub async fn run_prediction( ep_store.update(&mut cx, |store, _cx| { let model = match provider { PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, - PredictionProvider::Zeta2(version) => { - edit_prediction::EditPredictionModel::Zeta2 { version } - } + PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2, PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher(..) @@ -115,6 +114,15 @@ pub async fn run_prediction( } }; store.set_edit_prediction_model(model); + + // If user specified a non-default Zeta2 version, configure raw endpoint. + // ZED_ZETA_MODEL env var is optional. + if let PredictionProvider::Zeta2(format) = provider { + if format != ZetaFormat::default() { + let model_id = std::env::var("ZED_ZETA_MODEL").ok(); + store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format }); + } + } }); step_progress.set_substatus("configuring model"); let state = example.state.as_ref().context("state must be set")?; diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 2347e27ccaa9ee5a94a9db4d262607ce126c3e57..3e3ed33fd6de460eca0a6e16e58751ff03118166 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -217,9 +217,7 @@ fn assign_edit_prediction_provider( if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME && cx.has_flag::() => { - edit_prediction::EditPredictionModel::Zeta2 { - version: Default::default(), - } + edit_prediction::EditPredictionModel::Zeta2 } EditPredictionProvider::Zed if user_store.read(cx).current_user().is_some() => diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 8799d680287aefbcd8a4740eb3b558f9cd62ccb8..73fdecbd134f2346f22304ae84c76ab53c1636c4 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -39,7 +39,7 @@ pub struct ZetaPromptInput { Deserialize, )] #[allow(non_camel_case_types)] -pub enum ZetaVersion { +pub enum ZetaFormat { V0112MiddleAtEnd, V0113Ordered, #[default] @@ -48,28 +48,28 @@ pub enum ZetaVersion { V0131GitMergeMarkersPrefix, } -impl std::fmt::Display for ZetaVersion { +impl std::fmt::Display for ZetaFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", <&'static str>::from(self)) } } -impl ZetaVersion { - pub fn parse(version_string: &str) -> Result { - let mut results = ZetaVersion::iter().filter(|version| { +impl ZetaFormat { + pub fn parse(format_name: &str) -> Result { + let mut results = ZetaFormat::iter().filter(|version| { <&'static str>::from(version) .to_lowercase() - .contains(&version_string.to_lowercase()) + .contains(&format_name.to_lowercase()) }); let Some(result) = results.next() else { anyhow::bail!( - "`{version_string}` did not match any of:\n{}", + "`{format_name}` did not match any of:\n{}", Self::options_as_string() ); }; if results.next().is_some() { anyhow::bail!( - "`{version_string}` matched more than one of:\n{}", + "`{format_name}` matched more than one of:\n{}", Self::options_as_string() ); } @@ -77,8 +77,8 @@ impl ZetaVersion { } pub fn options_as_string() -> String { - ZetaVersion::iter() - .map(|version| format!("- {}\n", <&'static str>::from(version))) + ZetaFormat::iter() + .map(|format| format!("- {}\n", <&'static str>::from(format))) .collect::>() .concat() } @@ -137,27 +137,40 @@ pub struct RelatedExcerpt { pub text: Arc, } -pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String { - format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS) +pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String { + format_zeta_prompt_with_budget(input, format, MAX_PROMPT_TOKENS) +} + +/// Post-processes model output for the given zeta format by stripping format-specific suffixes. +pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str { + match format { + ZetaFormat::V0120GitMergeMarkers => output + .strip_suffix(v0120_git_merge_markers::END_MARKER) + .unwrap_or(output), + ZetaFormat::V0131GitMergeMarkersPrefix => output + .strip_suffix(v0131_git_merge_markers_prefix::END_MARKER) + .unwrap_or(output), + _ => output, + } } fn format_zeta_prompt_with_budget( input: &ZetaPromptInput, - version: ZetaVersion, + format: ZetaFormat, max_tokens: usize, ) -> String { let mut cursor_section = String::new(); - match version { - ZetaVersion::V0112MiddleAtEnd => { + match format { + ZetaFormat::V0112MiddleAtEnd => { v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input); } - ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => { + ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => { v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input) } - ZetaVersion::V0120GitMergeMarkers => { + ZetaFormat::V0120GitMergeMarkers => { v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input) } - ZetaVersion::V0131GitMergeMarkersPrefix => { + ZetaFormat::V0131GitMergeMarkersPrefix => { v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input) } } @@ -563,7 +576,7 @@ mod tests { } fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String { - format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens) + format_zeta_prompt_with_budget(input, ZetaFormat::V0114180EditableRegion, max_tokens) } #[test]