Track additional metrics in settled (#52938)

Ben Kunkle created

Stacked on https://github.com/zed-industries/zed/pull/50566.

Begin collecting kept chars rate, as well as the count of tree-sitter
errors in the code before and after applying the prediction.

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [ ] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A or Added/Fixed/Improved ...

Change summary

Cargo.lock                                          |   4 
crates/edit_prediction/Cargo.toml                   |  14 
crates/edit_prediction/benches/kept_rate.rs         |   2 
crates/edit_prediction/benches/ts_error_count.rs    | 454 +++++++++++++++
crates/edit_prediction/src/edit_prediction.rs       | 186 ++++-
crates/edit_prediction/src/edit_prediction_tests.rs |  11 
crates/edit_prediction/src/metrics.rs               |  10 
crates/edit_prediction/src/metrics/kept_rate.rs     | 155 +++-
crates/edit_prediction/src/metrics/tokenize.rs      |  54 +
crates/edit_prediction/src/metrics/tree_sitter.rs   |  88 ++
crates/edit_prediction/src/zeta.rs                  |  38 
crates/edit_prediction_cli/Cargo.toml               |   4 
crates/edit_prediction_cli/src/lib.rs               |   2 
crates/edit_prediction_cli/src/main.rs              |   2 
crates/edit_prediction_cli/src/metrics.rs           |   2 
crates/language/src/buffer.rs                       |   8 
16 files changed, 905 insertions(+), 129 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5176,6 +5176,7 @@ dependencies = [
  "copilot",
  "copilot_ui",
  "credentials_provider",
+ "criterion",
  "ctor",
  "db",
  "edit_prediction_context",
@@ -5189,9 +5190,11 @@ dependencies = [
  "itertools 0.14.0",
  "language",
  "language_model",
+ "languages",
  "log",
  "lsp",
  "menu",
+ "node_runtime",
  "open_ai",
  "parking_lot",
  "postage",
@@ -5235,7 +5238,6 @@ dependencies = [
  "client",
  "cloud_llm_client",
  "collections",
- "criterion",
  "db",
  "debug_adapter_extension",
  "dirs 4.0.0",

crates/edit_prediction/Cargo.toml 🔗

@@ -72,10 +72,14 @@ zeta_prompt.workspace = true
 zstd.workspace = true
 
 [dev-dependencies]
+criterion.workspace = true
+fs = { workspace = true, features = ["test-support"] }
+gpui = { workspace = true, features = ["test-support"] }
+languages = { workspace = true, features = ["load-grammars"] }
+node_runtime.workspace = true
 clock = { workspace = true, features = ["test-support"] }
 cloud_llm_client = { workspace = true, features = ["test-support"] }
 ctor.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
 indoc.workspace = true
 language = { workspace = true, features = ["test-support"] }
 language_model = { workspace = true, features = ["test-support"] }
@@ -86,3 +90,11 @@ settings = { workspace = true, features = ["test-support"] }
 workspace = { workspace = true, features = ["test-support"] }
 
 zlog.workspace = true
+
+[[bench]]
+name = "kept_rate"
+harness = false
+
+[[bench]]
+name = "ts_error_count"
+harness = false

crates/edit_prediction_cli/benches/kept_rate.rs → crates/edit_prediction/benches/kept_rate.rs 🔗

@@ -1,5 +1,5 @@
 use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
-use edit_prediction_cli::kept_rate::compute_kept_rate;
+use edit_prediction::metrics::compute_kept_rate;
 
 fn repeated_function_lines(line_count: usize) -> String {
     let mut text = String::with_capacity(line_count * 32);

crates/edit_prediction/benches/ts_error_count.rs 🔗

@@ -0,0 +1,454 @@
+use std::sync::Arc;
+
+use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
+use edit_prediction::metrics::count_tree_sitter_errors;
+use fs::FakeFs;
+use gpui::{AppContext as _, TestAppContext};
+use language::{Buffer, BufferSnapshot, LanguageRegistry};
+use languages::init as init_languages;
+use node_runtime::NodeRuntime;
+use settings::SettingsStore;
+
+struct ParsedCase {
+    label: String,
+    bytes: usize,
+    error_count: usize,
+    snapshot: BufferSnapshot,
+}
+
+fn replace_nth_occurrences(
+    source: &mut String,
+    needle: &str,
+    replacement: &str,
+    every: usize,
+    max_replacements: usize,
+) {
+    let mut rebuilt = String::with_capacity(source.len());
+    let mut cursor = 0;
+    let mut seen = 0;
+    let mut replaced = 0;
+
+    while let Some(relative_index) = source[cursor..].find(needle) {
+        let start = cursor + relative_index;
+        let end = start + needle.len();
+        rebuilt.push_str(&source[cursor..start]);
+
+        if seen % every == 0 && replaced < max_replacements {
+            rebuilt.push_str(replacement);
+            replaced += 1;
+        } else {
+            rebuilt.push_str(needle);
+        }
+
+        seen += 1;
+        cursor = end;
+    }
+
+    rebuilt.push_str(&source[cursor..]);
+    *source = rebuilt;
+}
+
+fn rust_source(function_count: usize) -> String {
+    let mut source = String::from(
+        "pub struct Counter {\n    value: usize,\n}\n\nimpl Counter {\n    pub fn new() -> Self {\n        Self { value: 0 }\n    }\n}\n\n",
+    );
+    for index in 0..function_count {
+        source.push_str(&format!(
+            "pub fn compute_value_{index}(input: usize) -> usize {{\n    let mut total = input;\n    for offset in 0..32 {{\n        total += offset + {index};\n    }}\n    if total % 2 == 0 {{\n        total / 2\n    }} else {{\n        total * 3 + 1\n    }}\n}}\n\n"
+        ));
+    }
+    source
+}
+
+fn rust_source_with_errors(function_count: usize) -> String {
+    let mut source = rust_source(function_count);
+    replace_nth_occurrences(
+        &mut source,
+        "    if total % 2 == 0 {\n",
+        "    if total % 2 == 0 \n",
+        17,
+        48,
+    );
+    source
+}
+
+fn python_source(function_count: usize) -> String {
+    let mut source = String::from(
+        "class Counter:\n    def __init__(self) -> None:\n        self.value = 0\n\n\n",
+    );
+    for index in 0..function_count {
+        source.push_str(&format!(
+            "def compute_value_{index}(input_value: int) -> int:\n    total = input_value\n    for offset in range(32):\n        total += offset + {index}\n    if total % 2 == 0:\n        return total // 2\n    return total * 3 + 1\n\n"
+        ));
+    }
+    source
+}
+
+fn python_source_with_errors(function_count: usize) -> String {
+    let mut source = python_source(function_count);
+    replace_nth_occurrences(
+        &mut source,
+        "    if total % 2 == 0:\n",
+        "    if total % 2 == 0\n",
+        19,
+        48,
+    );
+    source
+}
+
+fn go_source(function_count: usize) -> String {
+    let mut source = String::from(
+        "package bench\n\ntype Counter struct {\n\tvalue int\n}\n\nfunc NewCounter() Counter {\n\treturn Counter{value: 0}\n}\n\n",
+    );
+    for index in 0..function_count {
+        source.push_str(&format!(
+            "func ComputeValue{index}(inputValue int) int {{\n\ttotal := inputValue\n\tfor offset := 0; offset < 32; offset++ {{\n\t\ttotal += offset + {index}\n\t}}\n\tif total%2 == 0 {{\n\t\treturn total / 2\n\t}}\n\treturn total*3 + 1\n}}\n\n"
+        ));
+    }
+    source
+}
+
+fn go_source_with_errors(function_count: usize) -> String {
+    let mut source = go_source(function_count);
+    replace_nth_occurrences(
+        &mut source,
+        "\tfor offset := 0; offset < 32; offset++ {\n",
+        "\tfor offset := 0; offset < 32; offset++ \n",
+        17,
+        48,
+    );
+    source
+}
+
+fn typescript_source(function_count: usize) -> String {
+    let mut source = String::from(
+        "export type Counter = { value: number };\n\nexport function newCounter(): Counter {\n  return { value: 0 };\n}\n\n",
+    );
+    for index in 0..function_count {
+        source.push_str(&format!(
+            "export function computeValue{index}(inputValue: number): number {{\n  let total = inputValue;\n  for (let offset = 0; offset < 32; offset += 1) {{\n    total += offset + {index};\n  }}\n  return total % 2 === 0 ? total / 2 : total * 3 + 1;\n}}\n\n"
+        ));
+    }
+    source
+}
+
+fn typescript_source_with_errors(function_count: usize) -> String {
+    let mut source = typescript_source(function_count);
+    replace_nth_occurrences(
+        &mut source,
+        "  return total % 2 === 0 ? total / 2 : total * 3 + 1;\n",
+        "  return total % 2 === 0 ? total / 2 : ;\n",
+        17,
+        64,
+    );
+    source
+}
+
+fn tsx_source(component_count: usize) -> String {
+    let mut source = String::from(
+        "type ItemProps = { index: number; label: string };\n\nfunction Item({ index, label }: ItemProps) {\n  return <li data-index={index}>{label}</li>;\n}\n\nexport function App() {\n  return <section><ul>{[0, 1, 2].map((value) => <Item key={value} index={value} label={`item-${value}`} />)}</ul></section>;\n}\n\n",
+    );
+    for index in 0..component_count {
+        source.push_str(&format!(
+            "export function Widget{index}(): JSX.Element {{\n  const items = Array.from({{ length: 16 }}, (_, value) => value + {index});\n  return (\n    <div className=\"widget-{index}\">\n      <h2>Widget {index}</h2>\n      <ul>\n        {{items.map((value) => (\n          <Item key={{value}} index={{value}} label={{`widget-{index}-${{value}}`}} />\n        ))}}\n      </ul>\n    </div>\n  );\n}}\n\n"
+        ));
+    }
+    source
+}
+
+fn tsx_source_with_errors(component_count: usize) -> String {
+    let mut source = tsx_source(component_count);
+    replace_nth_occurrences(
+        &mut source,
+        "  const items = Array.from({ length: 16 }, (_, value) => value + ",
+        "  const items = Array.from({ length: 16 }, (_, value) => ); // ",
+        11,
+        32,
+    );
+    source
+}
+
+fn json_source(object_count: usize) -> String {
+    let mut source = String::from("{\n  \"items\": [\n");
+    for index in 0..object_count {
+        let suffix = if index + 1 == object_count { "" } else { "," };
+        source.push_str(&format!(
+            "    {{\n      \"id\": {index},\n      \"name\": \"item-{index}\",\n      \"enabled\": true,\n      \"tags\": [\"alpha\", \"beta\", \"gamma\"],\n      \"metrics\": {{ \"count\": {}, \"ratio\": {} }}\n    }}{suffix}\n",
+            index * 3 + 1,
+            index as f64 / 10.0,
+        ));
+    }
+    source.push_str("  ]\n}\n");
+    source
+}
+
+fn json_source_with_errors(object_count: usize) -> String {
+    let mut source = json_source(object_count);
+    replace_nth_occurrences(
+        &mut source,
+        "      \"enabled\": true,\n",
+        "      \"enabled\": ,\n",
+        23,
+        64,
+    );
+    source
+}
+
+fn yaml_source(document_count: usize) -> String {
+    let mut source = String::new();
+    for index in 0..document_count {
+        source.push_str(&format!(
+            "- id: {index}\n  name: item-{index}\n  enabled: true\n  tags:\n    - alpha\n    - beta\n    - gamma\n  metrics:\n    count: {}\n    ratio: {}\n",
+            index * 3 + 1,
+            index as f64 / 10.0,
+        ));
+    }
+    source
+}
+
+fn yaml_source_with_errors(document_count: usize) -> String {
+    let mut source = yaml_source(document_count);
+    replace_nth_occurrences(&mut source, "    count: ", "    count ", 23, 64);
+    source
+}
+
+fn css_source(rule_count: usize) -> String {
+    let mut source = String::new();
+    for index in 0..rule_count {
+        source.push_str(&format!(
+            ".widget-{index} {{\n  display: grid;\n  grid-template-columns: repeat(4, minmax(0, 1fr));\n  gap: 12px;\n  padding: 8px;\n  color: rgb({}, {}, {});\n}}\n\n.widget-{index} > .item-{index} {{\n  border: 1px solid rgba(0, 0, 0, 0.15);\n  background: linear-gradient(90deg, #fff, #eef);\n}}\n\n",
+            (index * 17) % 255,
+            (index * 31) % 255,
+            (index * 47) % 255,
+        ));
+    }
+    source
+}
+
+fn css_source_with_errors(rule_count: usize) -> String {
+    let mut source = css_source(rule_count);
+    replace_nth_occurrences(&mut source, "  gap: 12px;\n", "  gap 12px;\n", 29, 64);
+    source
+}
+
+fn build_case(
+    context: &mut TestAppContext,
+    languages: &Arc<LanguageRegistry>,
+    language_name: &'static str,
+    variant_name: &'static str,
+    source: String,
+    expect_errors: bool,
+) -> ParsedCase {
+    let language_task = context.background_spawn({
+        let languages = languages.clone();
+        async move { languages.language_for_name(language_name).await }
+    });
+    while !language_task.is_ready() {
+        context.run_until_parked();
+    }
+    let language = futures::executor::block_on(language_task)
+        .unwrap_or_else(|error| panic!("failed to load {language_name}: {error}"));
+
+    let buffer = context.new(|cx| Buffer::local(source, cx).with_language(language, cx));
+    context.run_until_parked();
+    while buffer.read_with(context, |buffer, _| buffer.is_parsing()) {
+        context.run_until_parked();
+    }
+
+    let snapshot = buffer.read_with(context, |buffer, _| buffer.snapshot());
+    let full_range = 0..snapshot.text.len();
+    let error_count = count_tree_sitter_errors(snapshot.syntax_layers());
+    if expect_errors {
+        assert!(
+            error_count > 0,
+            "expected tree-sitter errors for {language_name}/{variant_name}",
+        );
+    } else {
+        assert_eq!(
+            error_count, 0,
+            "expected no tree-sitter errors for {language_name}/{variant_name}",
+        );
+    }
+
+    let label = format!(
+        "{}/{}_{}kb_{}e",
+        language_name.to_lowercase(),
+        variant_name,
+        full_range.end / 1024,
+        error_count,
+    );
+    ParsedCase {
+        label,
+        bytes: full_range.end,
+        error_count,
+        snapshot,
+    }
+}
+
+fn parsed_cases() -> Vec<ParsedCase> {
+    let mut context = TestAppContext::single();
+    context.update(|cx| {
+        let settings_store = SettingsStore::test(cx);
+        cx.set_global(settings_store);
+    });
+
+    let languages = Arc::new(LanguageRegistry::new(context.executor()));
+    let fs = FakeFs::new(context.executor());
+    let node_runtime = NodeRuntime::unavailable();
+    context.update(|cx| init_languages(languages.clone(), fs, node_runtime, cx));
+
+    vec![
+        build_case(
+            &mut context,
+            &languages,
+            "Rust",
+            "valid",
+            rust_source(900),
+            false,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "Rust",
+            "error_heavy",
+            rust_source_with_errors(900),
+            true,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "Python",
+            "valid",
+            python_source(1100),
+            false,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "Python",
+            "error_heavy",
+            python_source_with_errors(1100),
+            true,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "Go",
+            "valid",
+            go_source(1000),
+            false,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "Go",
+            "error_heavy",
+            go_source_with_errors(1000),
+            true,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "TypeScript",
+            "valid",
+            typescript_source(1000),
+            false,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "TypeScript",
+            "error_heavy",
+            typescript_source_with_errors(1000),
+            true,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "TSX",
+            "valid",
+            tsx_source(350),
+            false,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "TSX",
+            "error_heavy",
+            tsx_source_with_errors(350),
+            true,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "JSON",
+            "valid",
+            json_source(2200),
+            false,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "JSON",
+            "error_heavy",
+            json_source_with_errors(2200),
+            true,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "YAML",
+            "valid",
+            yaml_source(2200),
+            false,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "YAML",
+            "error_heavy",
+            yaml_source_with_errors(2200),
+            true,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "CSS",
+            "valid",
+            css_source(2400),
+            false,
+        ),
+        build_case(
+            &mut context,
+            &languages,
+            "CSS",
+            "error_heavy",
+            css_source_with_errors(2400),
+            true,
+        ),
+    ]
+}
+
+fn ts_error_count_benchmark(c: &mut Criterion) {
+    let cases = parsed_cases();
+    let mut group = c.benchmark_group("ts_error_count/full_file");
+
+    for case in &cases {
+        group.bench_with_input(
+            BenchmarkId::from_parameter(&case.label),
+            case,
+            |bench, case| {
+                bench.iter(|| {
+                    black_box(case.bytes);
+                    black_box(case.error_count);
+                    black_box(count_tree_sitter_errors(case.snapshot.syntax_layers()))
+                });
+            },
+        );
+    }
+
+    group.finish();
+}
+
+criterion_group!(benches, ts_error_count_benchmark);
+criterion_main!(benches);

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -30,7 +30,7 @@ use gpui::{
 };
 use heapless::Vec as ArrayVec;
 use language::language_settings::all_language_settings;
-use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
+use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
 use language::{BufferSnapshot, OffsetRangeExt};
 use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
 use release_channel::AppVersion;
@@ -61,6 +61,7 @@ pub mod example_spec;
 pub mod fim;
 mod license_detection;
 pub mod mercury;
+pub mod metrics;
 pub mod ollama;
 mod onboarding_modal;
 pub mod open_ai_response;
@@ -80,6 +81,7 @@ use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
 use crate::example_spec::ExampleSpec;
 use crate::license_detection::LicenseDetectionWatcher;
 use crate::mercury::Mercury;
+pub use crate::metrics::{KeptRateResult, compute_kept_rate};
 use crate::onboarding_modal::ZedPredictModal;
 pub use crate::prediction::EditPrediction;
 pub use crate::prediction::EditPredictionId;
@@ -478,10 +480,13 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
 }
 
 #[derive(Clone)]
-
 struct PendingSettledPrediction {
     request_id: EditPredictionId,
     editable_anchor_range: Range<Anchor>,
+    editable_region_before_prediction: String,
+    predicted_editable_region: String,
+    ts_error_count_before_prediction: usize,
+    ts_error_count_after_prediction: usize,
     example: Option<ExampleSpec>,
     enqueued_at: Instant,
     last_edit_at: Instant,
@@ -1603,63 +1608,100 @@ impl EditPredictionStore {
             };
 
             let now = cx.background_executor().now();
-
             let mut oldest_edited_at = None;
+            let mut ready_predictions = Vec::new();
 
             this.update(cx, |this, _| {
                 for (_, project_state) in this.projects.iter_mut() {
                     for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
-                        registered_buffer
-                            .pending_predictions
-                            .retain_mut(|pending_prediction| {
-                                let age =
-                                    now.saturating_duration_since(pending_prediction.enqueued_at);
-                                if age >= EDIT_PREDICTION_SETTLED_TTL {
-                                    return false;
-                                }
-
-                                let quiet_for =
-                                    now.saturating_duration_since(pending_prediction.last_edit_at);
-                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
-                                    let settled_editable_region = registered_buffer
-                                        .snapshot
-                                        .text_for_range(
-                                            pending_prediction.editable_anchor_range.clone(),
-                                        )
-                                        .collect::<String>();
-
-                                    #[cfg(test)]
-                                    if let Some(callback) = &this.settled_event_callback {
-                                        callback(
-                                            pending_prediction.request_id.clone(),
-                                            settled_editable_region.clone(),
-                                        );
-                                    }
-
-                                    telemetry::event!(
-                                        EDIT_PREDICTION_SETTLED_EVENT,
-                                        request_id = pending_prediction.request_id.0.clone(),
-                                        settled_editable_region,
-                                        example = pending_prediction.example.take(),
-                                        e2e_latency = pending_prediction.e2e_latency.as_millis(),
-                                    );
-
-                                    return false;
-                                }
+                        let mut pending_index = 0;
+                        while pending_index < registered_buffer.pending_predictions.len() {
+                            let pending_prediction =
+                                &registered_buffer.pending_predictions[pending_index];
+                            let age = now.saturating_duration_since(pending_prediction.enqueued_at);
+                            if age >= EDIT_PREDICTION_SETTLED_TTL {
+                                registered_buffer.pending_predictions.remove(pending_index);
+                                continue;
+                            }
 
-                                if oldest_edited_at
-                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
-                                {
-                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
-                                }
+                            let quiet_for =
+                                now.saturating_duration_since(pending_prediction.last_edit_at);
+                            if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
+                                let pending_prediction =
+                                    registered_buffer.pending_predictions.remove(pending_index);
+                                let settled_editable_region = registered_buffer
+                                    .snapshot
+                                    .text_for_range(
+                                        pending_prediction.editable_anchor_range.clone(),
+                                    )
+                                    .collect::<String>();
+                                ready_predictions
+                                    .push((pending_prediction, settled_editable_region));
+                                continue;
+                            }
 
-                                true
-                            });
+                            if oldest_edited_at
+                                .is_none_or(|time| pending_prediction.last_edit_at < time)
+                            {
+                                oldest_edited_at = Some(pending_prediction.last_edit_at);
+                            }
+                            pending_index += 1;
+                        }
                     }
                 }
             });
 
-            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
+            for (pending_prediction, settled_editable_region) in ready_predictions {
+                let PendingSettledPrediction {
+                    request_id,
+                    editable_region_before_prediction,
+                    predicted_editable_region,
+                    ts_error_count_before_prediction,
+                    ts_error_count_after_prediction,
+                    example,
+                    e2e_latency,
+                    ..
+                } = pending_prediction;
+                let settled_editable_region_for_metrics = settled_editable_region.clone();
+                let kept_rate_result = cx
+                    .background_spawn(async move {
+                        compute_kept_rate(
+                            &editable_region_before_prediction,
+                            &predicted_editable_region,
+                            &settled_editable_region_for_metrics,
+                        )
+                    })
+                    .await;
+
+                #[cfg(test)]
+                {
+                    let request_id = request_id.clone();
+                    let settled_editable_region = settled_editable_region.clone();
+                    this.update(cx, |this, _| {
+                        if let Some(callback) = &this.settled_event_callback {
+                            callback(request_id, settled_editable_region);
+                        }
+                    });
+                }
+
+                telemetry::event!(
+                    EDIT_PREDICTION_SETTLED_EVENT,
+                    request_id = request_id.0.clone(),
+                    settled_editable_region,
+                    ts_error_count_before_prediction,
+                    ts_error_count_after_prediction,
+                    edit_bytes_predicted_new = kept_rate_result.predicted_new_chars,
+                    edit_bytes_final_new = kept_rate_result.final_new_chars,
+                    edit_bytes_kept = kept_rate_result.kept_chars,
+                    edit_bytes_discarded = kept_rate_result.discarded_chars,
+                    edit_bytes_context = kept_rate_result.context_chars,
+                    edit_bytes_kept_rate = kept_rate_result.kept_rate,
+                    example,
+                    e2e_latency = e2e_latency.as_millis(),
+                );
+            }
+
+            next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
         }
     }
 
