From 25b1377d1dfb16d951ef57e3ab16b8be029a1230 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Thu, 12 Feb 2026 14:24:24 -0600 Subject: [PATCH] Unify zeta endpoints (#48900) - [ ] Tests or screenshots needed? - [ ] Code Reviewed - [ ] Manual QA Release Notes: - N/A --------- Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com> Co-authored-by: Max Co-authored-by: Max Brunsfeld --- crates/edit_prediction/src/capture_example.rs | 7 +- crates/edit_prediction/src/cursor_excerpt.rs | 78 +++ crates/edit_prediction/src/edit_prediction.rs | 64 +- .../src/edit_prediction_tests.rs | 368 +--------- crates/edit_prediction/src/example_spec.rs | 2 + crates/edit_prediction/src/mercury.rs | 3 + crates/edit_prediction/src/ollama.rs | 6 + crates/edit_prediction/src/prediction.rs | 3 + crates/edit_prediction/src/sweep_ai.rs | 3 + crates/edit_prediction/src/zeta1.rs | 347 +--------- crates/edit_prediction/src/zeta2.rs | 66 +- .../edit_prediction_cli/src/format_prompt.rs | 7 + .../edit_prediction_cli/src/pull_examples.rs | 1 + .../src/edit_prediction_context.rs | 7 +- .../src/edit_prediction_context_tests.rs | 1 - crates/zeta_prompt/src/zeta_prompt.rs | 638 ++++++++++++++++-- 16 files changed, 767 insertions(+), 834 deletions(-) diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 33d7d12f1e0eb07ae2e9f13efd7447997c46463a..aa4ffd21f63695d679d7da35bb2f75012854fa85 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -15,8 +15,6 @@ use project::{Project, WorktreeId}; use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc}; use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _}; -pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500; - pub fn capture_example( project: Entity, buffer: Entity, @@ -156,6 +154,7 @@ pub fn capture_example( excerpt_start_row: Some(0), events: captured_events, related_files: captured_related_files, + in_open_source_repo: false, } }); @@ -304,10 +303,6 @@ fn generate_timestamp_name() -> String { } } -pub(crate) fn should_send_testing_zeta2_request() -> bool { - rand::random::() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/edit_prediction/src/cursor_excerpt.rs b/crates/edit_prediction/src/cursor_excerpt.rs index 682b937c3d6094334edf7842abe8e6f80f9c3fa2..900d78945ca6ab4fab9c9c60bf13009368c7c77b 100644 --- a/crates/edit_prediction/src/cursor_excerpt.rs +++ b/crates/edit_prediction/src/cursor_excerpt.rs @@ -1,5 +1,81 @@ use language::{BufferSnapshot, Point}; use std::ops::Range; +use zeta_prompt::ExcerptRanges; + +/// Pre-computed Point ranges for all editable/context budget combinations. +pub struct ExcerptRangePoints { + pub editable_150: Range, + pub editable_180: Range, + pub editable_350: Range, + pub editable_150_context_350: Range, + pub editable_180_context_350: Range, + pub editable_350_context_150: Range, +} + +/// Computes all range variants for a cursor position: editable ranges at 150, 180, and 350 +/// token budgets, plus their corresponding context expansions. Returns the full excerpt range +/// (union of all context ranges) and the individual sub-ranges as Points. +pub fn compute_excerpt_ranges( + position: Point, + snapshot: &BufferSnapshot, +) -> (Range, ExcerptRangePoints) { + let editable_150 = compute_editable_range(snapshot, position, 150); + let editable_180 = compute_editable_range(snapshot, position, 180); + let editable_350 = compute_editable_range(snapshot, position, 350); + + let editable_150_context_350 = + expand_context_syntactically_then_linewise(snapshot, editable_150.clone(), 350); + let editable_180_context_350 = + expand_context_syntactically_then_linewise(snapshot, editable_180.clone(), 350); + let editable_350_context_150 = + expand_context_syntactically_then_linewise(snapshot, editable_350.clone(), 150); + + let full_start_row = editable_150_context_350 + .start + .row + .min(editable_180_context_350.start.row) + .min(editable_350_context_150.start.row); + let full_end_row = editable_150_context_350 + .end + .row + .max(editable_180_context_350.end.row) + .max(editable_350_context_150.end.row); + + let full_context = + Point::new(full_start_row, 0)..Point::new(full_end_row, snapshot.line_len(full_end_row)); + + let ranges = ExcerptRangePoints { + editable_150, + editable_180, + editable_350, + editable_150_context_350, + editable_180_context_350, + editable_350_context_150, + }; + + (full_context, ranges) +} + +/// Converts `ExcerptRangePoints` to byte-offset `ExcerptRanges` relative to `excerpt_start`. +pub fn excerpt_ranges_to_byte_offsets( + ranges: &ExcerptRangePoints, + excerpt_start: usize, + snapshot: &BufferSnapshot, +) -> ExcerptRanges { + let to_offset = |range: &Range| -> Range { + let start = range.start.to_offset(snapshot); + let end = range.end.to_offset(snapshot); + (start - excerpt_start)..(end - excerpt_start) + }; + ExcerptRanges { + editable_150: to_offset(&ranges.editable_150), + editable_180: to_offset(&ranges.editable_180), + editable_350: to_offset(&ranges.editable_350), + editable_150_context_350: to_offset(&ranges.editable_150_context_350), + editable_180_context_350: to_offset(&ranges.editable_180_context_350), + editable_350_context_150: to_offset(&ranges.editable_350_context_150), + } +} pub fn editable_and_context_ranges_for_cursor_position( position: Point, @@ -312,6 +388,8 @@ fn expand_context_syntactically_then_linewise( start..end } +use language::ToOffset as _; + #[cfg(test)] mod tests { use super::*; diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 1ec3c7ac44fc8f592fa094f668b3bfd84245eb5a..13f7b46cb301ed95668bf021f36050f2e5da408e 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -72,7 +72,6 @@ pub mod zeta2; #[cfg(test)] mod edit_prediction_tests; -use crate::capture_example::should_send_testing_zeta2_request; use crate::license_detection::LicenseDetectionWatcher; use crate::mercury::Mercury; use crate::ollama::Ollama; @@ -734,10 +733,19 @@ impl EditPredictionStore { ) -> Vec { self.projects .get(&project.entity_id()) - .map(|project| { - project - .context - .update(cx, |context, cx| context.related_files(cx)) + .map(|project_state| { + project_state.context.update(cx, |context, cx| { + context + .related_files_with_buffers(cx) + .map(|(mut related_file, buffer)| { + related_file.in_open_source_repo = buffer + .read(cx) + .file() + .map_or(false, |file| self.is_file_open_source(&project, file, cx)); + related_file + }) + .collect() + }) }) .unwrap_or_default() } @@ -785,9 +793,9 @@ impl EditPredictionStore { self.projects .get(&project.entity_id()) .map(|project| { - project - .context - .update(cx, |context, cx| context.related_files_with_buffers(cx)) + project.context.update(cx, |context, cx| { + context.related_files_with_buffers(cx).collect() + }) }) .unwrap_or_default() } @@ -1771,15 +1779,18 @@ impl EditPredictionStore { }; let task = match &self.edit_prediction_model { - EditPredictionModel::Zeta1 => { - if should_send_testing_zeta2_request() { - let mut zeta2_inputs = inputs.clone(); - zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing; - zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach(); - } - zeta1::request_prediction_with_zeta1(self, inputs, cx) - } - EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx), + EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2( + self, + inputs, + Some(zeta_prompt::EditPredictionModelKind::Zeta1), + cx, + ), + EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2( + self, + inputs, + Some(zeta_prompt::EditPredictionModelKind::Zeta2), + cx, + ), EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx), EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx), @@ -2136,25 +2147,6 @@ impl EditPredictionStore { .is_some_and(|watcher| watcher.is_project_open_source()) } - fn can_collect_file(&self, project: &Entity, file: &Arc, cx: &App) -> bool { - self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx) - } - - fn can_collect_events(&self, events: &[Arc], cx: &App) -> bool { - if !self.data_collection_choice.is_enabled(cx) { - return false; - } - events.iter().all(|event| { - matches!( - event.as_ref(), - zeta_prompt::Event::BufferChange { - in_open_source_repo: true, - .. - } - ) - }) - } - fn load_data_collection_choice() -> DataCollectionChoice { let choice = KEY_VALUE_STORE .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 19d2532de094b849952ca16c100cf2c8b4a598dc..978ece1a75a18770798246d5ac38a8109ce05cc1 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1,11 +1,10 @@ use super::*; -use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS}; +use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string}; use client::{UserStore, test::FakeServer}; -use clock::{FakeSystemClock, ReplicaId}; +use clock::FakeSystemClock; use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use cloud_llm_client::{ - EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse, - RejectEditPredictionsBody, + EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody, predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response}, }; use futures::{ @@ -26,7 +25,7 @@ use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; use std::{path::Path, sync::Arc, time::Duration}; -use util::{path, rel_path::rel_path}; +use util::path; use uuid::Uuid; use zeta_prompt::ZetaPromptInput; @@ -1424,8 +1423,6 @@ fn init_test_with_fake_client( }) } -const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); - #[gpui::test] async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); @@ -1452,6 +1449,9 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { editable_range_in_excerpt: 0..0, cursor_offset_in_excerpt: 0, excerpt_start_row: None, + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), @@ -1555,13 +1555,10 @@ async fn test_clean_up_diff(cx: &mut TestAppContext) { } "}, indoc! {" - <|editable_region_start|> fn main() { let word_1 = \"lorem\"; let range = word_1.len()..word_1.len(); } - - <|editable_region_end|> "}, cx, ) @@ -1582,12 +1579,9 @@ async fn test_clean_up_diff(cx: &mut TestAppContext) { } "}, indoc! {" - <|editable_region_start|> fn main() { let story = \"the quick brown fox jumps over the lazy dog\"; } - - <|editable_region_end|> "}, cx, ) @@ -1605,18 +1599,11 @@ async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { init_test(cx); let buffer_content = "lorem\n"; - let completion_response = indoc! {" - ```animals.js - <|start_of_file|> - <|editable_region_start|> - lorem - ipsum - <|editable_region_end|> - ```"}; + let completion_response = "lorem\nipsum\n"; assert_eq!( apply_edit_prediction(buffer_content, completion_response, cx).await, - "lorem\nipsum" + "lorem\nipsum\n" ); } @@ -1685,298 +1672,6 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte }); } -#[gpui::test] -async fn test_can_collect_data(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/project/src/main.rs"), cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Disabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [], cx).await; - - let buffer = cx.new(|_cx| { - Buffer::remote( - language::BufferId::new(1).unwrap(), - ReplicaId::new(1), - language::Capability::ReadWrite, - "fn main() {\n println!(\"Hello\");\n}", - ) - }); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "LICENSE": BSD_0_TXT, - ".env": "SECRET_KEY=secret" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer("/project/.env", cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [], cx).await; - let buffer = cx.new(|cx| Buffer::local("", cx)); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer("/project/main.rs", cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/open_source_worktree"), - json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), - ) - .await; - fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) - .await; - - let project = Project::test( - fs.clone(), - [ - path!("/open_source_worktree").as_ref(), - path!("/closed_source_worktree").as_ref(), - ], - cx, - ) - .await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - let closed_source_file = project - .update(cx, |project, cx| { - let worktree2 = project - .worktree_for_root_name("closed_source_worktree", cx) - .unwrap(); - worktree2.update(cx, |worktree2, cx| { - worktree2.load_file(rel_path("main.rs"), cx) - }) - }) - .await - .unwrap() - .file; - - buffer.update(cx, |buffer, cx| { - buffer.file_updated(closed_source_file, cx); - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/worktree1"), - json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), - ) - .await; - fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) - .await; - - let project = Project::test( - fs.clone(), - [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], - cx, - ) - .await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/worktree1/main.rs"), cx) - }) - .await - .unwrap(); - let private_buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/worktree2/file.rs"), cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - // this has a side effect of registering the buffer to watch for edits - run_edit_prediction(&private_buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - - private_buffer.update(cx, |private_buffer, cx| { - private_buffer.edit([(0..0, "An edit for the history!")], None, cx); - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - - // make an edit that uses too many bytes, causing private_buffer edit to not be able to be - // included - buffer.update(cx, |buffer, cx| { - buffer.edit( - [( - 0..0, - " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS), - )], - None, - cx, - ); - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); -} - fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); @@ -1992,7 +1687,7 @@ async fn apply_edit_prediction( let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); - let (ep_store, _, response) = make_test_ep_store(&project, cx).await; + let (ep_store, response) = make_test_ep_store(&project, cx).await; *response.lock() = completion_response.to_string(); let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await; buffer.update(cx, |buffer, cx| { @@ -2021,28 +1716,13 @@ async fn run_edit_prediction( async fn make_test_ep_store( project: &Entity, cx: &mut TestAppContext, -) -> ( - Entity, - Arc>>, - Arc>, -) { - let default_response = indoc! {" - ```main.rs - <|start_of_file|> - <|editable_region_start|> - hello world - <|editable_region_end|> - ```" - }; - let captured_request: Arc>> = Arc::new(Mutex::new(None)); - let completion_response: Arc> = - Arc::new(Mutex::new(default_response.to_string())); +) -> (Entity, Arc>) { + let default_response = "hello world\n".to_string(); + let completion_response: Arc> = Arc::new(Mutex::new(default_response)); let http_client = FakeHttpClient::create({ - let captured_request = captured_request.clone(); let completion_response = completion_response.clone(); let mut next_request_id = 0; move |req| { - let captured_request = captured_request.clone(); let completion_response = completion_response.clone(); async move { match (req.method(), req.uri().path()) { @@ -2056,24 +1736,6 @@ async fn make_test_ep_store( .into(), ) .unwrap()), - (&Method::POST, "/predict_edits/v2") => { - let mut request_body = String::new(); - req.into_body().read_to_string(&mut request_body).await?; - *captured_request.lock() = - Some(serde_json::from_str(&request_body).unwrap()); - next_request_id += 1; - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: format!("request-{next_request_id}"), - output_excerpt: completion_response.lock().clone(), - }) - .unwrap() - .into(), - ) - .unwrap()) - } (&Method::POST, "/predict_edits/v3") => { next_request_id += 1; Ok(http_client::Response::builder() @@ -2081,7 +1743,7 @@ async fn make_test_ep_store( .body( serde_json::to_string(&PredictEditsV3Response { request_id: format!("request-{next_request_id}"), - output: "hello world".to_string(), + output: completion_response.lock().clone(), }) .unwrap() .into(), @@ -2120,7 +1782,7 @@ async fn make_test_ep_store( ep_store }); - (ep_store, captured_request, completion_response) + (ep_store, completion_response) } fn to_completion_edits( diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 5b9c98b83074cf5d4ead8af2bb974ff591c86e95..c6609e5f1f42f21eb165488f85575f2c50fcd1e0 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -66,6 +66,7 @@ pub struct CapturedPromptInput { pub excerpt_start_row: Option, pub events: Vec, pub related_files: Vec, + pub in_open_source_repo: bool, } #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] @@ -101,6 +102,7 @@ impl CapturedRelatedFile { zeta_prompt::RelatedFile { path: self.path.clone(), max_row: self.max_row, + in_open_source_repo: false, excerpts: self .excerpts .iter() diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index eba5f05f7b228c7468ecb8fbfde60feff568cebf..91c33f0fb663fa54cb94b302fb23f3db16378222 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -97,6 +97,9 @@ impl Mercury { - context_offset_range.start) ..(editable_offset_range.end - context_offset_range.start), excerpt_start_row: Some(context_start_row), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, }; let prompt = build_prompt(&inputs); diff --git a/crates/edit_prediction/src/ollama.rs b/crates/edit_prediction/src/ollama.rs index a79b61559cbcd7a74ae7619ee54b115eb576a637..c372c73a01990596db7a7d4551808788739fd9d8 100644 --- a/crates/edit_prediction/src/ollama.rs +++ b/crates/edit_prediction/src/ollama.rs @@ -169,6 +169,9 @@ impl Ollama { - context_offset_range.start) ..(editable_offset_range.end - context_offset_range.start), excerpt_start_row: Some(input_excerpt.context_range.start.row), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, }; (prompt, stop_tokens, Some(editable_offset_range), inputs) @@ -195,6 +198,9 @@ impl Ollama { .text_for_range(excerpt_range) .collect::() .into(), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, }; let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string(); diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 8d4a40d8b9ddf7a2ed8a68773da83a9498c4d516..3d87edb14ab775ef7ee8da2a8faa31efb79ec899 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -158,6 +158,9 @@ mod tests { cursor_excerpt: "".into(), editable_range_in_excerpt: 0..0, excerpt_start_row: None, + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index b42f54b7a89ea3f858501529d785c9013d490c99..eb8ee8fe68c9b4458663e196cfb45e1ffadaa0ce 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -219,6 +219,9 @@ impl SweepAi { editable_range_in_excerpt: 0..inputs.snapshot.len(), cursor_offset_in_excerpt: request_body.cursor_position, excerpt_start_row: Some(0), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, }; send_started_event( diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index 43d467950fd388fb5a771e8c101a005df57c6897..9baa9d8fef03e3f9c87b9a6f178e8acf3e222f8c 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -1,26 +1,13 @@ -use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; +use std::{fmt::Write, ops::Range, sync::Arc}; -use crate::{ - DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput, - EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError, - cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count}, - prediction::EditPredictionResult, -}; +use crate::cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count}; use anyhow::Result; -use cloud_llm_client::{ - PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse, -}; +use cloud_llm_client::PredictEditsBody; use edit_prediction_types::PredictedCursorPosition; -use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task}; -use language::{ - Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff, -}; -use project::{Project, ProjectPath}; -use release_channel::AppVersion; +use language::{Anchor, BufferSnapshot, Point, text_diff}; use text::Bias; -use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; use zeta_prompt::{ - Event, ZetaPromptInput, + Event, zeta1::{ CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER, @@ -28,260 +15,8 @@ use zeta_prompt::{ }; pub(crate) const MAX_CONTEXT_TOKENS: usize = 150; -pub(crate) const MAX_REWRITE_TOKENS: usize = 350; pub(crate) const MAX_EVENT_TOKENS: usize = 500; -pub(crate) fn request_prediction_with_zeta1( - store: &mut EditPredictionStore, - EditPredictionModelInput { - project, - buffer, - snapshot, - position, - events, - trigger, - debug_tx, - .. - }: EditPredictionModelInput, - cx: &mut Context, -) -> Task>> { - let buffer_snapshotted_at = Instant::now(); - let client = store.client.clone(); - let llm_token = store.llm_token.clone(); - let app_version = AppVersion::global(cx); - - let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { - let can_collect_file = store.can_collect_file(&project, file, cx); - let git_info = if can_collect_file { - git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx) - } else { - None - }; - (git_info, can_collect_file) - } else { - (None, false) - }; - - let full_path: Arc = snapshot - .file() - .map(|f| Arc::from(f.full_path(cx).as_path())) - .unwrap_or_else(|| Arc::from(Path::new("untitled"))); - let full_path_str = full_path.to_string_lossy().into_owned(); - let cursor_point = position.to_point(&snapshot); - let prompt_for_events = { - let events = events.clone(); - move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS) - }; - let gather_task = gather_context( - full_path_str, - &snapshot, - cursor_point, - prompt_for_events, - trigger, - cx, - ); - - let uri = match client - .http_client() - .build_zed_llm_url("/predict_edits/v2", &[]) - { - Ok(url) => Arc::from(url), - Err(err) => return Task::ready(Err(err)), - }; - - cx.spawn(async move |this, cx| { - let GatherContextOutput { - mut body, - context_range, - editable_range, - included_events_count, - } = gather_task.await?; - let done_gathering_context_at = Instant::now(); - - let included_events = &events[events.len() - included_events_count..events.len()]; - body.can_collect_data = can_collect_file - && this - .read_with(cx, |this, cx| this.can_collect_events(included_events, cx)) - .unwrap_or(false); - if body.can_collect_data { - body.git_info = git_info; - } - - log::debug!( - "Events:\n{}\nExcerpt:\n{:?}", - body.input_events, - body.input_excerpt - ); - - let response = EditPredictionStore::send_api_request::( - |request| { - Ok(request - .uri(uri.as_str()) - .body(serde_json::to_string(&body)?.into())?) - }, - client, - llm_token, - app_version, - true, - ) - .await; - - let context_start_offset = context_range.start.to_offset(&snapshot); - let context_start_row = context_range.start.row; - let editable_offset_range = editable_range.to_offset(&snapshot); - - let inputs = ZetaPromptInput { - events: included_events.into(), - related_files: vec![], - cursor_path: full_path, - cursor_excerpt: snapshot - .text_for_range(context_range) - .collect::() - .into(), - editable_range_in_excerpt: (editable_range.start - context_start_offset) - ..(editable_offset_range.end - context_start_offset), - cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset, - excerpt_start_row: Some(context_start_row), - }; - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(DebugEvent::EditPredictionStarted( - EditPredictionStartedDebugEvent { - buffer: buffer.downgrade(), - prompt: Some(serde_json::to_string(&inputs).unwrap()), - position, - }, - )) - .ok(); - } - - let (response, usage) = match response { - Ok(response) => response, - Err(err) => { - if err.is::() { - cx.update(|cx| { - this.update(cx, |ep_store, _cx| { - ep_store.update_required = true; - }) - .ok(); - - let error_message: SharedString = err.to_string().into(); - show_app_notification( - NotificationId::unique::(), - cx, - move |cx| { - cx.new(|cx| { - ErrorMessagePrompt::new(error_message.clone(), cx) - .with_link_button("Update Zed", "https://zed.dev/releases") - }) - }, - ); - }); - } - - return Err(err); - } - }; - - let received_response_at = Instant::now(); - log::debug!("completion response: {}", &response.output_excerpt); - - if let Some(usage) = usage { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); - }); - }) - .ok(); - } - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(DebugEvent::EditPredictionFinished( - EditPredictionFinishedDebugEvent { - buffer: buffer.downgrade(), - model_output: Some(response.output_excerpt.clone()), - position, - }, - )) - .ok(); - } - - let edit_prediction = process_completion_response( - response, - buffer, - &snapshot, - editable_range, - inputs, - buffer_snapshotted_at, - received_response_at, - cx, - ) - .await; - - let finished_at = Instant::now(); - - // record latency for ~1% of requests - if rand::random::() <= 2 { - telemetry::event!( - "Edit Prediction Request", - context_latency = done_gathering_context_at - .duration_since(buffer_snapshotted_at) - .as_millis(), - request_latency = received_response_at - .duration_since(done_gathering_context_at) - .as_millis(), - process_latency = finished_at.duration_since(received_response_at).as_millis() - ); - } - - edit_prediction.map(Some) - }) -} - -fn process_completion_response( - prediction_response: PredictEditsResponse, - buffer: Entity, - snapshot: &BufferSnapshot, - editable_range: Range, - inputs: ZetaPromptInput, - buffer_snapshotted_at: Instant, - received_response_at: Instant, - cx: &AsyncApp, -) -> Task> { - let snapshot = snapshot.clone(); - let request_id = prediction_response.request_id; - let output_excerpt = prediction_response.output_excerpt; - cx.spawn(async move |cx| { - let output_excerpt: Arc = output_excerpt.into(); - - let edits: Arc<[(Range, Arc)]> = cx - .background_spawn({ - let output_excerpt = output_excerpt.clone(); - let editable_range = editable_range.clone(); - let snapshot = snapshot.clone(); - async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) } - }) - .await? - .into(); - - let id = EditPredictionId(request_id.into()); - Ok(EditPredictionResult::new( - id, - &buffer, - &snapshot, - edits, - None, - buffer_snapshotted_at, - received_response_at, - inputs, - cx, - ) - .await) - }) -} - pub(crate) fn parse_edits( output_excerpt: &str, editable_range: Range, @@ -434,35 +169,6 @@ fn common_prefix, T2: Iterator>(a: T1, b: .sum() } -fn git_info_for_file( - project: &Entity, - project_path: &ProjectPath, - cx: &App, -) -> Option { - let git_store = project.read(cx).git_store().read(cx); - if let Some((repository, _repo_path)) = - git_store.repository_and_path_for_project_path(project_path, cx) - { - let repository = repository.read(cx); - let head_sha = repository - .head_commit - .as_ref() - .map(|head_commit| head_commit.sha.to_string()); - let remote_origin_url = repository.remote_origin_url.clone(); - let remote_upstream_url = repository.remote_upstream_url.clone(); - if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { - return None; - } - Some(PredictEditsGitInfo { - head_sha, - remote_origin_url, - remote_upstream_url, - }) - } else { - None - } -} - pub struct GatherContextOutput { pub body: PredictEditsBody, pub context_range: Range, @@ -470,48 +176,6 @@ pub struct GatherContextOutput { pub included_events_count: usize, } -pub fn gather_context( - full_path_str: String, - snapshot: &BufferSnapshot, - cursor_point: language::Point, - prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static, - trigger: PredictEditsRequestTrigger, - cx: &App, -) -> Task> { - cx.background_spawn({ - let snapshot = snapshot.clone(); - async move { - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &full_path_str, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let (input_events, included_events_count) = prompt_for_events(); - let editable_range = input_excerpt.editable_range.to_offset(&snapshot); - - let body = PredictEditsBody { - input_events, - input_excerpt: input_excerpt.prompt, - can_collect_data: false, - diagnostic_groups: None, - git_info: None, - outline: None, - speculated_output: None, - trigger, - }; - - Ok(GatherContextOutput { - body, - context_range: input_excerpt.context_range, - editable_range, - included_events_count, - }) - } - }) -} - pub(crate) fn prompt_for_events(events: &[Arc], max_tokens: usize) -> String { prompt_for_events_impl(events, max_tokens).0 } @@ -638,6 +302,7 @@ mod tests { use gpui::{App, AppContext}; use indoc::indoc; use language::Buffer; + use text::OffsetRangeExt as _; #[gpui::test] fn test_excerpt_for_cursor_position(cx: &mut App) { diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 36f70c6d9a85a0e2ac840f3655e48fdab9166252..874644b7605776364b3455092443263de05d84cd 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,10 +1,11 @@ +use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets}; use crate::prediction::EditPredictionResult; use crate::zeta1::compute_edits_and_cursor_position; use crate::{ CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, }; -use anyhow::{Result, anyhow}; +use anyhow::Result; use cloud_llm_client::predict_edits_v3::RawCompletionRequest; use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason}; use gpui::{App, Task, prelude::*}; @@ -13,8 +14,10 @@ use release_channel::AppVersion; use std::env; use std::{path::Path, sync::Arc, time::Instant}; -use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output}; -use zeta_prompt::{format_zeta_prompt, get_prefill}; +use zeta_prompt::{ + CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output, + format_zeta_prompt, get_prefill, +}; pub const MAX_CONTEXT_TOKENS: usize = 350; @@ -39,24 +42,30 @@ pub fn request_prediction_with_zeta2( events, debug_tx, trigger, + project, .. }: EditPredictionModelInput, + preferred_model: Option, cx: &mut Context, ) -> Task>> { let buffer_snapshotted_at = Instant::now(); let raw_config = store.zeta2_raw_config().cloned(); - let Some(excerpt_path) = snapshot + let excerpt_path: Arc = snapshot .file() .map(|file| -> Arc { file.full_path(cx).into() }) - else { - return Task::ready(Err(anyhow!("No file path for excerpt"))); - }; + .unwrap_or_else(|| Arc::from(Path::new("untitled"))); let client = store.client.clone(); let llm_token = store.llm_token.clone(); let app_version = AppVersion::global(cx); + let is_open_source = snapshot + .file() + .map_or(false, |file| store.is_file_open_source(&project, file, cx)) + && events.iter().all(|event| event.in_open_source_repo()) + && related_files.iter().all(|file| file.in_open_source_repo); + let request_task = cx.background_spawn({ async move { let zeta_version = raw_config @@ -72,6 +81,8 @@ pub fn request_prediction_with_zeta2( excerpt_path, cursor_offset, zeta_version, + preferred_model, + is_open_source, ); if let Some(debug_tx) = &debug_tx { @@ -248,41 +259,52 @@ pub fn zeta2_prompt_input( excerpt_path: Arc, cursor_offset: usize, zeta_format: ZetaFormat, + preferred_model: Option, + is_open_source: bool, ) -> (std::ops::Range, zeta_prompt::ZetaPromptInput) { let cursor_point = cursor_offset.to_point(snapshot); - let (editable_range, context_range) = - crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( - cursor_point, - snapshot, - max_editable_tokens(zeta_format), - MAX_CONTEXT_TOKENS, - ); + let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot); let related_files = crate::filter_redundant_excerpts( related_files, excerpt_path.as_ref(), - context_range.start.row..context_range.end.row, + full_context.start.row..full_context.end.row, ); - let context_start_offset = context_range.start.to_offset(snapshot); - let context_start_row = context_range.start.row; + let full_context_start_offset = full_context.start.to_offset(snapshot); + let full_context_start_row = full_context.start.row; + + let excerpt_ranges = + excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot); + + let editable_range = match preferred_model { + Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350, + _ => match zeta_format { + ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150, + _ => &range_points.editable_180, + }, + }; + let editable_offset_range = editable_range.to_offset(snapshot); - let cursor_offset_in_excerpt = cursor_offset - context_start_offset; - let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset) - ..(editable_offset_range.end - context_start_offset); + let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset; + let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset) + ..(editable_offset_range.end - full_context_start_offset); let prompt_input = zeta_prompt::ZetaPromptInput { cursor_path: excerpt_path, cursor_excerpt: snapshot - .text_for_range(context_range) + .text_for_range(full_context) .collect::() .into(), editable_range_in_excerpt, cursor_offset_in_excerpt, - excerpt_start_row: Some(context_start_row), + excerpt_start_row: Some(full_context_start_row), events, related_files, + excerpt_ranges: Some(excerpt_ranges), + preferred_model, + in_open_source_repo: is_open_source, }; (editable_offset_range, prompt_input) } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index aaa5b2307f7f6df9a3e5a2c584d7d815ffb5cb53..dbdc4ab19b8310ca1b653bfad3977adc8717f926 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -93,6 +93,13 @@ pub async fn run_format_prompt( excerpt_start_row: prompt_inputs.excerpt_start_row, events: prompt_inputs.edit_history.clone(), related_files: prompt_inputs.related_files.clone().unwrap_or_default(), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: example + .spec + .captured_prompt_input + .as_ref() + .map_or(false, |input| input.in_open_source_repo), }; let prompt = format_zeta_prompt(&input, version); let prefill = zeta_prompt::get_prefill(&input, version); diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index b48cc09e13b02cac85033786e780533304fa6de4..46ee3ba590ed98aad0e05aac527cf671018fd162 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -1304,6 +1304,7 @@ fn build_example_from_snowflake( excerpt_start_row: None, events, related_files, + in_open_source_repo: input.in_open_source_repo, }), telemetry: Some(TelemetrySource { request_id, diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 79bfdfa192a7d52d7f1189b93e164290380c71ea..0ae9253a49c81b50183c10cdce3877d8e41b64a8 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -136,11 +136,13 @@ impl RelatedExcerptStore { .collect() } - pub fn related_files_with_buffers(&mut self, cx: &App) -> Vec<(RelatedFile, Entity)> { + pub fn related_files_with_buffers( + &mut self, + cx: &App, + ) -> impl Iterator)> { self.related_buffers .iter_mut() .map(|related| (related.related_file(cx), related.buffer.clone())) - .collect::>() } pub fn set_related_files(&mut self, files: Vec, cx: &App) { @@ -424,6 +426,7 @@ impl RelatedBuffer { path, excerpts: cached.excerpts.clone(), max_row: buffer.max_point().row, + in_open_source_repo: false, }; return related_file; } diff --git a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs index 078bf0c56192b7ab5ea13b76d0940710ece2378d..79c53aea2a2fb5de9c137cbba4f5fa751db1f170 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs @@ -89,7 +89,6 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) { let company_buffer = related_excerpt_store.update(cx, |store, cx| { store .related_files_with_buffers(cx) - .into_iter() .find(|(file, _)| file.path.to_str() == Some("root/src/company.rs")) .map(|(_, buffer)| buffer) .expect("company.rs buffer not found") diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 407ed5f561080065fe5737e0a8b4b7c578284184..fa6f7ce8f03bf7a9534017b99f503ebd6041f827 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -18,6 +18,32 @@ fn estimate_tokens(bytes: usize) -> usize { bytes / 3 } +/// The client's preferred edit prediction model. The server may override this. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum EditPredictionModelKind { + Zeta1, + Zeta2, +} + +/// Pre-computed byte offset ranges within `cursor_excerpt` for different +/// editable and context token budgets. Allows the server to select the +/// appropriate ranges for whichever model it uses. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExcerptRanges { + /// Editable region computed with a 150-token budget. + pub editable_150: Range, + /// Editable region computed with a 180-token budget. + pub editable_180: Range, + /// Editable region computed with a 350-token budget. + pub editable_350: Range, + /// Context boundary when using editable_150 with 350 tokens of additional context. + pub editable_150_context_350: Range, + /// Context boundary when using editable_180 with 350 tokens of additional context. + pub editable_180_context_350: Range, + /// Context boundary when using editable_350 with 150 tokens of additional context. + pub editable_350_context_150: Range, +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ZetaPromptInput { pub cursor_path: Arc, @@ -28,6 +54,17 @@ pub struct ZetaPromptInput { pub excerpt_start_row: Option, pub events: Vec>, pub related_files: Vec, + /// When set, the excerpt was computed with a larger budget (~512 tokens) + /// and these ranges let the server select model-appropriate subsets. + /// When absent, the excerpt IS the context region and + /// `editable_range_in_excerpt` is the only editable range. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub excerpt_ranges: Option, + /// Client's preferred model. The server may override. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub preferred_model: Option, + #[serde(default)] + pub in_open_source_repo: bool, } #[derive( @@ -103,6 +140,17 @@ pub enum Event { }, } +impl Event { + pub fn in_open_source_repo(&self) -> bool { + match self { + Event::BufferChange { + in_open_source_repo, + .. + } => *in_open_source_repo, + } + } +} + pub fn write_event(prompt: &mut String, event: &Event) { fn write_path_as_unix_str(prompt: &mut String, path: &Path) { for component in path.components() { @@ -136,6 +184,8 @@ pub struct RelatedFile { pub path: Arc, pub max_row: u32, pub excerpts: Vec, + #[serde(default)] + pub in_open_source_repo: bool, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -164,27 +214,96 @@ pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str { } } +fn resolve_cursor_region( + input: &ZetaPromptInput, + format: ZetaFormat, +) -> (&str, Range, usize) { + let Some(ranges) = &input.excerpt_ranges else { + return ( + &input.cursor_excerpt, + input.editable_range_in_excerpt.clone(), + input.cursor_offset_in_excerpt, + ); + }; + + let (editable_range, context_range) = match format { + ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => ( + ranges.editable_150.clone(), + ranges.editable_150_context_350.clone(), + ), + ZetaFormat::V0114180EditableRegion + | ZetaFormat::V0120GitMergeMarkers + | ZetaFormat::V0131GitMergeMarkersPrefix + | ZetaFormat::V0211Prefill + | ZetaFormat::V0211SeedCoder => ( + ranges.editable_180.clone(), + ranges.editable_180_context_350.clone(), + ), + }; + + let context_start = context_range.start; + let context_text = &input.cursor_excerpt[context_range]; + let adjusted_editable = + (editable_range.start - context_start)..(editable_range.end - context_start); + let adjusted_cursor = input.cursor_offset_in_excerpt - context_start; + + (context_text, adjusted_editable, adjusted_cursor) +} + fn format_zeta_prompt_with_budget( input: &ZetaPromptInput, format: ZetaFormat, max_tokens: usize, ) -> String { + let (context, editable_range, cursor_offset) = resolve_cursor_region(input, format); + let path = &*input.cursor_path; + let mut cursor_section = String::new(); match format { ZetaFormat::V0112MiddleAtEnd => { - v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input); + v0112_middle_at_end::write_cursor_excerpt_section( + &mut cursor_section, + path, + context, + &editable_range, + cursor_offset, + ); } ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => { - v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input) - } - ZetaFormat::V0120GitMergeMarkers => { - v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input) + v0113_ordered::write_cursor_excerpt_section( + &mut cursor_section, + path, + context, + &editable_range, + cursor_offset, + ) } + ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::write_cursor_excerpt_section( + &mut cursor_section, + path, + context, + &editable_range, + cursor_offset, + ), ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => { - v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input) + v0131_git_merge_markers_prefix::write_cursor_excerpt_section( + &mut cursor_section, + path, + context, + &editable_range, + cursor_offset, + ) } ZetaFormat::V0211SeedCoder => { - return seed_coder::format_prompt_with_budget(input, max_tokens); + return seed_coder::format_prompt_with_budget( + path, + context, + &editable_range, + cursor_offset, + &input.events, + &input.related_files, + max_tokens, + ); } } @@ -343,29 +462,29 @@ pub fn write_related_files( mod v0112_middle_at_end { use super::*; - pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); prompt.push_str("<|fim_prefix|>\n"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + prompt.push_str(&context[..editable_range.start]); prompt.push_str("<|fim_suffix|>\n"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + prompt.push_str(&context[editable_range.end..]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str("<|fim_middle|>current\n"); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + prompt.push_str(&context[editable_range.start..cursor_offset]); prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + prompt.push_str(&context[cursor_offset..editable_range.end]); if !prompt.ends_with('\n') { prompt.push('\n'); } @@ -377,32 +496,32 @@ mod v0112_middle_at_end { mod v0113_ordered { use super::*; - pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); prompt.push_str("<|fim_prefix|>\n"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + prompt.push_str(&context[..editable_range.start]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str("<|fim_middle|>current\n"); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + prompt.push_str(&context[editable_range.start..cursor_offset]); prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + prompt.push_str(&context[cursor_offset..editable_range.end]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str("<|fim_suffix|>\n"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + prompt.push_str(&context[editable_range.end..]); if !prompt.ends_with('\n') { prompt.push('\n'); } @@ -441,30 +560,30 @@ pub mod v0120_git_merge_markers { pub const SEPARATOR: &str = "=======\n"; pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; - pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); prompt.push_str("<|fim_prefix|>"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + prompt.push_str(&context[..editable_range.start]); prompt.push_str("<|fim_suffix|>"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + prompt.push_str(&context[editable_range.end..]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str("<|fim_middle|>"); prompt.push_str(START_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + prompt.push_str(&context[editable_range.start..cursor_offset]); prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + prompt.push_str(&context[cursor_offset..editable_range.end]); if !prompt.ends_with('\n') { prompt.push('\n'); } @@ -502,29 +621,29 @@ pub mod v0131_git_merge_markers_prefix { pub const SEPARATOR: &str = "=======\n"; pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; - pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); prompt.push_str("<|fim_prefix|>"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + prompt.push_str(&context[..editable_range.start]); prompt.push_str(START_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + prompt.push_str(&context[editable_range.start..cursor_offset]); prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + prompt.push_str(&context[cursor_offset..editable_range.end]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str(SEPARATOR); prompt.push_str("<|fim_suffix|>"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + prompt.push_str(&context[editable_range.end..]); if !prompt.ends_with('\n') { prompt.push('\n'); } @@ -619,16 +738,25 @@ pub mod seed_coder { pub const SEPARATOR: &str = "=======\n"; pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; - pub fn format_prompt_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String { - let suffix_section = build_suffix_section(input); - let cursor_prefix_section = build_cursor_prefix_section(input); + pub fn format_prompt_with_budget( + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + events: &[Arc], + related_files: &[RelatedFile], + max_tokens: usize, + ) -> String { + let suffix_section = build_suffix_section(context, editable_range); + let cursor_prefix_section = + build_cursor_prefix_section(path, context, editable_range, cursor_offset); let suffix_tokens = estimate_tokens(suffix_section.len()); let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len()); let budget_after_cursor = max_tokens.saturating_sub(suffix_tokens + cursor_prefix_tokens); let edit_history_section = super::format_edit_history_within_budget( - &input.events, + events, FILE_MARKER, "edit_history", budget_after_cursor, @@ -637,7 +765,7 @@ pub mod seed_coder { let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens); let related_files_section = super::format_related_files_within_budget( - &input.related_files, + related_files, FILE_MARKER, budget_after_edit_history, ); @@ -658,32 +786,31 @@ pub mod seed_coder { prompt } - fn build_suffix_section(input: &ZetaPromptInput) -> String { + fn build_suffix_section(context: &str, editable_range: &Range) -> String { let mut section = String::new(); section.push_str(FIM_SUFFIX); - section.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + section.push_str(&context[editable_range.end..]); if !section.ends_with('\n') { section.push('\n'); } section } - fn build_cursor_prefix_section(input: &ZetaPromptInput) -> String { + fn build_cursor_prefix_section( + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) -> String { let mut section = String::new(); - let path_str = input.cursor_path.to_string_lossy(); + let path_str = path.to_string_lossy(); write!(section, "{}{}\n", FILE_MARKER, path_str).ok(); - section.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + section.push_str(&context[..editable_range.start]); section.push_str(START_MARKER); - section.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + section.push_str(&context[editable_range.start..cursor_offset]); section.push_str(CURSOR_MARKER); - section.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + section.push_str(&context[cursor_offset..editable_range.end]); if !section.ends_with('\n') { section.push('\n'); } @@ -694,6 +821,9 @@ pub mod seed_coder { /// The zeta1 prompt format pub mod zeta1 { + use super::*; + use std::fmt::Write; + pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>"; pub const START_OF_FILE_MARKER: &str = "<|start_of_file|>"; pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>"; @@ -725,6 +855,166 @@ pub mod zeta1 { prompt.push_str(RESPONSE_HEADER); prompt } + + /// Formats a complete zeta1 prompt from a `ZetaPromptInput` using the given + /// editable and context byte-offset ranges within `cursor_excerpt`. + pub fn format_zeta1_from_input( + input: &ZetaPromptInput, + editable_range: Range, + context_range: Range, + ) -> String { + let events = format_zeta1_events(&input.events); + let excerpt = format_zeta1_excerpt(input, editable_range, context_range); + format_zeta1_prompt(&events, &excerpt) + } + + /// Formats events in zeta1 style (oldest first). + fn format_zeta1_events(events: &[Arc]) -> String { + let mut result = String::new(); + for event in events { + let event_string = format_zeta1_event(event); + if event_string.is_empty() { + continue; + } + if !result.is_empty() { + result.push_str("\n\n"); + } + result.push_str(&event_string); + } + result + } + + fn format_zeta1_event(event: &Event) -> String { + match event { + Event::BufferChange { + path, + old_path, + diff, + .. + } => { + let mut prompt = String::new(); + if old_path != path { + writeln!( + prompt, + "User renamed {} to {}\n", + old_path.display(), + path.display() + ) + .ok(); + } + if !diff.is_empty() { + write!( + prompt, + "User edited {}:\n```diff\n{}\n```", + path.display(), + diff + ) + .ok(); + } + prompt + } + } + } + + /// Formats the excerpt section of a zeta1 prompt using byte-offset ranges + /// within `cursor_excerpt`. + fn format_zeta1_excerpt( + input: &ZetaPromptInput, + editable_range: Range, + context_range: Range, + ) -> String { + let path_str = input.cursor_path.to_string_lossy(); + let excerpt = &*input.cursor_excerpt; + let cursor_offset = input.cursor_offset_in_excerpt; + + let mut prompt = String::new(); + writeln!(&mut prompt, "```{path_str}").ok(); + + let starts_at_file_beginning = + input.excerpt_start_row == Some(0) && context_range.start == 0; + if starts_at_file_beginning { + writeln!(&mut prompt, "{START_OF_FILE_MARKER}").ok(); + } + + prompt.push_str(&excerpt[context_range.start..editable_range.start]); + + writeln!(&mut prompt, "{EDITABLE_REGION_START_MARKER}").ok(); + prompt.push_str(&excerpt[editable_range.start..cursor_offset]); + prompt.push_str(CURSOR_MARKER); + prompt.push_str(&excerpt[cursor_offset..editable_range.end]); + write!(&mut prompt, "\n{EDITABLE_REGION_END_MARKER}").ok(); + + prompt.push_str(&excerpt[editable_range.end..context_range.end]); + write!(prompt, "\n```").ok(); + + prompt + } + + /// Cleans zeta1 model output by extracting content between editable region + /// markers and converting the zeta1 cursor marker to the universal one. + /// Returns `None` if the output doesn't contain the expected markers. + pub fn clean_zeta1_model_output(output: &str) -> Option { + let content = output.replace(CURSOR_MARKER, ""); + + let content_start = content + .find(EDITABLE_REGION_START_MARKER) + .map(|pos| pos + EDITABLE_REGION_START_MARKER.len()) + .map(|pos| { + if content.as_bytes().get(pos) == Some(&b'\n') { + pos + 1 + } else { + pos + } + }) + .unwrap_or(0); + + let content_end = content + .find(EDITABLE_REGION_END_MARKER) + .map(|pos| { + if pos > 0 && content.as_bytes().get(pos - 1) == Some(&b'\n') { + pos - 1 + } else { + pos + } + }) + .unwrap_or(content.len()); + + if content_start > content_end { + return Some(String::new()); + } + + let extracted = &content[content_start..content_end]; + + let cursor_offset = output.find(CURSOR_MARKER).map(|zeta1_cursor_pos| { + let text_before_cursor = output[..zeta1_cursor_pos].replace(CURSOR_MARKER, ""); + let text_before_cursor = text_before_cursor + .find(EDITABLE_REGION_START_MARKER) + .map(|pos| { + let after_marker = pos + EDITABLE_REGION_START_MARKER.len(); + if text_before_cursor.as_bytes().get(after_marker) == Some(&b'\n') { + after_marker + 1 + } else { + after_marker + } + }) + .unwrap_or(0); + let offset_in_extracted = zeta1_cursor_pos + .saturating_sub(text_before_cursor) + .min(extracted.len()); + offset_in_extracted + }); + + let mut result = String::with_capacity(extracted.len() + super::CURSOR_MARKER.len()); + if let Some(offset) = cursor_offset { + result.push_str(&extracted[..offset]); + result.push_str(super::CURSOR_MARKER); + result.push_str(&extracted[offset..]); + } else { + result.push_str(extracted); + } + + Some(result) + } } #[cfg(test)] @@ -747,6 +1037,9 @@ mod tests { excerpt_start_row: None, events: events.into_iter().map(Arc::new).collect(), related_files, + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, } } @@ -768,6 +1061,7 @@ mod tests { row_range: 0..content.lines().count() as u32, text: content.into(), }], + in_open_source_repo: false, } } @@ -869,6 +1163,7 @@ mod tests { vec![RelatedFile { path: Path::new("big.rs").into(), max_row: 30, + in_open_source_repo: false, excerpts: vec![ RelatedExcerpt { row_range: 0..10, @@ -1106,4 +1401,201 @@ mod tests { "new code\n" ); } + + #[test] + fn test_format_zeta1_from_input_basic() { + let excerpt = "fn before() {}\nfn foo() {\n let x = 1;\n}\nfn after() {}\n"; + let input = ZetaPromptInput { + cursor_path: Path::new("src/main.rs").into(), + cursor_excerpt: excerpt.into(), + editable_range_in_excerpt: 15..41, + cursor_offset_in_excerpt: 30, + excerpt_start_row: Some(0), + events: vec![Arc::new(make_event("other.rs", "-old\n+new\n"))], + related_files: vec![], + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + }; + + let prompt = zeta1::format_zeta1_from_input(&input, 15..41, 0..excerpt.len()); + + assert_eq!( + prompt, + concat!( + "### Instruction:\n", + "You are a code completion assistant and your task is to analyze user edits and then rewrite an ", + "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ", + "into account the cursor location.\n", + "\n", + "### User Edits:\n", + "\n", + "User edited other.rs:\n", + "```diff\n", + "-old\n", + "+new\n", + "\n", + "```\n", + "\n", + "### User Excerpt:\n", + "\n", + "```src/main.rs\n", + "<|start_of_file|>\n", + "fn before() {}\n", + "<|editable_region_start|>\n", + "fn foo() {\n", + " <|user_cursor_is_here|>let x = 1;\n", + "\n", + "<|editable_region_end|>}\n", + "fn after() {}\n", + "\n", + "```\n", + "\n", + "### Response:\n", + ), + ); + } + + #[test] + fn test_format_zeta1_from_input_no_start_of_file() { + let excerpt = "fn foo() {\n let x = 1;\n}\n"; + let input = ZetaPromptInput { + cursor_path: Path::new("src/main.rs").into(), + cursor_excerpt: excerpt.into(), + editable_range_in_excerpt: 0..28, + cursor_offset_in_excerpt: 15, + excerpt_start_row: Some(10), + events: vec![], + related_files: vec![], + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + }; + + let prompt = zeta1::format_zeta1_from_input(&input, 0..28, 0..28); + + assert_eq!( + prompt, + concat!( + "### Instruction:\n", + "You are a code completion assistant and your task is to analyze user edits and then rewrite an ", + "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ", + "into account the cursor location.\n", + "\n", + "### User Edits:\n", + "\n", + "\n", + "\n", + "### User Excerpt:\n", + "\n", + "```src/main.rs\n", + "<|editable_region_start|>\n", + "fn foo() {\n", + " <|user_cursor_is_here|>let x = 1;\n", + "}\n", + "\n", + "<|editable_region_end|>\n", + "```\n", + "\n", + "### Response:\n", + ), + ); + } + + #[test] + fn test_format_zeta1_from_input_with_sub_ranges() { + let excerpt = "// prefix\nfn foo() {\n let x = 1;\n}\n// suffix\n"; + let editable_range = 10..37; + let context_range = 0..excerpt.len(); + + let input = ZetaPromptInput { + cursor_path: Path::new("test.rs").into(), + cursor_excerpt: excerpt.into(), + editable_range_in_excerpt: editable_range.clone(), + cursor_offset_in_excerpt: 25, + excerpt_start_row: Some(0), + events: vec![], + related_files: vec![], + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + }; + + let prompt = zeta1::format_zeta1_from_input(&input, editable_range, context_range); + + assert_eq!( + prompt, + concat!( + "### Instruction:\n", + "You are a code completion assistant and your task is to analyze user edits and then rewrite an ", + "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ", + "into account the cursor location.\n", + "\n", + "### User Edits:\n", + "\n", + "\n", + "\n", + "### User Excerpt:\n", + "\n", + "```test.rs\n", + "<|start_of_file|>\n", + "// prefix\n", + "<|editable_region_start|>\n", + "fn foo() {\n", + " <|user_cursor_is_here|>let x = 1;\n", + "}\n", + "<|editable_region_end|>\n", + "// suffix\n", + "\n", + "```\n", + "\n", + "### Response:\n", + ), + ); + } + + #[test] + fn test_clean_zeta1_model_output_basic() { + let output = indoc! {" + <|editable_region_start|> + fn main() { + println!(\"hello\"); + } + <|editable_region_end|> + "}; + + let cleaned = zeta1::clean_zeta1_model_output(output).unwrap(); + assert_eq!(cleaned, "fn main() {\n println!(\"hello\");\n}"); + } + + #[test] + fn test_clean_zeta1_model_output_with_cursor() { + let output = indoc! {" + <|editable_region_start|> + fn main() { + <|user_cursor_is_here|>println!(\"hello\"); + } + <|editable_region_end|> + "}; + + let cleaned = zeta1::clean_zeta1_model_output(output).unwrap(); + assert_eq!( + cleaned, + "fn main() {\n <|user_cursor|>println!(\"hello\");\n}" + ); + } + + #[test] + fn test_clean_zeta1_model_output_no_markers() { + let output = "fn main() {}\n"; + let cleaned = zeta1::clean_zeta1_model_output(output).unwrap(); + assert_eq!(cleaned, "fn main() {}\n"); + } + + #[test] + fn test_clean_zeta1_model_output_empty_region() { + let output = "<|editable_region_start|>\n<|editable_region_end|>\n"; + let cleaned = zeta1::clean_zeta1_model_output(output).unwrap(); + assert_eq!(cleaned, ""); + } }