Detailed changes
@@ -15,8 +15,6 @@ use project::{Project, WorktreeId};
use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
-pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
-
pub fn capture_example(
project: Entity<Project>,
buffer: Entity<Buffer>,
@@ -156,6 +154,7 @@ pub fn capture_example(
excerpt_start_row: Some(0),
events: captured_events,
related_files: captured_related_files,
+ in_open_source_repo: false,
}
});
@@ -304,10 +303,6 @@ fn generate_timestamp_name() -> String {
}
}
-pub(crate) fn should_send_testing_zeta2_request() -> bool {
- rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -1,5 +1,81 @@
use language::{BufferSnapshot, Point};
use std::ops::Range;
+use zeta_prompt::ExcerptRanges;
+
+/// Pre-computed Point ranges for all editable/context budget combinations.
+pub struct ExcerptRangePoints {
+ pub editable_150: Range<Point>,
+ pub editable_180: Range<Point>,
+ pub editable_350: Range<Point>,
+ pub editable_150_context_350: Range<Point>,
+ pub editable_180_context_350: Range<Point>,
+ pub editable_350_context_150: Range<Point>,
+}
+
+/// Computes all range variants for a cursor position: editable ranges at 150, 180, and 350
+/// token budgets, plus their corresponding context expansions. Returns the full excerpt range
+/// (union of all context ranges) and the individual sub-ranges as Points.
+pub fn compute_excerpt_ranges(
+ position: Point,
+ snapshot: &BufferSnapshot,
+) -> (Range<Point>, ExcerptRangePoints) {
+ let editable_150 = compute_editable_range(snapshot, position, 150);
+ let editable_180 = compute_editable_range(snapshot, position, 180);
+ let editable_350 = compute_editable_range(snapshot, position, 350);
+
+ let editable_150_context_350 =
+ expand_context_syntactically_then_linewise(snapshot, editable_150.clone(), 350);
+ let editable_180_context_350 =
+ expand_context_syntactically_then_linewise(snapshot, editable_180.clone(), 350);
+ let editable_350_context_150 =
+ expand_context_syntactically_then_linewise(snapshot, editable_350.clone(), 150);
+
+ let full_start_row = editable_150_context_350
+ .start
+ .row
+ .min(editable_180_context_350.start.row)
+ .min(editable_350_context_150.start.row);
+ let full_end_row = editable_150_context_350
+ .end
+ .row
+ .max(editable_180_context_350.end.row)
+ .max(editable_350_context_150.end.row);
+
+ let full_context =
+ Point::new(full_start_row, 0)..Point::new(full_end_row, snapshot.line_len(full_end_row));
+
+ let ranges = ExcerptRangePoints {
+ editable_150,
+ editable_180,
+ editable_350,
+ editable_150_context_350,
+ editable_180_context_350,
+ editable_350_context_150,
+ };
+
+ (full_context, ranges)
+}
+
+/// Converts `ExcerptRangePoints` to byte-offset `ExcerptRanges` relative to `excerpt_start`.
+pub fn excerpt_ranges_to_byte_offsets(
+ ranges: &ExcerptRangePoints,
+ excerpt_start: usize,
+ snapshot: &BufferSnapshot,
+) -> ExcerptRanges {
+ let to_offset = |range: &Range<Point>| -> Range<usize> {
+ let start = range.start.to_offset(snapshot);
+ let end = range.end.to_offset(snapshot);
+ (start - excerpt_start)..(end - excerpt_start)
+ };
+ ExcerptRanges {
+ editable_150: to_offset(&ranges.editable_150),
+ editable_180: to_offset(&ranges.editable_180),
+ editable_350: to_offset(&ranges.editable_350),
+ editable_150_context_350: to_offset(&ranges.editable_150_context_350),
+ editable_180_context_350: to_offset(&ranges.editable_180_context_350),
+ editable_350_context_150: to_offset(&ranges.editable_350_context_150),
+ }
+}
pub fn editable_and_context_ranges_for_cursor_position(
position: Point,
@@ -312,6 +388,8 @@ fn expand_context_syntactically_then_linewise(
start..end
}
+use language::ToOffset as _;
+
#[cfg(test)]
mod tests {
use super::*;
@@ -72,7 +72,6 @@ pub mod zeta2;
#[cfg(test)]
mod edit_prediction_tests;
-use crate::capture_example::should_send_testing_zeta2_request;
use crate::license_detection::LicenseDetectionWatcher;
use crate::mercury::Mercury;
use crate::ollama::Ollama;
@@ -734,10 +733,19 @@ impl EditPredictionStore {
) -> Vec<RelatedFile> {
self.projects
.get(&project.entity_id())
- .map(|project| {
- project
- .context
- .update(cx, |context, cx| context.related_files(cx))
+ .map(|project_state| {
+ project_state.context.update(cx, |context, cx| {
+ context
+ .related_files_with_buffers(cx)
+ .map(|(mut related_file, buffer)| {
+ related_file.in_open_source_repo = buffer
+ .read(cx)
+ .file()
+ .map_or(false, |file| self.is_file_open_source(&project, file, cx));
+ related_file
+ })
+ .collect()
+ })
})
.unwrap_or_default()
}
@@ -785,9 +793,9 @@ impl EditPredictionStore {
self.projects
.get(&project.entity_id())
.map(|project| {
- project
- .context
- .update(cx, |context, cx| context.related_files_with_buffers(cx))
+ project.context.update(cx, |context, cx| {
+ context.related_files_with_buffers(cx).collect()
+ })
})
.unwrap_or_default()
}
@@ -1771,15 +1779,18 @@ impl EditPredictionStore {
};
let task = match &self.edit_prediction_model {
- EditPredictionModel::Zeta1 => {
- if should_send_testing_zeta2_request() {
- let mut zeta2_inputs = inputs.clone();
- zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
- zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach();
- }
- zeta1::request_prediction_with_zeta1(self, inputs, cx)
- }
- EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
+ EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2(
+ self,
+ inputs,
+ Some(zeta_prompt::EditPredictionModelKind::Zeta1),
+ cx,
+ ),
+ EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
+ self,
+ inputs,
+ Some(zeta_prompt::EditPredictionModelKind::Zeta2),
+ cx,
+ ),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
@@ -2136,25 +2147,6 @@ impl EditPredictionStore {
.is_some_and(|watcher| watcher.is_project_open_source())
}
- fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
- self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
- }
-
- fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
- if !self.data_collection_choice.is_enabled(cx) {
- return false;
- }
- events.iter().all(|event| {
- matches!(
- event.as_ref(),
- zeta_prompt::Event::BufferChange {
- in_open_source_repo: true,
- ..
- }
- )
- })
- }
-
fn load_data_collection_choice() -> DataCollectionChoice {
let choice = KEY_VALUE_STORE
.read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
@@ -1,11 +1,10 @@
use super::*;
-use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
+use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
use client::{UserStore, test::FakeServer};
-use clock::{FakeSystemClock, ReplicaId};
+use clock::FakeSystemClock;
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
use cloud_llm_client::{
- EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
- RejectEditPredictionsBody,
+ EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
};
use futures::{
@@ -26,7 +25,7 @@ use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use std::{path::Path, sync::Arc, time::Duration};
-use util::{path, rel_path::rel_path};
+use util::path;
use uuid::Uuid;
use zeta_prompt::ZetaPromptInput;
@@ -1424,8 +1423,6 @@ fn init_test_with_fake_client(
})
}
-const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
-
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
@@ -1452,6 +1449,9 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
editable_range_in_excerpt: 0..0,
cursor_offset_in_excerpt: 0,
excerpt_start_row: None,
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1555,13 +1555,10 @@ async fn test_clean_up_diff(cx: &mut TestAppContext) {
}
"},
indoc! {"
- <|editable_region_start|>
fn main() {
let word_1 = \"lorem\";
let range = word_1.len()..word_1.len();
}
-
- <|editable_region_end|>
"},
cx,
)
@@ -1582,12 +1579,9 @@ async fn test_clean_up_diff(cx: &mut TestAppContext) {
}
"},
indoc! {"
- <|editable_region_start|>
fn main() {
let story = \"the quick brown fox jumps over the lazy dog\";
}
-
- <|editable_region_end|>
"},
cx,
)
@@ -1605,18 +1599,11 @@ async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
init_test(cx);
let buffer_content = "lorem\n";
- let completion_response = indoc! {"
- ```animals.js
- <|start_of_file|>
- <|editable_region_start|>
- lorem
- ipsum
- <|editable_region_end|>
- ```"};
+ let completion_response = "lorem\nipsum\n";
assert_eq!(
apply_edit_prediction(buffer_content, completion_response, cx).await,
- "lorem\nipsum"
+ "lorem\nipsum\n"
);
}
@@ -1685,298 +1672,6 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte
});
}
-#[gpui::test]
-async fn test_can_collect_data(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/project/src/main.rs"), cx)
- })
- .await
- .unwrap();
-
- let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
- ep_store.update(cx, |ep_store, _cx| {
- ep_store.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- ep_store.update(cx, |ep_store, _cx| {
- ep_store.data_collection_choice = DataCollectionChoice::Disabled
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
-
- let buffer = cx.new(|_cx| {
- Buffer::remote(
- language::BufferId::new(1).unwrap(),
- ReplicaId::new(1),
- language::Capability::ReadWrite,
- "fn main() {\n println!(\"Hello\");\n}",
- )
- });
-
- let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
- ep_store.update(cx, |ep_store, _cx| {
- ep_store.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/project"),
- json!({
- "LICENSE": BSD_0_TXT,
- ".env": "SECRET_KEY=secret"
- }),
- )
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer("/project/.env", cx)
- })
- .await
- .unwrap();
-
- let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
- ep_store.update(cx, |ep_store, _cx| {
- ep_store.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
- let buffer = cx.new(|cx| Buffer::local("", cx));
-
- let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
- ep_store.update(cx, |ep_store, _cx| {
- ep_store.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer("/project/main.rs", cx)
- })
- .await
- .unwrap();
-
- let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
- ep_store.update(cx, |ep_store, _cx| {
- ep_store.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/open_source_worktree"),
- json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
- )
- .await;
- fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
- .await;
-
- let project = Project::test(
- fs.clone(),
- [
- path!("/open_source_worktree").as_ref(),
- path!("/closed_source_worktree").as_ref(),
- ],
- cx,
- )
- .await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
- })
- .await
- .unwrap();
-
- let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
- ep_store.update(cx, |ep_store, _cx| {
- ep_store.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- let closed_source_file = project
- .update(cx, |project, cx| {
- let worktree2 = project
- .worktree_for_root_name("closed_source_worktree", cx)
- .unwrap();
- worktree2.update(cx, |worktree2, cx| {
- worktree2.load_file(rel_path("main.rs"), cx)
- })
- })
- .await
- .unwrap()
- .file;
-
- buffer.update(cx, |buffer, cx| {
- buffer.file_updated(closed_source_file, cx);
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/worktree1"),
- json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
- )
- .await;
- fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
- .await;
-
- let project = Project::test(
- fs.clone(),
- [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
- cx,
- )
- .await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/worktree1/main.rs"), cx)
- })
- .await
- .unwrap();
- let private_buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/worktree2/file.rs"), cx)
- })
- .await
- .unwrap();
-
- let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
- ep_store.update(cx, |ep_store, _cx| {
- ep_store.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- // this has a side effect of registering the buffer to watch for edits
- run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-
- private_buffer.update(cx, |private_buffer, cx| {
- private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-
- // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
- // included
- buffer.update(cx, |buffer, cx| {
- buffer.edit(
- [(
- 0..0,
- " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
- )],
- None,
- cx,
- );
- });
-
- run_edit_prediction(&buffer, &project, &ep_store, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-}
-
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
@@ -1992,7 +1687,7 @@ async fn apply_edit_prediction(
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
- let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
+ let (ep_store, response) = make_test_ep_store(&project, cx).await;
*response.lock() = completion_response.to_string();
let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
buffer.update(cx, |buffer, cx| {
@@ -2021,28 +1716,13 @@ async fn run_edit_prediction(
async fn make_test_ep_store(
project: &Entity<Project>,
cx: &mut TestAppContext,
-) -> (
- Entity<EditPredictionStore>,
- Arc<Mutex<Option<PredictEditsBody>>>,
- Arc<Mutex<String>>,
-) {
- let default_response = indoc! {"
- ```main.rs
- <|start_of_file|>
- <|editable_region_start|>
- hello world
- <|editable_region_end|>
- ```"
- };
- let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
- let completion_response: Arc<Mutex<String>> =
- Arc::new(Mutex::new(default_response.to_string()));
+) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
+ let default_response = "hello world\n".to_string();
+ let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
let http_client = FakeHttpClient::create({
- let captured_request = captured_request.clone();
let completion_response = completion_response.clone();
let mut next_request_id = 0;
move |req| {
- let captured_request = captured_request.clone();
let completion_response = completion_response.clone();
async move {
match (req.method(), req.uri().path()) {
@@ -2056,24 +1736,6 @@ async fn make_test_ep_store(
.into(),
)
.unwrap()),
- (&Method::POST, "/predict_edits/v2") => {
- let mut request_body = String::new();
- req.into_body().read_to_string(&mut request_body).await?;
- *captured_request.lock() =
- Some(serde_json::from_str(&request_body).unwrap());
- next_request_id += 1;
- Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&PredictEditsResponse {
- request_id: format!("request-{next_request_id}"),
- output_excerpt: completion_response.lock().clone(),
- })
- .unwrap()
- .into(),
- )
- .unwrap())
- }
(&Method::POST, "/predict_edits/v3") => {
next_request_id += 1;
Ok(http_client::Response::builder()
@@ -2081,7 +1743,7 @@ async fn make_test_ep_store(
.body(
serde_json::to_string(&PredictEditsV3Response {
request_id: format!("request-{next_request_id}"),
- output: "hello world".to_string(),
+ output: completion_response.lock().clone(),
})
.unwrap()
.into(),
@@ -2120,7 +1782,7 @@ async fn make_test_ep_store(
ep_store
});
- (ep_store, captured_request, completion_response)
+ (ep_store, completion_response)
}
fn to_completion_edits(
@@ -66,6 +66,7 @@ pub struct CapturedPromptInput {
pub excerpt_start_row: Option<u32>,
pub events: Vec<CapturedEvent>,
pub related_files: Vec<CapturedRelatedFile>,
+ pub in_open_source_repo: bool,
}
#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
@@ -101,6 +102,7 @@ impl CapturedRelatedFile {
zeta_prompt::RelatedFile {
path: self.path.clone(),
max_row: self.max_row,
+ in_open_source_repo: false,
excerpts: self
.excerpts
.iter()
@@ -97,6 +97,9 @@ impl Mercury {
- context_offset_range.start)
..(editable_offset_range.end - context_offset_range.start),
excerpt_start_row: Some(context_start_row),
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
};
let prompt = build_prompt(&inputs);
@@ -169,6 +169,9 @@ impl Ollama {
- context_offset_range.start)
..(editable_offset_range.end - context_offset_range.start),
excerpt_start_row: Some(input_excerpt.context_range.start.row),
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
};
(prompt, stop_tokens, Some(editable_offset_range), inputs)
@@ -195,6 +198,9 @@ impl Ollama {
.text_for_range(excerpt_range)
.collect::<String>()
.into(),
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
};
let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
@@ -158,6 +158,9 @@ mod tests {
cursor_excerpt: "".into(),
editable_range_in_excerpt: 0..0,
excerpt_start_row: None,
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -219,6 +219,9 @@ impl SweepAi {
editable_range_in_excerpt: 0..inputs.snapshot.len(),
cursor_offset_in_excerpt: request_body.cursor_position,
excerpt_start_row: Some(0),
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
};
send_started_event(
@@ -1,26 +1,13 @@
-use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
+use std::{fmt::Write, ops::Range, sync::Arc};
-use crate::{
- DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
- EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
- cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
- prediction::EditPredictionResult,
-};
+use crate::cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count};
use anyhow::Result;
-use cloud_llm_client::{
- PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
-};
+use cloud_llm_client::PredictEditsBody;
use edit_prediction_types::PredictedCursorPosition;
-use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
-use language::{
- Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
-};
-use project::{Project, ProjectPath};
-use release_channel::AppVersion;
+use language::{Anchor, BufferSnapshot, Point, text_diff};
use text::Bias;
-use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
use zeta_prompt::{
- Event, ZetaPromptInput,
+ Event,
zeta1::{
CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
START_OF_FILE_MARKER,
@@ -28,260 +15,8 @@ use zeta_prompt::{
};
pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
-pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
pub(crate) const MAX_EVENT_TOKENS: usize = 500;
-pub(crate) fn request_prediction_with_zeta1(
- store: &mut EditPredictionStore,
- EditPredictionModelInput {
- project,
- buffer,
- snapshot,
- position,
- events,
- trigger,
- debug_tx,
- ..
- }: EditPredictionModelInput,
- cx: &mut Context<EditPredictionStore>,
-) -> Task<Result<Option<EditPredictionResult>>> {
- let buffer_snapshotted_at = Instant::now();
- let client = store.client.clone();
- let llm_token = store.llm_token.clone();
- let app_version = AppVersion::global(cx);
-
- let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
- let can_collect_file = store.can_collect_file(&project, file, cx);
- let git_info = if can_collect_file {
- git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
- } else {
- None
- };
- (git_info, can_collect_file)
- } else {
- (None, false)
- };
-
- let full_path: Arc<Path> = snapshot
- .file()
- .map(|f| Arc::from(f.full_path(cx).as_path()))
- .unwrap_or_else(|| Arc::from(Path::new("untitled")));
- let full_path_str = full_path.to_string_lossy().into_owned();
- let cursor_point = position.to_point(&snapshot);
- let prompt_for_events = {
- let events = events.clone();
- move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
- };
- let gather_task = gather_context(
- full_path_str,
- &snapshot,
- cursor_point,
- prompt_for_events,
- trigger,
- cx,
- );
-
- let uri = match client
- .http_client()
- .build_zed_llm_url("/predict_edits/v2", &[])
- {
- Ok(url) => Arc::from(url),
- Err(err) => return Task::ready(Err(err)),
- };
-
- cx.spawn(async move |this, cx| {
- let GatherContextOutput {
- mut body,
- context_range,
- editable_range,
- included_events_count,
- } = gather_task.await?;
- let done_gathering_context_at = Instant::now();
-
- let included_events = &events[events.len() - included_events_count..events.len()];
- body.can_collect_data = can_collect_file
- && this
- .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
- .unwrap_or(false);
- if body.can_collect_data {
- body.git_info = git_info;
- }
-
- log::debug!(
- "Events:\n{}\nExcerpt:\n{:?}",
- body.input_events,
- body.input_excerpt
- );
-
- let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
- |request| {
- Ok(request
- .uri(uri.as_str())
- .body(serde_json::to_string(&body)?.into())?)
- },
- client,
- llm_token,
- app_version,
- true,
- )
- .await;
-
- let context_start_offset = context_range.start.to_offset(&snapshot);
- let context_start_row = context_range.start.row;
- let editable_offset_range = editable_range.to_offset(&snapshot);
-
- let inputs = ZetaPromptInput {
- events: included_events.into(),
- related_files: vec![],
- cursor_path: full_path,
- cursor_excerpt: snapshot
- .text_for_range(context_range)
- .collect::<String>()
- .into(),
- editable_range_in_excerpt: (editable_range.start - context_start_offset)
- ..(editable_offset_range.end - context_start_offset),
- cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
- excerpt_start_row: Some(context_start_row),
- };
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(DebugEvent::EditPredictionStarted(
- EditPredictionStartedDebugEvent {
- buffer: buffer.downgrade(),
- prompt: Some(serde_json::to_string(&inputs).unwrap()),
- position,
- },
- ))
- .ok();
- }
-
- let (response, usage) = match response {
- Ok(response) => response,
- Err(err) => {
- if err.is::<ZedUpdateRequiredError>() {
- cx.update(|cx| {
- this.update(cx, |ep_store, _cx| {
- ep_store.update_required = true;
- })
- .ok();
-
- let error_message: SharedString = err.to_string().into();
- show_app_notification(
- NotificationId::unique::<ZedUpdateRequiredError>(),
- cx,
- move |cx| {
- cx.new(|cx| {
- ErrorMessagePrompt::new(error_message.clone(), cx)
- .with_link_button("Update Zed", "https://zed.dev/releases")
- })
- },
- );
- });
- }
-
- return Err(err);
- }
- };
-
- let received_response_at = Instant::now();
- log::debug!("completion response: {}", &response.output_excerpt);
-
- if let Some(usage) = usage {
- this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
- });
- })
- .ok();
- }
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(DebugEvent::EditPredictionFinished(
- EditPredictionFinishedDebugEvent {
- buffer: buffer.downgrade(),
- model_output: Some(response.output_excerpt.clone()),
- position,
- },
- ))
- .ok();
- }
-
- let edit_prediction = process_completion_response(
- response,
- buffer,
- &snapshot,
- editable_range,
- inputs,
- buffer_snapshotted_at,
- received_response_at,
- cx,
- )
- .await;
-
- let finished_at = Instant::now();
-
- // record latency for ~1% of requests
- if rand::random::<u8>() <= 2 {
- telemetry::event!(
- "Edit Prediction Request",
- context_latency = done_gathering_context_at
- .duration_since(buffer_snapshotted_at)
- .as_millis(),
- request_latency = received_response_at
- .duration_since(done_gathering_context_at)
- .as_millis(),
- process_latency = finished_at.duration_since(received_response_at).as_millis()
- );
- }
-
- edit_prediction.map(Some)
- })
-}
-
-fn process_completion_response(
- prediction_response: PredictEditsResponse,
- buffer: Entity<Buffer>,
- snapshot: &BufferSnapshot,
- editable_range: Range<usize>,
- inputs: ZetaPromptInput,
- buffer_snapshotted_at: Instant,
- received_response_at: Instant,
- cx: &AsyncApp,
-) -> Task<Result<EditPredictionResult>> {
- let snapshot = snapshot.clone();
- let request_id = prediction_response.request_id;
- let output_excerpt = prediction_response.output_excerpt;
- cx.spawn(async move |cx| {
- let output_excerpt: Arc<str> = output_excerpt.into();
-
- let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
- .background_spawn({
- let output_excerpt = output_excerpt.clone();
- let editable_range = editable_range.clone();
- let snapshot = snapshot.clone();
- async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
- })
- .await?
- .into();
-
- let id = EditPredictionId(request_id.into());
- Ok(EditPredictionResult::new(
- id,
- &buffer,
- &snapshot,
- edits,
- None,
- buffer_snapshotted_at,
- received_response_at,
- inputs,
- cx,
- )
- .await)
- })
-}
-
pub(crate) fn parse_edits(
output_excerpt: &str,
editable_range: Range<usize>,
@@ -434,35 +169,6 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
.sum()
}
-fn git_info_for_file(
- project: &Entity<Project>,
- project_path: &ProjectPath,
- cx: &App,
-) -> Option<PredictEditsGitInfo> {
- let git_store = project.read(cx).git_store().read(cx);
- if let Some((repository, _repo_path)) =
- git_store.repository_and_path_for_project_path(project_path, cx)
- {
- let repository = repository.read(cx);
- let head_sha = repository
- .head_commit
- .as_ref()
- .map(|head_commit| head_commit.sha.to_string());
- let remote_origin_url = repository.remote_origin_url.clone();
- let remote_upstream_url = repository.remote_upstream_url.clone();
- if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
- return None;
- }
- Some(PredictEditsGitInfo {
- head_sha,
- remote_origin_url,
- remote_upstream_url,
- })
- } else {
- None
- }
-}
-
pub struct GatherContextOutput {
pub body: PredictEditsBody,
pub context_range: Range<Point>,
@@ -470,48 +176,6 @@ pub struct GatherContextOutput {
pub included_events_count: usize,
}
-pub fn gather_context(
- full_path_str: String,
- snapshot: &BufferSnapshot,
- cursor_point: language::Point,
- prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
- trigger: PredictEditsRequestTrigger,
- cx: &App,
-) -> Task<Result<GatherContextOutput>> {
- cx.background_spawn({
- let snapshot = snapshot.clone();
- async move {
- let input_excerpt = excerpt_for_cursor_position(
- cursor_point,
- &full_path_str,
- &snapshot,
- MAX_REWRITE_TOKENS,
- MAX_CONTEXT_TOKENS,
- );
- let (input_events, included_events_count) = prompt_for_events();
- let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
-
- let body = PredictEditsBody {
- input_events,
- input_excerpt: input_excerpt.prompt,
- can_collect_data: false,
- diagnostic_groups: None,
- git_info: None,
- outline: None,
- speculated_output: None,
- trigger,
- };
-
- Ok(GatherContextOutput {
- body,
- context_range: input_excerpt.context_range,
- editable_range,
- included_events_count,
- })
- }
- })
-}
-
pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
prompt_for_events_impl(events, max_tokens).0
}
@@ -638,6 +302,7 @@ mod tests {
use gpui::{App, AppContext};
use indoc::indoc;
use language::Buffer;
+ use text::OffsetRangeExt as _;
#[gpui::test]
fn test_excerpt_for_cursor_position(cx: &mut App) {
@@ -1,10 +1,11 @@
+use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets};
use crate::prediction::EditPredictionResult;
use crate::zeta1::compute_edits_and_cursor_position;
use crate::{
CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
-use anyhow::{Result, anyhow};
+use anyhow::Result;
use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
use gpui::{App, Task, prelude::*};
@@ -13,8 +14,10 @@ use release_channel::AppVersion;
use std::env;
use std::{path::Path, sync::Arc, time::Instant};
-use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output};
-use zeta_prompt::{format_zeta_prompt, get_prefill};
+use zeta_prompt::{
+ CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
+ format_zeta_prompt, get_prefill,
+};
pub const MAX_CONTEXT_TOKENS: usize = 350;
@@ -39,24 +42,30 @@ pub fn request_prediction_with_zeta2(
events,
debug_tx,
trigger,
+ project,
..
}: EditPredictionModelInput,
+ preferred_model: Option<EditPredictionModelKind>,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer_snapshotted_at = Instant::now();
let raw_config = store.zeta2_raw_config().cloned();
- let Some(excerpt_path) = snapshot
+ let excerpt_path: Arc<Path> = snapshot
.file()
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
- else {
- return Task::ready(Err(anyhow!("No file path for excerpt")));
- };
+ .unwrap_or_else(|| Arc::from(Path::new("untitled")));
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
+ let is_open_source = snapshot
+ .file()
+ .map_or(false, |file| store.is_file_open_source(&project, file, cx))
+ && events.iter().all(|event| event.in_open_source_repo())
+ && related_files.iter().all(|file| file.in_open_source_repo);
+
let request_task = cx.background_spawn({
async move {
let zeta_version = raw_config
@@ -72,6 +81,8 @@ pub fn request_prediction_with_zeta2(
excerpt_path,
cursor_offset,
zeta_version,
+ preferred_model,
+ is_open_source,
);
if let Some(debug_tx) = &debug_tx {
@@ -248,41 +259,52 @@ pub fn zeta2_prompt_input(
excerpt_path: Arc<Path>,
cursor_offset: usize,
zeta_format: ZetaFormat,
+ preferred_model: Option<EditPredictionModelKind>,
+ is_open_source: bool,
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
let cursor_point = cursor_offset.to_point(snapshot);
- let (editable_range, context_range) =
- crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
- cursor_point,
- snapshot,
- max_editable_tokens(zeta_format),
- MAX_CONTEXT_TOKENS,
- );
+ let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot);
let related_files = crate::filter_redundant_excerpts(
related_files,
excerpt_path.as_ref(),
- context_range.start.row..context_range.end.row,
+ full_context.start.row..full_context.end.row,
);
- let context_start_offset = context_range.start.to_offset(snapshot);
- let context_start_row = context_range.start.row;
+ let full_context_start_offset = full_context.start.to_offset(snapshot);
+ let full_context_start_row = full_context.start.row;
+
+ let excerpt_ranges =
+ excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot);
+
+ let editable_range = match preferred_model {
+ Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350,
+ _ => match zeta_format {
+ ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150,
+ _ => &range_points.editable_180,
+ },
+ };
+
let editable_offset_range = editable_range.to_offset(snapshot);
- let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
- let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
- ..(editable_offset_range.end - context_start_offset);
+ let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
+ let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset)
+ ..(editable_offset_range.end - full_context_start_offset);
let prompt_input = zeta_prompt::ZetaPromptInput {
cursor_path: excerpt_path,
cursor_excerpt: snapshot
- .text_for_range(context_range)
+ .text_for_range(full_context)
.collect::<String>()
.into(),
editable_range_in_excerpt,
cursor_offset_in_excerpt,
- excerpt_start_row: Some(context_start_row),
+ excerpt_start_row: Some(full_context_start_row),
events,
related_files,
+ excerpt_ranges: Some(excerpt_ranges),
+ preferred_model,
+ in_open_source_repo: is_open_source,
};
(editable_offset_range, prompt_input)
}
@@ -93,6 +93,13 @@ pub async fn run_format_prompt(
excerpt_start_row: prompt_inputs.excerpt_start_row,
events: prompt_inputs.edit_history.clone(),
related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: example
+ .spec
+ .captured_prompt_input
+ .as_ref()
+ .map_or(false, |input| input.in_open_source_repo),
};
let prompt = format_zeta_prompt(&input, version);
let prefill = zeta_prompt::get_prefill(&input, version);
@@ -1304,6 +1304,7 @@ fn build_example_from_snowflake(
excerpt_start_row: None,
events,
related_files,
+ in_open_source_repo: input.in_open_source_repo,
}),
telemetry: Some(TelemetrySource {
request_id,
@@ -136,11 +136,13 @@ impl RelatedExcerptStore {
.collect()
}
- pub fn related_files_with_buffers(&mut self, cx: &App) -> Vec<(RelatedFile, Entity<Buffer>)> {
+ pub fn related_files_with_buffers(
+ &mut self,
+ cx: &App,
+ ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
self.related_buffers
.iter_mut()
.map(|related| (related.related_file(cx), related.buffer.clone()))
- .collect::<Vec<_>>()
}
pub fn set_related_files(&mut self, files: Vec<RelatedFile>, cx: &App) {
@@ -424,6 +426,7 @@ impl RelatedBuffer {
path,
excerpts: cached.excerpts.clone(),
max_row: buffer.max_point().row,
+ in_open_source_repo: false,
};
return related_file;
}
@@ -89,7 +89,6 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
let company_buffer = related_excerpt_store.update(cx, |store, cx| {
store
.related_files_with_buffers(cx)
- .into_iter()
.find(|(file, _)| file.path.to_str() == Some("root/src/company.rs"))
.map(|(_, buffer)| buffer)
.expect("company.rs buffer not found")
@@ -18,6 +18,32 @@ fn estimate_tokens(bytes: usize) -> usize {
bytes / 3
}
+/// The client's preferred edit prediction model. The server may override this.
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
+pub enum EditPredictionModelKind {
+ Zeta1,
+ Zeta2,
+}
+
+/// Pre-computed byte offset ranges within `cursor_excerpt` for different
+/// editable and context token budgets. Allows the server to select the
+/// appropriate ranges for whichever model it uses.
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExcerptRanges {
+ /// Editable region computed with a 150-token budget.
+ pub editable_150: Range<usize>,
+ /// Editable region computed with a 180-token budget.
+ pub editable_180: Range<usize>,
+ /// Editable region computed with a 350-token budget.
+ pub editable_350: Range<usize>,
+ /// Context boundary when using editable_150 with 350 tokens of additional context.
+ pub editable_150_context_350: Range<usize>,
+ /// Context boundary when using editable_180 with 350 tokens of additional context.
+ pub editable_180_context_350: Range<usize>,
+ /// Context boundary when using editable_350 with 150 tokens of additional context.
+ pub editable_350_context_150: Range<usize>,
+}
+
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ZetaPromptInput {
pub cursor_path: Arc<Path>,
@@ -28,6 +54,17 @@ pub struct ZetaPromptInput {
pub excerpt_start_row: Option<u32>,
pub events: Vec<Arc<Event>>,
pub related_files: Vec<RelatedFile>,
+ /// When set, the excerpt was computed with a larger budget (~512 tokens)
+ /// and these ranges let the server select model-appropriate subsets.
+ /// When absent, the excerpt IS the context region and
+ /// `editable_range_in_excerpt` is the only editable range.
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub excerpt_ranges: Option<ExcerptRanges>,
+ /// Client's preferred model. The server may override.
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub preferred_model: Option<EditPredictionModelKind>,
+ #[serde(default)]
+ pub in_open_source_repo: bool,
}
#[derive(
@@ -103,6 +140,17 @@ pub enum Event {
},
}
+impl Event {
+ pub fn in_open_source_repo(&self) -> bool {
+ match self {
+ Event::BufferChange {
+ in_open_source_repo,
+ ..
+ } => *in_open_source_repo,
+ }
+ }
+}
+
pub fn write_event(prompt: &mut String, event: &Event) {
fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
for component in path.components() {
@@ -136,6 +184,8 @@ pub struct RelatedFile {
pub path: Arc<Path>,
pub max_row: u32,
pub excerpts: Vec<RelatedExcerpt>,
+ #[serde(default)]
+ pub in_open_source_repo: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -164,27 +214,96 @@ pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str {
}
}
+fn resolve_cursor_region(
+ input: &ZetaPromptInput,
+ format: ZetaFormat,
+) -> (&str, Range<usize>, usize) {
+ let Some(ranges) = &input.excerpt_ranges else {
+ return (
+ &input.cursor_excerpt,
+ input.editable_range_in_excerpt.clone(),
+ input.cursor_offset_in_excerpt,
+ );
+ };
+
+ let (editable_range, context_range) = match format {
+ ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => (
+ ranges.editable_150.clone(),
+ ranges.editable_150_context_350.clone(),
+ ),
+ ZetaFormat::V0114180EditableRegion
+ | ZetaFormat::V0120GitMergeMarkers
+ | ZetaFormat::V0131GitMergeMarkersPrefix
+ | ZetaFormat::V0211Prefill
+ | ZetaFormat::V0211SeedCoder => (
+ ranges.editable_180.clone(),
+ ranges.editable_180_context_350.clone(),
+ ),
+ };
+
+ let context_start = context_range.start;
+ let context_text = &input.cursor_excerpt[context_range];
+ let adjusted_editable =
+ (editable_range.start - context_start)..(editable_range.end - context_start);
+ let adjusted_cursor = input.cursor_offset_in_excerpt - context_start;
+
+ (context_text, adjusted_editable, adjusted_cursor)
+}
+
fn format_zeta_prompt_with_budget(
input: &ZetaPromptInput,
format: ZetaFormat,
max_tokens: usize,
) -> String {
+ let (context, editable_range, cursor_offset) = resolve_cursor_region(input, format);
+ let path = &*input.cursor_path;
+
let mut cursor_section = String::new();
match format {
ZetaFormat::V0112MiddleAtEnd => {
- v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input);
+ v0112_middle_at_end::write_cursor_excerpt_section(
+ &mut cursor_section,
+ path,
+ context,
+ &editable_range,
+ cursor_offset,
+ );
}
ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
- v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input)
- }
- ZetaFormat::V0120GitMergeMarkers => {
- v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
+ v0113_ordered::write_cursor_excerpt_section(
+ &mut cursor_section,
+ path,
+ context,
+ &editable_range,
+ cursor_offset,
+ )
}
+ ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::write_cursor_excerpt_section(
+ &mut cursor_section,
+ path,
+ context,
+ &editable_range,
+ cursor_offset,
+ ),
ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
- v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input)
+ v0131_git_merge_markers_prefix::write_cursor_excerpt_section(
+ &mut cursor_section,
+ path,
+ context,
+ &editable_range,
+ cursor_offset,
+ )
}
ZetaFormat::V0211SeedCoder => {
- return seed_coder::format_prompt_with_budget(input, max_tokens);
+ return seed_coder::format_prompt_with_budget(
+ path,
+ context,
+ &editable_range,
+ cursor_offset,
+ &input.events,
+ &input.related_files,
+ max_tokens,
+ );
}
}
@@ -343,29 +462,29 @@ pub fn write_related_files(
mod v0112_middle_at_end {
use super::*;
- pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
- let path_str = input.cursor_path.to_string_lossy();
+ pub fn write_cursor_excerpt_section(
+ prompt: &mut String,
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+ ) {
+ let path_str = path.to_string_lossy();
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
prompt.push_str("<|fim_prefix|>\n");
- prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ prompt.push_str(&context[..editable_range.start]);
prompt.push_str("<|fim_suffix|>\n");
- prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ prompt.push_str(&context[editable_range.end..]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_middle|>current\n");
- prompt.push_str(
- &input.cursor_excerpt
- [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
- );
+ prompt.push_str(&context[editable_range.start..cursor_offset]);
prompt.push_str(CURSOR_MARKER);
- prompt.push_str(
- &input.cursor_excerpt
- [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
- );
+ prompt.push_str(&context[cursor_offset..editable_range.end]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
@@ -377,32 +496,32 @@ mod v0112_middle_at_end {
mod v0113_ordered {
use super::*;
- pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
- let path_str = input.cursor_path.to_string_lossy();
+ pub fn write_cursor_excerpt_section(
+ prompt: &mut String,
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+ ) {
+ let path_str = path.to_string_lossy();
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
prompt.push_str("<|fim_prefix|>\n");
- prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ prompt.push_str(&context[..editable_range.start]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_middle|>current\n");
- prompt.push_str(
- &input.cursor_excerpt
- [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
- );
+ prompt.push_str(&context[editable_range.start..cursor_offset]);
prompt.push_str(CURSOR_MARKER);
- prompt.push_str(
- &input.cursor_excerpt
- [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
- );
+ prompt.push_str(&context[cursor_offset..editable_range.end]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_suffix|>\n");
- prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ prompt.push_str(&context[editable_range.end..]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
@@ -441,30 +560,30 @@ pub mod v0120_git_merge_markers {
pub const SEPARATOR: &str = "=======\n";
pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
- pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
- let path_str = input.cursor_path.to_string_lossy();
+ pub fn write_cursor_excerpt_section(
+ prompt: &mut String,
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+ ) {
+ let path_str = path.to_string_lossy();
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
prompt.push_str("<|fim_prefix|>");
- prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ prompt.push_str(&context[..editable_range.start]);
prompt.push_str("<|fim_suffix|>");
- prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ prompt.push_str(&context[editable_range.end..]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_middle|>");
prompt.push_str(START_MARKER);
- prompt.push_str(
- &input.cursor_excerpt
- [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
- );
+ prompt.push_str(&context[editable_range.start..cursor_offset]);
prompt.push_str(CURSOR_MARKER);
- prompt.push_str(
- &input.cursor_excerpt
- [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
- );
+ prompt.push_str(&context[cursor_offset..editable_range.end]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
@@ -502,29 +621,29 @@ pub mod v0131_git_merge_markers_prefix {
pub const SEPARATOR: &str = "=======\n";
pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
- pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
- let path_str = input.cursor_path.to_string_lossy();
+ pub fn write_cursor_excerpt_section(
+ prompt: &mut String,
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+ ) {
+ let path_str = path.to_string_lossy();
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
prompt.push_str("<|fim_prefix|>");
- prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ prompt.push_str(&context[..editable_range.start]);
prompt.push_str(START_MARKER);
- prompt.push_str(
- &input.cursor_excerpt
- [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
- );
+ prompt.push_str(&context[editable_range.start..cursor_offset]);
prompt.push_str(CURSOR_MARKER);
- prompt.push_str(
- &input.cursor_excerpt
- [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
- );
+ prompt.push_str(&context[cursor_offset..editable_range.end]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str(SEPARATOR);
prompt.push_str("<|fim_suffix|>");
- prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ prompt.push_str(&context[editable_range.end..]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
@@ -619,16 +738,25 @@ pub mod seed_coder {
pub const SEPARATOR: &str = "=======\n";
pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
- pub fn format_prompt_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
- let suffix_section = build_suffix_section(input);
- let cursor_prefix_section = build_cursor_prefix_section(input);
+ pub fn format_prompt_with_budget(
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+ events: &[Arc<Event>],
+ related_files: &[RelatedFile],
+ max_tokens: usize,
+ ) -> String {
+ let suffix_section = build_suffix_section(context, editable_range);
+ let cursor_prefix_section =
+ build_cursor_prefix_section(path, context, editable_range, cursor_offset);
let suffix_tokens = estimate_tokens(suffix_section.len());
let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len());
let budget_after_cursor = max_tokens.saturating_sub(suffix_tokens + cursor_prefix_tokens);
let edit_history_section = super::format_edit_history_within_budget(
- &input.events,
+ events,
FILE_MARKER,
"edit_history",
budget_after_cursor,
@@ -637,7 +765,7 @@ pub mod seed_coder {
let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
let related_files_section = super::format_related_files_within_budget(
- &input.related_files,
+ related_files,
FILE_MARKER,
budget_after_edit_history,
);
@@ -658,32 +786,31 @@ pub mod seed_coder {
prompt
}
- fn build_suffix_section(input: &ZetaPromptInput) -> String {
+ fn build_suffix_section(context: &str, editable_range: &Range<usize>) -> String {
let mut section = String::new();
section.push_str(FIM_SUFFIX);
- section.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ section.push_str(&context[editable_range.end..]);
if !section.ends_with('\n') {
section.push('\n');
}
section
}
- fn build_cursor_prefix_section(input: &ZetaPromptInput) -> String {
+ fn build_cursor_prefix_section(
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+ ) -> String {
let mut section = String::new();
- let path_str = input.cursor_path.to_string_lossy();
+ let path_str = path.to_string_lossy();
write!(section, "{}{}\n", FILE_MARKER, path_str).ok();
- section.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ section.push_str(&context[..editable_range.start]);
section.push_str(START_MARKER);
- section.push_str(
- &input.cursor_excerpt
- [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
- );
+ section.push_str(&context[editable_range.start..cursor_offset]);
section.push_str(CURSOR_MARKER);
- section.push_str(
- &input.cursor_excerpt
- [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
- );
+ section.push_str(&context[cursor_offset..editable_range.end]);
if !section.ends_with('\n') {
section.push('\n');
}
@@ -694,6 +821,9 @@ pub mod seed_coder {
/// The zeta1 prompt format
pub mod zeta1 {
+ use super::*;
+ use std::fmt::Write;
+
pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
pub const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
@@ -725,6 +855,166 @@ pub mod zeta1 {
prompt.push_str(RESPONSE_HEADER);
prompt
}
+
+ /// Formats a complete zeta1 prompt from a `ZetaPromptInput` using the given
+ /// editable and context byte-offset ranges within `cursor_excerpt`.
+ pub fn format_zeta1_from_input(
+ input: &ZetaPromptInput,
+ editable_range: Range<usize>,
+ context_range: Range<usize>,
+ ) -> String {
+ let events = format_zeta1_events(&input.events);
+ let excerpt = format_zeta1_excerpt(input, editable_range, context_range);
+ format_zeta1_prompt(&events, &excerpt)
+ }
+
+ /// Formats events in zeta1 style (oldest first).
+ fn format_zeta1_events(events: &[Arc<Event>]) -> String {
+ let mut result = String::new();
+ for event in events {
+ let event_string = format_zeta1_event(event);
+ if event_string.is_empty() {
+ continue;
+ }
+ if !result.is_empty() {
+ result.push_str("\n\n");
+ }
+ result.push_str(&event_string);
+ }
+ result
+ }
+
+ fn format_zeta1_event(event: &Event) -> String {
+ match event {
+ Event::BufferChange {
+ path,
+ old_path,
+ diff,
+ ..
+ } => {
+ let mut prompt = String::new();
+ if old_path != path {
+ writeln!(
+ prompt,
+ "User renamed {} to {}\n",
+ old_path.display(),
+ path.display()
+ )
+ .ok();
+ }
+ if !diff.is_empty() {
+ write!(
+ prompt,
+ "User edited {}:\n```diff\n{}\n```",
+ path.display(),
+ diff
+ )
+ .ok();
+ }
+ prompt
+ }
+ }
+ }
+
+ /// Formats the excerpt section of a zeta1 prompt using byte-offset ranges
+ /// within `cursor_excerpt`.
+ fn format_zeta1_excerpt(
+ input: &ZetaPromptInput,
+ editable_range: Range<usize>,
+ context_range: Range<usize>,
+ ) -> String {
+ let path_str = input.cursor_path.to_string_lossy();
+ let excerpt = &*input.cursor_excerpt;
+ let cursor_offset = input.cursor_offset_in_excerpt;
+
+ let mut prompt = String::new();
+ writeln!(&mut prompt, "```{path_str}").ok();
+
+ let starts_at_file_beginning =
+ input.excerpt_start_row == Some(0) && context_range.start == 0;
+ if starts_at_file_beginning {
+ writeln!(&mut prompt, "{START_OF_FILE_MARKER}").ok();
+ }
+
+ prompt.push_str(&excerpt[context_range.start..editable_range.start]);
+
+ writeln!(&mut prompt, "{EDITABLE_REGION_START_MARKER}").ok();
+ prompt.push_str(&excerpt[editable_range.start..cursor_offset]);
+ prompt.push_str(CURSOR_MARKER);
+ prompt.push_str(&excerpt[cursor_offset..editable_range.end]);
+ write!(&mut prompt, "\n{EDITABLE_REGION_END_MARKER}").ok();
+
+ prompt.push_str(&excerpt[editable_range.end..context_range.end]);
+ write!(prompt, "\n```").ok();
+
+ prompt
+ }
+
+ /// Cleans zeta1 model output by extracting content between editable region
+ /// markers and converting the zeta1 cursor marker to the universal one.
+ /// Returns `None` if the output doesn't contain the expected markers.
+ pub fn clean_zeta1_model_output(output: &str) -> Option<String> {
+ let content = output.replace(CURSOR_MARKER, "");
+
+ let content_start = content
+ .find(EDITABLE_REGION_START_MARKER)
+ .map(|pos| pos + EDITABLE_REGION_START_MARKER.len())
+ .map(|pos| {
+ if content.as_bytes().get(pos) == Some(&b'\n') {
+ pos + 1
+ } else {
+ pos
+ }
+ })
+ .unwrap_or(0);
+
+ let content_end = content
+ .find(EDITABLE_REGION_END_MARKER)
+ .map(|pos| {
+ if pos > 0 && content.as_bytes().get(pos - 1) == Some(&b'\n') {
+ pos - 1
+ } else {
+ pos
+ }
+ })
+ .unwrap_or(content.len());
+
+ if content_start > content_end {
+ return Some(String::new());
+ }
+
+ let extracted = &content[content_start..content_end];
+
+ let cursor_offset = output.find(CURSOR_MARKER).map(|zeta1_cursor_pos| {
+ let text_before_cursor = output[..zeta1_cursor_pos].replace(CURSOR_MARKER, "");
+ let text_before_cursor = text_before_cursor
+ .find(EDITABLE_REGION_START_MARKER)
+ .map(|pos| {
+ let after_marker = pos + EDITABLE_REGION_START_MARKER.len();
+ if text_before_cursor.as_bytes().get(after_marker) == Some(&b'\n') {
+ after_marker + 1
+ } else {
+ after_marker
+ }
+ })
+ .unwrap_or(0);
+ let offset_in_extracted = zeta1_cursor_pos
+ .saturating_sub(text_before_cursor)
+ .min(extracted.len());
+ offset_in_extracted
+ });
+
+ let mut result = String::with_capacity(extracted.len() + super::CURSOR_MARKER.len());
+ if let Some(offset) = cursor_offset {
+ result.push_str(&extracted[..offset]);
+ result.push_str(super::CURSOR_MARKER);
+ result.push_str(&extracted[offset..]);
+ } else {
+ result.push_str(extracted);
+ }
+
+ Some(result)
+ }
}
#[cfg(test)]
@@ -747,6 +1037,9 @@ mod tests {
excerpt_start_row: None,
events: events.into_iter().map(Arc::new).collect(),
related_files,
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
}
}
@@ -768,6 +1061,7 @@ mod tests {
row_range: 0..content.lines().count() as u32,
text: content.into(),
}],
+ in_open_source_repo: false,
}
}
@@ -869,6 +1163,7 @@ mod tests {
vec![RelatedFile {
path: Path::new("big.rs").into(),
max_row: 30,
+ in_open_source_repo: false,
excerpts: vec![
RelatedExcerpt {
row_range: 0..10,
@@ -1106,4 +1401,201 @@ mod tests {
"new code\n"
);
}
+
+ #[test]
+ fn test_format_zeta1_from_input_basic() {
+ let excerpt = "fn before() {}\nfn foo() {\n let x = 1;\n}\nfn after() {}\n";
+ let input = ZetaPromptInput {
+ cursor_path: Path::new("src/main.rs").into(),
+ cursor_excerpt: excerpt.into(),
+ editable_range_in_excerpt: 15..41,
+ cursor_offset_in_excerpt: 30,
+ excerpt_start_row: Some(0),
+ events: vec![Arc::new(make_event("other.rs", "-old\n+new\n"))],
+ related_files: vec![],
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
+ };
+
+ let prompt = zeta1::format_zeta1_from_input(&input, 15..41, 0..excerpt.len());
+
+ assert_eq!(
+ prompt,
+ concat!(
+ "### Instruction:\n",
+ "You are a code completion assistant and your task is to analyze user edits and then rewrite an ",
+ "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ",
+ "into account the cursor location.\n",
+ "\n",
+ "### User Edits:\n",
+ "\n",
+ "User edited other.rs:\n",
+ "```diff\n",
+ "-old\n",
+ "+new\n",
+ "\n",
+ "```\n",
+ "\n",
+ "### User Excerpt:\n",
+ "\n",
+ "```src/main.rs\n",
+ "<|start_of_file|>\n",
+ "fn before() {}\n",
+ "<|editable_region_start|>\n",
+ "fn foo() {\n",
+ " <|user_cursor_is_here|>let x = 1;\n",
+ "\n",
+ "<|editable_region_end|>}\n",
+ "fn after() {}\n",
+ "\n",
+ "```\n",
+ "\n",
+ "### Response:\n",
+ ),
+ );
+ }
+
+ #[test]
+ fn test_format_zeta1_from_input_no_start_of_file() {
+ let excerpt = "fn foo() {\n let x = 1;\n}\n";
+ let input = ZetaPromptInput {
+ cursor_path: Path::new("src/main.rs").into(),
+ cursor_excerpt: excerpt.into(),
+ editable_range_in_excerpt: 0..28,
+ cursor_offset_in_excerpt: 15,
+ excerpt_start_row: Some(10),
+ events: vec![],
+ related_files: vec![],
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
+ };
+
+ let prompt = zeta1::format_zeta1_from_input(&input, 0..28, 0..28);
+
+ assert_eq!(
+ prompt,
+ concat!(
+ "### Instruction:\n",
+ "You are a code completion assistant and your task is to analyze user edits and then rewrite an ",
+ "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ",
+ "into account the cursor location.\n",
+ "\n",
+ "### User Edits:\n",
+ "\n",
+ "\n",
+ "\n",
+ "### User Excerpt:\n",
+ "\n",
+ "```src/main.rs\n",
+ "<|editable_region_start|>\n",
+ "fn foo() {\n",
+ " <|user_cursor_is_here|>let x = 1;\n",
+ "}\n",
+ "\n",
+ "<|editable_region_end|>\n",
+ "```\n",
+ "\n",
+ "### Response:\n",
+ ),
+ );
+ }
+
+ #[test]
+ fn test_format_zeta1_from_input_with_sub_ranges() {
+ let excerpt = "// prefix\nfn foo() {\n let x = 1;\n}\n// suffix\n";
+ let editable_range = 10..37;
+ let context_range = 0..excerpt.len();
+
+ let input = ZetaPromptInput {
+ cursor_path: Path::new("test.rs").into(),
+ cursor_excerpt: excerpt.into(),
+ editable_range_in_excerpt: editable_range.clone(),
+ cursor_offset_in_excerpt: 25,
+ excerpt_start_row: Some(0),
+ events: vec![],
+ related_files: vec![],
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
+ };
+
+ let prompt = zeta1::format_zeta1_from_input(&input, editable_range, context_range);
+
+ assert_eq!(
+ prompt,
+ concat!(
+ "### Instruction:\n",
+ "You are a code completion assistant and your task is to analyze user edits and then rewrite an ",
+ "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ",
+ "into account the cursor location.\n",
+ "\n",
+ "### User Edits:\n",
+ "\n",
+ "\n",
+ "\n",
+ "### User Excerpt:\n",
+ "\n",
+ "```test.rs\n",
+ "<|start_of_file|>\n",
+ "// prefix\n",
+ "<|editable_region_start|>\n",
+ "fn foo() {\n",
+ " <|user_cursor_is_here|>let x = 1;\n",
+ "}\n",
+ "<|editable_region_end|>\n",
+ "// suffix\n",
+ "\n",
+ "```\n",
+ "\n",
+ "### Response:\n",
+ ),
+ );
+ }
+
+ #[test]
+ fn test_clean_zeta1_model_output_basic() {
+ let output = indoc! {"
+ <|editable_region_start|>
+ fn main() {
+ println!(\"hello\");
+ }
+ <|editable_region_end|>
+ "};
+
+ let cleaned = zeta1::clean_zeta1_model_output(output).unwrap();
+ assert_eq!(cleaned, "fn main() {\n println!(\"hello\");\n}");
+ }
+
+ #[test]
+ fn test_clean_zeta1_model_output_with_cursor() {
+ let output = indoc! {"
+ <|editable_region_start|>
+ fn main() {
+ <|user_cursor_is_here|>println!(\"hello\");
+ }
+ <|editable_region_end|>
+ "};
+
+ let cleaned = zeta1::clean_zeta1_model_output(output).unwrap();
+ assert_eq!(
+ cleaned,
+ "fn main() {\n <|user_cursor|>println!(\"hello\");\n}"
+ );
+ }
+
+ #[test]
+ fn test_clean_zeta1_model_output_no_markers() {
+ let output = "fn main() {}\n";
+ let cleaned = zeta1::clean_zeta1_model_output(output).unwrap();
+ assert_eq!(cleaned, "fn main() {}\n");
+ }
+
+ #[test]
+ fn test_clean_zeta1_model_output_empty_region() {
+ let output = "<|editable_region_start|>\n<|editable_region_end|>\n";
+ let cleaned = zeta1::clean_zeta1_model_output(output).unwrap();
+ assert_eq!(cleaned, "");
+ }
}