From 7597666c08c8a2bbb45e9b02954112194e31f6f4 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Wed, 8 Apr 2026 17:39:17 -0500 Subject: [PATCH] Track additional metrics in settled (#52938) Stacked on https://github.com/zed-industries/zed/pull/50566. Begin collecting kept chars rate, as well as the count of tree-sitter errors in the code before and after applying the prediction. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [ ] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A or Added/Fixed/Improved ... --- Cargo.lock | 4 +- crates/edit_prediction/Cargo.toml | 14 +- .../benches/kept_rate.rs | 2 +- .../edit_prediction/benches/ts_error_count.rs | 454 ++++++++++++++++++ crates/edit_prediction/src/edit_prediction.rs | 186 ++++--- .../src/edit_prediction_tests.rs | 11 + crates/edit_prediction/src/metrics.rs | 10 + .../src/metrics}/kept_rate.rs | 155 ++++-- .../edit_prediction/src/metrics/tokenize.rs | 54 +++ .../src/metrics/tree_sitter.rs | 88 ++++ crates/edit_prediction/src/zeta.rs | 38 +- crates/edit_prediction_cli/Cargo.toml | 4 - crates/edit_prediction_cli/src/lib.rs | 2 - crates/edit_prediction_cli/src/main.rs | 2 +- crates/edit_prediction_cli/src/metrics.rs | 2 +- crates/language/src/buffer.rs | 8 + 16 files changed, 905 insertions(+), 129 deletions(-) rename crates/{edit_prediction_cli => edit_prediction}/benches/kept_rate.rs (98%) create mode 100644 crates/edit_prediction/benches/ts_error_count.rs create mode 100644 crates/edit_prediction/src/metrics.rs rename crates/{edit_prediction_cli/src => edit_prediction/src/metrics}/kept_rate.rs (77%) create mode 100644 crates/edit_prediction/src/metrics/tokenize.rs create mode 100644 crates/edit_prediction/src/metrics/tree_sitter.rs diff --git a/Cargo.lock b/Cargo.lock index fdd1a67b752a6486ed3ece0e94dc65b2520e1e14..81d697838a0eba0afb874889517c1a822e3ee68e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5176,6 +5176,7 @@ dependencies = [ "copilot", "copilot_ui", "credentials_provider", + "criterion", "ctor", "db", "edit_prediction_context", @@ -5189,9 +5190,11 @@ dependencies = [ "itertools 0.14.0", "language", "language_model", + "languages", "log", "lsp", "menu", + "node_runtime", "open_ai", "parking_lot", "postage", @@ -5235,7 +5238,6 @@ dependencies = [ "client", "cloud_llm_client", "collections", - "criterion", "db", "debug_adapter_extension", "dirs 4.0.0", diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 87ad4e42e7826cdda4fc6a8c31a27afe888830f0..ae6def3686c727c18607b5cc6c135e4a0d16613d 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -72,10 +72,14 @@ zeta_prompt.workspace = true zstd.workspace = true [dev-dependencies] +criterion.workspace = true +fs = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +languages = { workspace = true, features = ["load-grammars"] } +node_runtime.workspace = true clock = { workspace = true, features = ["test-support"] } cloud_llm_client = { workspace = true, features = ["test-support"] } ctor.workspace = true -gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } language_model = { workspace = true, features = ["test-support"] } @@ -86,3 +90,11 @@ settings = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } zlog.workspace = true + +[[bench]] +name = "kept_rate" +harness = false + +[[bench]] +name = "ts_error_count" +harness = false diff --git a/crates/edit_prediction_cli/benches/kept_rate.rs b/crates/edit_prediction/benches/kept_rate.rs similarity index 98% rename from crates/edit_prediction_cli/benches/kept_rate.rs rename to crates/edit_prediction/benches/kept_rate.rs index eccbb42dc0591ee15a0b942a4c326d0e4f2123ee..44defe5d19d3c3a294e07914e39ef8467bf260c9 100644 --- a/crates/edit_prediction_cli/benches/kept_rate.rs +++ b/crates/edit_prediction/benches/kept_rate.rs @@ -1,5 +1,5 @@ use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; -use edit_prediction_cli::kept_rate::compute_kept_rate; +use edit_prediction::metrics::compute_kept_rate; fn repeated_function_lines(line_count: usize) -> String { let mut text = String::with_capacity(line_count * 32); diff --git a/crates/edit_prediction/benches/ts_error_count.rs b/crates/edit_prediction/benches/ts_error_count.rs new file mode 100644 index 0000000000000000000000000000000000000000..518f9ab5bc258fb9726e3bb8b118ba29a84d4c73 --- /dev/null +++ b/crates/edit_prediction/benches/ts_error_count.rs @@ -0,0 +1,454 @@ +use std::sync::Arc; + +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use edit_prediction::metrics::count_tree_sitter_errors; +use fs::FakeFs; +use gpui::{AppContext as _, TestAppContext}; +use language::{Buffer, BufferSnapshot, LanguageRegistry}; +use languages::init as init_languages; +use node_runtime::NodeRuntime; +use settings::SettingsStore; + +struct ParsedCase { + label: String, + bytes: usize, + error_count: usize, + snapshot: BufferSnapshot, +} + +fn replace_nth_occurrences( + source: &mut String, + needle: &str, + replacement: &str, + every: usize, + max_replacements: usize, +) { + let mut rebuilt = String::with_capacity(source.len()); + let mut cursor = 0; + let mut seen = 0; + let mut replaced = 0; + + while let Some(relative_index) = source[cursor..].find(needle) { + let start = cursor + relative_index; + let end = start + needle.len(); + rebuilt.push_str(&source[cursor..start]); + + if seen % every == 0 && replaced < max_replacements { + rebuilt.push_str(replacement); + replaced += 1; + } else { + rebuilt.push_str(needle); + } + + seen += 1; + cursor = end; + } + + rebuilt.push_str(&source[cursor..]); + *source = rebuilt; +} + +fn rust_source(function_count: usize) -> String { + let mut source = String::from( + "pub struct Counter {\n value: usize,\n}\n\nimpl Counter {\n pub fn new() -> Self {\n Self { value: 0 }\n }\n}\n\n", + ); + for index in 0..function_count { + source.push_str(&format!( + "pub fn compute_value_{index}(input: usize) -> usize {{\n let mut total = input;\n for offset in 0..32 {{\n total += offset + {index};\n }}\n if total % 2 == 0 {{\n total / 2\n }} else {{\n total * 3 + 1\n }}\n}}\n\n" + )); + } + source +} + +fn rust_source_with_errors(function_count: usize) -> String { + let mut source = rust_source(function_count); + replace_nth_occurrences( + &mut source, + " if total % 2 == 0 {\n", + " if total % 2 == 0 \n", + 17, + 48, + ); + source +} + +fn python_source(function_count: usize) -> String { + let mut source = String::from( + "class Counter:\n def __init__(self) -> None:\n self.value = 0\n\n\n", + ); + for index in 0..function_count { + source.push_str(&format!( + "def compute_value_{index}(input_value: int) -> int:\n total = input_value\n for offset in range(32):\n total += offset + {index}\n if total % 2 == 0:\n return total // 2\n return total * 3 + 1\n\n" + )); + } + source +} + +fn python_source_with_errors(function_count: usize) -> String { + let mut source = python_source(function_count); + replace_nth_occurrences( + &mut source, + " if total % 2 == 0:\n", + " if total % 2 == 0\n", + 19, + 48, + ); + source +} + +fn go_source(function_count: usize) -> String { + let mut source = String::from( + "package bench\n\ntype Counter struct {\n\tvalue int\n}\n\nfunc NewCounter() Counter {\n\treturn Counter{value: 0}\n}\n\n", + ); + for index in 0..function_count { + source.push_str(&format!( + "func ComputeValue{index}(inputValue int) int {{\n\ttotal := inputValue\n\tfor offset := 0; offset < 32; offset++ {{\n\t\ttotal += offset + {index}\n\t}}\n\tif total%2 == 0 {{\n\t\treturn total / 2\n\t}}\n\treturn total*3 + 1\n}}\n\n" + )); + } + source +} + +fn go_source_with_errors(function_count: usize) -> String { + let mut source = go_source(function_count); + replace_nth_occurrences( + &mut source, + "\tfor offset := 0; offset < 32; offset++ {\n", + "\tfor offset := 0; offset < 32; offset++ \n", + 17, + 48, + ); + source +} + +fn typescript_source(function_count: usize) -> String { + let mut source = String::from( + "export type Counter = { value: number };\n\nexport function newCounter(): Counter {\n return { value: 0 };\n}\n\n", + ); + for index in 0..function_count { + source.push_str(&format!( + "export function computeValue{index}(inputValue: number): number {{\n let total = inputValue;\n for (let offset = 0; offset < 32; offset += 1) {{\n total += offset + {index};\n }}\n return total % 2 === 0 ? total / 2 : total * 3 + 1;\n}}\n\n" + )); + } + source +} + +fn typescript_source_with_errors(function_count: usize) -> String { + let mut source = typescript_source(function_count); + replace_nth_occurrences( + &mut source, + " return total % 2 === 0 ? total / 2 : total * 3 + 1;\n", + " return total % 2 === 0 ? total / 2 : ;\n", + 17, + 64, + ); + source +} + +fn tsx_source(component_count: usize) -> String { + let mut source = String::from( + "type ItemProps = { index: number; label: string };\n\nfunction Item({ index, label }: ItemProps) {\n return
  • {label}
  • ;\n}\n\nexport function App() {\n return
    ;\n}\n\n", + ); + for index in 0..component_count { + source.push_str(&format!( + "export function Widget{index}(): JSX.Element {{\n const items = Array.from({{ length: 16 }}, (_, value) => value + {index});\n return (\n
    \n

    Widget {index}

    \n \n
    \n );\n}}\n\n" + )); + } + source +} + +fn tsx_source_with_errors(component_count: usize) -> String { + let mut source = tsx_source(component_count); + replace_nth_occurrences( + &mut source, + " const items = Array.from({ length: 16 }, (_, value) => value + ", + " const items = Array.from({ length: 16 }, (_, value) => ); // ", + 11, + 32, + ); + source +} + +fn json_source(object_count: usize) -> String { + let mut source = String::from("{\n \"items\": [\n"); + for index in 0..object_count { + let suffix = if index + 1 == object_count { "" } else { "," }; + source.push_str(&format!( + " {{\n \"id\": {index},\n \"name\": \"item-{index}\",\n \"enabled\": true,\n \"tags\": [\"alpha\", \"beta\", \"gamma\"],\n \"metrics\": {{ \"count\": {}, \"ratio\": {} }}\n }}{suffix}\n", + index * 3 + 1, + index as f64 / 10.0, + )); + } + source.push_str(" ]\n}\n"); + source +} + +fn json_source_with_errors(object_count: usize) -> String { + let mut source = json_source(object_count); + replace_nth_occurrences( + &mut source, + " \"enabled\": true,\n", + " \"enabled\": ,\n", + 23, + 64, + ); + source +} + +fn yaml_source(document_count: usize) -> String { + let mut source = String::new(); + for index in 0..document_count { + source.push_str(&format!( + "- id: {index}\n name: item-{index}\n enabled: true\n tags:\n - alpha\n - beta\n - gamma\n metrics:\n count: {}\n ratio: {}\n", + index * 3 + 1, + index as f64 / 10.0, + )); + } + source +} + +fn yaml_source_with_errors(document_count: usize) -> String { + let mut source = yaml_source(document_count); + replace_nth_occurrences(&mut source, " count: ", " count ", 23, 64); + source +} + +fn css_source(rule_count: usize) -> String { + let mut source = String::new(); + for index in 0..rule_count { + source.push_str(&format!( + ".widget-{index} {{\n display: grid;\n grid-template-columns: repeat(4, minmax(0, 1fr));\n gap: 12px;\n padding: 8px;\n color: rgb({}, {}, {});\n}}\n\n.widget-{index} > .item-{index} {{\n border: 1px solid rgba(0, 0, 0, 0.15);\n background: linear-gradient(90deg, #fff, #eef);\n}}\n\n", + (index * 17) % 255, + (index * 31) % 255, + (index * 47) % 255, + )); + } + source +} + +fn css_source_with_errors(rule_count: usize) -> String { + let mut source = css_source(rule_count); + replace_nth_occurrences(&mut source, " gap: 12px;\n", " gap 12px;\n", 29, 64); + source +} + +fn build_case( + context: &mut TestAppContext, + languages: &Arc, + language_name: &'static str, + variant_name: &'static str, + source: String, + expect_errors: bool, +) -> ParsedCase { + let language_task = context.background_spawn({ + let languages = languages.clone(); + async move { languages.language_for_name(language_name).await } + }); + while !language_task.is_ready() { + context.run_until_parked(); + } + let language = futures::executor::block_on(language_task) + .unwrap_or_else(|error| panic!("failed to load {language_name}: {error}")); + + let buffer = context.new(|cx| Buffer::local(source, cx).with_language(language, cx)); + context.run_until_parked(); + while buffer.read_with(context, |buffer, _| buffer.is_parsing()) { + context.run_until_parked(); + } + + let snapshot = buffer.read_with(context, |buffer, _| buffer.snapshot()); + let full_range = 0..snapshot.text.len(); + let error_count = count_tree_sitter_errors(snapshot.syntax_layers()); + if expect_errors { + assert!( + error_count > 0, + "expected tree-sitter errors for {language_name}/{variant_name}", + ); + } else { + assert_eq!( + error_count, 0, + "expected no tree-sitter errors for {language_name}/{variant_name}", + ); + } + + let label = format!( + "{}/{}_{}kb_{}e", + language_name.to_lowercase(), + variant_name, + full_range.end / 1024, + error_count, + ); + ParsedCase { + label, + bytes: full_range.end, + error_count, + snapshot, + } +} + +fn parsed_cases() -> Vec { + let mut context = TestAppContext::single(); + context.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + }); + + let languages = Arc::new(LanguageRegistry::new(context.executor())); + let fs = FakeFs::new(context.executor()); + let node_runtime = NodeRuntime::unavailable(); + context.update(|cx| init_languages(languages.clone(), fs, node_runtime, cx)); + + vec![ + build_case( + &mut context, + &languages, + "Rust", + "valid", + rust_source(900), + false, + ), + build_case( + &mut context, + &languages, + "Rust", + "error_heavy", + rust_source_with_errors(900), + true, + ), + build_case( + &mut context, + &languages, + "Python", + "valid", + python_source(1100), + false, + ), + build_case( + &mut context, + &languages, + "Python", + "error_heavy", + python_source_with_errors(1100), + true, + ), + build_case( + &mut context, + &languages, + "Go", + "valid", + go_source(1000), + false, + ), + build_case( + &mut context, + &languages, + "Go", + "error_heavy", + go_source_with_errors(1000), + true, + ), + build_case( + &mut context, + &languages, + "TypeScript", + "valid", + typescript_source(1000), + false, + ), + build_case( + &mut context, + &languages, + "TypeScript", + "error_heavy", + typescript_source_with_errors(1000), + true, + ), + build_case( + &mut context, + &languages, + "TSX", + "valid", + tsx_source(350), + false, + ), + build_case( + &mut context, + &languages, + "TSX", + "error_heavy", + tsx_source_with_errors(350), + true, + ), + build_case( + &mut context, + &languages, + "JSON", + "valid", + json_source(2200), + false, + ), + build_case( + &mut context, + &languages, + "JSON", + "error_heavy", + json_source_with_errors(2200), + true, + ), + build_case( + &mut context, + &languages, + "YAML", + "valid", + yaml_source(2200), + false, + ), + build_case( + &mut context, + &languages, + "YAML", + "error_heavy", + yaml_source_with_errors(2200), + true, + ), + build_case( + &mut context, + &languages, + "CSS", + "valid", + css_source(2400), + false, + ), + build_case( + &mut context, + &languages, + "CSS", + "error_heavy", + css_source_with_errors(2400), + true, + ), + ] +} + +fn ts_error_count_benchmark(c: &mut Criterion) { + let cases = parsed_cases(); + let mut group = c.benchmark_group("ts_error_count/full_file"); + + for case in &cases { + group.bench_with_input( + BenchmarkId::from_parameter(&case.label), + case, + |bench, case| { + bench.iter(|| { + black_box(case.bytes); + black_box(case.error_count); + black_box(count_tree_sitter_errors(case.snapshot.syntax_layers())) + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, ts_error_count_benchmark); +criterion_main!(benches); diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 2d90e13fb9b45aedd354f753502cd4e616ae3bcd..9148a0bb62462a6ab32ce4837312c5de701d21f2 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -30,7 +30,7 @@ use gpui::{ }; use heapless::Vec as ArrayVec; use language::language_settings::all_language_settings; -use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; +use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; @@ -61,6 +61,7 @@ pub mod example_spec; pub mod fim; mod license_detection; pub mod mercury; +pub mod metrics; pub mod ollama; mod onboarding_modal; pub mod open_ai_response; @@ -80,6 +81,7 @@ use crate::cursor_excerpt::expand_context_syntactically_then_linewise; use crate::example_spec::ExampleSpec; use crate::license_detection::LicenseDetectionWatcher; use crate::mercury::Mercury; +pub use crate::metrics::{KeptRateResult, compute_kept_rate}; use crate::onboarding_modal::ZedPredictModal; pub use crate::prediction::EditPrediction; pub use crate::prediction::EditPredictionId; @@ -478,10 +480,13 @@ impl std::ops::Deref for BufferEditPrediction<'_> { } #[derive(Clone)] - struct PendingSettledPrediction { request_id: EditPredictionId, editable_anchor_range: Range, + editable_region_before_prediction: String, + predicted_editable_region: String, + ts_error_count_before_prediction: usize, + ts_error_count_after_prediction: usize, example: Option, enqueued_at: Instant, last_edit_at: Instant, @@ -1603,63 +1608,100 @@ impl EditPredictionStore { }; let now = cx.background_executor().now(); - let mut oldest_edited_at = None; + let mut ready_predictions = Vec::new(); this.update(cx, |this, _| { for (_, project_state) in this.projects.iter_mut() { for (_, registered_buffer) in project_state.registered_buffers.iter_mut() { - registered_buffer - .pending_predictions - .retain_mut(|pending_prediction| { - let age = - now.saturating_duration_since(pending_prediction.enqueued_at); - if age >= EDIT_PREDICTION_SETTLED_TTL { - return false; - } - - let quiet_for = - now.saturating_duration_since(pending_prediction.last_edit_at); - if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE { - let settled_editable_region = registered_buffer - .snapshot - .text_for_range( - pending_prediction.editable_anchor_range.clone(), - ) - .collect::(); - - #[cfg(test)] - if let Some(callback) = &this.settled_event_callback { - callback( - pending_prediction.request_id.clone(), - settled_editable_region.clone(), - ); - } - - telemetry::event!( - EDIT_PREDICTION_SETTLED_EVENT, - request_id = pending_prediction.request_id.0.clone(), - settled_editable_region, - example = pending_prediction.example.take(), - e2e_latency = pending_prediction.e2e_latency.as_millis(), - ); - - return false; - } + let mut pending_index = 0; + while pending_index < registered_buffer.pending_predictions.len() { + let pending_prediction = + ®istered_buffer.pending_predictions[pending_index]; + let age = now.saturating_duration_since(pending_prediction.enqueued_at); + if age >= EDIT_PREDICTION_SETTLED_TTL { + registered_buffer.pending_predictions.remove(pending_index); + continue; + } - if oldest_edited_at - .is_none_or(|t| pending_prediction.last_edit_at < t) - { - oldest_edited_at = Some(pending_prediction.last_edit_at); - } + let quiet_for = + now.saturating_duration_since(pending_prediction.last_edit_at); + if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE { + let pending_prediction = + registered_buffer.pending_predictions.remove(pending_index); + let settled_editable_region = registered_buffer + .snapshot + .text_for_range( + pending_prediction.editable_anchor_range.clone(), + ) + .collect::(); + ready_predictions + .push((pending_prediction, settled_editable_region)); + continue; + } - true - }); + if oldest_edited_at + .is_none_or(|time| pending_prediction.last_edit_at < time) + { + oldest_edited_at = Some(pending_prediction.last_edit_at); + } + pending_index += 1; + } } } }); - next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE); + for (pending_prediction, settled_editable_region) in ready_predictions { + let PendingSettledPrediction { + request_id, + editable_region_before_prediction, + predicted_editable_region, + ts_error_count_before_prediction, + ts_error_count_after_prediction, + example, + e2e_latency, + .. + } = pending_prediction; + let settled_editable_region_for_metrics = settled_editable_region.clone(); + let kept_rate_result = cx + .background_spawn(async move { + compute_kept_rate( + &editable_region_before_prediction, + &predicted_editable_region, + &settled_editable_region_for_metrics, + ) + }) + .await; + + #[cfg(test)] + { + let request_id = request_id.clone(); + let settled_editable_region = settled_editable_region.clone(); + this.update(cx, |this, _| { + if let Some(callback) = &this.settled_event_callback { + callback(request_id, settled_editable_region); + } + }); + } + + telemetry::event!( + EDIT_PREDICTION_SETTLED_EVENT, + request_id = request_id.0.clone(), + settled_editable_region, + ts_error_count_before_prediction, + ts_error_count_after_prediction, + edit_bytes_predicted_new = kept_rate_result.predicted_new_chars, + edit_bytes_final_new = kept_rate_result.final_new_chars, + edit_bytes_kept = kept_rate_result.kept_chars, + edit_bytes_discarded = kept_rate_result.discarded_chars, + edit_bytes_context = kept_rate_result.context_chars, + edit_bytes_kept_rate = kept_rate_result.kept_rate, + example, + e2e_latency = e2e_latency.as_millis(), + ); + } + + next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE); } } @@ -1670,28 +1712,58 @@ impl EditPredictionStore { edited_buffer: &Entity, edited_buffer_snapshot: &BufferSnapshot, editable_offset_range: Range, + edit_preview: &EditPreview, example: Option, e2e_latency: std::time::Duration, cx: &mut Context, ) { let this = &mut *self; let project_state = this.get_or_init_project(project, cx); - if let Some(buffer) = project_state + let Some(registered_buffer) = project_state .registered_buffers .get_mut(&edited_buffer.entity_id()) - { - let now = cx.background_executor().now(); - buffer.pending_predictions.push(PendingSettledPrediction { - request_id: request_id, - editable_anchor_range: edited_buffer_snapshot - .anchor_range_inside(editable_offset_range), + else { + return; + }; + + let editable_region_before_prediction = edited_buffer_snapshot + .text_for_range(editable_offset_range.clone()) + .collect::(); + let editable_anchor_range_for_result = + edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone()); + let predicted_editable_region = edit_preview + .result_text_snapshot() + .text_for_range(editable_anchor_range_for_result.clone()) + .collect(); + let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors( + edited_buffer_snapshot + .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true), + ); + let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors( + edit_preview.result_syntax_snapshot().layers_for_range( + editable_anchor_range_for_result, + edit_preview.result_text_snapshot(), + true, + ), + ); + let editable_anchor_range = + edited_buffer_snapshot.anchor_range_inside(editable_offset_range); + let now = cx.background_executor().now(); + registered_buffer + .pending_predictions + .push(PendingSettledPrediction { + request_id, + editable_anchor_range, + editable_region_before_prediction, + predicted_editable_region, + ts_error_count_before_prediction, + ts_error_count_after_prediction, example, e2e_latency, enqueued_at: now, last_edit_at: now, }); - this.settled_predictions_tx.unbounded_send(now).ok(); - } + this.settled_predictions_tx.unbounded_send(now).ok(); } fn reject_current_prediction( diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index ea7233cd976148f5eb726730635e0efaf6ceef86..54dabf93f8da290d76c13222ae5a110e80d0b388 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -3252,6 +3252,12 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { cx.run_until_parked(); let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let empty_edits: Arc<[(Range, Arc)]> = Vec::new().into(); + let edit_preview_a = buffer + .read_with(cx, |buffer, cx| { + buffer.preview_edits(empty_edits.clone(), cx) + }) + .await; // Region A: first 10 lines of the buffer. let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0)); @@ -3263,6 +3269,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { &buffer, &snapshot_a, editable_region_a.clone(), + &edit_preview_a, None, Duration::from_secs(0), cx, @@ -3318,6 +3325,9 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { cx.run_until_parked(); let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let edit_preview_b = buffer + .read_with(cx, |buffer, cx| buffer.preview_edits(empty_edits, cx)) + .await; let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0)); ep_store.update(cx, |ep_store, cx| { @@ -3327,6 +3337,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { &buffer, &snapshot_b2, editable_region_b.clone(), + &edit_preview_b, None, Duration::from_secs(0), cx, diff --git a/crates/edit_prediction/src/metrics.rs b/crates/edit_prediction/src/metrics.rs new file mode 100644 index 0000000000000000000000000000000000000000..20abd683a53fa34397a32a24abb0b49f553c0895 --- /dev/null +++ b/crates/edit_prediction/src/metrics.rs @@ -0,0 +1,10 @@ +mod kept_rate; +mod tokenize; +mod tree_sitter; + +pub use kept_rate::KeptRateResult; +#[cfg(test)] +pub use kept_rate::TokenAnnotation; +pub use kept_rate::compute_kept_rate; +pub(crate) use tokenize::tokenize; +pub use tree_sitter::count_tree_sitter_errors; diff --git a/crates/edit_prediction_cli/src/kept_rate.rs b/crates/edit_prediction/src/metrics/kept_rate.rs similarity index 77% rename from crates/edit_prediction_cli/src/kept_rate.rs rename to crates/edit_prediction/src/metrics/kept_rate.rs index 565597fd12b567e7f7f23be233b87ba2284a176f..4843c4465251756f47b9f1e82726c70bba6940c4 100644 --- a/crates/edit_prediction_cli/src/kept_rate.rs +++ b/crates/edit_prediction/src/metrics/kept_rate.rs @@ -1,4 +1,6 @@ -use crate::word_diff::tokenize; +use crate::metrics::tokenize; + +const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512; #[cfg(test)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -25,20 +27,28 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize { row * width + column } -/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each -/// side while sharing a single DP table construction. -fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec, Vec) { +/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side +/// while sharing a single DP table construction. +fn fill_lcs_keep_masks( + a: &[&str], + b: &[&str], + mut keep_a: Option<&mut [bool]>, + mut keep_b: Option<&mut [bool]>, +) { if a.is_empty() || b.is_empty() { - return (vec![false; a.len()], vec![false; b.len()]); + return; } if a == b { - return (vec![true; a.len()], vec![true; b.len()]); + if let Some(keep_a) = keep_a.as_mut() { + keep_a.fill(true); + } + if let Some(keep_b) = keep_b.as_mut() { + keep_b.fill(true); + } + return; } - let mut keep_a = vec![false; a.len()]; - let mut keep_b = vec![false; b.len()]; - let prefix_len = a .iter() .zip(b.iter()) @@ -61,22 +71,30 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec, Vec) { }; for index in 0..prefix_len { - keep_a[index] = true; - keep_b[index] = true; + if let Some(keep_a) = keep_a.as_mut() { + keep_a[index] = true; + } + if let Some(keep_b) = keep_b.as_mut() { + keep_b[index] = true; + } } for offset in 0..suffix_len { let a_index = a.len() - suffix_len + offset; let b_index = b.len() - suffix_len + offset; - keep_a[a_index] = true; - keep_b[b_index] = true; + if let Some(keep_a) = keep_a.as_mut() { + keep_a[a_index] = true; + } + if let Some(keep_b) = keep_b.as_mut() { + keep_b[b_index] = true; + } } let a_mid = &a[prefix_len..a.len() - suffix_len]; let b_mid = &b[prefix_len..b.len() - suffix_len]; if a_mid.is_empty() || b_mid.is_empty() { - return (keep_a, keep_b); + return; } let row_count = a_mid.len() + 1; @@ -97,44 +115,59 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec, Vec) { } } - let mut i = a_mid.len(); - let mut j = b_mid.len(); + if let Some(keep_a) = keep_a.as_mut() { + let mut i = a_mid.len(); + let mut j = b_mid.len(); - while i > 0 && j > 0 { - if a_mid[i - 1] == b_mid[j - 1] { - keep_a[prefix_len + i - 1] = true; - i -= 1; - j -= 1; - } else { - let up = dp[dp_index(column_count, i - 1, j)]; - let left = dp[dp_index(column_count, i, j - 1)]; - if up >= left { + while i > 0 && j > 0 { + if a_mid[i - 1] == b_mid[j - 1] { + keep_a[prefix_len + i - 1] = true; i -= 1; - } else { j -= 1; + } else { + let up = dp[dp_index(column_count, i - 1, j)]; + let left = dp[dp_index(column_count, i, j - 1)]; + if up >= left { + i -= 1; + } else { + j -= 1; + } } } } - let mut i = a_mid.len(); - let mut j = b_mid.len(); + if let Some(keep_b) = keep_b.as_mut() { + let mut i = a_mid.len(); + let mut j = b_mid.len(); - while i > 0 && j > 0 { - if a_mid[i - 1] == b_mid[j - 1] { - keep_b[prefix_len + j - 1] = true; - i -= 1; - j -= 1; - } else { - let up = dp[dp_index(column_count, i - 1, j)]; - let left = dp[dp_index(column_count, i, j - 1)]; - if left >= up { + while i > 0 && j > 0 { + if a_mid[i - 1] == b_mid[j - 1] { + keep_b[prefix_len + j - 1] = true; + i -= 1; j -= 1; } else { - i -= 1; + let up = dp[dp_index(column_count, i - 1, j)]; + let left = dp[dp_index(column_count, i, j - 1)]; + if left >= up { + j -= 1; + } else { + i -= 1; + } } } } +} + +fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec { + let mut keep_a = vec![false; a.len()]; + fill_lcs_keep_masks(a, b, Some(&mut keep_a), None); + keep_a +} +fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec, Vec) { + let mut keep_a = vec![false; a.len()]; + let mut keep_b = vec![false; b.len()]; + fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b)); (keep_a, keep_b) } @@ -155,6 +188,12 @@ fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str> (unmasked_tokens, unmasked_chars, masked_chars) } +fn should_bail_for_dirty_final(base: &str, predicted: &str, final_text: &str) -> bool { + let predicted_delta_chars = predicted.len().abs_diff(base.len()); + let final_delta_chars = final_text.len().abs_diff(base.len()); + predicted_delta_chars.abs_diff(final_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS +} + pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult { if base == predicted && predicted == final_text { let predicted_tokens = tokenize(predicted); @@ -171,11 +210,26 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR }; } + if should_bail_for_dirty_final(base, predicted, final_text) { + let predicted_new_chars = predicted.len().abs_diff(base.len()); + let final_new_chars = final_text.len().abs_diff(base.len()); + return KeptRateResult { + predicted_new_chars, + final_new_chars, + kept_chars: 0, + discarded_chars: predicted_new_chars, + context_chars: 0, + kept_rate: 0.0, + #[cfg(test)] + token_annotations: vec![TokenAnnotation::Discarded; tokenize(predicted).len()], + }; + } + let base_tokens = tokenize(base); let predicted_tokens = tokenize(predicted); let final_tokens = tokenize(final_text); - let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens); + let pred_base_mask = lcs_keep_mask(&predicted_tokens, &base_tokens); let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens); let context_mask: Vec = pred_base_mask .iter() @@ -186,7 +240,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR let (stripped_predicted, predicted_new_chars, context_chars) = analyze_masked_tokens(&predicted_tokens, &context_mask); - let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens); + let final_base_mask = lcs_keep_mask(&final_tokens, &base_tokens); let final_context_mask: Vec = final_base_mask .iter() .zip(final_pred_mask.iter()) @@ -196,7 +250,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR let (stripped_final, final_new_chars, _) = analyze_masked_tokens(&final_tokens, &final_context_mask); - let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0; + let keep_mask = lcs_keep_mask(&stripped_predicted, &stripped_final); let kept_chars: usize = stripped_predicted .iter() @@ -265,8 +319,8 @@ mod test_kept_rate { let a = ["x", "a", "x", "b"]; let b = ["a", "x", "b", "x"]; let (a_mask, b_mask) = lcs_keep_masks(&a, &b); - assert_eq!(a_mask, lcs_keep_masks(&a, &b).0); - assert_eq!(b_mask, lcs_keep_masks(&b, &a).0); + assert_eq!(a_mask, lcs_keep_mask(&a, &b)); + assert_eq!(b_mask, lcs_keep_mask(&b, &a)); } #[test] @@ -342,6 +396,21 @@ mod test_kept_rate { assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0); } + #[test] + fn test_bails_for_dirty_final() { + let base = "fn example() {\n work();\n}\n"; + let predicted = "fn example() {\n work();\n predicted();\n}\n"; + let final_text = format!( + "fn example() {{\n work();\n {}\n}}\n", + "settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64) + ); + + let result = compute_kept_rate(base, predicted, &final_text); + assert_eq!(result.kept_rate, 0.0); + assert_eq!(result.kept_chars, 0); + assert_eq!(result.discarded_chars, result.predicted_new_chars); + } + #[test] fn test_eprintln_token_alignment() { let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) {\n epr\n"; diff --git a/crates/edit_prediction/src/metrics/tokenize.rs b/crates/edit_prediction/src/metrics/tokenize.rs new file mode 100644 index 0000000000000000000000000000000000000000..250a5c15167cbcd00e5ee3fb0397cfed011be5bb --- /dev/null +++ b/crates/edit_prediction/src/metrics/tokenize.rs @@ -0,0 +1,54 @@ +fn char_class(character: char) -> u8 { + if character.is_alphanumeric() || character == '_' { + 0 + } else if character.is_whitespace() { + 1 + } else { + 2 + } +} + +pub(crate) fn tokenize(text: &str) -> Vec<&str> { + let mut tokens = Vec::new(); + let mut characters = text.char_indices().peekable(); + + while let Some((start, character)) = characters.next() { + let class = char_class(character); + if class == 2 { + tokens.push(&text[start..start + character.len_utf8()]); + continue; + } + + let mut end = start + character.len_utf8(); + while let Some(&(_, next_character)) = characters.peek() { + if char_class(next_character) != class { + break; + } + end += next_character.len_utf8(); + characters.next(); + } + tokens.push(&text[start..end]); + } + + tokens +} + +#[cfg(test)] +mod tests { + use super::tokenize; + + #[test] + fn tokenizes_code_like_text() { + assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]); + assert_eq!( + tokenize("foo_bar123 + baz"), + vec!["foo_bar123", " ", "+", " ", "baz"] + ); + assert_eq!( + tokenize("print(\"hello\")"), + vec!["print", "(", "\"", "hello", "\"", ")"] + ); + assert_eq!(tokenize("hello_world"), vec!["hello_world"]); + assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]); + } +} diff --git a/crates/edit_prediction/src/metrics/tree_sitter.rs b/crates/edit_prediction/src/metrics/tree_sitter.rs new file mode 100644 index 0000000000000000000000000000000000000000..1bb200289ca5007fd4711f0cb46c80ea1153bf28 --- /dev/null +++ b/crates/edit_prediction/src/metrics/tree_sitter.rs @@ -0,0 +1,88 @@ +use language::SyntaxLayer; + +pub fn count_tree_sitter_errors<'a>(layers: impl Iterator>) -> usize { + let mut total_count: usize = 0; + for layer in layers { + let node = layer.node(); + let mut cursor = node.walk(); + 'layer: loop { + let current = cursor.node(); + if current.is_error() || current.is_missing() { + total_count += 1; + } + if current.has_error() && cursor.goto_first_child() { + continue; + } + if cursor.goto_next_sibling() { + continue; + } + loop { + if !cursor.goto_parent() { + break 'layer; + } + if cursor.goto_next_sibling() { + continue; + } + } + } + } + total_count +} + +#[cfg(test)] +mod tests { + use std::ops::Range; + + use super::count_tree_sitter_errors; + use gpui::{AppContext as _, TestAppContext}; + use language::{Buffer, BufferSnapshot, rust_lang}; + + fn error_count_in_range(edited_buffer_snapshot: &BufferSnapshot, range: Range) -> usize { + let layers = edited_buffer_snapshot.syntax_layers_for_range(range, true); + count_tree_sitter_errors(layers) + } + + fn rust_snapshot(text: &str, cx: &mut TestAppContext) -> BufferSnapshot { + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); + while buffer.read_with(cx, |buffer, _| buffer.is_parsing()) { + cx.run_until_parked(); + } + buffer.read_with(cx, |buffer, _| buffer.snapshot()) + } + + #[gpui::test] + async fn counts_no_errors_for_valid_rust(cx: &mut TestAppContext) { + let text = "fn helper(value: usize) -> usize {\n value + 1\n}\n"; + let snapshot = rust_snapshot(text, cx); + + assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 0); + } + + #[gpui::test] + async fn counts_errors_for_invalid_rust(cx: &mut TestAppContext) { + let text = "fn helper(value: usize) -> usize {\n let total = ;\n total\n}\n"; + let snapshot = rust_snapshot(text, cx); + + assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 1); + } + + #[gpui::test] + async fn counts_no_errors_for_subrange_of_valid_rust(cx: &mut TestAppContext) { + let text = "fn first() -> usize {\n let value = 1;\n value + 1\n}\n"; + let snapshot = rust_snapshot(text, cx); + let body_start = text.find("let value").unwrap(); + let body_end = body_start + "let value = 1;".len(); + + assert_eq!(error_count_in_range(&snapshot, body_start..body_end), 0); + } + + #[gpui::test] + async fn counts_errors_for_subrange_of_invalid_rust(cx: &mut TestAppContext) { + let text = "fn second() -> usize {\n let broken = ;\n broken\n}\n"; + let snapshot = rust_snapshot(text, cx); + let error_start = text.find("let broken = ;").unwrap(); + let error_end = error_start + "let broken = ;".len(); + + assert_eq!(error_count_in_range(&snapshot, error_start..error_end), 1); + } +} diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index b4556e58b9247624e2d4caeddb5614ff5000d854..1173cd047a93253add13da946f02cbccb8da55f9 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -85,7 +85,6 @@ pub fn request_prediction_with_zeta( } else { None }; - let client = store.client.clone(); let llm_token = store.llm_token.clone(); let organization_id = store @@ -383,11 +382,26 @@ pub fn request_prediction_with_zeta( })); }; - if can_collect_data { + let result = EditPredictionResult::new( + id, + &edited_buffer, + &edited_buffer_snapshot, + edits.into(), + cursor_position, + inputs, + model_version, + request_duration, + cx, + ) + .await; + + if can_collect_data && let Ok(prediction) = &result.prediction { let weak_this = this.clone(); - let id = id.clone(); + let request_id = prediction.id.clone(); let edited_buffer = edited_buffer.clone(); let edited_buffer_snapshot = edited_buffer_snapshot.clone(); + let editable_range_in_buffer = editable_range_in_buffer.clone(); + let edit_preview = prediction.edit_preview.clone(); let example_task = capture_data.and_then(|stored_events| { cx.update(|cx| { crate::capture_example( @@ -410,11 +424,12 @@ pub fn request_prediction_with_zeta( weak_this .update(cx, |this, cx| { this.enqueue_settled_prediction( - id.clone(), + request_id.clone(), &project, &edited_buffer, &edited_buffer_snapshot, editable_range_in_buffer, + &edit_preview, example_spec, request_duration, cx, @@ -425,20 +440,7 @@ pub fn request_prediction_with_zeta( .detach(); } - Ok(Some( - EditPredictionResult::new( - id, - &edited_buffer, - &edited_buffer_snapshot, - edits.into(), - cursor_position, - inputs, - model_version, - request_duration, - cx, - ) - .await, - )) + Ok(Some(result)) }) } diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index a999fed2baf990273f0801bac15573b3aed0cc78..8aa4ff63aca1d9b6f418924c4ccc232d368d5a69 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -83,7 +83,6 @@ dynamic_prompts = [] ignored = ["wasmtime"] [dev-dependencies] -criterion.workspace = true gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true pretty_assertions.workspace = true @@ -91,6 +90,3 @@ project = { workspace = true, features = ["test-support"] } tempfile.workspace = true workspace = { workspace = true, features = ["test-support"] } -[[bench]] -name = "kept_rate" -harness = false diff --git a/crates/edit_prediction_cli/src/lib.rs b/crates/edit_prediction_cli/src/lib.rs index 920bd942675b460c1a292cda7024ad914ba8167c..c47a3e53f35bf9f33b608aa65be15f419238b711 100644 --- a/crates/edit_prediction_cli/src/lib.rs +++ b/crates/edit_prediction_cli/src/lib.rs @@ -1,4 +1,2 @@ #[allow(dead_code)] mod word_diff; - -pub mod kept_rate; diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 0f29d33947612d64b74f4fd847957ced5ad359a4..d144f998ff27b90e3009f82c367bf4699db4341e 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -5,7 +5,7 @@ mod filter_languages; mod format_prompt; mod git; mod headless; -mod kept_rate; + mod load_project; mod metrics; mod openai_client; diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index ffa26beea6eeb52a9dfdfe823ad474f9e63627a8..b28edbb7eb12929ee883eed29a9ef775e100281f 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1298,4 +1298,4 @@ index abc123..def456 100644 } } -pub use crate::kept_rate::compute_kept_rate; +pub use edit_prediction::metrics::compute_kept_rate; diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 1e54134efcab4f0074a73b241f8e0d04cfbcbcdd..698efbfeed8363d38aa79f5afd93ba00b42e80b4 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -878,6 +878,14 @@ impl EditPreview { }) } + pub fn result_text_snapshot(&self) -> &text::BufferSnapshot { + &self.applied_edits_snapshot + } + + pub fn result_syntax_snapshot(&self) -> &SyntaxSnapshot { + &self.syntax_snapshot + } + pub fn anchor_to_offset_in_result(&self, anchor: Anchor) -> usize { anchor .bias_right(&self.old_snapshot)