@@ -1670,28 +1712,58 @@ impl EditPredictionStore {
         edited_buffer: &Entity<Buffer>,
         edited_buffer_snapshot: &BufferSnapshot,
         editable_offset_range: Range<usize>,
+        edit_preview: &EditPreview,
         example: Option<ExampleSpec>,
         e2e_latency: std::time::Duration,
         cx: &mut Context<Self>,
     ) {
         let this = &mut *self;
         let project_state = this.get_or_init_project(project, cx);
-        if let Some(buffer) = project_state
+        let Some(registered_buffer) = project_state
             .registered_buffers
             .get_mut(&edited_buffer.entity_id())
-        {
-            let now = cx.background_executor().now();
-            buffer.pending_predictions.push(PendingSettledPrediction {
-                request_id: request_id,
-                editable_anchor_range: edited_buffer_snapshot
-                    .anchor_range_inside(editable_offset_range),
+        else {
+            return;
+        };
+
+        let editable_region_before_prediction = edited_buffer_snapshot
+            .text_for_range(editable_offset_range.clone())
+            .collect::<String>();
+        let editable_anchor_range_for_result =
+            edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
+        let predicted_editable_region = edit_preview
+            .result_text_snapshot()
+            .text_for_range(editable_anchor_range_for_result.clone())
+            .collect();
+        let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
+            edited_buffer_snapshot
+                .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
+        );
+        let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
+            edit_preview.result_syntax_snapshot().layers_for_range(
+                editable_anchor_range_for_result,
+                edit_preview.result_text_snapshot(),
+                true,
+            ),
+        );
+        let editable_anchor_range =
+            edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
+        let now = cx.background_executor().now();
+        registered_buffer
+            .pending_predictions
+            .push(PendingSettledPrediction {
+                request_id,
+                editable_anchor_range,
+                editable_region_before_prediction,
+                predicted_editable_region,
+                ts_error_count_before_prediction,
+                ts_error_count_after_prediction,
                 example,
                 e2e_latency,
                 enqueued_at: now,
                 last_edit_at: now,
             });
-            this.settled_predictions_tx.unbounded_send(now).ok();
-        }
+        this.settled_predictions_tx.unbounded_send(now).ok();
     }
 
     fn reject_current_prediction(

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -3252,6 +3252,12 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
     cx.run_until_parked();
 
     let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+    let empty_edits: Arc<[(Range<Anchor>, Arc<str>)]> = Vec::new().into();
+    let edit_preview_a = buffer
+        .read_with(cx, |buffer, cx| {
+            buffer.preview_edits(empty_edits.clone(), cx)
+        })
+        .await;
 
     // Region A: first 10 lines of the buffer.
     let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
@@ -3263,6 +3269,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
             &buffer,
             &snapshot_a,
             editable_region_a.clone(),
+            &edit_preview_a,
             None,
             Duration::from_secs(0),
             cx,
@@ -3318,6 +3325,9 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
     cx.run_until_parked();
 
     let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+    let edit_preview_b = buffer
+        .read_with(cx, |buffer, cx| buffer.preview_edits(empty_edits, cx))
+        .await;
     let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
 
     ep_store.update(cx, |ep_store, cx| {
@@ -3327,6 +3337,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
             &buffer,
             &snapshot_b2,
             editable_region_b.clone(),
+            &edit_preview_b,
             None,
             Duration::from_secs(0),
             cx,

crates/edit_prediction/src/metrics.rs 🔗

@@ -0,0 +1,10 @@
+mod kept_rate;
+mod tokenize;
+mod tree_sitter;
+
+pub use kept_rate::KeptRateResult;
+#[cfg(test)]
+pub use kept_rate::TokenAnnotation;
+pub use kept_rate::compute_kept_rate;
+pub(crate) use tokenize::tokenize;
+pub use tree_sitter::count_tree_sitter_errors;

crates/edit_prediction_cli/src/kept_rate.rs → crates/edit_prediction/src/metrics/kept_rate.rs 🔗

@@ -1,4 +1,6 @@
-use crate::word_diff::tokenize;
+use crate::metrics::tokenize;
+
+const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
 
 #[cfg(test)]
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -25,20 +27,28 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize {
     row * width + column
 }
 
-/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
-/// side while sharing a single DP table construction.
-fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
+/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
+/// while sharing a single DP table construction.
+fn fill_lcs_keep_masks(
+    a: &[&str],
+    b: &[&str],
+    mut keep_a: Option<&mut [bool]>,
+    mut keep_b: Option<&mut [bool]>,
+) {
     if a.is_empty() || b.is_empty() {
-        return (vec![false; a.len()], vec![false; b.len()]);
+        return;
     }
 
     if a == b {
-        return (vec![true; a.len()], vec![true; b.len()]);
+        if let Some(keep_a) = keep_a.as_mut() {
+            keep_a.fill(true);
+        }
+        if let Some(keep_b) = keep_b.as_mut() {
+            keep_b.fill(true);
+        }
+        return;
     }
 
-    let mut keep_a = vec![false; a.len()];
-    let mut keep_b = vec![false; b.len()];
-
     let prefix_len = a
         .iter()
         .zip(b.iter())
@@ -61,22 +71,30 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
     };
 
     for index in 0..prefix_len {
-        keep_a[index] = true;
-        keep_b[index] = true;
+        if let Some(keep_a) = keep_a.as_mut() {
+            keep_a[index] = true;
+        }
+        if let Some(keep_b) = keep_b.as_mut() {
+            keep_b[index] = true;
+        }
     }
 
     for offset in 0..suffix_len {
         let a_index = a.len() - suffix_len + offset;
         let b_index = b.len() - suffix_len + offset;
-        keep_a[a_index] = true;
-        keep_b[b_index] = true;
+        if let Some(keep_a) = keep_a.as_mut() {
+            keep_a[a_index] = true;
+        }
+        if let Some(keep_b) = keep_b.as_mut() {
+            keep_b[b_index] = true;
+        }
     }
 
     let a_mid = &a[prefix_len..a.len() - suffix_len];
     let b_mid = &b[prefix_len..b.len() - suffix_len];
 
     if a_mid.is_empty() || b_mid.is_empty() {
-        return (keep_a, keep_b);
+        return;
     }
 
     let row_count = a_mid.len() + 1;
@@ -97,44 +115,59 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
         }
     }
 
-    let mut i = a_mid.len();
-    let mut j = b_mid.len();
+    if let Some(keep_a) = keep_a.as_mut() {
+        let mut i = a_mid.len();
+        let mut j = b_mid.len();
 
-    while i > 0 && j > 0 {
-        if a_mid[i - 1] == b_mid[j - 1] {
-            keep_a[prefix_len + i - 1] = true;
-            i -= 1;
-            j -= 1;
-        } else {
-            let up = dp[dp_index(column_count, i - 1, j)];
-            let left = dp[dp_index(column_count, i, j - 1)];
-            if up >= left {
+        while i > 0 && j > 0 {
+            if a_mid[i - 1] == b_mid[j - 1] {
+                keep_a[prefix_len + i - 1] = true;
                 i -= 1;
-            } else {
                 j -= 1;
+            } else {
+                let up = dp[dp_index(column_count, i - 1, j)];
+                let left = dp[dp_index(column_count, i, j - 1)];
+                if up >= left {
+                    i -= 1;
+                } else {
+                    j -= 1;
+                }
             }
         }
     }
 
-    let mut i = a_mid.len();
-    let mut j = b_mid.len();
+    if let Some(keep_b) = keep_b.as_mut() {
+        let mut i = a_mid.len();
+        let mut j = b_mid.len();
 
-    while i > 0 && j > 0 {
-        if a_mid[i - 1] == b_mid[j - 1] {
-            keep_b[prefix_len + j - 1] = true;
-            i -= 1;
-            j -= 1;
-        } else {
-            let up = dp[dp_index(column_count, i - 1, j)];
-            let left = dp[dp_index(column_count, i, j - 1)];
-            if left >= up {
+        while i > 0 && j > 0 {
+            if a_mid[i - 1] == b_mid[j - 1] {
+                keep_b[prefix_len + j - 1] = true;
+                i -= 1;
                 j -= 1;
             } else {
-                i -= 1;
+                let up = dp[dp_index(column_count, i - 1, j)];
+                let left = dp[dp_index(column_count, i, j - 1)];
+                if left >= up {
+                    j -= 1;
+                } else {
+                    i -= 1;
+                }
             }
         }
     }
+}
+
+fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
+    let mut keep_a = vec![false; a.len()];
+    fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
+    keep_a
+}
 
+fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
+    let mut keep_a = vec![false; a.len()];
+    let mut keep_b = vec![false; b.len()];
+    fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
     (keep_a, keep_b)
 }
 
@@ -155,6 +188,12 @@ fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>
     (unmasked_tokens, unmasked_chars, masked_chars)
 }
 
+fn should_bail_for_dirty_final(base: &str, predicted: &str, final_text: &str) -> bool {
+    let predicted_delta_chars = predicted.len().abs_diff(base.len());
+    let final_delta_chars = final_text.len().abs_diff(base.len());
+    predicted_delta_chars.abs_diff(final_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
+}
+
 pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
     if base == predicted && predicted == final_text {
         let predicted_tokens = tokenize(predicted);
@@ -171,11 +210,26 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
         };
     }
 
+    if should_bail_for_dirty_final(base, predicted, final_text) {
+        let predicted_new_chars = predicted.len().abs_diff(base.len());
+        let final_new_chars = final_text.len().abs_diff(base.len());
+        return KeptRateResult {
+            predicted_new_chars,
+            final_new_chars,
+            kept_chars: 0,
+            discarded_chars: predicted_new_chars,
+            context_chars: 0,
+            kept_rate: 0.0,
+            #[cfg(test)]
+            token_annotations: vec![TokenAnnotation::Discarded; tokenize(predicted).len()],
+        };
+    }
+
     let base_tokens = tokenize(base);
     let predicted_tokens = tokenize(predicted);
     let final_tokens = tokenize(final_text);
 
-    let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens);
+    let pred_base_mask = lcs_keep_mask(&predicted_tokens, &base_tokens);
     let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
     let context_mask: Vec<bool> = pred_base_mask
         .iter()
@@ -186,7 +240,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
     let (stripped_predicted, predicted_new_chars, context_chars) =
         analyze_masked_tokens(&predicted_tokens, &context_mask);
 
-    let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens);
+    let final_base_mask = lcs_keep_mask(&final_tokens, &base_tokens);
     let final_context_mask: Vec<bool> = final_base_mask
         .iter()
         .zip(final_pred_mask.iter())
@@ -196,7 +250,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
     let (stripped_final, final_new_chars, _) =
         analyze_masked_tokens(&final_tokens, &final_context_mask);
 
-    let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0;
+    let keep_mask = lcs_keep_mask(&stripped_predicted, &stripped_final);
 
     let kept_chars: usize = stripped_predicted
         .iter()
@@ -265,8 +319,8 @@ mod test_kept_rate {
         let a = ["x", "a", "x", "b"];
         let b = ["a", "x", "b", "x"];
         let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
-        assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
-        assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
+        assert_eq!(a_mask, lcs_keep_mask(&a, &b));
+        assert_eq!(b_mask, lcs_keep_mask(&b, &a));
     }
 
     #[test]
@@ -342,6 +396,21 @@ mod test_kept_rate {
         assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
     }
 
+    #[test]
+    fn test_bails_for_dirty_final() {
+        let base = "fn example() {\n    work();\n}\n";
+        let predicted = "fn example() {\n    work();\n    predicted();\n}\n";
+        let final_text = format!(
+            "fn example() {{\n    work();\n    {}\n}}\n",
+            "settled();\n    ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
+        );
+
+        let result = compute_kept_rate(base, predicted, &final_text);
+        assert_eq!(result.kept_rate, 0.0);
+        assert_eq!(result.kept_chars, 0);
+        assert_eq!(result.discarded_chars, result.predicted_new_chars);
+    }
+
     #[test]
     fn test_eprintln_token_alignment() {
         let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";

crates/edit_prediction/src/metrics/tokenize.rs 🔗

@@ -0,0 +1,54 @@
+fn char_class(character: char) -> u8 {
+    if character.is_alphanumeric() || character == '_' {
+        0
+    } else if character.is_whitespace() {
+        1
+    } else {
+        2
+    }
+}
+
+pub(crate) fn tokenize(text: &str) -> Vec<&str> {
+    let mut tokens = Vec::new();
+    let mut characters = text.char_indices().peekable();
+
+    while let Some((start, character)) = characters.next() {
+        let class = char_class(character);
+        if class == 2 {
+            tokens.push(&text[start..start + character.len_utf8()]);
+            continue;
+        }
+
+        let mut end = start + character.len_utf8();
+        while let Some(&(_, next_character)) = characters.peek() {
+            if char_class(next_character) != class {
+                break;
+            }
+            end += next_character.len_utf8();
+            characters.next();
+        }
+        tokens.push(&text[start..end]);
+    }
+
+    tokens
+}
+
+#[cfg(test)]
+mod tests {
+    use super::tokenize;
+
+    #[test]
+    fn tokenizes_code_like_text() {
+        assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]);
+        assert_eq!(
+            tokenize("foo_bar123 + baz"),
+            vec!["foo_bar123", " ", "+", " ", "baz"]
+        );
+        assert_eq!(
+            tokenize("print(\"hello\")"),
+            vec!["print", "(", "\"", "hello", "\"", ")"]
+        );
+        assert_eq!(tokenize("hello_world"), vec!["hello_world"]);
+        assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]);
+    }
+}

