diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 00238983f800861cb6d94a4e49d8ca3a91d5bbaf..bfe56408dc5ea9c1017c8c77c54068e3ae0f99cf 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -15,8 +15,6 @@ use project::{Project, WorktreeId}; use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc}; use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _}; -pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500; - pub fn capture_example( project: Entity, buffer: Entity, @@ -156,6 +154,7 @@ pub fn capture_example( excerpt_start_row: Some(0), events: captured_events, related_files: captured_related_files, + in_open_source_repo: false, } }); @@ -304,10 +303,6 @@ fn generate_timestamp_name() -> String { } } -pub(crate) fn should_send_testing_zeta2_request() -> bool { - rand::random::() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/edit_prediction/src/cursor_excerpt.rs b/crates/edit_prediction/src/cursor_excerpt.rs index 682b937c3d6094334edf7842abe8e6f80f9c3fa2..900d78945ca6ab4fab9c9c60bf13009368c7c77b 100644 --- a/crates/edit_prediction/src/cursor_excerpt.rs +++ b/crates/edit_prediction/src/cursor_excerpt.rs @@ -1,5 +1,81 @@ use language::{BufferSnapshot, Point}; use std::ops::Range; +use zeta_prompt::ExcerptRanges; + +/// Pre-computed Point ranges for all editable/context budget combinations. +pub struct ExcerptRangePoints { + pub editable_150: Range, + pub editable_180: Range, + pub editable_350: Range, + pub editable_150_context_350: Range, + pub editable_180_context_350: Range, + pub editable_350_context_150: Range, +} + +/// Computes all range variants for a cursor position: editable ranges at 150, 180, and 350 +/// token budgets, plus their corresponding context expansions. Returns the full excerpt range +/// (union of all context ranges) and the individual sub-ranges as Points. +pub fn compute_excerpt_ranges( + position: Point, + snapshot: &BufferSnapshot, +) -> (Range, ExcerptRangePoints) { + let editable_150 = compute_editable_range(snapshot, position, 150); + let editable_180 = compute_editable_range(snapshot, position, 180); + let editable_350 = compute_editable_range(snapshot, position, 350); + + let editable_150_context_350 = + expand_context_syntactically_then_linewise(snapshot, editable_150.clone(), 350); + let editable_180_context_350 = + expand_context_syntactically_then_linewise(snapshot, editable_180.clone(), 350); + let editable_350_context_150 = + expand_context_syntactically_then_linewise(snapshot, editable_350.clone(), 150); + + let full_start_row = editable_150_context_350 + .start + .row + .min(editable_180_context_350.start.row) + .min(editable_350_context_150.start.row); + let full_end_row = editable_150_context_350 + .end + .row + .max(editable_180_context_350.end.row) + .max(editable_350_context_150.end.row); + + let full_context = + Point::new(full_start_row, 0)..Point::new(full_end_row, snapshot.line_len(full_end_row)); + + let ranges = ExcerptRangePoints { + editable_150, + editable_180, + editable_350, + editable_150_context_350, + editable_180_context_350, + editable_350_context_150, + }; + + (full_context, ranges) +} + +/// Converts `ExcerptRangePoints` to byte-offset `ExcerptRanges` relative to `excerpt_start`. +pub fn excerpt_ranges_to_byte_offsets( + ranges: &ExcerptRangePoints, + excerpt_start: usize, + snapshot: &BufferSnapshot, +) -> ExcerptRanges { + let to_offset = |range: &Range| -> Range { + let start = range.start.to_offset(snapshot); + let end = range.end.to_offset(snapshot); + (start - excerpt_start)..(end - excerpt_start) + }; + ExcerptRanges { + editable_150: to_offset(&ranges.editable_150), + editable_180: to_offset(&ranges.editable_180), + editable_350: to_offset(&ranges.editable_350), + editable_150_context_350: to_offset(&ranges.editable_150_context_350), + editable_180_context_350: to_offset(&ranges.editable_180_context_350), + editable_350_context_150: to_offset(&ranges.editable_350_context_150), + } +} pub fn editable_and_context_ranges_for_cursor_position( position: Point, @@ -312,6 +388,8 @@ fn expand_context_syntactically_then_linewise( start..end } +use language::ToOffset as _; + #[cfg(test)] mod tests { use super::*; diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 01d813173b1957eea1d900187af520510d9dfc09..1eb32c244dc561ef7dd31394051e459786bf8683 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -72,7 +72,6 @@ pub mod zeta2; #[cfg(test)] mod edit_prediction_tests; -use crate::capture_example::should_send_testing_zeta2_request; use crate::license_detection::LicenseDetectionWatcher; use crate::mercury::Mercury; use crate::ollama::Ollama; @@ -784,10 +783,19 @@ impl EditPredictionStore { ) -> Vec { self.projects .get(&project.entity_id()) - .map(|project| { - project - .context - .update(cx, |context, cx| context.related_files(cx)) + .map(|project_state| { + project_state.context.update(cx, |context, cx| { + context + .related_files_with_buffers(cx) + .map(|(mut related_file, buffer)| { + related_file.in_open_source_repo = buffer + .read(cx) + .file() + .map_or(false, |file| self.is_file_open_source(&project, file, cx)); + related_file + }) + .collect() + }) }) .unwrap_or_default() } @@ -835,9 +843,9 @@ impl EditPredictionStore { self.projects .get(&project.entity_id()) .map(|project| { - project - .context - .update(cx, |context, cx| context.related_files_with_buffers(cx)) + project.context.update(cx, |context, cx| { + context.related_files_with_buffers(cx).collect() + }) }) .unwrap_or_default() } @@ -1832,15 +1840,18 @@ impl EditPredictionStore { }; let task = match &self.edit_prediction_model { - EditPredictionModel::Zeta1 => { - if should_send_testing_zeta2_request() { - let mut zeta2_inputs = inputs.clone(); - zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing; - zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach(); - } - zeta1::request_prediction_with_zeta1(self, inputs, cx) - } - EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx), + EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2( + self, + inputs, + Some(zeta_prompt::EditPredictionModelKind::Zeta1), + cx, + ), + EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2( + self, + inputs, + Some(zeta_prompt::EditPredictionModelKind::Zeta2), + cx, + ), EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx), EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx), @@ -2197,23 +2208,8 @@ impl EditPredictionStore { .is_some_and(|watcher| watcher.is_project_open_source()) } - fn can_collect_file(&self, project: &Entity, file: &Arc, cx: &App) -> bool { - self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx) - } - - fn can_collect_events(&self, events: &[Arc], cx: &App) -> bool { - if !self.data_collection_choice.is_enabled(cx) { - return false; - } - events.iter().all(|event| { - matches!( - event.as_ref(), - zeta_prompt::Event::BufferChange { - in_open_source_repo: true, - .. - } - ) - }) + pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool { + self.data_collection_choice.is_enabled(cx) } fn load_data_collection_choice() -> DataCollectionChoice { diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 8e3ebb8a2219ec35e83487efcf449fe81fbd9713..242a2bf3fff5f0eb87b183ec6c65280cbe75256a 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1,11 +1,10 @@ use super::*; -use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS}; +use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string}; use client::{UserStore, test::FakeServer}; -use clock::{FakeSystemClock, ReplicaId}; +use clock::FakeSystemClock; use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use cloud_llm_client::{ - EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse, - RejectEditPredictionsBody, + EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody, predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response}, }; use futures::{ @@ -26,7 +25,7 @@ use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; use std::{path::Path, sync::Arc, time::Duration}; -use util::{path, rel_path::rel_path}; +use util::path; use uuid::Uuid; use zeta_prompt::ZetaPromptInput; @@ -1679,8 +1678,6 @@ fn init_test_with_fake_client( }) } -const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); - #[gpui::test] async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); @@ -1707,6 +1704,10 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { editable_range_in_excerpt: 0..0, cursor_offset_in_excerpt: 0, excerpt_start_row: None, + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), @@ -1810,13 +1811,10 @@ async fn test_clean_up_diff(cx: &mut TestAppContext) { } "}, indoc! {" - <|editable_region_start|> fn main() { let word_1 = \"lorem\"; let range = word_1.len()..word_1.len(); } - - <|editable_region_end|> "}, cx, ) @@ -1837,12 +1835,9 @@ async fn test_clean_up_diff(cx: &mut TestAppContext) { } "}, indoc! {" - <|editable_region_start|> fn main() { let story = \"the quick brown fox jumps over the lazy dog\"; } - - <|editable_region_end|> "}, cx, ) @@ -1860,18 +1855,11 @@ async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { init_test(cx); let buffer_content = "lorem\n"; - let completion_response = indoc! {" - ```animals.js - <|start_of_file|> - <|editable_region_start|> - lorem - ipsum - <|editable_region_end|> - ```"}; + let completion_response = "lorem\nipsum\n"; assert_eq!( apply_edit_prediction(buffer_content, completion_response, cx).await, - "lorem\nipsum" + "lorem\nipsum\n" ); } @@ -1940,298 +1928,6 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte }); } -#[gpui::test] -async fn test_can_collect_data(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/project/src/main.rs"), cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Disabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [], cx).await; - - let buffer = cx.new(|_cx| { - Buffer::remote( - language::BufferId::new(1).unwrap(), - ReplicaId::new(1), - language::Capability::ReadWrite, - "fn main() {\n println!(\"Hello\");\n}", - ) - }); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "LICENSE": BSD_0_TXT, - ".env": "SECRET_KEY=secret" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer("/project/.env", cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [], cx).await; - let buffer = cx.new(|cx| Buffer::local("", cx)); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer("/project/main.rs", cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/open_source_worktree"), - json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), - ) - .await; - fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) - .await; - - let project = Project::test( - fs.clone(), - [ - path!("/open_source_worktree").as_ref(), - path!("/closed_source_worktree").as_ref(), - ], - cx, - ) - .await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - let closed_source_file = project - .update(cx, |project, cx| { - let worktree2 = project - .worktree_for_root_name("closed_source_worktree", cx) - .unwrap(); - worktree2.update(cx, |worktree2, cx| { - worktree2.load_file(rel_path("main.rs"), cx) - }) - }) - .await - .unwrap() - .file; - - buffer.update(cx, |buffer, cx| { - buffer.file_updated(closed_source_file, cx); - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); -} - -#[gpui::test] -async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/worktree1"), - json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), - ) - .await; - fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) - .await; - - let project = Project::test( - fs.clone(), - [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], - cx, - ) - .await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/worktree1/main.rs"), cx) - }) - .await - .unwrap(); - let private_buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/worktree2/file.rs"), cx) - }) - .await - .unwrap(); - - let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; - ep_store.update(cx, |ep_store, _cx| { - ep_store.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - // this has a side effect of registering the buffer to watch for edits - run_edit_prediction(&private_buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - - private_buffer.update(cx, |private_buffer, cx| { - private_buffer.edit([(0..0, "An edit for the history!")], None, cx); - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - - // make an edit that uses too many bytes, causing private_buffer edit to not be able to be - // included - buffer.update(cx, |buffer, cx| { - buffer.edit( - [( - 0..0, - " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS), - )], - None, - cx, - ); - }); - - run_edit_prediction(&buffer, &project, &ep_store, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); -} - fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); @@ -2247,7 +1943,7 @@ async fn apply_edit_prediction( let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); - let (ep_store, _, response) = make_test_ep_store(&project, cx).await; + let (ep_store, response) = make_test_ep_store(&project, cx).await; *response.lock() = completion_response.to_string(); let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await; buffer.update(cx, |buffer, cx| { @@ -2276,28 +1972,13 @@ async fn run_edit_prediction( async fn make_test_ep_store( project: &Entity, cx: &mut TestAppContext, -) -> ( - Entity, - Arc>>, - Arc>, -) { - let default_response = indoc! {" - ```main.rs - <|start_of_file|> - <|editable_region_start|> - hello world - <|editable_region_end|> - ```" - }; - let captured_request: Arc>> = Arc::new(Mutex::new(None)); - let completion_response: Arc> = - Arc::new(Mutex::new(default_response.to_string())); +) -> (Entity, Arc>) { + let default_response = "hello world\n".to_string(); + let completion_response: Arc> = Arc::new(Mutex::new(default_response)); let http_client = FakeHttpClient::create({ - let captured_request = captured_request.clone(); let completion_response = completion_response.clone(); let mut next_request_id = 0; move |req| { - let captured_request = captured_request.clone(); let completion_response = completion_response.clone(); async move { match (req.method(), req.uri().path()) { @@ -2311,24 +1992,6 @@ async fn make_test_ep_store( .into(), ) .unwrap()), - (&Method::POST, "/predict_edits/v2") => { - let mut request_body = String::new(); - req.into_body().read_to_string(&mut request_body).await?; - *captured_request.lock() = - Some(serde_json::from_str(&request_body).unwrap()); - next_request_id += 1; - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: format!("request-{next_request_id}"), - output_excerpt: completion_response.lock().clone(), - }) - .unwrap() - .into(), - ) - .unwrap()) - } (&Method::POST, "/predict_edits/v3") => { next_request_id += 1; Ok(http_client::Response::builder() @@ -2336,7 +1999,7 @@ async fn make_test_ep_store( .body( serde_json::to_string(&PredictEditsV3Response { request_id: format!("request-{next_request_id}"), - output: "hello world".to_string(), + output: completion_response.lock().clone(), }) .unwrap() .into(), @@ -2375,7 +2038,7 @@ async fn make_test_ep_store( ep_store }); - (ep_store, captured_request, completion_response) + (ep_store, completion_response) } fn to_completion_edits( diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 5b9c98b83074cf5d4ead8af2bb974ff591c86e95..c6609e5f1f42f21eb165488f85575f2c50fcd1e0 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -66,6 +66,7 @@ pub struct CapturedPromptInput { pub excerpt_start_row: Option, pub events: Vec, pub related_files: Vec, + pub in_open_source_repo: bool, } #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] @@ -101,6 +102,7 @@ impl CapturedRelatedFile { zeta_prompt::RelatedFile { path: self.path.clone(), max_row: self.max_row, + in_open_source_repo: false, excerpts: self .excerpts .iter() diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index eba5f05f7b228c7468ecb8fbfde60feff568cebf..9d546ff8272c0d8acd4a22285c3aa069ea4c525a 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -97,6 +97,10 @@ impl Mercury { - context_offset_range.start) ..(editable_offset_range.end - context_offset_range.start), excerpt_start_row: Some(context_start_row), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, }; let prompt = build_prompt(&inputs); diff --git a/crates/edit_prediction/src/ollama.rs b/crates/edit_prediction/src/ollama.rs index a79b61559cbcd7a74ae7619ee54b115eb576a637..8de90ba67ee2a9ecac3e52d7fee101b8db84c54e 100644 --- a/crates/edit_prediction/src/ollama.rs +++ b/crates/edit_prediction/src/ollama.rs @@ -169,6 +169,10 @@ impl Ollama { - context_offset_range.start) ..(editable_offset_range.end - context_offset_range.start), excerpt_start_row: Some(input_excerpt.context_range.start.row), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, }; (prompt, stop_tokens, Some(editable_offset_range), inputs) @@ -195,6 +199,10 @@ impl Ollama { .text_for_range(excerpt_range) .collect::() .into(), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, }; let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string(); diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 8d4a40d8b9ddf7a2ed8a68773da83a9498c4d516..750b1a435ae4a7a281ef41973e1f6d0d2158445e 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -158,6 +158,10 @@ mod tests { cursor_excerpt: "".into(), editable_range_in_excerpt: 0..0, excerpt_start_row: None, + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index b42f54b7a89ea3f858501529d785c9013d490c99..1253916487894d757c74293c21f4ace1c681cd11 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -219,6 +219,10 @@ impl SweepAi { editable_range_in_excerpt: 0..inputs.snapshot.len(), cursor_offset_in_excerpt: request_body.cursor_position, excerpt_start_row: Some(0), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, }; send_started_event( diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index d95a244e105678f45c8a465b9831578396a9e8f0..b3102455d7d4ac9640307ed706ca4cacc14d8592 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -1,26 +1,13 @@ -use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; +use std::{fmt::Write, ops::Range, sync::Arc}; -use crate::{ - DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput, - EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError, - cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count}, - prediction::EditPredictionResult, -}; +use crate::cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count}; use anyhow::Result; -use cloud_llm_client::{ - PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse, -}; +use cloud_llm_client::PredictEditsBody; use edit_prediction_types::PredictedCursorPosition; -use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task}; -use language::{ - Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff, -}; -use project::{Project, ProjectPath}; -use release_channel::AppVersion; +use language::{Anchor, BufferSnapshot, Point, text_diff}; use text::Bias; -use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; use zeta_prompt::{ - Event, ZetaPromptInput, + Event, zeta1::{ CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER, @@ -28,260 +15,8 @@ use zeta_prompt::{ }; pub(crate) const MAX_CONTEXT_TOKENS: usize = 150; -pub(crate) const MAX_REWRITE_TOKENS: usize = 350; pub(crate) const MAX_EVENT_TOKENS: usize = 500; -pub(crate) fn request_prediction_with_zeta1( - store: &mut EditPredictionStore, - EditPredictionModelInput { - project, - buffer, - snapshot, - position, - events, - trigger, - debug_tx, - .. - }: EditPredictionModelInput, - cx: &mut Context, -) -> Task>> { - let buffer_snapshotted_at = Instant::now(); - let client = store.client.clone(); - let llm_token = store.llm_token.clone(); - let app_version = AppVersion::global(cx); - - let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { - let can_collect_file = store.can_collect_file(&project, file, cx); - let git_info = if can_collect_file { - git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx) - } else { - None - }; - (git_info, can_collect_file) - } else { - (None, false) - }; - - let full_path: Arc = snapshot - .file() - .map(|f| Arc::from(f.full_path(cx).as_path())) - .unwrap_or_else(|| Arc::from(Path::new("untitled"))); - let full_path_str = full_path.to_string_lossy().into_owned(); - let cursor_point = position.to_point(&snapshot); - let prompt_for_events = { - let events = events.clone(); - move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS) - }; - let gather_task = gather_context( - full_path_str, - &snapshot, - cursor_point, - prompt_for_events, - trigger, - cx, - ); - - let uri = match client - .http_client() - .build_zed_llm_url("/predict_edits/v2", &[]) - { - Ok(url) => Arc::from(url), - Err(err) => return Task::ready(Err(err)), - }; - - cx.spawn(async move |this, cx| { - let GatherContextOutput { - mut body, - context_range, - editable_range, - included_events_count, - } = gather_task.await?; - let done_gathering_context_at = Instant::now(); - - let included_events = &events[events.len() - included_events_count..events.len()]; - body.can_collect_data = can_collect_file - && this - .read_with(cx, |this, cx| this.can_collect_events(included_events, cx)) - .unwrap_or(false); - if body.can_collect_data { - body.git_info = git_info; - } - - log::debug!( - "Events:\n{}\nExcerpt:\n{:?}", - body.input_events, - body.input_excerpt - ); - - let response = EditPredictionStore::send_api_request::( - |request| { - Ok(request - .uri(uri.as_str()) - .body(serde_json::to_string(&body)?.into())?) - }, - client, - llm_token, - app_version, - true, - ) - .await; - - let context_start_offset = context_range.start.to_offset(&snapshot); - let context_start_row = context_range.start.row; - let editable_offset_range = editable_range.to_offset(&snapshot); - - let inputs = ZetaPromptInput { - events: included_events.into(), - related_files: vec![], - cursor_path: full_path, - cursor_excerpt: snapshot - .text_for_range(context_range) - .collect::() - .into(), - editable_range_in_excerpt: (editable_range.start - context_start_offset) - ..(editable_offset_range.end - context_start_offset), - cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset, - excerpt_start_row: Some(context_start_row), - }; - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(DebugEvent::EditPredictionStarted( - EditPredictionStartedDebugEvent { - buffer: buffer.downgrade(), - prompt: Some(serde_json::to_string(&inputs).unwrap()), - position, - }, - )) - .ok(); - } - - let (response, usage) = match response { - Ok(response) => response, - Err(err) => { - if err.is::() { - cx.update(|cx| { - this.update(cx, |ep_store, _cx| { - ep_store.update_required = true; - }) - .ok(); - - let error_message: SharedString = err.to_string().into(); - show_app_notification( - NotificationId::unique::(), - cx, - move |cx| { - cx.new(|cx| { - ErrorMessagePrompt::new(error_message.clone(), cx) - .with_link_button("Update Zed", "https://zed.dev/releases") - }) - }, - ); - }); - } - - return Err(err); - } - }; - - let received_response_at = Instant::now(); - log::debug!("completion response: {}", &response.output_excerpt); - - if let Some(usage) = usage { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); - }); - }) - .ok(); - } - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(DebugEvent::EditPredictionFinished( - EditPredictionFinishedDebugEvent { - buffer: buffer.downgrade(), - model_output: Some(response.output_excerpt.clone()), - position, - }, - )) - .ok(); - } - - let edit_prediction = process_completion_response( - response, - buffer, - &snapshot, - editable_range, - inputs, - buffer_snapshotted_at, - received_response_at, - cx, - ) - .await; - - let finished_at = Instant::now(); - - // record latency for ~1% of requests - if rand::random::() <= 2 { - telemetry::event!( - "Edit Prediction Request", - context_latency = done_gathering_context_at - .duration_since(buffer_snapshotted_at) - .as_millis(), - request_latency = received_response_at - .duration_since(done_gathering_context_at) - .as_millis(), - process_latency = finished_at.duration_since(received_response_at).as_millis() - ); - } - - edit_prediction.map(Some) - }) -} - -fn process_completion_response( - prediction_response: PredictEditsResponse, - buffer: Entity, - snapshot: &BufferSnapshot, - editable_range: Range, - inputs: ZetaPromptInput, - buffer_snapshotted_at: Instant, - received_response_at: Instant, - cx: &AsyncApp, -) -> Task> { - let snapshot = snapshot.clone(); - let request_id = prediction_response.request_id; - let output_excerpt = prediction_response.output_excerpt; - cx.spawn(async move |cx| { - let output_excerpt: Arc = output_excerpt.into(); - - let edits: Arc<[(Range, Arc)]> = cx - .background_spawn({ - let output_excerpt = output_excerpt.clone(); - let editable_range = editable_range.clone(); - let snapshot = snapshot.clone(); - async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) } - }) - .await? - .into(); - - let id = EditPredictionId(request_id.into()); - Ok(EditPredictionResult::new( - id, - &buffer, - &snapshot, - edits, - None, - buffer_snapshotted_at, - received_response_at, - inputs, - cx, - ) - .await) - }) -} - pub(crate) fn parse_edits( output_excerpt: &str, editable_range: Range, @@ -452,35 +187,6 @@ fn common_prefix, T2: Iterator>(a: T1, b: .sum() } -fn git_info_for_file( - project: &Entity, - project_path: &ProjectPath, - cx: &App, -) -> Option { - let git_store = project.read(cx).git_store().read(cx); - if let Some((repository, _repo_path)) = - git_store.repository_and_path_for_project_path(project_path, cx) - { - let repository = repository.read(cx); - let head_sha = repository - .head_commit - .as_ref() - .map(|head_commit| head_commit.sha.to_string()); - let remote_origin_url = repository.remote_origin_url.clone(); - let remote_upstream_url = repository.remote_upstream_url.clone(); - if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { - return None; - } - Some(PredictEditsGitInfo { - head_sha, - remote_origin_url, - remote_upstream_url, - }) - } else { - None - } -} - pub struct GatherContextOutput { pub body: PredictEditsBody, pub context_range: Range, @@ -488,48 +194,6 @@ pub struct GatherContextOutput { pub included_events_count: usize, } -pub fn gather_context( - full_path_str: String, - snapshot: &BufferSnapshot, - cursor_point: language::Point, - prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static, - trigger: PredictEditsRequestTrigger, - cx: &App, -) -> Task> { - cx.background_spawn({ - let snapshot = snapshot.clone(); - async move { - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &full_path_str, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let (input_events, included_events_count) = prompt_for_events(); - let editable_range = input_excerpt.editable_range.to_offset(&snapshot); - - let body = PredictEditsBody { - input_events, - input_excerpt: input_excerpt.prompt, - can_collect_data: false, - diagnostic_groups: None, - git_info: None, - outline: None, - speculated_output: None, - trigger, - }; - - Ok(GatherContextOutput { - body, - context_range: input_excerpt.context_range, - editable_range, - included_events_count, - }) - } - }) -} - pub(crate) fn prompt_for_events(events: &[Arc], max_tokens: usize) -> String { prompt_for_events_impl(events, max_tokens).0 } @@ -656,6 +320,7 @@ mod tests { use gpui::{App, AppContext}; use indoc::indoc; use language::Buffer; + use text::OffsetRangeExt as _; #[gpui::test] fn test_excerpt_for_cursor_position(cx: &mut App) { diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 2a3efa5c803aee1ed53572c506d238317fc9842a..78cbfa2082751e36cb54021a586fd913669d79a1 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,10 +1,11 @@ +use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets}; use crate::prediction::EditPredictionResult; use crate::zeta1::compute_edits_and_cursor_position; use crate::{ CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, }; -use anyhow::{Result, anyhow}; +use anyhow::Result; use cloud_llm_client::predict_edits_v3::RawCompletionRequest; use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason}; use gpui::{App, Task, prelude::*}; @@ -13,7 +14,10 @@ use release_channel::AppVersion; use std::env; use std::{path::Path, sync::Arc, time::Instant}; -use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt}; +use zeta_prompt::{ + CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output, + format_zeta_prompt, get_prefill, +}; pub const MAX_CONTEXT_TOKENS: usize = 350; @@ -23,6 +27,8 @@ pub fn max_editable_tokens(format: ZetaFormat) -> usize { ZetaFormat::V0114180EditableRegion => 180, ZetaFormat::V0120GitMergeMarkers => 180, ZetaFormat::V0131GitMergeMarkersPrefix => 180, + ZetaFormat::V0211Prefill => 180, + ZetaFormat::V0211SeedCoder => 180, } } @@ -36,24 +42,32 @@ pub fn request_prediction_with_zeta2( events, debug_tx, trigger, + project, .. }: EditPredictionModelInput, + preferred_model: Option, cx: &mut Context, ) -> Task>> { let buffer_snapshotted_at = Instant::now(); let raw_config = store.zeta2_raw_config().cloned(); - let Some(excerpt_path) = snapshot + let excerpt_path: Arc = snapshot .file() .map(|file| -> Arc { file.full_path(cx).into() }) - else { - return Task::ready(Err(anyhow!("No file path for excerpt"))); - }; + .unwrap_or_else(|| Arc::from(Path::new("untitled"))); let client = store.client.clone(); let llm_token = store.llm_token.clone(); let app_version = AppVersion::global(cx); + let is_open_source = snapshot + .file() + .map_or(false, |file| store.is_file_open_source(&project, file, cx)) + && events.iter().all(|event| event.in_open_source_repo()) + && related_files.iter().all(|file| file.in_open_source_repo); + + let can_collect_data = is_open_source && store.is_data_collection_enabled(cx); + let request_task = cx.background_spawn({ async move { let zeta_version = raw_config @@ -69,6 +83,9 @@ pub fn request_prediction_with_zeta2( excerpt_path, cursor_offset, zeta_version, + preferred_model, + is_open_source, + can_collect_data, ); if let Some(debug_tx) = &debug_tx { @@ -88,6 +105,8 @@ pub fn request_prediction_with_zeta2( let (request_id, output_text, usage) = if let Some(config) = &raw_config { let prompt = format_zeta_prompt(&prompt_input, config.format); + let prefill = get_prefill(&prompt_input, config.format); + let prompt = format!("{prompt}{prefill}"); let request = RawCompletionRequest { model: config.model_id.clone().unwrap_or_default(), prompt, @@ -108,7 +127,9 @@ pub fn request_prediction_with_zeta2( let request_id = EditPredictionId(response.id.clone().into()); let output_text = response.choices.pop().map(|choice| { - clean_zeta2_model_output(&choice.text, config.format).to_string() + let response = &choice.text; + let output = format!("{prefill}{response}"); + clean_zeta2_model_output(&output, config.format).to_string() }); (request_id, output_text, usage) @@ -241,41 +262,54 @@ pub fn zeta2_prompt_input( excerpt_path: Arc, cursor_offset: usize, zeta_format: ZetaFormat, + preferred_model: Option, + is_open_source: bool, + can_collect_data: bool, ) -> (std::ops::Range, zeta_prompt::ZetaPromptInput) { let cursor_point = cursor_offset.to_point(snapshot); - let (editable_range, context_range) = - crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( - cursor_point, - snapshot, - max_editable_tokens(zeta_format), - MAX_CONTEXT_TOKENS, - ); + let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot); let related_files = crate::filter_redundant_excerpts( related_files, excerpt_path.as_ref(), - context_range.start.row..context_range.end.row, + full_context.start.row..full_context.end.row, ); - let context_start_offset = context_range.start.to_offset(snapshot); - let context_start_row = context_range.start.row; + let full_context_start_offset = full_context.start.to_offset(snapshot); + let full_context_start_row = full_context.start.row; + + let excerpt_ranges = + excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot); + + let editable_range = match preferred_model { + Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350, + _ => match zeta_format { + ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150, + _ => &range_points.editable_180, + }, + }; + let editable_offset_range = editable_range.to_offset(snapshot); - let cursor_offset_in_excerpt = cursor_offset - context_start_offset; - let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset) - ..(editable_offset_range.end - context_start_offset); + let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset; + let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset) + ..(editable_offset_range.end - full_context_start_offset); let prompt_input = zeta_prompt::ZetaPromptInput { cursor_path: excerpt_path, cursor_excerpt: snapshot - .text_for_range(context_range) + .text_for_range(full_context) .collect::() .into(), editable_range_in_excerpt, cursor_offset_in_excerpt, - excerpt_start_row: Some(context_start_row), + excerpt_start_row: Some(full_context_start_row), events, related_files, + excerpt_ranges: Some(excerpt_ranges), + preferred_model, + in_open_source_repo: is_open_source, + can_collect_data, }; (editable_offset_range, prompt_input) } diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index d0627d0fadfe019dd10da0ab237a0f8829e32058..5fd81afd30a6e3f9e643702361a8cf80b8b47b60 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -76,6 +76,8 @@ pub struct ExamplePrompt { pub input: String, pub expected_output: String, pub rejected_output: Option, // For DPO + #[serde(default)] + pub prefill: Option, pub provider: PredictionProvider, } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index c0f078ed9af489c358695db80136dec854b0f532..c52fe4d4b47dd6454dbdb540b17bbfca9a7a7ce4 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -65,6 +65,7 @@ pub async fn run_format_prompt( input: prompt, expected_output: String::new(), rejected_output: None, + prefill: None, provider: args.provider, }); } @@ -92,8 +93,17 @@ pub async fn run_format_prompt( excerpt_start_row: prompt_inputs.excerpt_start_row, events: prompt_inputs.edit_history.clone(), related_files: prompt_inputs.related_files.clone().unwrap_or_default(), + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: example + .spec + .captured_prompt_input + .as_ref() + .map_or(false, |input| input.in_open_source_repo), + can_collect_data: false, }; let prompt = format_zeta_prompt(&input, version); + let prefill = zeta_prompt::get_prefill(&input, version); let (expected_patch, expected_cursor_offset) = example .spec .expected_patches_with_cursor_positions() @@ -113,6 +123,7 @@ pub async fn run_format_prompt( expected_output, rejected_output, provider: args.provider, + prefill: Some(prefill), }); } _ => { @@ -155,7 +166,9 @@ pub fn zeta2_output_for_patch( } match version { - ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => { + ZetaFormat::V0120GitMergeMarkers + | ZetaFormat::V0131GitMergeMarkersPrefix + | ZetaFormat::V0211SeedCoder => { if !result.ends_with('\n') { result.push('\n'); } diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index e45060924d07a992ec2e563e5b16c3f85938ee2d..1eda4c94d6f78499eb185002a197107e373d5bb8 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -55,10 +55,16 @@ fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result { ("<|fim_middle|>current\n", "<|fim_suffix|>") } - ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => ( + ZetaFormat::V0120GitMergeMarkers + | ZetaFormat::V0131GitMergeMarkersPrefix + | ZetaFormat::V0211Prefill => ( zeta_prompt::v0120_git_merge_markers::START_MARKER, zeta_prompt::v0120_git_merge_markers::SEPARATOR, ), + ZetaFormat::V0211SeedCoder => ( + zeta_prompt::seed_coder::START_MARKER, + zeta_prompt::seed_coder::SEPARATOR, + ), }; let start = prompt.find(current_marker).with_context(|| { @@ -101,11 +107,14 @@ fn parse_zeta2_output( }; let suffix = match format { - ZetaFormat::V0131GitMergeMarkersPrefix => { + ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => { zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER } ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER, - _ => "", + ZetaFormat::V0112MiddleAtEnd + | ZetaFormat::V0113Ordered + | ZetaFormat::V0114180EditableRegion => "", + ZetaFormat::V0211SeedCoder => zeta_prompt::seed_coder::END_MARKER, }; if !suffix.is_empty() { new_text = new_text diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 5979439a2a7f3a66bfe94881bd04b9d948fe3c7e..075d5749b82103de8a2cd9951cc5f1f8b6160f6a 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -159,6 +159,7 @@ pub async fn run_prediction( expected_output: String::new(), rejected_output: None, provider, + prefill: None, }); } } diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index b48cc09e13b02cac85033786e780533304fa6de4..46ee3ba590ed98aad0e05aac527cf671018fd162 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -1304,6 +1304,7 @@ fn build_example_from_snowflake( excerpt_start_row: None, events, related_files, + in_open_source_repo: input.in_open_source_repo, }), telemetry: Some(TelemetrySource { request_id, diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 79bfdfa192a7d52d7f1189b93e164290380c71ea..0ae9253a49c81b50183c10cdce3877d8e41b64a8 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -136,11 +136,13 @@ impl RelatedExcerptStore { .collect() } - pub fn related_files_with_buffers(&mut self, cx: &App) -> Vec<(RelatedFile, Entity)> { + pub fn related_files_with_buffers( + &mut self, + cx: &App, + ) -> impl Iterator)> { self.related_buffers .iter_mut() .map(|related| (related.related_file(cx), related.buffer.clone())) - .collect::>() } pub fn set_related_files(&mut self, files: Vec, cx: &App) { @@ -424,6 +426,7 @@ impl RelatedBuffer { path, excerpts: cached.excerpts.clone(), max_row: buffer.max_point().row, + in_open_source_repo: false, }; return related_file; } diff --git a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs index 078bf0c56192b7ab5ea13b76d0940710ece2378d..79c53aea2a2fb5de9c137cbba4f5fa751db1f170 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs @@ -89,7 +89,6 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) { let company_buffer = related_excerpt_store.update(cx, |store, cx| { store .related_files_with_buffers(cx) - .into_iter() .find(|(file, _)| file.path.to_str() == Some("root/src/company.rs")) .map(|(_, buffer)| buffer) .expect("company.rs buffer not found") diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 73fdecbd134f2346f22304ae84c76ab53c1636c4..a779fbcc7166e5b70bdf431e0649d50101f2cf3a 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -9,10 +9,41 @@ use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr}; pub const CURSOR_MARKER: &str = "<|user_cursor|>"; pub const MAX_PROMPT_TOKENS: usize = 4096; +/// Use up to this amount of the editable region for prefill. +/// Larger values may result in more robust generation, but +/// this region becomes non-editable. +pub const PREFILL_RATIO: f64 = 0.1; // 10% + fn estimate_tokens(bytes: usize) -> usize { bytes / 3 } +/// The client's preferred edit prediction model. The server may override this. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum EditPredictionModelKind { + Zeta1, + Zeta2, +} + +/// Pre-computed byte offset ranges within `cursor_excerpt` for different +/// editable and context token budgets. Allows the server to select the +/// appropriate ranges for whichever model it uses. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExcerptRanges { + /// Editable region computed with a 150-token budget. + pub editable_150: Range, + /// Editable region computed with a 180-token budget. + pub editable_180: Range, + /// Editable region computed with a 350-token budget. + pub editable_350: Range, + /// Context boundary when using editable_150 with 350 tokens of additional context. + pub editable_150_context_350: Range, + /// Context boundary when using editable_180 with 350 tokens of additional context. + pub editable_180_context_350: Range, + /// Context boundary when using editable_350 with 150 tokens of additional context. + pub editable_350_context_150: Range, +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ZetaPromptInput { pub cursor_path: Arc, @@ -23,6 +54,19 @@ pub struct ZetaPromptInput { pub excerpt_start_row: Option, pub events: Vec>, pub related_files: Vec, + /// When set, the excerpt was computed with a larger budget (~512 tokens) + /// and these ranges let the server select model-appropriate subsets. + /// When absent, the excerpt IS the context region and + /// `editable_range_in_excerpt` is the only editable range. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub excerpt_ranges: Option, + /// Client's preferred model. The server may override. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub preferred_model: Option, + #[serde(default)] + pub in_open_source_repo: bool, + #[serde(default)] + pub can_collect_data: bool, } #[derive( @@ -46,6 +90,8 @@ pub enum ZetaFormat { V0114180EditableRegion, V0120GitMergeMarkers, V0131GitMergeMarkersPrefix, + V0211Prefill, + V0211SeedCoder, } impl std::fmt::Display for ZetaFormat { @@ -96,6 +142,17 @@ pub enum Event { }, } +impl Event { + pub fn in_open_source_repo(&self) -> bool { + match self { + Event::BufferChange { + in_open_source_repo, + .. + } => *in_open_source_repo, + } + } +} + pub fn write_event(prompt: &mut String, event: &Event) { fn write_path_as_unix_str(prompt: &mut String, path: &Path) { for component in path.components() { @@ -129,6 +186,8 @@ pub struct RelatedFile { pub path: Arc, pub max_row: u32, pub excerpts: Vec, + #[serde(default)] + pub in_open_source_repo: bool, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -150,41 +209,123 @@ pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str { ZetaFormat::V0131GitMergeMarkersPrefix => output .strip_suffix(v0131_git_merge_markers_prefix::END_MARKER) .unwrap_or(output), + ZetaFormat::V0211SeedCoder => output + .strip_suffix(seed_coder::END_MARKER) + .unwrap_or(output), _ => output, } } +fn resolve_cursor_region( + input: &ZetaPromptInput, + format: ZetaFormat, +) -> (&str, Range, usize) { + let Some(ranges) = &input.excerpt_ranges else { + return ( + &input.cursor_excerpt, + input.editable_range_in_excerpt.clone(), + input.cursor_offset_in_excerpt, + ); + }; + + let (editable_range, context_range) = match format { + ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => ( + ranges.editable_150.clone(), + ranges.editable_150_context_350.clone(), + ), + ZetaFormat::V0114180EditableRegion + | ZetaFormat::V0120GitMergeMarkers + | ZetaFormat::V0131GitMergeMarkersPrefix + | ZetaFormat::V0211Prefill + | ZetaFormat::V0211SeedCoder => ( + ranges.editable_180.clone(), + ranges.editable_180_context_350.clone(), + ), + }; + + let context_start = context_range.start; + let context_text = &input.cursor_excerpt[context_range]; + let adjusted_editable = + (editable_range.start - context_start)..(editable_range.end - context_start); + let adjusted_cursor = input.cursor_offset_in_excerpt - context_start; + + (context_text, adjusted_editable, adjusted_cursor) +} + fn format_zeta_prompt_with_budget( input: &ZetaPromptInput, format: ZetaFormat, max_tokens: usize, ) -> String { + let (context, editable_range, cursor_offset) = resolve_cursor_region(input, format); + let path = &*input.cursor_path; + let mut cursor_section = String::new(); match format { ZetaFormat::V0112MiddleAtEnd => { - v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input); + v0112_middle_at_end::write_cursor_excerpt_section( + &mut cursor_section, + path, + context, + &editable_range, + cursor_offset, + ); } ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => { - v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input) + v0113_ordered::write_cursor_excerpt_section( + &mut cursor_section, + path, + context, + &editable_range, + cursor_offset, + ) } - ZetaFormat::V0120GitMergeMarkers => { - v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input) + ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::write_cursor_excerpt_section( + &mut cursor_section, + path, + context, + &editable_range, + cursor_offset, + ), + ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => { + v0131_git_merge_markers_prefix::write_cursor_excerpt_section( + &mut cursor_section, + path, + context, + &editable_range, + cursor_offset, + ) } - ZetaFormat::V0131GitMergeMarkersPrefix => { - v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input) + ZetaFormat::V0211SeedCoder => { + return seed_coder::format_prompt_with_budget( + path, + context, + &editable_range, + cursor_offset, + &input.events, + &input.related_files, + max_tokens, + ); } } let cursor_tokens = estimate_tokens(cursor_section.len()); let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens); - let edit_history_section = - format_edit_history_within_budget(&input.events, budget_after_cursor); + let edit_history_section = format_edit_history_within_budget( + &input.events, + "<|file_sep|>", + "edit history", + budget_after_cursor, + ); let edit_history_tokens = estimate_tokens(edit_history_section.len()); let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens); - let related_files_section = - format_related_files_within_budget(&input.related_files, budget_after_edit_history); + let related_files_section = format_related_files_within_budget( + &input.related_files, + "<|file_sep|>", + budget_after_edit_history, + ); let mut prompt = String::new(); prompt.push_str(&related_files_section); @@ -193,8 +334,25 @@ fn format_zeta_prompt_with_budget( prompt } -fn format_edit_history_within_budget(events: &[Arc], max_tokens: usize) -> String { - let header = "<|file_sep|>edit history\n"; +pub fn get_prefill(input: &ZetaPromptInput, format: ZetaFormat) -> String { + match format { + ZetaFormat::V0112MiddleAtEnd + | ZetaFormat::V0113Ordered + | ZetaFormat::V0114180EditableRegion + | ZetaFormat::V0120GitMergeMarkers + | ZetaFormat::V0131GitMergeMarkersPrefix + | ZetaFormat::V0211SeedCoder => String::new(), + ZetaFormat::V0211Prefill => v0211_prefill::get_prefill(input), + } +} + +fn format_edit_history_within_budget( + events: &[Arc], + file_marker: &str, + edit_history_name: &str, + max_tokens: usize, +) -> String { + let header = format!("{}{}\n", file_marker, edit_history_name); let header_tokens = estimate_tokens(header.len()); if header_tokens >= max_tokens { return String::new(); @@ -219,21 +377,25 @@ fn format_edit_history_within_budget(events: &[Arc], max_tokens: usize) - return String::new(); } - let mut result = String::from(header); + let mut result = header; for event_str in event_strings.iter().rev() { - result.push_str(&event_str); + result.push_str(event_str); } result } -fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: usize) -> String { +fn format_related_files_within_budget( + related_files: &[RelatedFile], + file_marker: &str, + max_tokens: usize, +) -> String { let mut result = String::new(); let mut total_tokens = 0; for file in related_files { let path_str = file.path.to_string_lossy(); - let header_len = "<|file_sep|>".len() + path_str.len() + 1; - let header_tokens = estimate_tokens(header_len); + let header = format!("{}{}\n", file_marker, path_str); + let header_tokens = estimate_tokens(header.len()); if total_tokens + header_tokens > max_tokens { break; @@ -246,12 +408,8 @@ fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: let needs_newline = !excerpt.text.ends_with('\n'); let needs_ellipsis = excerpt.row_range.end < file.max_row; let excerpt_len = excerpt.text.len() - + if needs_newline { "\n".len() } else { "".len() } - + if needs_ellipsis { - "...\n".len() - } else { - "".len() - }; + + if needs_newline { "\n".len() } else { 0 } + + if needs_ellipsis { "...\n".len() } else { 0 }; let excerpt_tokens = estimate_tokens(excerpt_len); if total_tokens + file_tokens + excerpt_tokens > max_tokens { @@ -263,7 +421,7 @@ fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: if excerpts_to_include > 0 { total_tokens += file_tokens; - write!(result, "<|file_sep|>{}\n", path_str).ok(); + result.push_str(&header); for excerpt in file.excerpts.iter().take(excerpts_to_include) { result.push_str(&excerpt.text); if !result.ends_with('\n') { @@ -306,29 +464,29 @@ pub fn write_related_files( mod v0112_middle_at_end { use super::*; - pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); prompt.push_str("<|fim_prefix|>\n"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + prompt.push_str(&context[..editable_range.start]); prompt.push_str("<|fim_suffix|>\n"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + prompt.push_str(&context[editable_range.end..]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str("<|fim_middle|>current\n"); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + prompt.push_str(&context[editable_range.start..cursor_offset]); prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + prompt.push_str(&context[cursor_offset..editable_range.end]); if !prompt.ends_with('\n') { prompt.push('\n'); } @@ -340,32 +498,32 @@ mod v0112_middle_at_end { mod v0113_ordered { use super::*; - pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); prompt.push_str("<|fim_prefix|>\n"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + prompt.push_str(&context[..editable_range.start]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str("<|fim_middle|>current\n"); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + prompt.push_str(&context[editable_range.start..cursor_offset]); prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + prompt.push_str(&context[cursor_offset..editable_range.end]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str("<|fim_suffix|>\n"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + prompt.push_str(&context[editable_range.end..]); if !prompt.ends_with('\n') { prompt.push('\n'); } @@ -404,30 +562,30 @@ pub mod v0120_git_merge_markers { pub const SEPARATOR: &str = "=======\n"; pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; - pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); prompt.push_str("<|fim_prefix|>"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + prompt.push_str(&context[..editable_range.start]); prompt.push_str("<|fim_suffix|>"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + prompt.push_str(&context[editable_range.end..]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str("<|fim_middle|>"); prompt.push_str(START_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + prompt.push_str(&context[editable_range.start..cursor_offset]); prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + prompt.push_str(&context[cursor_offset..editable_range.end]); if !prompt.ends_with('\n') { prompt.push('\n'); } @@ -465,29 +623,29 @@ pub mod v0131_git_merge_markers_prefix { pub const SEPARATOR: &str = "=======\n"; pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; - pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { - let path_str = input.cursor_path.to_string_lossy(); + pub fn write_cursor_excerpt_section( + prompt: &mut String, + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) { + let path_str = path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); prompt.push_str("<|fim_prefix|>"); - prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + prompt.push_str(&context[..editable_range.start]); prompt.push_str(START_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], - ); + prompt.push_str(&context[editable_range.start..cursor_offset]); prompt.push_str(CURSOR_MARKER); - prompt.push_str( - &input.cursor_excerpt - [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], - ); + prompt.push_str(&context[cursor_offset..editable_range.end]); if !prompt.ends_with('\n') { prompt.push('\n'); } prompt.push_str(SEPARATOR); prompt.push_str("<|fim_suffix|>"); - prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + prompt.push_str(&context[editable_range.end..]); if !prompt.ends_with('\n') { prompt.push('\n'); } @@ -496,8 +654,178 @@ pub mod v0131_git_merge_markers_prefix { } } +pub mod v0211_prefill { + use super::*; + + pub fn get_prefill(input: &ZetaPromptInput) -> String { + let editable_region = &input.cursor_excerpt + [input.editable_range_in_excerpt.start..input.editable_range_in_excerpt.end]; + + let prefill_len = (editable_region.len() as f64 * PREFILL_RATIO) as usize; + let prefill_len = editable_region.floor_char_boundary(prefill_len); + + // Find a token boundary to avoid splitting tokens in the prefill. + // In Qwen2.5-Coder, \n is always the END of a token (e.g. `;\n`, + // ` {\n`), and \n\n / \n\n\n are single tokens, so we must include + // the \n and consume any consecutive \n characters after it. + let prefill = &editable_region[..prefill_len]; + match prefill.rfind('\n') { + Some(pos) => { + let mut end = pos + 1; + while end < editable_region.len() + && editable_region.as_bytes().get(end) == Some(&b'\n') + { + end += 1; + } + editable_region[..end].to_string() + } + // No newline found. Fall back to splitting before the last space + // (word-level boundary) + None => match prefill.rfind(' ') { + Some(pos) => prefill[..pos].to_string(), + None => prefill.to_string(), + }, + } + } +} + +pub mod seed_coder { + //! Seed-Coder prompt format using SPM (Suffix-Prefix-Middle) FIM mode. + //! + //! Seed-Coder uses different FIM tokens and order than Qwen: + //! - SPM order: suffix comes FIRST, then prefix, then middle + //! - Tokens: `<[fim-suffix]>`, `<[fim-prefix]>`, `<[fim-middle]>` + //! - File markers: StarCoder-style `path` (single token + path) + //! + //! All context (related files, edit history) goes in the PREFIX section. + //! The suffix contains only code after the editable region. + //! + //! Example prompt: + //! + //! <[fim-suffix]> + //! code after editable region + //! <[fim-prefix]>related/file.py + //! related file content + //! + //! edit_history + //! --- a/some_file.py + //! +++ b/some_file.py + //! -old + //! +new + //! + //! path/to/target_file.py + //! code before editable region + //! <<<<<<< CURRENT + //! code that + //! needs to<|user_cursor|> + //! be rewritten + //! ======= + //! <[fim-middle]> + //! + //! Expected output (model generates): + //! + //! updated + //! code with + //! changes applied + //! >>>>>>> UPDATED + + use super::*; + + pub const FIM_SUFFIX: &str = "<[fim-suffix]>"; + pub const FIM_PREFIX: &str = "<[fim-prefix]>"; + pub const FIM_MIDDLE: &str = "<[fim-middle]>"; + pub const FILE_MARKER: &str = ""; + + pub const START_MARKER: &str = "<<<<<<< CURRENT\n"; + pub const SEPARATOR: &str = "=======\n"; + pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; + + pub fn format_prompt_with_budget( + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + events: &[Arc], + related_files: &[RelatedFile], + max_tokens: usize, + ) -> String { + let suffix_section = build_suffix_section(context, editable_range); + let cursor_prefix_section = + build_cursor_prefix_section(path, context, editable_range, cursor_offset); + + let suffix_tokens = estimate_tokens(suffix_section.len()); + let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len()); + let budget_after_cursor = max_tokens.saturating_sub(suffix_tokens + cursor_prefix_tokens); + + let edit_history_section = super::format_edit_history_within_budget( + events, + FILE_MARKER, + "edit_history", + budget_after_cursor, + ); + let edit_history_tokens = estimate_tokens(edit_history_section.len()); + let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens); + + let related_files_section = super::format_related_files_within_budget( + related_files, + FILE_MARKER, + budget_after_edit_history, + ); + + let mut prompt = String::new(); + prompt.push_str(&suffix_section); + prompt.push_str(FIM_PREFIX); + prompt.push_str(&related_files_section); + if !related_files_section.is_empty() { + prompt.push('\n'); + } + prompt.push_str(&edit_history_section); + if !edit_history_section.is_empty() { + prompt.push('\n'); + } + prompt.push_str(&cursor_prefix_section); + prompt.push_str(FIM_MIDDLE); + prompt + } + + fn build_suffix_section(context: &str, editable_range: &Range) -> String { + let mut section = String::new(); + section.push_str(FIM_SUFFIX); + section.push_str(&context[editable_range.end..]); + if !section.ends_with('\n') { + section.push('\n'); + } + section + } + + fn build_cursor_prefix_section( + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, + ) -> String { + let mut section = String::new(); + let path_str = path.to_string_lossy(); + write!(section, "{}{}\n", FILE_MARKER, path_str).ok(); + + section.push_str(&context[..editable_range.start]); + section.push_str(START_MARKER); + section.push_str(&context[editable_range.start..cursor_offset]); + section.push_str(CURSOR_MARKER); + section.push_str(&context[cursor_offset..editable_range.end]); + if !section.ends_with('\n') { + section.push('\n'); + } + section.push_str(SEPARATOR); + section + } +} + /// The zeta1 prompt format pub mod zeta1 { + use super::*; + use std::fmt::Write; + pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>"; pub const START_OF_FILE_MARKER: &str = "<|start_of_file|>"; pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>"; @@ -529,6 +857,166 @@ pub mod zeta1 { prompt.push_str(RESPONSE_HEADER); prompt } + + /// Formats a complete zeta1 prompt from a `ZetaPromptInput` using the given + /// editable and context byte-offset ranges within `cursor_excerpt`. + pub fn format_zeta1_from_input( + input: &ZetaPromptInput, + editable_range: Range, + context_range: Range, + ) -> String { + let events = format_zeta1_events(&input.events); + let excerpt = format_zeta1_excerpt(input, editable_range, context_range); + format_zeta1_prompt(&events, &excerpt) + } + + /// Formats events in zeta1 style (oldest first). + fn format_zeta1_events(events: &[Arc]) -> String { + let mut result = String::new(); + for event in events { + let event_string = format_zeta1_event(event); + if event_string.is_empty() { + continue; + } + if !result.is_empty() { + result.push_str("\n\n"); + } + result.push_str(&event_string); + } + result + } + + fn format_zeta1_event(event: &Event) -> String { + match event { + Event::BufferChange { + path, + old_path, + diff, + .. + } => { + let mut prompt = String::new(); + if old_path != path { + writeln!( + prompt, + "User renamed {} to {}\n", + old_path.display(), + path.display() + ) + .ok(); + } + if !diff.is_empty() { + write!( + prompt, + "User edited {}:\n```diff\n{}\n```", + path.display(), + diff + ) + .ok(); + } + prompt + } + } + } + + /// Formats the excerpt section of a zeta1 prompt using byte-offset ranges + /// within `cursor_excerpt`. + fn format_zeta1_excerpt( + input: &ZetaPromptInput, + editable_range: Range, + context_range: Range, + ) -> String { + let path_str = input.cursor_path.to_string_lossy(); + let excerpt = &*input.cursor_excerpt; + let cursor_offset = input.cursor_offset_in_excerpt; + + let mut prompt = String::new(); + writeln!(&mut prompt, "```{path_str}").ok(); + + let starts_at_file_beginning = + input.excerpt_start_row == Some(0) && context_range.start == 0; + if starts_at_file_beginning { + writeln!(&mut prompt, "{START_OF_FILE_MARKER}").ok(); + } + + prompt.push_str(&excerpt[context_range.start..editable_range.start]); + + writeln!(&mut prompt, "{EDITABLE_REGION_START_MARKER}").ok(); + prompt.push_str(&excerpt[editable_range.start..cursor_offset]); + prompt.push_str(CURSOR_MARKER); + prompt.push_str(&excerpt[cursor_offset..editable_range.end]); + write!(&mut prompt, "\n{EDITABLE_REGION_END_MARKER}").ok(); + + prompt.push_str(&excerpt[editable_range.end..context_range.end]); + write!(prompt, "\n```").ok(); + + prompt + } + + /// Cleans zeta1 model output by extracting content between editable region + /// markers and converting the zeta1 cursor marker to the universal one. + /// Returns `None` if the output doesn't contain the expected markers. + pub fn clean_zeta1_model_output(output: &str) -> Option { + let content = output.replace(CURSOR_MARKER, ""); + + let content_start = content + .find(EDITABLE_REGION_START_MARKER) + .map(|pos| pos + EDITABLE_REGION_START_MARKER.len()) + .map(|pos| { + if content.as_bytes().get(pos) == Some(&b'\n') { + pos + 1 + } else { + pos + } + }) + .unwrap_or(0); + + let content_end = content + .find(EDITABLE_REGION_END_MARKER) + .map(|pos| { + if pos > 0 && content.as_bytes().get(pos - 1) == Some(&b'\n') { + pos - 1 + } else { + pos + } + }) + .unwrap_or(content.len()); + + if content_start > content_end { + return Some(String::new()); + } + + let extracted = &content[content_start..content_end]; + + let cursor_offset = output.find(CURSOR_MARKER).map(|zeta1_cursor_pos| { + let text_before_cursor = output[..zeta1_cursor_pos].replace(CURSOR_MARKER, ""); + let text_before_cursor = text_before_cursor + .find(EDITABLE_REGION_START_MARKER) + .map(|pos| { + let after_marker = pos + EDITABLE_REGION_START_MARKER.len(); + if text_before_cursor.as_bytes().get(after_marker) == Some(&b'\n') { + after_marker + 1 + } else { + after_marker + } + }) + .unwrap_or(0); + let offset_in_extracted = zeta1_cursor_pos + .saturating_sub(text_before_cursor) + .min(extracted.len()); + offset_in_extracted + }); + + let mut result = String::with_capacity(extracted.len() + super::CURSOR_MARKER.len()); + if let Some(offset) = cursor_offset { + result.push_str(&extracted[..offset]); + result.push_str(super::CURSOR_MARKER); + result.push_str(&extracted[offset..]); + } else { + result.push_str(extracted); + } + + Some(result) + } } #[cfg(test)] @@ -551,6 +1039,10 @@ mod tests { excerpt_start_row: None, events: events.into_iter().map(Arc::new).collect(), related_files, + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, } } @@ -572,6 +1064,7 @@ mod tests { row_range: 0..content.lines().count() as u32, text: content.into(), }], + in_open_source_repo: false, } } @@ -673,6 +1166,7 @@ mod tests { vec![RelatedFile { path: Path::new("big.rs").into(), max_row: 30, + in_open_source_repo: false, excerpts: vec![ RelatedExcerpt { row_range: 0..10, @@ -792,4 +1286,322 @@ mod tests { "#} ); } + + fn format_seed_coder(input: &ZetaPromptInput) -> String { + format_zeta_prompt_with_budget(input, ZetaFormat::V0211SeedCoder, 10000) + } + + fn format_seed_coder_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String { + format_zeta_prompt_with_budget(input, ZetaFormat::V0211SeedCoder, max_tokens) + } + + #[test] + fn test_seed_coder_basic_format() { + let input = make_input( + "prefix\neditable\nsuffix", + 7..15, + 10, + vec![make_event("a.rs", "-old\n+new\n")], + vec![make_related_file("related.rs", "fn helper() {}\n")], + ); + + assert_eq!( + format_seed_coder(&input), + indoc! {r#" + <[fim-suffix]> + suffix + <[fim-prefix]>related.rs + fn helper() {} + + edit_history + --- a/a.rs + +++ b/a.rs + -old + +new + + test.rs + prefix + <<<<<<< CURRENT + edi<|user_cursor|>table + ======= + <[fim-middle]>"#} + ); + } + + #[test] + fn test_seed_coder_no_context() { + let input = make_input("before\nmiddle\nafter", 7..13, 10, vec![], vec![]); + + assert_eq!( + format_seed_coder(&input), + indoc! {r#" + <[fim-suffix]> + after + <[fim-prefix]>test.rs + before + <<<<<<< CURRENT + mid<|user_cursor|>dle + ======= + <[fim-middle]>"#} + ); + } + + #[test] + fn test_seed_coder_truncation_drops_context() { + let input = make_input( + "code", + 0..4, + 2, + vec![make_event("a.rs", "-x\n+y\n")], + vec![make_related_file("r1.rs", "content\n")], + ); + + // With large budget, everything is included + assert_eq!( + format_seed_coder(&input), + indoc! {r#" + <[fim-suffix]> + <[fim-prefix]>r1.rs + content + + edit_history + --- a/a.rs + +++ b/a.rs + -x + +y + + test.rs + <<<<<<< CURRENT + co<|user_cursor|>de + ======= + <[fim-middle]>"#} + ); + + // With tight budget, context is dropped but cursor section remains + assert_eq!( + format_seed_coder_with_budget(&input, 30), + indoc! {r#" + <[fim-suffix]> + <[fim-prefix]>test.rs + <<<<<<< CURRENT + co<|user_cursor|>de + ======= + <[fim-middle]>"#} + ); + } + + #[test] + fn test_seed_coder_clean_output() { + let output_with_marker = "new code\n>>>>>>> UPDATED\n"; + let output_without_marker = "new code\n"; + + assert_eq!( + clean_zeta2_model_output(output_with_marker, ZetaFormat::V0211SeedCoder), + "new code\n" + ); + assert_eq!( + clean_zeta2_model_output(output_without_marker, ZetaFormat::V0211SeedCoder), + "new code\n" + ); + } + + #[test] + fn test_format_zeta1_from_input_basic() { + let excerpt = "fn before() {}\nfn foo() {\n let x = 1;\n}\nfn after() {}\n"; + let input = ZetaPromptInput { + cursor_path: Path::new("src/main.rs").into(), + cursor_excerpt: excerpt.into(), + editable_range_in_excerpt: 15..41, + cursor_offset_in_excerpt: 30, + excerpt_start_row: Some(0), + events: vec![Arc::new(make_event("other.rs", "-old\n+new\n"))], + related_files: vec![], + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, + }; + + let prompt = zeta1::format_zeta1_from_input(&input, 15..41, 0..excerpt.len()); + + assert_eq!( + prompt, + concat!( + "### Instruction:\n", + "You are a code completion assistant and your task is to analyze user edits and then rewrite an ", + "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ", + "into account the cursor location.\n", + "\n", + "### User Edits:\n", + "\n", + "User edited other.rs:\n", + "```diff\n", + "-old\n", + "+new\n", + "\n", + "```\n", + "\n", + "### User Excerpt:\n", + "\n", + "```src/main.rs\n", + "<|start_of_file|>\n", + "fn before() {}\n", + "<|editable_region_start|>\n", + "fn foo() {\n", + " <|user_cursor_is_here|>let x = 1;\n", + "\n", + "<|editable_region_end|>}\n", + "fn after() {}\n", + "\n", + "```\n", + "\n", + "### Response:\n", + ), + ); + } + + #[test] + fn test_format_zeta1_from_input_no_start_of_file() { + let excerpt = "fn foo() {\n let x = 1;\n}\n"; + let input = ZetaPromptInput { + cursor_path: Path::new("src/main.rs").into(), + cursor_excerpt: excerpt.into(), + editable_range_in_excerpt: 0..28, + cursor_offset_in_excerpt: 15, + excerpt_start_row: Some(10), + events: vec![], + related_files: vec![], + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, + }; + + let prompt = zeta1::format_zeta1_from_input(&input, 0..28, 0..28); + + assert_eq!( + prompt, + concat!( + "### Instruction:\n", + "You are a code completion assistant and your task is to analyze user edits and then rewrite an ", + "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ", + "into account the cursor location.\n", + "\n", + "### User Edits:\n", + "\n", + "\n", + "\n", + "### User Excerpt:\n", + "\n", + "```src/main.rs\n", + "<|editable_region_start|>\n", + "fn foo() {\n", + " <|user_cursor_is_here|>let x = 1;\n", + "}\n", + "\n", + "<|editable_region_end|>\n", + "```\n", + "\n", + "### Response:\n", + ), + ); + } + + #[test] + fn test_format_zeta1_from_input_with_sub_ranges() { + let excerpt = "// prefix\nfn foo() {\n let x = 1;\n}\n// suffix\n"; + let editable_range = 10..37; + let context_range = 0..excerpt.len(); + + let input = ZetaPromptInput { + cursor_path: Path::new("test.rs").into(), + cursor_excerpt: excerpt.into(), + editable_range_in_excerpt: editable_range.clone(), + cursor_offset_in_excerpt: 25, + excerpt_start_row: Some(0), + events: vec![], + related_files: vec![], + excerpt_ranges: None, + preferred_model: None, + in_open_source_repo: false, + can_collect_data: false, + }; + + let prompt = zeta1::format_zeta1_from_input(&input, editable_range, context_range); + + assert_eq!( + prompt, + concat!( + "### Instruction:\n", + "You are a code completion assistant and your task is to analyze user edits and then rewrite an ", + "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ", + "into account the cursor location.\n", + "\n", + "### User Edits:\n", + "\n", + "\n", + "\n", + "### User Excerpt:\n", + "\n", + "```test.rs\n", + "<|start_of_file|>\n", + "// prefix\n", + "<|editable_region_start|>\n", + "fn foo() {\n", + " <|user_cursor_is_here|>let x = 1;\n", + "}\n", + "<|editable_region_end|>\n", + "// suffix\n", + "\n", + "```\n", + "\n", + "### Response:\n", + ), + ); + } + + #[test] + fn test_clean_zeta1_model_output_basic() { + let output = indoc! {" + <|editable_region_start|> + fn main() { + println!(\"hello\"); + } + <|editable_region_end|> + "}; + + let cleaned = zeta1::clean_zeta1_model_output(output).unwrap(); + assert_eq!(cleaned, "fn main() {\n println!(\"hello\");\n}"); + } + + #[test] + fn test_clean_zeta1_model_output_with_cursor() { + let output = indoc! {" + <|editable_region_start|> + fn main() { + <|user_cursor_is_here|>println!(\"hello\"); + } + <|editable_region_end|> + "}; + + let cleaned = zeta1::clean_zeta1_model_output(output).unwrap(); + assert_eq!( + cleaned, + "fn main() {\n <|user_cursor|>println!(\"hello\");\n}" + ); + } + + #[test] + fn test_clean_zeta1_model_output_no_markers() { + let output = "fn main() {}\n"; + let cleaned = zeta1::clean_zeta1_model_output(output).unwrap(); + assert_eq!(cleaned, "fn main() {}\n"); + } + + #[test] + fn test_clean_zeta1_model_output_empty_region() { + let output = "<|editable_region_start|>\n<|editable_region_end|>\n"; + let cleaned = zeta1::clean_zeta1_model_output(output).unwrap(); + assert_eq!(cleaned, ""); + } }