From f1edfd9163adbe7f93e054030186ea52569feccc Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 27 Jan 2026 18:08:49 -0800 Subject: [PATCH] Expand editable region line-wise, then syntactically, then line-wise again --- Cargo.lock | 1 + crates/edit_prediction/Cargo.toml | 1 + crates/edit_prediction/src/cursor_excerpt.rs | 317 +++++++++++++++--- .../src/edit_prediction_tests.rs | 282 +++++++++++++++- crates/edit_prediction/src/zeta1.rs | 18 +- 5 files changed, 564 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 31b1a06c4372d8af4ef0a103b24f83d40afac63b..dc1ddaaef246fa6691fb37419340422ce96746f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5300,6 +5300,7 @@ dependencies = [ "thiserror 2.0.17", "time", "toml 0.8.23", + "tree-sitter-rust", "ui", "util", "uuid", diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index b91b30a1bc29771866d0237b2f5a108196495b82..767fd49a9f7f3fe5d481f93583ef14e17a87e9f3 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -80,4 +80,5 @@ parking_lot.workspace = true project = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } +tree-sitter-rust.workspace = true zlog.workspace = true diff --git a/crates/edit_prediction/src/cursor_excerpt.rs b/crates/edit_prediction/src/cursor_excerpt.rs index 1f2f1d32ebcb2eaa151433bd49d275e0e2a3b817..76da9c8d99cf256998ddff3bd4e253150f0599c8 100644 --- a/crates/edit_prediction/src/cursor_excerpt.rs +++ b/crates/edit_prediction/src/cursor_excerpt.rs @@ -7,66 +7,199 @@ pub fn editable_and_context_ranges_for_cursor_position( editable_region_token_limit: usize, context_token_limit: usize, ) -> (Range, Range) { - let mut scope_range = position..position; - let mut remaining_edit_tokens = editable_region_token_limit; - - while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { - let parent_tokens = guess_token_count(parent.byte_range().len()); - let parent_point_range = Point::new( - parent.start_position().row as u32, - parent.start_position().column as u32, - ) - ..Point::new( - parent.end_position().row as u32, - parent.end_position().column as u32, - ); - if parent_point_range == scope_range { - break; - } else if parent_tokens <= editable_region_token_limit { - scope_range = parent_point_range; - remaining_edit_tokens = editable_region_token_limit - parent_tokens; + let editable_range = compute_editable_range(snapshot, position, editable_region_token_limit); + + let context_range = expand_context_syntactically_then_linewise( + snapshot, + editable_range.clone(), + context_token_limit, + ); + + (editable_range, context_range) +} + +/// Computes the editable range using a three-phase approach: +/// 1. Expand symmetrically from cursor (75% of budget) +/// 2. Expand to syntax boundaries +/// 3. Continue line-wise in the least-expanded direction +fn compute_editable_range( + snapshot: &BufferSnapshot, + cursor: Point, + token_limit: usize, +) -> Range { + // Phase 1: Expand symmetrically from cursor using 75% of budget. + let initial_budget = (token_limit * 3) / 4; + let (mut start_row, mut end_row, mut remaining_tokens) = + expand_symmetric_from_cursor(snapshot, cursor.row, initial_budget); + + // Add remaining budget from phase 1. + remaining_tokens += token_limit.saturating_sub(initial_budget); + + let original_start = start_row; + let original_end = end_row; + + // Phase 2: Expand to syntax boundaries that fit within budget. + for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row) + { + let tokens_for_start = if boundary_start < start_row { + estimate_tokens_for_rows(snapshot, boundary_start, start_row) + } else { + 0 + }; + let tokens_for_end = if boundary_end > end_row { + estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1) + } else { + 0 + }; + + let total_needed = tokens_for_start + tokens_for_end; + + if total_needed <= remaining_tokens { + if boundary_start < start_row { + start_row = boundary_start; + } + if boundary_end > end_row { + end_row = boundary_end; + } + remaining_tokens = remaining_tokens.saturating_sub(total_needed); } else { break; } } - let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens); - let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); - (editable_range, context_range) + // Phase 3: Continue line-wise in the direction we expanded least during syntax phase. + let expanded_up = original_start.saturating_sub(start_row); + let expanded_down = end_row.saturating_sub(original_end); + + (start_row, end_row, _) = expand_linewise_biased( + snapshot, + start_row, + end_row, + remaining_tokens, + expanded_up <= expanded_down, // prefer_up if we expanded less upward + ); + + let start = Point::new(start_row, 0); + let end = Point::new(end_row, snapshot.line_len(end_row)); + start..end } -fn expand_range( +/// Expands symmetrically from cursor, one line at a time, alternating down then up. +/// Returns (start_row, end_row, remaining_tokens). +fn expand_symmetric_from_cursor( snapshot: &BufferSnapshot, - range: Range, + cursor_row: u32, + mut token_budget: usize, +) -> (u32, u32, usize) { + let mut start_row = cursor_row; + let mut end_row = cursor_row; + + // Account for the cursor's line. + let cursor_line_tokens = line_token_count(snapshot, cursor_row); + token_budget = token_budget.saturating_sub(cursor_line_tokens); + + loop { + let can_expand_up = start_row > 0; + let can_expand_down = end_row < snapshot.max_point().row; + + if token_budget == 0 || (!can_expand_up && !can_expand_down) { + break; + } + + // Expand down first (slight forward bias for edit prediction). + if can_expand_down { + let next_row = end_row + 1; + let line_tokens = line_token_count(snapshot, next_row); + if line_tokens <= token_budget { + end_row = next_row; + token_budget = token_budget.saturating_sub(line_tokens); + } else { + break; + } + } + + // Then expand up. + if can_expand_up && token_budget > 0 { + let next_row = start_row - 1; + let line_tokens = line_token_count(snapshot, next_row); + if line_tokens <= token_budget { + start_row = next_row; + token_budget = token_budget.saturating_sub(line_tokens); + } else { + break; + } + } + } + + (start_row, end_row, token_budget) +} + +/// Expands line-wise with a bias toward one direction. +/// Returns (start_row, end_row, remaining_tokens). +fn expand_linewise_biased( + snapshot: &BufferSnapshot, + mut start_row: u32, + mut end_row: u32, mut remaining_tokens: usize, -) -> Range { - let mut expanded_range = range; - expanded_range.start.column = 0; - expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + prefer_up: bool, +) -> (u32, u32, usize) { loop { - let mut expanded = false; + let can_expand_up = start_row > 0; + let can_expand_down = end_row < snapshot.max_point().row; - if remaining_tokens > 0 && expanded_range.start.row > 0 { - expanded_range.start.row -= 1; - let line_tokens = - guess_token_count(snapshot.line_len(expanded_range.start.row) as usize); - remaining_tokens = remaining_tokens.saturating_sub(line_tokens); - expanded = true; + if remaining_tokens == 0 || (!can_expand_up && !can_expand_down) { + break; } - if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { - expanded_range.end.row += 1; - expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - let line_tokens = guess_token_count(expanded_range.end.column as usize); - remaining_tokens = remaining_tokens.saturating_sub(line_tokens); - expanded = true; + let mut expanded = false; + + // Try preferred direction first. + if prefer_up { + if can_expand_up { + let next_row = start_row - 1; + let line_tokens = line_token_count(snapshot, next_row); + if line_tokens <= remaining_tokens { + start_row = next_row; + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + } + if can_expand_down && remaining_tokens > 0 { + let next_row = end_row + 1; + let line_tokens = line_token_count(snapshot, next_row); + if line_tokens <= remaining_tokens { + end_row = next_row; + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + } + } else { + if can_expand_down { + let next_row = end_row + 1; + let line_tokens = line_token_count(snapshot, next_row); + if line_tokens <= remaining_tokens { + end_row = next_row; + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + } + if can_expand_up && remaining_tokens > 0 { + let next_row = start_row - 1; + let line_tokens = line_token_count(snapshot, next_row); + if line_tokens <= remaining_tokens { + start_row = next_row; + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + } } if !expanded { break; } } - expanded_range + + (start_row, end_row, remaining_tokens) } /// Typical number of string bytes per token for the purposes of limiting model input. This is @@ -76,3 +209,105 @@ pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3; pub fn guess_token_count(bytes: usize) -> usize { bytes / BYTES_PER_TOKEN_GUESS } + +fn line_token_count(snapshot: &BufferSnapshot, row: u32) -> usize { + guess_token_count(snapshot.line_len(row) as usize).max(1) +} + +/// Estimates token count for rows in range [start_row, end_row). +fn estimate_tokens_for_rows(snapshot: &BufferSnapshot, start_row: u32, end_row: u32) -> usize { + let mut tokens = 0; + for row in start_row..end_row { + tokens += line_token_count(snapshot, row); + } + tokens +} + +/// Returns an iterator of (start_row, end_row) for successively larger syntax nodes +/// containing the given row range. Smallest containing node first. +fn containing_syntax_boundaries( + snapshot: &BufferSnapshot, + start_row: u32, + end_row: u32, +) -> impl Iterator { + let range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row)); + let mut current = snapshot.syntax_ancestor(range); + let mut last_rows: Option<(u32, u32)> = None; + + std::iter::from_fn(move || { + while let Some(node) = current.take() { + let node_start_row = node.start_position().row as u32; + let node_end_row = node.end_position().row as u32; + let rows = (node_start_row, node_end_row); + + current = node.parent(); + + // Skip nodes that don't extend beyond our range. + if node_start_row >= start_row && node_end_row <= end_row { + continue; + } + + // Skip if same as last returned (some nodes have same span). + if last_rows == Some(rows) { + continue; + } + + last_rows = Some(rows); + return Some(rows); + } + None + }) +} + +/// Expands context by first trying to reach syntax boundaries, +/// then expanding line-wise only if no syntax expansion occurred. +fn expand_context_syntactically_then_linewise( + snapshot: &BufferSnapshot, + editable_range: Range, + context_token_limit: usize, +) -> Range { + let mut start_row = editable_range.start.row; + let mut end_row = editable_range.end.row; + let mut remaining_tokens = context_token_limit; + let mut did_syntax_expand = false; + + // Phase 1: Try to expand to containing syntax boundaries, picking the largest that fits. + for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row) + { + let tokens_for_start = if boundary_start < start_row { + estimate_tokens_for_rows(snapshot, boundary_start, start_row) + } else { + 0 + }; + let tokens_for_end = if boundary_end > end_row { + estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1) + } else { + 0 + }; + + let total_needed = tokens_for_start + tokens_for_end; + + if total_needed <= remaining_tokens { + if boundary_start < start_row { + start_row = boundary_start; + } + if boundary_end > end_row { + end_row = boundary_end; + } + remaining_tokens = remaining_tokens.saturating_sub(total_needed); + did_syntax_expand = true; + } else { + break; + } + } + + // Phase 2: Only expand line-wise if no syntax expansion occurred. + if !did_syntax_expand { + (start_row, end_row, _) = + expand_linewise_biased(snapshot, start_row, end_row, remaining_tokens, true); + } + + let start = Point::new(start_row, 0); + let end = Point::new(end_row, snapshot.line_len(end_row)); + start..end +} diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 1291d23a80896e53f2a4d2ceaa595fd26b39b949..3cfe6866954c118940904b7a581173215f5eac5d 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1,4 +1,5 @@ use super::*; +use crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position; use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS}; use client::{UserStore, test::FakeServer}; use clock::{FakeSystemClock, ReplicaId}; @@ -12,12 +13,13 @@ use futures::{ AsyncReadExt, StreamExt, channel::{mpsc, oneshot}, }; +use gpui::App; use gpui::{ Entity, TestAppContext, http_client::{FakeHttpClient, Response}, }; use indoc::indoc; -use language::Point; +use language::{Buffer, Point, rust_lang}; use lsp::LanguageServerId; use parking_lot::Mutex; use pretty_assertions::{assert_eq, assert_matches}; @@ -25,7 +27,11 @@ use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; use std::{path::Path, sync::Arc, time::Duration}; -use util::{path, rel_path::rel_path}; +use util::{ + path, + rel_path::rel_path, + test::{TextRangeMarker, marked_text_ranges_by}, +}; use uuid::Uuid; use zeta_prompt::ZetaPromptInput; @@ -2374,3 +2380,275 @@ fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) { fn init_logger() { zlog::init_test(); } + +struct TestCase { + name: &'static str, + marked_text: &'static str, + editable_token_limit: usize, + context_token_limit: usize, +} + +#[gpui::test] +fn test_editable_and_context_ranges(cx: &mut App) { + // Markers: + // ˇ = cursor position + // « » = expected editable range + // [ ] = expected context range + let test_cases = vec![ + TestCase { + name: "cursor near end of function - expands to syntax boundaries", + marked_text: indoc! {r#" + [fn first() { + let a = 1; + let b = 2; + } + + fn foo() { + « let x = 1; + let y = 2; + println!("{}", x + y);ˇ + }»] + "#}, + // 18 tokens - expands symmetrically then to syntax boundaries + editable_token_limit: 18, + context_token_limit: 35, + }, + TestCase { + name: "cursor at function start - expands to syntax boundaries", + marked_text: indoc! {r#" + [fn before() { + « let a = 1; + } + + fn foo() {ˇ + let x = 1; + let y = 2; + let z = 3; + } + » + fn after() { + let b = 2; + }] + "#}, + // 25 tokens - expands symmetrically then to syntax boundaries + editable_token_limit: 25, + context_token_limit: 50, + }, + TestCase { + name: "tiny budget - just lines around cursor", + marked_text: indoc! {r#" + fn outer() { + [ let line1 = 1; + let line2 = 2; + « let line3 = 3; + let line4 = 4;ˇ» + let line5 = 5; + let line6 = 6;] + let line7 = 7; + } + "#}, + // 12 tokens (~36 bytes) = just the cursor line with tiny budget + editable_token_limit: 12, + context_token_limit: 24, + }, + TestCase { + name: "small function fits entirely", + marked_text: indoc! {r#" + [«fn foo() { + let x = 1;ˇ + let y = 2; + }»] + "#}, + // Plenty of budget for this small function + editable_token_limit: 30, + context_token_limit: 60, + }, + TestCase { + name: "context extends beyond editable", + marked_text: indoc! {r#" + [fn first() { let a = 1; } + «fn second() { let b = 2; } + fn third() { let c = 3; }ˇ + fn fourth() { let d = 4; }» + fn fifth() { let e = 5; }] + "#}, + // Small editable, larger context + editable_token_limit: 25, + context_token_limit: 45, + }, + // Tests for syntax-aware editable and context expansion + TestCase { + name: "cursor in first if-statement - expands to syntax boundaries", + marked_text: indoc! {r#" + [«fn before() { } + + fn process() { + if condition1 { + let a = 1;ˇ + let b = 2; + } + if condition2 {» + let c = 3; + let d = 4; + } + if condition3 { + let e = 5; + let f = 6; + } + } + + fn after() { }] + "#}, + // 35 tokens allows expansion to include function header and first two if blocks + editable_token_limit: 35, + // 60 tokens allows context to include the whole file + context_token_limit: 60, + }, + TestCase { + name: "cursor in middle if-statement - expands to syntax boundaries", + marked_text: indoc! {r#" + [fn before() { } + + fn process() { + if condition1 { + let a = 1; + « let b = 2; + } + if condition2 { + let c = 3;ˇ + let d = 4; + } + if condition3 { + let e = 5;» + let f = 6; + } + } + + fn after() { }] + "#}, + // 40 tokens allows expansion to surrounding if blocks + editable_token_limit: 40, + // 60 tokens allows context to include the whole file + context_token_limit: 60, + }, + TestCase { + name: "cursor near bottom of long function - editable expands toward syntax, context reaches function", + marked_text: indoc! {r#" + [fn other() { } + + fn long_function() { + let line1 = 1; + let line2 = 2; + let line3 = 3; + let line4 = 4; + let line5 = 5; + let line6 = 6; + « let line7 = 7; + let line8 = 8; + let line9 = 9; + let line10 = 10;ˇ + let line11 = 11; + } + + fn another() { }»] + "#}, + // 40 tokens for editable - allows several lines plus syntax expansion + editable_token_limit: 40, + // 55 tokens - enough for function but not whole file + context_token_limit: 55, + }, + ]; + + for test_case in test_cases { + let cursor_marker: TextRangeMarker = 'ˇ'.into(); + let editable_marker: TextRangeMarker = ('«', '»').into(); + let context_marker: TextRangeMarker = ('[', ']').into(); + + let (text, mut ranges) = marked_text_ranges_by( + test_case.marked_text, + vec![ + cursor_marker.clone(), + editable_marker.clone(), + context_marker.clone(), + ], + ); + + let cursor_ranges = ranges.remove(&cursor_marker).unwrap_or_default(); + let expected_editable = ranges.remove(&editable_marker).unwrap_or_default(); + let expected_context = ranges.remove(&context_marker).unwrap_or_default(); + assert_eq!(expected_editable.len(), 1); + assert_eq!(expected_context.len(), 1); + + cx.new(|cx| { + let text = text.trim_end_matches('\n'); + let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); + let snapshot = buffer.snapshot(); + + let cursor_offset = cursor_ranges[0].start; + let cursor_point = snapshot.offset_to_point(cursor_offset); + let expected_editable_start = snapshot.offset_to_point(expected_editable[0].start); + let expected_editable_end = snapshot.offset_to_point(expected_editable[0].end); + let expected_context_start = snapshot.offset_to_point(expected_context[0].start); + let expected_context_end = snapshot.offset_to_point(expected_context[0].end); + + let (actual_editable, actual_context) = editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + test_case.editable_token_limit, + test_case.context_token_limit, + ); + + let range_text = |start: Point, end: Point| -> String { + snapshot.text_for_range(start..end).collect() + }; + + let editable_match = actual_editable.start == expected_editable_start + && actual_editable.end == expected_editable_end; + let context_match = actual_context.start == expected_context_start + && actual_context.end == expected_context_end; + + if !editable_match || !context_match { + println!("\n=== FAILED: {} ===", test_case.name); + if !editable_match { + println!( + "\nExpected editable ({:?}..{:?}):", + expected_editable_start, expected_editable_end + ); + println!( + "---\n{}---", + range_text(expected_editable_start, expected_editable_end) + ); + println!( + "\nActual editable ({:?}..{:?}):", + actual_editable.start, actual_editable.end + ); + println!( + "---\n{}---", + range_text(actual_editable.start, actual_editable.end) + ); + } + if !context_match { + println!( + "\nExpected context ({:?}..{:?}):", + expected_context_start, expected_context_end + ); + println!( + "---\n{}---", + range_text(expected_context_start, expected_context_end) + ); + println!( + "\nActual context ({:?}..{:?}):", + actual_context.start, actual_context.end + ); + println!( + "---\n{}---", + range_text(actual_context.start, actual_context.end) + ); + } + panic!("Test '{}' failed - see output above", test_case.name); + } + + buffer + }); + } +} diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index 216074aae15d75aea9f43592d64097d2153e9e6a..6ce1a94228fc24cafcf95321017c9bb30b045ae1 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -610,20 +610,17 @@ mod tests { let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx)); let snapshot = buffer.read(cx).snapshot(); - // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion - // when a larger scope doesn't fit the editable region. + // The excerpt expands to syntax boundaries. + // With 50 token editable limit, we get a region that expands to syntax nodes. let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32); assert_eq!( excerpt.prompt, indoc! {r#" ```main.rs - let x = 42; - println!("Hello, world!"); - <|editable_region_start|> - } fn bar() { let x = 42; + <|editable_region_start|> let mut sum = 0; for i in 0..x { sum += i; @@ -639,7 +636,7 @@ mod tests { ```"#} ); - // The `bar` function won't fit within the editable region, so we resort to line-based expansion. + // With smaller budget, the region expands to syntax boundaries but is tighter. let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32); assert_eq!( excerpt.prompt, @@ -648,8 +645,8 @@ mod tests { fn bar() { let x = 42; let mut sum = 0; - <|editable_region_start|> for i in 0..x { + <|editable_region_start|> sum += i; } println!("Sum: {}", sum); @@ -657,11 +654,8 @@ mod tests { } fn generate_random_numbers() -> Vec { - let mut rng = rand::thread_rng(); <|editable_region_end|> - let mut numbers = Vec::new(); - for _ in 0..5 { - numbers.push(rng.random_range(1..101)); + let mut rng = rand::thread_rng(); ```"#} ); }