crates/edit_prediction/src/metrics/tree_sitter.rs 🔗

@@ -0,0 +1,88 @@
+use language::SyntaxLayer;
+
+pub fn count_tree_sitter_errors<'a>(layers: impl Iterator<Item = SyntaxLayer<'a>>) -> usize {
+    let mut total_count: usize = 0;
+    for layer in layers {
+        let node = layer.node();
+        let mut cursor = node.walk();
+        'layer: loop {
+            let current = cursor.node();
+            if current.is_error() || current.is_missing() {
+                total_count += 1;
+            }
+            if current.has_error() && cursor.goto_first_child() {
+                continue;
+            }
+            if cursor.goto_next_sibling() {
+                continue;
+            }
+            loop {
+                if !cursor.goto_parent() {
+                    break 'layer;
+                }
+                if cursor.goto_next_sibling() {
+                    continue;
+                }
+            }
+        }
+    }
+    total_count
+}
+
+#[cfg(test)]
+mod tests {
+    use std::ops::Range;
+
+    use super::count_tree_sitter_errors;
+    use gpui::{AppContext as _, TestAppContext};
+    use language::{Buffer, BufferSnapshot, rust_lang};
+
+    fn error_count_in_range(edited_buffer_snapshot: &BufferSnapshot, range: Range<usize>) -> usize {
+        let layers = edited_buffer_snapshot.syntax_layers_for_range(range, true);
+        count_tree_sitter_errors(layers)
+    }
+
+    fn rust_snapshot(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
+        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
+        while buffer.read_with(cx, |buffer, _| buffer.is_parsing()) {
+            cx.run_until_parked();
+        }
+        buffer.read_with(cx, |buffer, _| buffer.snapshot())
+    }
+
+    #[gpui::test]
+    async fn counts_no_errors_for_valid_rust(cx: &mut TestAppContext) {
+        let text = "fn helper(value: usize) -> usize {\n    value + 1\n}\n";
+        let snapshot = rust_snapshot(text, cx);
+
+        assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 0);
+    }
+
+    #[gpui::test]
+    async fn counts_errors_for_invalid_rust(cx: &mut TestAppContext) {
+        let text = "fn helper(value: usize) -> usize {\n    let total = ;\n    total\n}\n";
+        let snapshot = rust_snapshot(text, cx);
+
+        assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 1);
+    }
+
+    #[gpui::test]
+    async fn counts_no_errors_for_subrange_of_valid_rust(cx: &mut TestAppContext) {
+        let text = "fn first() -> usize {\n    let value = 1;\n    value + 1\n}\n";
+        let snapshot = rust_snapshot(text, cx);
+        let body_start = text.find("let value").unwrap();
+        let body_end = body_start + "let value = 1;".len();
+
+        assert_eq!(error_count_in_range(&snapshot, body_start..body_end), 0);
+    }
+
+    #[gpui::test]
+    async fn counts_errors_for_subrange_of_invalid_rust(cx: &mut TestAppContext) {
+        let text = "fn second() -> usize {\n    let broken = ;\n    broken\n}\n";
+        let snapshot = rust_snapshot(text, cx);
+        let error_start = text.find("let broken = ;").unwrap();
+        let error_end = error_start + "let broken = ;".len();
+
+        assert_eq!(error_count_in_range(&snapshot, error_start..error_end), 1);
+    }
+}

