Detailed changes
@@ -5176,6 +5176,7 @@ dependencies = [
"copilot",
"copilot_ui",
"credentials_provider",
+ "criterion",
"ctor",
"db",
"edit_prediction_context",
@@ -5189,9 +5190,11 @@ dependencies = [
"itertools 0.14.0",
"language",
"language_model",
+ "languages",
"log",
"lsp",
"menu",
+ "node_runtime",
"open_ai",
"parking_lot",
"postage",
@@ -5235,7 +5238,6 @@ dependencies = [
"client",
"cloud_llm_client",
"collections",
- "criterion",
"db",
"debug_adapter_extension",
"dirs 4.0.0",
@@ -72,10 +72,14 @@ zeta_prompt.workspace = true
zstd.workspace = true
[dev-dependencies]
+criterion.workspace = true
+fs = { workspace = true, features = ["test-support"] }
+gpui = { workspace = true, features = ["test-support"] }
+languages = { workspace = true, features = ["load-grammars"] }
+node_runtime.workspace = true
clock = { workspace = true, features = ["test-support"] }
cloud_llm_client = { workspace = true, features = ["test-support"] }
ctor.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
@@ -86,3 +90,11 @@ settings = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
zlog.workspace = true
+
+[[bench]]
+name = "kept_rate"
+harness = false
+
+[[bench]]
+name = "ts_error_count"
+harness = false
@@ -1,5 +1,5 @@
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
-use edit_prediction_cli::kept_rate::compute_kept_rate;
+use edit_prediction::metrics::compute_kept_rate;
fn repeated_function_lines(line_count: usize) -> String {
let mut text = String::with_capacity(line_count * 32);
@@ -0,0 +1,454 @@
+use std::sync::Arc;
+
+use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
+use edit_prediction::metrics::count_tree_sitter_errors;
+use fs::FakeFs;
+use gpui::{AppContext as _, TestAppContext};
+use language::{Buffer, BufferSnapshot, LanguageRegistry};
+use languages::init as init_languages;
+use node_runtime::NodeRuntime;
+use settings::SettingsStore;
+
+struct ParsedCase {
+ label: String,
+ bytes: usize,
+ error_count: usize,
+ snapshot: BufferSnapshot,
+}
+
+fn replace_nth_occurrences(
+ source: &mut String,
+ needle: &str,
+ replacement: &str,
+ every: usize,
+ max_replacements: usize,
+) {
+ let mut rebuilt = String::with_capacity(source.len());
+ let mut cursor = 0;
+ let mut seen = 0;
+ let mut replaced = 0;
+
+ while let Some(relative_index) = source[cursor..].find(needle) {
+ let start = cursor + relative_index;
+ let end = start + needle.len();
+ rebuilt.push_str(&source[cursor..start]);
+
+ if seen % every == 0 && replaced < max_replacements {
+ rebuilt.push_str(replacement);
+ replaced += 1;
+ } else {
+ rebuilt.push_str(needle);
+ }
+
+ seen += 1;
+ cursor = end;
+ }
+
+ rebuilt.push_str(&source[cursor..]);
+ *source = rebuilt;
+}
+
+fn rust_source(function_count: usize) -> String {
+ let mut source = String::from(
+ "pub struct Counter {\n value: usize,\n}\n\nimpl Counter {\n pub fn new() -> Self {\n Self { value: 0 }\n }\n}\n\n",
+ );
+ for index in 0..function_count {
+ source.push_str(&format!(
+ "pub fn compute_value_{index}(input: usize) -> usize {{\n let mut total = input;\n for offset in 0..32 {{\n total += offset + {index};\n }}\n if total % 2 == 0 {{\n total / 2\n }} else {{\n total * 3 + 1\n }}\n}}\n\n"
+ ));
+ }
+ source
+}
+
+fn rust_source_with_errors(function_count: usize) -> String {
+ let mut source = rust_source(function_count);
+ replace_nth_occurrences(
+ &mut source,
+ " if total % 2 == 0 {\n",
+ " if total % 2 == 0 \n",
+ 17,
+ 48,
+ );
+ source
+}
+
+fn python_source(function_count: usize) -> String {
+ let mut source = String::from(
+ "class Counter:\n def __init__(self) -> None:\n self.value = 0\n\n\n",
+ );
+ for index in 0..function_count {
+ source.push_str(&format!(
+ "def compute_value_{index}(input_value: int) -> int:\n total = input_value\n for offset in range(32):\n total += offset + {index}\n if total % 2 == 0:\n return total // 2\n return total * 3 + 1\n\n"
+ ));
+ }
+ source
+}
+
+fn python_source_with_errors(function_count: usize) -> String {
+ let mut source = python_source(function_count);
+ replace_nth_occurrences(
+ &mut source,
+ " if total % 2 == 0:\n",
+ " if total % 2 == 0\n",
+ 19,
+ 48,
+ );
+ source
+}
+
+fn go_source(function_count: usize) -> String {
+ let mut source = String::from(
+ "package bench\n\ntype Counter struct {\n\tvalue int\n}\n\nfunc NewCounter() Counter {\n\treturn Counter{value: 0}\n}\n\n",
+ );
+ for index in 0..function_count {
+ source.push_str(&format!(
+ "func ComputeValue{index}(inputValue int) int {{\n\ttotal := inputValue\n\tfor offset := 0; offset < 32; offset++ {{\n\t\ttotal += offset + {index}\n\t}}\n\tif total%2 == 0 {{\n\t\treturn total / 2\n\t}}\n\treturn total*3 + 1\n}}\n\n"
+ ));
+ }
+ source
+}
+
+fn go_source_with_errors(function_count: usize) -> String {
+ let mut source = go_source(function_count);
+ replace_nth_occurrences(
+ &mut source,
+ "\tfor offset := 0; offset < 32; offset++ {\n",
+ "\tfor offset := 0; offset < 32; offset++ \n",
+ 17,
+ 48,
+ );
+ source
+}
+
+fn typescript_source(function_count: usize) -> String {
+ let mut source = String::from(
+ "export type Counter = { value: number };\n\nexport function newCounter(): Counter {\n return { value: 0 };\n}\n\n",
+ );
+ for index in 0..function_count {
+ source.push_str(&format!(
+ "export function computeValue{index}(inputValue: number): number {{\n let total = inputValue;\n for (let offset = 0; offset < 32; offset += 1) {{\n total += offset + {index};\n }}\n return total % 2 === 0 ? total / 2 : total * 3 + 1;\n}}\n\n"
+ ));
+ }
+ source
+}
+
+fn typescript_source_with_errors(function_count: usize) -> String {
+ let mut source = typescript_source(function_count);
+ replace_nth_occurrences(
+ &mut source,
+ " return total % 2 === 0 ? total / 2 : total * 3 + 1;\n",
+ " return total % 2 === 0 ? total / 2 : ;\n",
+ 17,
+ 64,
+ );
+ source
+}
+
+fn tsx_source(component_count: usize) -> String {
+ let mut source = String::from(
+ "type ItemProps = { index: number; label: string };\n\nfunction Item({ index, label }: ItemProps) {\n return <li data-index={index}>{label}</li>;\n}\n\nexport function App() {\n return <section><ul>{[0, 1, 2].map((value) => <Item key={value} index={value} label={`item-${value}`} />)}</ul></section>;\n}\n\n",
+ );
+ for index in 0..component_count {
+ source.push_str(&format!(
+ "export function Widget{index}(): JSX.Element {{\n const items = Array.from({{ length: 16 }}, (_, value) => value + {index});\n return (\n <div className=\"widget-{index}\">\n <h2>Widget {index}</h2>\n <ul>\n {{items.map((value) => (\n <Item key={{value}} index={{value}} label={{`widget-{index}-${{value}}`}} />\n ))}}\n </ul>\n </div>\n );\n}}\n\n"
+ ));
+ }
+ source
+}
+
+fn tsx_source_with_errors(component_count: usize) -> String {
+ let mut source = tsx_source(component_count);
+ replace_nth_occurrences(
+ &mut source,
+ " const items = Array.from({ length: 16 }, (_, value) => value + ",
+ " const items = Array.from({ length: 16 }, (_, value) => ); // ",
+ 11,
+ 32,
+ );
+ source
+}
+
+fn json_source(object_count: usize) -> String {
+ let mut source = String::from("{\n \"items\": [\n");
+ for index in 0..object_count {
+ let suffix = if index + 1 == object_count { "" } else { "," };
+ source.push_str(&format!(
+ " {{\n \"id\": {index},\n \"name\": \"item-{index}\",\n \"enabled\": true,\n \"tags\": [\"alpha\", \"beta\", \"gamma\"],\n \"metrics\": {{ \"count\": {}, \"ratio\": {} }}\n }}{suffix}\n",
+ index * 3 + 1,
+ index as f64 / 10.0,
+ ));
+ }
+ source.push_str(" ]\n}\n");
+ source
+}
+
+fn json_source_with_errors(object_count: usize) -> String {
+ let mut source = json_source(object_count);
+ replace_nth_occurrences(
+ &mut source,
+ " \"enabled\": true,\n",
+ " \"enabled\": ,\n",
+ 23,
+ 64,
+ );
+ source
+}
+
+fn yaml_source(document_count: usize) -> String {
+ let mut source = String::new();
+ for index in 0..document_count {
+ source.push_str(&format!(
+ "- id: {index}\n name: item-{index}\n enabled: true\n tags:\n - alpha\n - beta\n - gamma\n metrics:\n count: {}\n ratio: {}\n",
+ index * 3 + 1,
+ index as f64 / 10.0,
+ ));
+ }
+ source
+}
+
+fn yaml_source_with_errors(document_count: usize) -> String {
+ let mut source = yaml_source(document_count);
+ replace_nth_occurrences(&mut source, " count: ", " count ", 23, 64);
+ source
+}
+
+fn css_source(rule_count: usize) -> String {
+ let mut source = String::new();
+ for index in 0..rule_count {
+ source.push_str(&format!(
+ ".widget-{index} {{\n display: grid;\n grid-template-columns: repeat(4, minmax(0, 1fr));\n gap: 12px;\n padding: 8px;\n color: rgb({}, {}, {});\n}}\n\n.widget-{index} > .item-{index} {{\n border: 1px solid rgba(0, 0, 0, 0.15);\n background: linear-gradient(90deg, #fff, #eef);\n}}\n\n",
+ (index * 17) % 255,
+ (index * 31) % 255,
+ (index * 47) % 255,
+ ));
+ }
+ source
+}
+
+fn css_source_with_errors(rule_count: usize) -> String {
+ let mut source = css_source(rule_count);
+ replace_nth_occurrences(&mut source, " gap: 12px;\n", " gap 12px;\n", 29, 64);
+ source
+}
+
+fn build_case(
+ context: &mut TestAppContext,
+ languages: &Arc<LanguageRegistry>,
+ language_name: &'static str,
+ variant_name: &'static str,
+ source: String,
+ expect_errors: bool,
+) -> ParsedCase {
+ let language_task = context.background_spawn({
+ let languages = languages.clone();
+ async move { languages.language_for_name(language_name).await }
+ });
+ while !language_task.is_ready() {
+ context.run_until_parked();
+ }
+ let language = futures::executor::block_on(language_task)
+ .unwrap_or_else(|error| panic!("failed to load {language_name}: {error}"));
+
+ let buffer = context.new(|cx| Buffer::local(source, cx).with_language(language, cx));
+ context.run_until_parked();
+ while buffer.read_with(context, |buffer, _| buffer.is_parsing()) {
+ context.run_until_parked();
+ }
+
+ let snapshot = buffer.read_with(context, |buffer, _| buffer.snapshot());
+ let full_range = 0..snapshot.text.len();
+ let error_count = count_tree_sitter_errors(snapshot.syntax_layers());
+ if expect_errors {
+ assert!(
+ error_count > 0,
+ "expected tree-sitter errors for {language_name}/{variant_name}",
+ );
+ } else {
+ assert_eq!(
+ error_count, 0,
+ "expected no tree-sitter errors for {language_name}/{variant_name}",
+ );
+ }
+
+ let label = format!(
+ "{}/{}_{}kb_{}e",
+ language_name.to_lowercase(),
+ variant_name,
+ full_range.end / 1024,
+ error_count,
+ );
+ ParsedCase {
+ label,
+ bytes: full_range.end,
+ error_count,
+ snapshot,
+ }
+}
+
+fn parsed_cases() -> Vec<ParsedCase> {
+ let mut context = TestAppContext::single();
+ context.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ });
+
+ let languages = Arc::new(LanguageRegistry::new(context.executor()));
+ let fs = FakeFs::new(context.executor());
+ let node_runtime = NodeRuntime::unavailable();
+ context.update(|cx| init_languages(languages.clone(), fs, node_runtime, cx));
+
+ vec![
+ build_case(
+ &mut context,
+ &languages,
+ "Rust",
+ "valid",
+ rust_source(900),
+ false,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "Rust",
+ "error_heavy",
+ rust_source_with_errors(900),
+ true,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "Python",
+ "valid",
+ python_source(1100),
+ false,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "Python",
+ "error_heavy",
+ python_source_with_errors(1100),
+ true,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "Go",
+ "valid",
+ go_source(1000),
+ false,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "Go",
+ "error_heavy",
+ go_source_with_errors(1000),
+ true,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "TypeScript",
+ "valid",
+ typescript_source(1000),
+ false,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "TypeScript",
+ "error_heavy",
+ typescript_source_with_errors(1000),
+ true,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "TSX",
+ "valid",
+ tsx_source(350),
+ false,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "TSX",
+ "error_heavy",
+ tsx_source_with_errors(350),
+ true,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "JSON",
+ "valid",
+ json_source(2200),
+ false,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "JSON",
+ "error_heavy",
+ json_source_with_errors(2200),
+ true,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "YAML",
+ "valid",
+ yaml_source(2200),
+ false,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "YAML",
+ "error_heavy",
+ yaml_source_with_errors(2200),
+ true,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "CSS",
+ "valid",
+ css_source(2400),
+ false,
+ ),
+ build_case(
+ &mut context,
+ &languages,
+ "CSS",
+ "error_heavy",
+ css_source_with_errors(2400),
+ true,
+ ),
+ ]
+}
+
+fn ts_error_count_benchmark(c: &mut Criterion) {
+ let cases = parsed_cases();
+ let mut group = c.benchmark_group("ts_error_count/full_file");
+
+ for case in &cases {
+ group.bench_with_input(
+ BenchmarkId::from_parameter(&case.label),
+ case,
+ |bench, case| {
+ bench.iter(|| {
+ black_box(case.bytes);
+ black_box(case.error_count);
+ black_box(count_tree_sitter_errors(case.snapshot.syntax_layers()))
+ });
+ },
+ );
+ }
+
+ group.finish();
+}
+
+criterion_group!(benches, ts_error_count_benchmark);
+criterion_main!(benches);
@@ -30,7 +30,7 @@ use gpui::{
};
use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
-use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
+use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
@@ -61,6 +61,7 @@ pub mod example_spec;
pub mod fim;
mod license_detection;
pub mod mercury;
+pub mod metrics;
pub mod ollama;
mod onboarding_modal;
pub mod open_ai_response;
@@ -80,6 +81,7 @@ use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
use crate::example_spec::ExampleSpec;
use crate::license_detection::LicenseDetectionWatcher;
use crate::mercury::Mercury;
+pub use crate::metrics::{KeptRateResult, compute_kept_rate};
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
@@ -478,10 +480,13 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
}
#[derive(Clone)]
-
struct PendingSettledPrediction {
request_id: EditPredictionId,
editable_anchor_range: Range<Anchor>,
+ editable_region_before_prediction: String,
+ predicted_editable_region: String,
+ ts_error_count_before_prediction: usize,
+ ts_error_count_after_prediction: usize,
example: Option<ExampleSpec>,
enqueued_at: Instant,
last_edit_at: Instant,
@@ -1603,63 +1608,100 @@ impl EditPredictionStore {
};
let now = cx.background_executor().now();
-
let mut oldest_edited_at = None;
+ let mut ready_predictions = Vec::new();
this.update(cx, |this, _| {
for (_, project_state) in this.projects.iter_mut() {
for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
- registered_buffer
- .pending_predictions
- .retain_mut(|pending_prediction| {
- let age =
- now.saturating_duration_since(pending_prediction.enqueued_at);
- if age >= EDIT_PREDICTION_SETTLED_TTL {
- return false;
- }
-
- let quiet_for =
- now.saturating_duration_since(pending_prediction.last_edit_at);
- if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
- let settled_editable_region = registered_buffer
- .snapshot
- .text_for_range(
- pending_prediction.editable_anchor_range.clone(),
- )
- .collect::<String>();
-
- #[cfg(test)]
- if let Some(callback) = &this.settled_event_callback {
- callback(
- pending_prediction.request_id.clone(),
- settled_editable_region.clone(),
- );
- }
-
- telemetry::event!(
- EDIT_PREDICTION_SETTLED_EVENT,
- request_id = pending_prediction.request_id.0.clone(),
- settled_editable_region,
- example = pending_prediction.example.take(),
- e2e_latency = pending_prediction.e2e_latency.as_millis(),
- );
-
- return false;
- }
+ let mut pending_index = 0;
+ while pending_index < registered_buffer.pending_predictions.len() {
+ let pending_prediction =
+ ®istered_buffer.pending_predictions[pending_index];
+ let age = now.saturating_duration_since(pending_prediction.enqueued_at);
+ if age >= EDIT_PREDICTION_SETTLED_TTL {
+ registered_buffer.pending_predictions.remove(pending_index);
+ continue;
+ }
- if oldest_edited_at
- .is_none_or(|t| pending_prediction.last_edit_at < t)
- {
- oldest_edited_at = Some(pending_prediction.last_edit_at);
- }
+ let quiet_for =
+ now.saturating_duration_since(pending_prediction.last_edit_at);
+ if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
+ let pending_prediction =
+ registered_buffer.pending_predictions.remove(pending_index);
+ let settled_editable_region = registered_buffer
+ .snapshot
+ .text_for_range(
+ pending_prediction.editable_anchor_range.clone(),
+ )
+ .collect::<String>();
+ ready_predictions
+ .push((pending_prediction, settled_editable_region));
+ continue;
+ }
- true
- });
+ if oldest_edited_at
+ .is_none_or(|time| pending_prediction.last_edit_at < time)
+ {
+ oldest_edited_at = Some(pending_prediction.last_edit_at);
+ }
+ pending_index += 1;
+ }
}
}
});
- next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
+ for (pending_prediction, settled_editable_region) in ready_predictions {
+ let PendingSettledPrediction {
+ request_id,
+ editable_region_before_prediction,
+ predicted_editable_region,
+ ts_error_count_before_prediction,
+ ts_error_count_after_prediction,
+ example,
+ e2e_latency,
+ ..
+ } = pending_prediction;
+ let settled_editable_region_for_metrics = settled_editable_region.clone();
+ let kept_rate_result = cx
+ .background_spawn(async move {
+ compute_kept_rate(
+ &editable_region_before_prediction,
+ &predicted_editable_region,
+ &settled_editable_region_for_metrics,
+ )
+ })
+ .await;
+
+ #[cfg(test)]
+ {
+ let request_id = request_id.clone();
+ let settled_editable_region = settled_editable_region.clone();
+ this.update(cx, |this, _| {
+ if let Some(callback) = &this.settled_event_callback {
+ callback(request_id, settled_editable_region);
+ }
+ });
+ }
+
+ telemetry::event!(
+ EDIT_PREDICTION_SETTLED_EVENT,
+ request_id = request_id.0.clone(),
+ settled_editable_region,
+ ts_error_count_before_prediction,
+ ts_error_count_after_prediction,
+ edit_bytes_predicted_new = kept_rate_result.predicted_new_chars,
+ edit_bytes_final_new = kept_rate_result.final_new_chars,
+ edit_bytes_kept = kept_rate_result.kept_chars,
+ edit_bytes_discarded = kept_rate_result.discarded_chars,
+ edit_bytes_context = kept_rate_result.context_chars,
+ edit_bytes_kept_rate = kept_rate_result.kept_rate,
+ example,
+ e2e_latency = e2e_latency.as_millis(),
+ );
+ }
+
+ next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
}
}
@@ -1670,28 +1712,58 @@ impl EditPredictionStore {
edited_buffer: &Entity<Buffer>,
edited_buffer_snapshot: &BufferSnapshot,
editable_offset_range: Range<usize>,
+ edit_preview: &EditPreview,
example: Option<ExampleSpec>,
e2e_latency: std::time::Duration,
cx: &mut Context<Self>,
) {
let this = &mut *self;
let project_state = this.get_or_init_project(project, cx);
- if let Some(buffer) = project_state
+ let Some(registered_buffer) = project_state
.registered_buffers
.get_mut(&edited_buffer.entity_id())
- {
- let now = cx.background_executor().now();
- buffer.pending_predictions.push(PendingSettledPrediction {
- request_id: request_id,
- editable_anchor_range: edited_buffer_snapshot
- .anchor_range_inside(editable_offset_range),
+ else {
+ return;
+ };
+
+ let editable_region_before_prediction = edited_buffer_snapshot
+ .text_for_range(editable_offset_range.clone())
+ .collect::<String>();
+ let editable_anchor_range_for_result =
+ edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
+ let predicted_editable_region = edit_preview
+ .result_text_snapshot()
+ .text_for_range(editable_anchor_range_for_result.clone())
+ .collect();
+ let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
+ edited_buffer_snapshot
+ .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
+ );
+ let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
+ edit_preview.result_syntax_snapshot().layers_for_range(
+ editable_anchor_range_for_result,
+ edit_preview.result_text_snapshot(),
+ true,
+ ),
+ );
+ let editable_anchor_range =
+ edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
+ let now = cx.background_executor().now();
+ registered_buffer
+ .pending_predictions
+ .push(PendingSettledPrediction {
+ request_id,
+ editable_anchor_range,
+ editable_region_before_prediction,
+ predicted_editable_region,
+ ts_error_count_before_prediction,
+ ts_error_count_after_prediction,
example,
e2e_latency,
enqueued_at: now,
last_edit_at: now,
});
- this.settled_predictions_tx.unbounded_send(now).ok();
- }
+ this.settled_predictions_tx.unbounded_send(now).ok();
}
fn reject_current_prediction(
@@ -3252,6 +3252,12 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
cx.run_until_parked();
let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let empty_edits: Arc<[(Range<Anchor>, Arc<str>)]> = Vec::new().into();
+ let edit_preview_a = buffer
+ .read_with(cx, |buffer, cx| {
+ buffer.preview_edits(empty_edits.clone(), cx)
+ })
+ .await;
// Region A: first 10 lines of the buffer.
let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
@@ -3263,6 +3269,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
&buffer,
&snapshot_a,
editable_region_a.clone(),
+ &edit_preview_a,
None,
Duration::from_secs(0),
cx,
@@ -3318,6 +3325,9 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
cx.run_until_parked();
let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let edit_preview_b = buffer
+ .read_with(cx, |buffer, cx| buffer.preview_edits(empty_edits, cx))
+ .await;
let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
ep_store.update(cx, |ep_store, cx| {
@@ -3327,6 +3337,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
&buffer,
&snapshot_b2,
editable_region_b.clone(),
+ &edit_preview_b,
None,
Duration::from_secs(0),
cx,
@@ -0,0 +1,10 @@
+mod kept_rate;
+mod tokenize;
+mod tree_sitter;
+
+pub use kept_rate::KeptRateResult;
+#[cfg(test)]
+pub use kept_rate::TokenAnnotation;
+pub use kept_rate::compute_kept_rate;
+pub(crate) use tokenize::tokenize;
+pub use tree_sitter::count_tree_sitter_errors;
@@ -1,4 +1,6 @@
-use crate::word_diff::tokenize;
+use crate::metrics::tokenize;
+
+const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -25,20 +27,28 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize {
row * width + column
}
-/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
-/// side while sharing a single DP table construction.
-fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
+/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
+/// while sharing a single DP table construction.
+fn fill_lcs_keep_masks(
+ a: &[&str],
+ b: &[&str],
+ mut keep_a: Option<&mut [bool]>,
+ mut keep_b: Option<&mut [bool]>,
+) {
if a.is_empty() || b.is_empty() {
- return (vec![false; a.len()], vec![false; b.len()]);
+ return;
}
if a == b {
- return (vec![true; a.len()], vec![true; b.len()]);
+ if let Some(keep_a) = keep_a.as_mut() {
+ keep_a.fill(true);
+ }
+ if let Some(keep_b) = keep_b.as_mut() {
+ keep_b.fill(true);
+ }
+ return;
}
- let mut keep_a = vec![false; a.len()];
- let mut keep_b = vec![false; b.len()];
-
let prefix_len = a
.iter()
.zip(b.iter())
@@ -61,22 +71,30 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
};
for index in 0..prefix_len {
- keep_a[index] = true;
- keep_b[index] = true;
+ if let Some(keep_a) = keep_a.as_mut() {
+ keep_a[index] = true;
+ }
+ if let Some(keep_b) = keep_b.as_mut() {
+ keep_b[index] = true;
+ }
}
for offset in 0..suffix_len {
let a_index = a.len() - suffix_len + offset;
let b_index = b.len() - suffix_len + offset;
- keep_a[a_index] = true;
- keep_b[b_index] = true;
+ if let Some(keep_a) = keep_a.as_mut() {
+ keep_a[a_index] = true;
+ }
+ if let Some(keep_b) = keep_b.as_mut() {
+ keep_b[b_index] = true;
+ }
}
let a_mid = &a[prefix_len..a.len() - suffix_len];
let b_mid = &b[prefix_len..b.len() - suffix_len];
if a_mid.is_empty() || b_mid.is_empty() {
- return (keep_a, keep_b);
+ return;
}
let row_count = a_mid.len() + 1;
@@ -97,44 +115,59 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
}
}
- let mut i = a_mid.len();
- let mut j = b_mid.len();
+ if let Some(keep_a) = keep_a.as_mut() {
+ let mut i = a_mid.len();
+ let mut j = b_mid.len();
- while i > 0 && j > 0 {
- if a_mid[i - 1] == b_mid[j - 1] {
- keep_a[prefix_len + i - 1] = true;
- i -= 1;
- j -= 1;
- } else {
- let up = dp[dp_index(column_count, i - 1, j)];
- let left = dp[dp_index(column_count, i, j - 1)];
- if up >= left {
+ while i > 0 && j > 0 {
+ if a_mid[i - 1] == b_mid[j - 1] {
+ keep_a[prefix_len + i - 1] = true;
i -= 1;
- } else {
j -= 1;
+ } else {
+ let up = dp[dp_index(column_count, i - 1, j)];
+ let left = dp[dp_index(column_count, i, j - 1)];
+ if up >= left {
+ i -= 1;
+ } else {
+ j -= 1;
+ }
}
}
}
- let mut i = a_mid.len();
- let mut j = b_mid.len();
+ if let Some(keep_b) = keep_b.as_mut() {
+ let mut i = a_mid.len();
+ let mut j = b_mid.len();
- while i > 0 && j > 0 {
- if a_mid[i - 1] == b_mid[j - 1] {
- keep_b[prefix_len + j - 1] = true;
- i -= 1;
- j -= 1;
- } else {
- let up = dp[dp_index(column_count, i - 1, j)];
- let left = dp[dp_index(column_count, i, j - 1)];
- if left >= up {
+ while i > 0 && j > 0 {
+ if a_mid[i - 1] == b_mid[j - 1] {
+ keep_b[prefix_len + j - 1] = true;
+ i -= 1;
j -= 1;
} else {
- i -= 1;
+ let up = dp[dp_index(column_count, i - 1, j)];
+ let left = dp[dp_index(column_count, i, j - 1)];
+ if left >= up {
+ j -= 1;
+ } else {
+ i -= 1;
+ }
}
}
}
+}
+
+fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
+ let mut keep_a = vec![false; a.len()];
+ fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
+ keep_a
+}
+fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
+ let mut keep_a = vec![false; a.len()];
+ let mut keep_b = vec![false; b.len()];
+ fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
(keep_a, keep_b)
}
@@ -155,6 +188,12 @@ fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>
(unmasked_tokens, unmasked_chars, masked_chars)
}
+fn should_bail_for_dirty_final(base: &str, predicted: &str, final_text: &str) -> bool {
+ let predicted_delta_chars = predicted.len().abs_diff(base.len());
+ let final_delta_chars = final_text.len().abs_diff(base.len());
+ predicted_delta_chars.abs_diff(final_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
+}
+
pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
if base == predicted && predicted == final_text {
let predicted_tokens = tokenize(predicted);
@@ -171,11 +210,26 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
};
}
+ if should_bail_for_dirty_final(base, predicted, final_text) {
+ let predicted_new_chars = predicted.len().abs_diff(base.len());
+ let final_new_chars = final_text.len().abs_diff(base.len());
+ return KeptRateResult {
+ predicted_new_chars,
+ final_new_chars,
+ kept_chars: 0,
+ discarded_chars: predicted_new_chars,
+ context_chars: 0,
+ kept_rate: 0.0,
+ #[cfg(test)]
+ token_annotations: vec![TokenAnnotation::Discarded; tokenize(predicted).len()],
+ };
+ }
+
let base_tokens = tokenize(base);
let predicted_tokens = tokenize(predicted);
let final_tokens = tokenize(final_text);
- let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens);
+ let pred_base_mask = lcs_keep_mask(&predicted_tokens, &base_tokens);
let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
let context_mask: Vec<bool> = pred_base_mask
.iter()
@@ -186,7 +240,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
let (stripped_predicted, predicted_new_chars, context_chars) =
analyze_masked_tokens(&predicted_tokens, &context_mask);
- let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens);
+ let final_base_mask = lcs_keep_mask(&final_tokens, &base_tokens);
let final_context_mask: Vec<bool> = final_base_mask
.iter()
.zip(final_pred_mask.iter())
@@ -196,7 +250,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
let (stripped_final, final_new_chars, _) =
analyze_masked_tokens(&final_tokens, &final_context_mask);
- let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0;
+ let keep_mask = lcs_keep_mask(&stripped_predicted, &stripped_final);
let kept_chars: usize = stripped_predicted
.iter()
@@ -265,8 +319,8 @@ mod test_kept_rate {
let a = ["x", "a", "x", "b"];
let b = ["a", "x", "b", "x"];
let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
- assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
- assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
+ assert_eq!(a_mask, lcs_keep_mask(&a, &b));
+ assert_eq!(b_mask, lcs_keep_mask(&b, &a));
}
#[test]
@@ -342,6 +396,21 @@ mod test_kept_rate {
assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
}
+ #[test]
+ fn test_bails_for_dirty_final() {
+ let base = "fn example() {\n work();\n}\n";
+ let predicted = "fn example() {\n work();\n predicted();\n}\n";
+ let final_text = format!(
+ "fn example() {{\n work();\n {}\n}}\n",
+ "settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
+ );
+
+ let result = compute_kept_rate(base, predicted, &final_text);
+ assert_eq!(result.kept_rate, 0.0);
+ assert_eq!(result.kept_chars, 0);
+ assert_eq!(result.discarded_chars, result.predicted_new_chars);
+ }
+
#[test]
fn test_eprintln_token_alignment() {
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
@@ -0,0 +1,54 @@
+fn char_class(character: char) -> u8 {
+ if character.is_alphanumeric() || character == '_' {
+ 0
+ } else if character.is_whitespace() {
+ 1
+ } else {
+ 2
+ }
+}
+
+pub(crate) fn tokenize(text: &str) -> Vec<&str> {
+ let mut tokens = Vec::new();
+ let mut characters = text.char_indices().peekable();
+
+ while let Some((start, character)) = characters.next() {
+ let class = char_class(character);
+ if class == 2 {
+ tokens.push(&text[start..start + character.len_utf8()]);
+ continue;
+ }
+
+ let mut end = start + character.len_utf8();
+ while let Some(&(_, next_character)) = characters.peek() {
+ if char_class(next_character) != class {
+ break;
+ }
+ end += next_character.len_utf8();
+ characters.next();
+ }
+ tokens.push(&text[start..end]);
+ }
+
+ tokens
+}
+
+#[cfg(test)]
+mod tests {
+ use super::tokenize;
+
+ #[test]
+ fn tokenizes_code_like_text() {
+ assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]);
+ assert_eq!(
+ tokenize("foo_bar123 + baz"),
+ vec!["foo_bar123", " ", "+", " ", "baz"]
+ );
+ assert_eq!(
+ tokenize("print(\"hello\")"),
+ vec!["print", "(", "\"", "hello", "\"", ")"]
+ );
+ assert_eq!(tokenize("hello_world"), vec!["hello_world"]);
+ assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]);
+ }
+}
@@ -0,0 +1,88 @@
+use language::SyntaxLayer;
+
+pub fn count_tree_sitter_errors<'a>(layers: impl Iterator<Item = SyntaxLayer<'a>>) -> usize {
+ let mut total_count: usize = 0;
+ for layer in layers {
+ let node = layer.node();
+ let mut cursor = node.walk();
+ 'layer: loop {
+ let current = cursor.node();
+ if current.is_error() || current.is_missing() {
+ total_count += 1;
+ }
+ if current.has_error() && cursor.goto_first_child() {
+ continue;
+ }
+ if cursor.goto_next_sibling() {
+ continue;
+ }
+ loop {
+ if !cursor.goto_parent() {
+ break 'layer;
+ }
+ if cursor.goto_next_sibling() {
+ continue;
+ }
+ }
+ }
+ }
+ total_count
+}
+
+#[cfg(test)]
+mod tests {
+ use std::ops::Range;
+
+ use super::count_tree_sitter_errors;
+ use gpui::{AppContext as _, TestAppContext};
+ use language::{Buffer, BufferSnapshot, rust_lang};
+
+ fn error_count_in_range(edited_buffer_snapshot: &BufferSnapshot, range: Range<usize>) -> usize {
+ let layers = edited_buffer_snapshot.syntax_layers_for_range(range, true);
+ count_tree_sitter_errors(layers)
+ }
+
+ fn rust_snapshot(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
+ while buffer.read_with(cx, |buffer, _| buffer.is_parsing()) {
+ cx.run_until_parked();
+ }
+ buffer.read_with(cx, |buffer, _| buffer.snapshot())
+ }
+
+ #[gpui::test]
+ async fn counts_no_errors_for_valid_rust(cx: &mut TestAppContext) {
+ let text = "fn helper(value: usize) -> usize {\n value + 1\n}\n";
+ let snapshot = rust_snapshot(text, cx);
+
+ assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 0);
+ }
+
+ #[gpui::test]
+ async fn counts_errors_for_invalid_rust(cx: &mut TestAppContext) {
+ let text = "fn helper(value: usize) -> usize {\n let total = ;\n total\n}\n";
+ let snapshot = rust_snapshot(text, cx);
+
+ assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 1);
+ }
+
+ #[gpui::test]
+ async fn counts_no_errors_for_subrange_of_valid_rust(cx: &mut TestAppContext) {
+ let text = "fn first() -> usize {\n let value = 1;\n value + 1\n}\n";
+ let snapshot = rust_snapshot(text, cx);
+ let body_start = text.find("let value").unwrap();
+ let body_end = body_start + "let value = 1;".len();
+
+ assert_eq!(error_count_in_range(&snapshot, body_start..body_end), 0);
+ }
+
+ #[gpui::test]
+ async fn counts_errors_for_subrange_of_invalid_rust(cx: &mut TestAppContext) {
+ let text = "fn second() -> usize {\n let broken = ;\n broken\n}\n";
+ let snapshot = rust_snapshot(text, cx);
+ let error_start = text.find("let broken = ;").unwrap();
+ let error_end = error_start + "let broken = ;".len();
+
+ assert_eq!(error_count_in_range(&snapshot, error_start..error_end), 1);
+ }
+}
@@ -85,7 +85,6 @@ pub fn request_prediction_with_zeta(
} else {
None
};
-
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let organization_id = store
@@ -383,11 +382,26 @@ pub fn request_prediction_with_zeta(
}));
};
- if can_collect_data {
+ let result = EditPredictionResult::new(
+ id,
+ &edited_buffer,
+ &edited_buffer_snapshot,
+ edits.into(),
+ cursor_position,
+ inputs,
+ model_version,
+ request_duration,
+ cx,
+ )
+ .await;
+
+ if can_collect_data && let Ok(prediction) = &result.prediction {
let weak_this = this.clone();
- let id = id.clone();
+ let request_id = prediction.id.clone();
let edited_buffer = edited_buffer.clone();
let edited_buffer_snapshot = edited_buffer_snapshot.clone();
+ let editable_range_in_buffer = editable_range_in_buffer.clone();
+ let edit_preview = prediction.edit_preview.clone();
let example_task = capture_data.and_then(|stored_events| {
cx.update(|cx| {
crate::capture_example(
@@ -410,11 +424,12 @@ pub fn request_prediction_with_zeta(
weak_this
.update(cx, |this, cx| {
this.enqueue_settled_prediction(
- id.clone(),
+ request_id.clone(),
&project,
&edited_buffer,
&edited_buffer_snapshot,
editable_range_in_buffer,
+ &edit_preview,
example_spec,
request_duration,
cx,
@@ -425,20 +440,7 @@ pub fn request_prediction_with_zeta(
.detach();
}
- Ok(Some(
- EditPredictionResult::new(
- id,
- &edited_buffer,
- &edited_buffer_snapshot,
- edits.into(),
- cursor_position,
- inputs,
- model_version,
- request_duration,
- cx,
- )
- .await,
- ))
+ Ok(Some(result))
})
}
@@ -83,7 +83,6 @@ dynamic_prompts = []
ignored = ["wasmtime"]
[dev-dependencies]
-criterion.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
pretty_assertions.workspace = true
@@ -91,6 +90,3 @@ project = { workspace = true, features = ["test-support"] }
tempfile.workspace = true
workspace = { workspace = true, features = ["test-support"] }
-[[bench]]
-name = "kept_rate"
-harness = false
@@ -1,4 +1,2 @@
#[allow(dead_code)]
mod word_diff;
-
-pub mod kept_rate;
@@ -5,7 +5,7 @@ mod filter_languages;
mod format_prompt;
mod git;
mod headless;
-mod kept_rate;
+
mod load_project;
mod metrics;
mod openai_client;
@@ -1298,4 +1298,4 @@ index abc123..def456 100644
}
}
-pub use crate::kept_rate::compute_kept_rate;
+pub use edit_prediction::metrics::compute_kept_rate;
@@ -878,6 +878,14 @@ impl EditPreview {
})
}
+ pub fn result_text_snapshot(&self) -> &text::BufferSnapshot {
+ &self.applied_edits_snapshot
+ }
+
+ pub fn result_syntax_snapshot(&self) -> &SyntaxSnapshot {
+ &self.syntax_snapshot
+ }
+
pub fn anchor_to_offset_in_result(&self, anchor: Anchor) -> usize {
anchor
.bias_right(&self.old_snapshot)