From 6cbeb848803e291a1d8aa8ca6245d88e611b0e4f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 16 Feb 2026 16:38:43 -0800 Subject: [PATCH] Tune edit prediction teacher to leave fewer blank spots in predictions (#49315) Release Notes: - N/A --- .../evals/.zed/settings.json | 3 + .../evals/codex-acp--add-derive.md | 69 ++++++++ .../flask--add-and-rename-test-function.md | 80 +++++++++ .../evals/flask--add-import-statement.md | 76 ++++++++ .../evals/flask--add-test-function.md | 164 ++++++++++++++++++ .../evals/terraform--add-comment.md | 140 +++++++++++++++ .../evals/tree-sitter--if-let-to-match.md | 112 ++++++++++++ ...tree-sitter--tuple-to-struct-definition.md | 105 +++++++++++ ...e-sitter--tuple-to-struct-destructuring.md | 131 ++++++++++++++ ...ee-sitter--tuple-to-struct-field-access.md | 62 +++++++ .../tree-sitter--tuple-to-struct-for-loop.md | 106 +++++++++++ .../tree-sitter--tuple-to-struct-literal.md | 144 +++++++++++++++ .../evals/zed--add-eprintln.md | 56 ++++++ crates/edit_prediction_cli/src/predict.rs | 66 ++++++- .../src/prompts/teacher.md | 62 ++++++- 15 files changed, 1367 insertions(+), 9 deletions(-) create mode 100644 crates/edit_prediction_cli/evals/.zed/settings.json create mode 100644 crates/edit_prediction_cli/evals/codex-acp--add-derive.md create mode 100644 crates/edit_prediction_cli/evals/flask--add-and-rename-test-function.md create mode 100644 crates/edit_prediction_cli/evals/flask--add-import-statement.md create mode 100644 crates/edit_prediction_cli/evals/flask--add-test-function.md create mode 100644 crates/edit_prediction_cli/evals/terraform--add-comment.md create mode 100644 crates/edit_prediction_cli/evals/tree-sitter--if-let-to-match.md create mode 100644 crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-definition.md create mode 100644 crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-destructuring.md create mode 100644 crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-field-access.md create mode 100644 crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-for-loop.md create mode 100644 crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-literal.md create mode 100644 crates/edit_prediction_cli/evals/zed--add-eprintln.md diff --git a/crates/edit_prediction_cli/evals/.zed/settings.json b/crates/edit_prediction_cli/evals/.zed/settings.json new file mode 100644 index 0000000000000000000000000000000000000000..f1e74a3aee3b9cd6bb41ec3a87a30c7ad016e379 --- /dev/null +++ b/crates/edit_prediction_cli/evals/.zed/settings.json @@ -0,0 +1,3 @@ +{ + "remove_trailing_whitespace_on_save": false, +} diff --git a/crates/edit_prediction_cli/evals/codex-acp--add-derive.md b/crates/edit_prediction_cli/evals/codex-acp--add-derive.md new file mode 100644 index 0000000000000000000000000000000000000000..9d57c75d208851f5cfddc2c7e16854ed3d0fe72e --- /dev/null +++ b/crates/edit_prediction_cli/evals/codex-acp--add-derive.md @@ -0,0 +1,69 @@ ++++ +repository_url = "https://github.com/zed-industries/codex-acp" +revision = "c3d24ee70928fc9da08c131fc632d624413ccc43" ++++ + +## Edit History + +```diff +--- a/src/prompt_args.rs ++++ b/src/prompt_args.rs +@@ -28,7 +28,7 @@ impl PromptArgsError { + } + } + +-#[derive(Debug)] ++#[derive(Debug, Serialize)] + pub enum PromptExpansionError { + Args { + command: String, +``` + +## Cursor Position + +```src/prompt_args.rs +#[derive(Debug)] +pub enum PromptArgsError { +// ^[CURSOR_POSITION] + MissingAssignment { token: String }, + MissingKey { token: String }, +} +``` + +## Expected Patch + +```diff +--- a/src/prompt_args.rs ++++ b/src/prompt_args.rs +@@ -9,7 +9,7 @@ use std::sync::LazyLock; + static PROMPT_ARG_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"\$[A-Z][A-Z0-9_]*").unwrap_or_else(|_| std::process::abort())); + +-#[derive(Debug)] ++#[derive(Debug, Serialize)] + pub enum PromptArgsError { + MissingAssignment { token: String }, + MissingKey { token: String }, +``` + +```diff +--- a/src/prompt_args.rs ++++ b/src/prompt_args.rs +@@ -3,19 +3,20 @@ + use regex_lite::Regex; ++use serde::Serialize; + use shlex::Shlex; + use std::collections::HashMap; + use std::collections::HashSet; + use std::sync::LazyLock; + + static PROMPT_ARG_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"\$[A-Z][A-Z0-9_]*").unwrap_or_else(|_| std::process::abort())); + +-#[derive(Debug)] ++#[derive(Debug, Serialize)] + pub enum PromptArgsError { + MissingAssignment { token: String }, + MissingKey { token: String }, + } +``` diff --git a/crates/edit_prediction_cli/evals/flask--add-and-rename-test-function.md b/crates/edit_prediction_cli/evals/flask--add-and-rename-test-function.md new file mode 100644 index 0000000000000000000000000000000000000000..2946013deb10a7d1082c35545b41164dacf2d40e --- /dev/null +++ b/crates/edit_prediction_cli/evals/flask--add-and-rename-test-function.md @@ -0,0 +1,80 @@ ++++ +repository_url = "https://github.com/pallets/flask" +revision = "2fec0b206c6e83ea813ab26597e15c96fab08be7" ++++ + +## Edit History + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1376,5 +1376,8 @@ + def test_static_files(app, client): + rv = client.get("/static/index.html") + assert rv.status_code == 200 + assert rv.data.strip() == b"

Hello World!

" + with app.test_request_context(): + assert flask.url_for("static", filename="index.html") == "/static/index.html" + rv.close() + + ++de ++ ++ + def test_static_url_path(): +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_(): ++ pass + + + def test_static_url_path(): +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-def test_(): ++def test_static_file_not_found(): + pass + + + def test_static_url_path(): +``` + +## Cursor Position + +```tests/test_basic.py +def test_static_file_not_found(): +# ^[CURSOR_POSITION] + pass + + +def test_static_url_path(): +``` + +## Expected Patch + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-def test_static_file_not_found(): +- pass ++def test_static_file_not_found(app, client): ++ rv = client.get("/static/non_existent.html") ++ assert rv.status_code == 404 ++ rv.close() +``` + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-def test_static_file_not_found(): +- pass ++def test_static_file_not_found(app, client): ++ rv = client.get("/static/not_found.html") ++ assert rv.status_code == 404 ++ rv.close() +``` diff --git a/crates/edit_prediction_cli/evals/flask--add-import-statement.md b/crates/edit_prediction_cli/evals/flask--add-import-statement.md new file mode 100644 index 0000000000000000000000000000000000000000..2bffde527f0e5c0e12eb98508dfe9da015cd5cc5 --- /dev/null +++ b/crates/edit_prediction_cli/evals/flask--add-import-statement.md @@ -0,0 +1,76 @@ ++++ +repository_url = "https://github.com/pallets/flask" +revision = "2fec0b206c6e83ea813ab26597e15c96fab08be7" ++++ + +## Edit History + +```diff +--- a/src/flask/logging.py ++++ b/src/flask/logging.py +@@ -4,7 +4,7 @@ + import sys + import typing as t + +-from werkzeug.local import LocalProxy ++imfrom werkzeug.local import LocalProxy + + from .globals import request + +``` + +## Cursor Position + +```src/flask/logging.py +from __future__ import annotations + +import logging +import sys +import typing as t + +imfrom werkzeug.local import LocalProxy +# ^[CURSOR_POSITION] + +from .globals import request + +if t.TYPE_CHECKING: # pragma: no cover + from .sansio.app import App +``` + +## Expected Patch + +```diff +--- a/src/flask/logging.py ++++ b/src/flask/logging.py +@@ -1,21 +1,21 @@ + from __future__ import annotations + + import logging + import sys + import typing as t + +-imfrom werkzeug.local import LocalProxy ++import +# ^[CURSOR_POSITION] ++from werkzeug.local import LocalProxy + + from .globals import request +``` + +```diff +--- a/src/flask/logging.py ++++ b/src/flask/logging.py +@@ -1,21 +1,21 @@ + from __future__ import annotations + + import logging + import sys + import typing as t +- +-imfrom werkzeug.local import LocalProxy ++import werkzeug +# ^[CURSOR_POSITION] ++from werkzeug.local import LocalProxy + + from .globals import request +``` diff --git a/crates/edit_prediction_cli/evals/flask--add-test-function.md b/crates/edit_prediction_cli/evals/flask--add-test-function.md new file mode 100644 index 0000000000000000000000000000000000000000..ea9d47a6db4dfbd4ac48e55618417ce574699ec3 --- /dev/null +++ b/crates/edit_prediction_cli/evals/flask--add-test-function.md @@ -0,0 +1,164 @@ ++++ +repository_url = "https://github.com/pallets/flask" +revision = "2fec0b206c6e83ea813ab26597e15c96fab08be7" ++++ + +## Edit History + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1376,5 +1376,8 @@ + def test_static_files(app, client): + rv = client.get("/static/index.html") + assert rv.status_code == 200 + assert rv.data.strip() == b"

Hello World!

" + with app.test_request_context(): + assert flask.url_for("static", filename="index.html") == "/static/index.html" + rv.close() + + ++de ++ ++ + def test_static_url_path(): +``` + +## Cursor Position + +```tests/test_basic.py +def test_static_files(app, client): + rv = client.get("/static/index.html") + assert rv.status_code == 200 + assert rv.data.strip() == b"

Hello World!

" + with app.test_request_context(): + assert flask.url_for("static", filename="index.html") == "/static/index.html" + rv.close() + + +de +# ^[CURSOR_POSITION] + + +def test_static_url_path(): +``` + +## Expected Patch + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_(): +# ^[CURSOR_POSITION] ++ + + +def test_static_url_path(): +``` + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_(): +# ^[CURSOR_POSITION] ++ pass + + +def test_static_url_path(): +``` + + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_(app, client): +# ^[CURSOR_POSITION] ++ + + +def test_static_url_path(): +``` + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_(app, client): +# ^[CURSOR_POSITION] ++ pass + + +def test_static_url_path(): +``` + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_static_(): +# ^[CURSOR_POSITION] ++ + + +def test_static_url_path(): +``` + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_static_(): +# ^[CURSOR_POSITION] ++ pass + + +def test_static_url_path(): +``` + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_static_(app, client): +# ^[CURSOR_POSITION] ++ + + +def test_static_url_path(): +``` + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_static_(app, client): +# ^[CURSOR_POSITION] ++ pass + + +def test_static_url_path(): +``` + +```diff +--- a/tests/test_basic.py ++++ b/tests/test_basic.py +@@ -1372,15 +1372,15 @@ +-de ++def test_static_route_with_host_matching(app, client): ++ +# ^[CURSOR_POSITION] + + +def test_static_url_path(): +``` diff --git a/crates/edit_prediction_cli/evals/terraform--add-comment.md b/crates/edit_prediction_cli/evals/terraform--add-comment.md new file mode 100644 index 0000000000000000000000000000000000000000..1fe85575217c1df1082db17dd159b02e60926ce6 --- /dev/null +++ b/crates/edit_prediction_cli/evals/terraform--add-comment.md @@ -0,0 +1,140 @@ ++++ +repository_url = "https://github.com/hashicorp/terraform" +revision = "a3dc571150a7651a1a4a8b302342d26089c97795" ++++ + +## Edit History + +```diff +--- a/internal/actions/actions.go ++++ b/internal/actions/actions.go +@@ -63,6 +63,7 @@ + a.mu.Lock() + defer a.mu.Unlock() + ++ / + result := []addrs.AbsActionInstance{} + for _, data := range a.actionInstances.Elements() { + if data.Key.ContainingAction().Equal(addr) { +``` + +## Cursor Position + +```internal/actions/actions.go + defer a.mu.Unlock() + + data, ok := a.actionInstances.GetOk(addr) + + if !ok { + return nil, false + } + + return &data, true +} + +func (a *Actions) GetActionInstanceKeys(addr addrs.AbsAction) []addrs.AbsActionInstance { + a.mu.Lock() + defer a.mu.Unlock() + + / + // <[CURSOR_POSITION] + result := []addrs.AbsActionInstance{} + for _, data := range a.actionInstances.Elements() { + if data.Key.ContainingAction().Equal(addr) { + result = append(result, data.Key) + } + } + + return result +} +``` + +## Expected Patch + +```diff +--- a/internal/actions/actions.go ++++ b/internal/actions/actions.go +@@ -51,26 +51,26 @@ + func (a *Actions) GetActionInstanceKeys(addr addrs.AbsAction) []addrs.AbsActionInstance { + a.mu.Lock() + defer a.mu.Unlock() + +- / ++ // Filter action instances by the given action. + result := []addrs.AbsActionInstance{} + for _, data := range a.actionInstances.Elements() { + if data.Key.ContainingAction().Equal(addr) { + result = append(result, data.Key) + } + } +``` + +```diff +--- a/internal/actions/actions.go ++++ b/internal/actions/actions.go +@@ -54,25 +54,25 @@ + func (a *Actions) GetActionInstanceKeys(addr addrs.AbsAction) []addrs.AbsActionInstance { + a.mu.Lock() + defer a.mu.Unlock() + +- / ++ // Filter action instances that belong to the given action + result := []addrs.AbsActionInstance{} + for _, data := range a.actionInstances.Elements() { + if data.Key.ContainingAction().Equal(addr) { + result = append(result, data.Key) + } + } +``` + +```diff +--- a/internal/actions/actions.go ++++ b/internal/actions/actions.go +@@ -54,25 +54,25 @@ + func (a *Actions) GetActionInstanceKeys(addr addrs.AbsAction) []addrs.AbsActionInstance { + a.mu.Lock() + defer a.mu.Unlock() + +- / ++ // Iterate through all action instances and filter by the containing action + result := []addrs.AbsActionInstance{} + for _, data := range a.actionInstances.Elements() { + if data.Key.ContainingAction().Equal(addr) { + result = append(result, data.Key) + } + } +``` + +```diff +--- a/internal/actions/actions.go ++++ b/internal/actions/actions.go + func (a *Actions) GetActionInstanceKeys(addr addrs.AbsAction) []addrs.AbsActionInstance { + a.mu.Lock() + defer a.mu.Unlock() + +- / ++ // Iterate through all action instances and return those that belong to the given action + result := []addrs.AbsActionInstance{} + for _, data := range a.actionInstances.Elements() { + if data.Key.ContainingAction().Equal(addr) { + result = append(result, data.Key) + } + } +``` + +```diff +--- a/internal/actions/actions.go ++++ b/internal/actions/actions.go + func (a *Actions) GetActionInstanceKeys(addr addrs.AbsAction) []addrs.AbsActionInstance { + a.mu.Lock() + defer a.mu.Unlock() + +- / ++ // Collect all action instances that belong to the given action + result := []addrs.AbsActionInstance{} + for _, data := range a.actionInstances.Elements() { + if data.Key.ContainingAction().Equal(addr) { + result = append(result, data.Key) + } + } +``` diff --git a/crates/edit_prediction_cli/evals/tree-sitter--if-let-to-match.md b/crates/edit_prediction_cli/evals/tree-sitter--if-let-to-match.md new file mode 100644 index 0000000000000000000000000000000000000000..b1de3bb85207d07f4f70116d45ef2e25124092b7 --- /dev/null +++ b/crates/edit_prediction_cli/evals/tree-sitter--if-let-to-match.md @@ -0,0 +1,112 @@ ++++ +repository_url = "git@github.com:tree-sitter/tree-sitter" +revision = "17e3c7a5c56527a179fa6e37ce7ee934493e5047" ++++ + +## Edit History + +```diff +--- a/crates/loader/src/loader.rs ++++ b/crates/loader/src/loader.rs +@@ -729,15 +729,16 @@ + )); + } + for parser_container_dir in &config.parser_directories { +- if let Ok(entries) = fs::read_dir(parser_container_dir) { +- for entry in entries { +- let entry = entry.map_err(|e| LoaderError::IO(IoError::new(e, None)))?; +- if let Some(parser_dir_name) = entry.file_name().to_str() { +- if parser_dir_name.starts_with("tree-sitter-") { +- self.find_language_configurations_at_path( +- &parser_container_dir.join(parser_dir_name), +- false, +- ) ++ match fs::read_dir(parser_container_dir) { ++ Ok(entries) => { ++ for entry in entries { ++ let entry = entry.map_err(|e| LoaderError::IO(IoError::new(e, None)))?; ++ if let Some(parser_dir_name) = entry.file_name().to_str() { ++ if parser_dir_name.starts_with("tree-sitter-") { ++ self.find_language_configurations_at_path( ++ &parser_container_dir.join(parser_dir_name), ++ false, ++ ) + .ok(); + } + } +--- a/crates/loader/src/loader.rs ++++ b/crates/loader/src/loader.rs +@@ -739,7 +739,8 @@ + &parser_container_dir.join(parser_dir_name), + false, + ) +- .ok(); ++ .ok(); ++ } + } + } + } +``` + +## Cursor Position + +```crates/loader/src/loader.rs + if let Some(parser_dir_name) = entry.file_name().to_str() { + if parser_dir_name.starts_with("tree-sitter-") { + self.find_language_configurations_at_path( + &parser_container_dir.join(parser_dir_name), + false, + ) + .ok(); + } +// ^[CURSOR_POSITION] + } + } + } + } + } +``` + +## Expected Patch + +```diff +--- a/crates/loader/src/loader.rs ++++ b/crates/loader/src/loader.rs +@@ -736,13 +736,13 @@ + if let Some(parser_dir_name) = entry.file_name().to_str() { + if parser_dir_name.starts_with("tree-sitter-") { + self.find_language_configurations_at_path( + &parser_container_dir.join(parser_dir_name), + false, + ) + .ok(); + } + } + } + } ++ Err(error) => {} +# ^[CURSOR_POSITION] + } + } +``` + +```diff +--- a/crates/loader/src/loader.rs ++++ b/crates/loader/src/loader.rs +@@ -736,13 +736,13 @@ + if let Some(parser_dir_name) = entry.file_name().to_str() { + if parser_dir_name.starts_with("tree-sitter-") { + self.find_language_configurations_at_path( + &parser_container_dir.join(parser_dir_name), + false, + ) + .ok(); + } + } + } + } ++ Err(_) => {} +# ^[CURSOR_POSITION] + } + } +``` diff --git a/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-definition.md b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-definition.md new file mode 100644 index 0000000000000000000000000000000000000000..f603c00ea34e543927d1452b8cd1361f8c2bf147 --- /dev/null +++ b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-definition.md @@ -0,0 +1,105 @@ ++++ +repository_url = "git@github.com:tree-sitter/tree-sitter" +revision = "24007727d42b4caceda3095ac685c463fae1ba1a" ++++ + +## Edit History + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -604,7 +604,7 @@ + + pub struct Loader { + pub parser_lib_path: PathBuf, +- languages_by_id: Vec<(PathBuf, OnceCell, Option>)>, ++ languages_by_id: Vec, + language_configurations: Vec>, + language_configuration_ids_by_file_type: HashMap>, + language_configuration_in_current_path: Option, +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -621,6 +621,8 @@ + wasm_store: Mutex>, + } + ++str + pub struct CompileConfig<'a> { + pub src_path: &'a Path, + pub header_paths: Vec<&'a Path>, +``` + +## Cursor Position + +```tree-sitter/crates/loader/src/loader.rs + sanitize_build: bool, + force_rebuild: bool, + + #[cfg(feature = "wasm")] + wasm_store: Mutex>, +} + +str +// ^[CURSOR_POSITION] +pub struct CompileConfig<'a> { + pub src_path: &'a Path, + pub header_paths: Vec<&'a Path>, + pub parser_path: PathBuf, + pub scanner_path: Option, + pub external_files: Option<&'a [PathBuf]>, +``` + +## Expected Patch + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -621,6 +621,8 @@ + wasm_store: Mutex>, + } + +-str ++struct LanguageEntry { ++ path: PathBuf, ++ language: OnceCell, ++ external_files: Option>, ++} ++ + pub struct CompileConfig<'a> { + pub src_path: &'a Path, + pub header_paths: Vec<&'a Path>, +``` + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -621,6 +621,8 @@ + wasm_store: Mutex>, + } + +-str ++struct LanguageEntry { ++ path: PathBuf, ++ language: OnceCell, ++ dependencies: Option>, ++} ++ + pub struct CompileConfig<'a> { + pub src_path: &'a Path, + pub header_paths: Vec<&'a Path>, +``` + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -621,6 +621,8 @@ + wasm_store: Mutex>, + } + +-str ++struct LanguageEntry(PathBuf, OnceCell, Option>); ++ + pub struct CompileConfig<'a> { + pub src_path: &'a Path, + pub header_paths: Vec<&'a Path>, +``` diff --git a/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-destructuring.md b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-destructuring.md new file mode 100644 index 0000000000000000000000000000000000000000..b4c7a9d5b7d0e8bd8d715abc90a9d687a0abd050 --- /dev/null +++ b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-destructuring.md @@ -0,0 +1,131 @@ ++++ +repository_url = "git@github.com:tree-sitter/tree-sitter" +revision = "24007727d42b4caceda3095ac685c463fae1ba1a" ++++ + +## Edit History + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -604,7 +604,7 @@ + + pub struct Loader { + pub parser_lib_path: PathBuf, +- languages_by_id: Vec<(PathBuf, OnceCell, Option>)>, ++ languages_by_id: Vec, + language_configurations: Vec>, + language_configuration_ids_by_file_type: HashMap>, + language_configuration_in_current_path: Option, +@@ -619,6 +619,12 @@ + + #[cfg(feature = "wasm")] + wasm_store: Mutex>, + } ++ ++struct LanguageEntry { ++ path: PathBuf, ++ language: OnceCell, ++ external_files: Option>, ++} + + pub struct CompileConfig<'a> { +@@ -767,7 +773,7 @@ + pub fn get_all_language_configurations(&self) -> Vec<(&LanguageConfiguration, &Path)> { + self.language_configurations + .iter() +- .map(|c| (c, self.languages_by_id[c.language_id].0.as_ref())) ++ .map(|c| (c, self.languages_by_id[c.language_id].path.as_ref())) + .collect() + } + +``` + +## Cursor Position + +```tree-sitter/crates/loader/src/loader.rs + fn language_for_id(&self, id: usize) -> LoaderResult { + let (path, language, externals) = &self.languages_by_id[id]; + // ^[CURSOR_POSITION] + language + .get_or_try_init(|| { + let src_path = path.join("src"); + self.load_language_at_path(CompileConfig::new( + &src_path, + externals.as_deref(), + None, + )) + }) + .cloned() + } +``` + +## Expected Patch + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -926,7 +926,11 @@ + } + + fn language_for_id(&self, id: usize) -> LoaderResult { +- let (path, language, externals) = &self.languages_by_id[id]; ++ let LanguageEntry { ++ path, ++ language, ++ external_files, ++ } = &self.languages_by_id[id]; + language + .get_or_try_init(|| { + let src_path = path.join("src"); + self.load_language_at_path(CompileConfig::new( + &src_path, +- externals.as_deref(), ++ external_files.as_deref(), + None, + )) + }) + .cloned() +``` + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -926,7 +926,11 @@ + } + + fn language_for_id(&self, id: usize) -> LoaderResult { +- let (path, language, externals) = &self.languages_by_id[id]; ++ let LanguageEntry { ++ path, ++ language, ++ external_files: externals, ++ } = &self.languages_by_id[id]; + language + .get_or_try_init(|| { + let src_path = path.join("src"); +``` + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -926,13 +926,14 @@ + } + + fn language_for_id(&self, id: usize) -> LoaderResult { +- let (path, language, externals) = &self.languages_by_id[id]; +- language ++ let entry = &self.languages_by_id[id]; ++ entry ++ .language + .get_or_try_init(|| { +- let src_path = path.join("src"); ++ let src_path = entry.path.join("src"); + self.load_language_at_path(CompileConfig::new( + &src_path, +- externals.as_deref(), ++ entry.external_files.as_deref(), + None, + )) + }) +``` diff --git a/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-field-access.md b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-field-access.md new file mode 100644 index 0000000000000000000000000000000000000000..03b2fb519e2c46e16edb57eb59398e0dbd68538c --- /dev/null +++ b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-field-access.md @@ -0,0 +1,62 @@ ++++ +repository_url = "git@github.com:tree-sitter/tree-sitter" +revision = "24007727d42b4caceda3095ac685c463fae1ba1a" ++++ + +## Edit History + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -604,7 +604,7 @@ + + pub struct Loader { + pub parser_lib_path: PathBuf, +- languages_by_id: Vec<(PathBuf, OnceCell, Option>)>, ++ languages_by_id: Vec, + language_configurations: Vec>, + language_configuration_ids_by_file_type: HashMap>, + language_configuration_in_current_path: Option, +@@ -619,6 +619,12 @@ + + #[cfg(feature = "wasm")] + wasm_store: Mutex>, + } ++ ++struct LanguageEntry { ++ path: PathBuf, ++ language: OnceCell, ++ external_files: Option>, ++} + + pub struct CompileConfig<'a> { +``` + +## Cursor Position + +```tree-sitter/crates/loader/src/loader.rs + #[must_use] + pub fn get_all_language_configurations(&self) -> Vec<(&LanguageConfiguration, &Path)> { + self.language_configurations + .iter() + .map(|c| (c, self.languages_by_id[c.language_id].0.as_ref())) + // ^[CURSOR_POSITION] + .collect() + } +``` + +## Expected Patch + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -773,7 +773,7 @@ + pub fn get_all_language_configurations(&self) -> Vec<(&LanguageConfiguration, &Path)> { + self.language_configurations + .iter() +- .map(|c| (c, self.languages_by_id[c.language_id].0.as_ref())) ++ .map(|c| (c, self.languages_by_id[c.language_id].path.as_ref())) + .collect() + } + +``` diff --git a/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-for-loop.md b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-for-loop.md new file mode 100644 index 0000000000000000000000000000000000000000..aeb7042b1850bece104ff51c40ff2643a04a2b03 --- /dev/null +++ b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-for-loop.md @@ -0,0 +1,106 @@ ++++ +repository_url = "git@github.com:tree-sitter/tree-sitter" +revision = "24007727d42b4caceda3095ac685c463fae1ba1a" ++++ + +## Edit History + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -604,7 +604,7 @@ + + pub struct Loader { + pub parser_lib_path: PathBuf, +- languages_by_id: Vec<(PathBuf, OnceCell, Option>)>, ++ languages_by_id: Vec, + language_configurations: Vec>, + language_configuration_ids_by_file_type: HashMap>, + language_configuration_in_current_path: Option, +@@ -619,6 +619,12 @@ + + #[cfg(feature = "wasm")] + wasm_store: Mutex>, + } ++ ++struct LanguageEntry { ++ path: PathBuf, ++ language: OnceCell, ++ external_files: Option>, ++} + + pub struct CompileConfig<'a> { +@@ -767,7 +773,7 @@ + pub fn get_all_language_configurations(&self) -> Vec<(&LanguageConfiguration, &Path)> { + self.language_configurations + .iter() +- .map(|c| (c, self.languages_by_id[c.language_id].0.as_ref())) ++ .map(|c| (c, self.languages_by_id[c.language_id].path.as_ref())) + .collect() + } + +@@ -920,7 +926,11 @@ + } + + fn language_for_id(&self, id: usize) -> LoaderResult { +- let (path, language, externals) = &self.languages_by_id[id]; ++ let LanguageEntry { ++ path, ++ language, ++ external_files, ++ } = &self.languages_by_id[id]; + language + .get_or_try_init(|| { + let src_path = path.join("src"); +``` + +## Cursor Position + +```tree-sitter/crates/loader/src/loader.rs + let mut language_id = None; + for (id, (path, _, _)) in + // ^[CURSOR_POSITION] + self.languages_by_id.iter().enumerate().skip(language_count) + { + if language_path == *path { + language_id = Some(id); + } + } +``` + +## Expected Patch + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -1542,8 +1542,7 @@ + // Determine if a previous language configuration in this package.json file + // already uses the same language. + let mut language_id = None; +- for (id, (path, _, _)) in +- self.languages_by_id.iter().enumerate().skip(language_count) ++ for (id, entry) in self.languages_by_id.iter().enumerate().skip(language_count) + { +- if language_path == *path { ++ if language_path == entry.path { + language_id = Some(id); + } + } +``` + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -1542,8 +1542,7 @@ + // Determine if a previous language configuration in this package.json file + // already uses the same language. + let mut language_id = None; +- for (id, (path, _, _)) in ++ for (id, LanguageEntry { path, .. }) in + self.languages_by_id.iter().enumerate().skip(language_count) + { + if language_path == *path { + language_id = Some(id); + } + } +``` diff --git a/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-literal.md b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-literal.md new file mode 100644 index 0000000000000000000000000000000000000000..92df291faf3e9bfa2657d4b5e202d8b7d626264b --- /dev/null +++ b/crates/edit_prediction_cli/evals/tree-sitter--tuple-to-struct-literal.md @@ -0,0 +1,144 @@ ++++ +repository_url = "git@github.com:tree-sitter/tree-sitter" +revision = "24007727d42b4caceda3095ac685c463fae1ba1a" ++++ + +## Edit History + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -604,7 +604,7 @@ + + pub struct Loader { + pub parser_lib_path: PathBuf, +- languages_by_id: Vec<(PathBuf, OnceCell, Option>)>, ++ languages_by_id: Vec, + language_configurations: Vec>, + language_configuration_ids_by_file_type: HashMap>, + language_configuration_in_current_path: Option, +@@ -619,6 +619,12 @@ + + #[cfg(feature = "wasm")] + wasm_store: Mutex>, ++} ++ ++struct LanguageEntry { ++ path: PathBuf, ++ language: OnceCell, ++ external_files: Option>, + } + + pub struct CompileConfig<'a> { +@@ -767,7 +773,7 @@ + pub fn get_all_language_configurations(&self) -> Vec<(&LanguageConfiguration, &Path)> { + self.language_configurations + .iter() +- .map(|c| (c, self.languages_by_id[c.language_id].0.as_ref())) ++ .map(|c| (c, self.languages_by_id[c.language_id].path.as_ref())) + .collect() + } + +@@ -920,13 +926,17 @@ + } + + fn language_for_id(&self, id: usize) -> LoaderResult { +- let (path, language, externals) = &self.languages_by_id[id]; ++ let LanguageEntry { ++ path, ++ language, ++ external_files, ++ } = &self.languages_by_id[id]; + language + .get_or_try_init(|| { + let src_path = path.join("src"); + self.load_language_at_path(CompileConfig::new( + &src_path, +- externals.as_deref(), ++ external_files.as_deref(), + None, + )) + }) +@@ -1532,10 +1542,9 @@ + // Determine if a previous language configuration in this package.json file + // already uses the same language. + let mut language_id = None; +- for (id, (path, _, _)) in +- self.languages_by_id.iter().enumerate().skip(language_count) ++ for (id, entry) in self.languages_by_id.iter().enumerate().skip(language_count) + { +- if language_path == *path { ++ if language_path == entry.path { + language_id = Some(id); + } + } +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -1553,10 +1553,10 @@ + let language_id = if let Some(language_id) = language_id { + language_id + } else { +- self.languages_by_id.push(( +- language_path, +- OnceCell::new(), +- grammar ++ self.languages_by_id.push(LanguageEntry { ++ path: language_path, ++ language: OnceCell::new(), ++ external_files: grammar + .external_files + .clone() + .into_vec() +``` + +## Cursor Position + +```tree-sitter/crates/loader/src/loader.rs + let language_id = if let Some(language_id) = language_id { + language_id + } else { + self.languages_by_id.push(LanguageEntry { + path: language_path, + language: OnceCell::new(), + external_files: grammar + .external_files + .clone() + .into_vec() + .map(|files| { + files + .into_iter() + .map(|path| { + let path = parser_path.join(path); + // prevent p being above/outside of parser_path + if path.starts_with(parser_path) { + Ok(path) + } else { + Err(LoaderError::ExternalFile( + path.to_string_lossy().to_string(), + parser_path.to_string_lossy().to_string(), + )) + } + }) + .collect::>>() + }) + .transpose()?, + // ^[CURSOR_POSITION] + )); + self.languages_by_id.len() - 1 + }; +``` + +## Expected Patch + +```diff +--- a/tree-sitter/crates/loader/src/loader.rs ++++ b/tree-sitter/crates/loader/src/loader.rs +@@ -1578,7 +1578,7 @@ + .collect::>>() + }) + .transpose()?, +- )); ++ }); + self.languages_by_id.len() - 1 + }; +``` diff --git a/crates/edit_prediction_cli/evals/zed--add-eprintln.md b/crates/edit_prediction_cli/evals/zed--add-eprintln.md new file mode 100644 index 0000000000000000000000000000000000000000..d4252810b5f97df0991de3015c19e12138e8a27b --- /dev/null +++ b/crates/edit_prediction_cli/evals/zed--add-eprintln.md @@ -0,0 +1,56 @@ ++++ +repository_url = "git@github.com:zed-industries/zed" +revision = "780a87dd98f26816876d12e2728933b17faca78d" ++++ + +## Edit History + +```diff +--- a/crates/edit_prediction_ui/src/rate_prediction_modal.rs ++++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs +@@ -206,6 +206,7 @@ + self.select_next_edit(&Default::default(), window, cx); + self.confirm(&Default::default(), window, cx); + ++ epr + cx.notify(); + } + +``` + +## Cursor Position + +```crates/edit_prediction_ui/src/rate_prediction_modal.rs + let current_completion = self + .active_prediction + .as_ref() + .map(|completion| completion.prediction.clone()); + self.select_completion(current_completion, false, window, cx); + self.select_next_edit(&Default::default(), window, cx); + self.confirm(&Default::default(), window, cx); + + epr + // ^[CURSOR_POSITION] + cx.notify(); + } + + pub fn thumbs_down_active( + &mut self, + _: &ThumbsDownActivePrediction, + window: &mut Window, +``` + +## Expected Patch + +```diff +--- a/crates/edit_prediction_ui/src/rate_prediction_modal.rs ++++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs +@@ -201,16 +201,16 @@ + self.confirm(&Default::default(), window, cx); + +- epr ++ eprintln!(""); +# ^[CURSOR_POSITION] + cx.notify(); + } +``` diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 075d5749b82103de8a2cd9951cc5f1f8b6160f6a..f7021f15b4900fd050f1b2019553528f919e038d 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -58,8 +58,6 @@ pub async fn run_prediction( if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) = provider { - let _step_progress = example_progress.start(Step::Predict); - run_format_prompt( example, &FormatPromptArgs { provider }, @@ -69,8 +67,17 @@ pub async fn run_prediction( ) .await?; + let step_progress = example_progress.start(Step::Predict); let batched = matches!(provider, PredictionProvider::Teacher(..)); - return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await; + return predict_teacher( + example, + backend, + batched, + repetition_count, + args.cache_only, + &step_progress, + ) + .await; } run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; @@ -194,6 +201,16 @@ pub async fn run_prediction( run_dir.clone() }; + if repetition_count > 1 { + step_progress.set_substatus(format!( + "running prediction {}/{}", + ix + 1, + repetition_count + )); + } else { + step_progress.set_substatus("running prediction"); + } + fs::create_dir_all(&run_dir)?; if LATEST_EXAMPLE_RUN_DIR.is_symlink() { fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?; @@ -273,13 +290,30 @@ async fn predict_teacher( batched: bool, repetition_count: usize, cache_only: bool, + step_progress: &crate::progress::StepProgress, ) -> anyhow::Result<()> { match backend { TeacherBackend::Sonnet45 => { - predict_anthropic(example, backend, batched, repetition_count, cache_only).await + predict_anthropic( + example, + backend, + batched, + repetition_count, + cache_only, + step_progress, + ) + .await } TeacherBackend::Gpt52 => { - predict_openai(example, backend, batched, repetition_count, cache_only).await + predict_openai( + example, + backend, + batched, + repetition_count, + cache_only, + step_progress, + ) + .await } } } @@ -290,6 +324,7 @@ async fn predict_anthropic( batched: bool, repetition_count: usize, cache_only: bool, + step_progress: &crate::progress::StepProgress, ) -> anyhow::Result<()> { let llm_model_name = backend.model_name(); let max_tokens = 16384; @@ -305,6 +340,16 @@ async fn predict_anthropic( let prompt = example.prompt.as_ref().context("Prompt is required")?; for ix in 0..repetition_count { + if repetition_count > 1 { + step_progress.set_substatus(format!( + "running prediction {}/{}", + ix + 1, + repetition_count + )); + } else { + step_progress.set_substatus("running prediction"); + } + let messages = vec![anthropic::Message { role: anthropic::Role::User, content: vec![anthropic::RequestContent::Text { @@ -357,6 +402,7 @@ async fn predict_openai( batched: bool, repetition_count: usize, cache_only: bool, + step_progress: &crate::progress::StepProgress, ) -> anyhow::Result<()> { let llm_model_name = backend.model_name(); let max_tokens = 16384; @@ -372,6 +418,16 @@ async fn predict_openai( let prompt = example.prompt.as_ref().context("Prompt is required")?; for ix in 0..repetition_count { + if repetition_count > 1 { + step_progress.set_substatus(format!( + "running prediction {}/{}", + ix + 1, + repetition_count + )); + } else { + step_progress.set_substatus("running prediction"); + } + let messages = vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt.input.clone()), }]; diff --git a/crates/edit_prediction_cli/src/prompts/teacher.md b/crates/edit_prediction_cli/src/prompts/teacher.md index 4a47bebdad02e3afc4fc863bffd3602d6c738a22..4f202b2c6b068371f4788f0f6db9f1af334f7686 100644 --- a/crates/edit_prediction_cli/src/prompts/teacher.md +++ b/crates/edit_prediction_cli/src/prompts/teacher.md @@ -25,8 +25,9 @@ You are an edit prediction assistant in a code editor. Your task is to predict t - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. - Keep existing formatting unless it's absolutely necessary - When edit history and surrounding code suggest different edits, prioritize the most recent edits in the history as they best reflect current intent. -- When uncertain, predict only the minimal, high-confidence portion of the edit. Prefer a small, correct prediction over a large, speculative one - Treat partial text at or near the cursor as the beginning of something the user is actively typing. Complete the code the user appears to be creating based on context. +- When completing partial code, prefer predictions that save meaningful keystrokes, even if this requires making educated guesses about the user's intent. +- It's better to make a substantive prediction that might be rejected than to make a minimal prediction that saves only a few keystrokes. # Input Format @@ -46,8 +47,7 @@ You will be provided with: ````` NO_EDITS ````` -- If the next edit has some uncertainty, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in. - - e.g. if a user is typing `func<|user_cursor|>`, but you don't know what the function name should be, you can predict `function <|user_cursor|>() {}` +- If there is a specific place in the predicted output where the user is likely to edit next, indicate it using the `<|user_cursor|>` tag. ## Example 1 @@ -149,6 +149,60 @@ fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) { ## Example 3 +Here, the user is adding a function. There's no way to tell for sure what the function's name will be. In this situation, you should make a reasonable guess at the function's name and signature, and place the user's cursor in the function body. This way, if you guess correctly, it will save the user a meaningful number of keystrokes, and the file will be left in a coherent state. + +### User Edit History + +````` +--- a/src/modal.rs ++++ b/src/modal.rs +@@ -100,4 +100,4 @@ + fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) { + modal_state.close(); + modal_state.dismiss(); + } ++ ++fn + + fn handle_keystroke(modal_state: &mut ModalState, evt: &Event) { +````` + +### Current File + +`````src/modal.rs +// handle the close button click +fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) { + modal_state.close(); +<|editable_region_start|> + modal_state.dismiss(); +} + +fn<|user_cursor|> + +fn handle_keystroke(modal_state: &mut ModalState, evt: &Event) { +<|editable_region_end|> + modal_state.begin_edit(); +````` + +### Output + +The user is adding a new function. The existing functions I see are `handle_close_button_click` and `handle_keystroke`, which have similar signatures. One possible function they might be adding is `handle_submit`. + +````` +<|editable_region_start|> + modal_state.dismiss(); +} + +fn handle_submit(modal_state: &mut ModalState, evt: &Event) { + <|user_cursor|> +} + +fn handle_keystroke(modal_state: &mut ModalState, evt: &Event) { +<|editable_region_end|> +````` + +## Example 4 + The code is already complete and there is no clear next edit to make. You should output NO_EDITS. ### User Edit History @@ -181,7 +235,7 @@ The user just fixed a bug in the `add` function, changing subtraction to additio NO_EDITS ````` -## Example 4 +## Example 5 The user just deleted code, leaving behind what looks incomplete. You must NOT "complete" it by restoring deleted content—that would undo their edit. Output NO_EDITS. **This is the correct response even though the code appears broken.**