crates/edit_prediction/src/zeta.rs 🔗

@@ -85,7 +85,6 @@ pub fn request_prediction_with_zeta(
     } else {
         None
     };
-
     let client = store.client.clone();
     let llm_token = store.llm_token.clone();
     let organization_id = store
@@ -383,11 +382,26 @@ pub fn request_prediction_with_zeta(
             }));
         };
 
-        if can_collect_data {
+        let result = EditPredictionResult::new(
+            id,
+            &edited_buffer,
+            &edited_buffer_snapshot,
+            edits.into(),
+            cursor_position,
+            inputs,
+            model_version,
+            request_duration,
+            cx,
+        )
+        .await;
+
+        if can_collect_data && let Ok(prediction) = &result.prediction {
             let weak_this = this.clone();
-            let id = id.clone();
+            let request_id = prediction.id.clone();
             let edited_buffer = edited_buffer.clone();
             let edited_buffer_snapshot = edited_buffer_snapshot.clone();
+            let editable_range_in_buffer = editable_range_in_buffer.clone();
+            let edit_preview = prediction.edit_preview.clone();
             let example_task = capture_data.and_then(|stored_events| {
                 cx.update(|cx| {
                     crate::capture_example(
@@ -410,11 +424,12 @@ pub fn request_prediction_with_zeta(
                 weak_this
                     .update(cx, |this, cx| {
                         this.enqueue_settled_prediction(
-                            id.clone(),
+                            request_id.clone(),
                             &project,
                             &edited_buffer,
                             &edited_buffer_snapshot,
                             editable_range_in_buffer,
+                            &edit_preview,
                             example_spec,
                             request_duration,
                             cx,
@@ -425,20 +440,7 @@ pub fn request_prediction_with_zeta(
             .detach();
         }
 
-        Ok(Some(
-            EditPredictionResult::new(
-                id,
-                &edited_buffer,
-                &edited_buffer_snapshot,
-                edits.into(),
-                cursor_position,
-                inputs,
-                model_version,
-                request_duration,
-                cx,
-            )
-            .await,
-        ))
+        Ok(Some(result))
     })
 }
 

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -83,7 +83,6 @@ dynamic_prompts = []
 ignored = ["wasmtime"]
 
 [dev-dependencies]
-criterion.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
 indoc.workspace = true
 pretty_assertions.workspace = true
@@ -91,6 +90,3 @@ project = { workspace = true, features = ["test-support"] }
 tempfile.workspace = true
 workspace = { workspace = true, features = ["test-support"] }
 
-[[bench]]
-name = "kept_rate"
-harness = false

crates/language/src/buffer.rs 🔗

@@ -878,6 +878,14 @@ impl EditPreview {
         })
     }
 
+    pub fn result_text_snapshot(&self) -> &text::BufferSnapshot {
+        &self.applied_edits_snapshot
+    }
+
+    pub fn result_syntax_snapshot(&self) -> &SyntaxSnapshot {
+        &self.syntax_snapshot
+    }
+
     pub fn anchor_to_offset_in_result(&self, anchor: Anchor) -> usize {
         anchor
             .bias_right(&self.old_snapshot)