cursor_excerpt.rs

  1use language::{BufferSnapshot, Point, ToPoint as _};
  2use std::ops::Range;
  3use text::OffsetRangeExt as _;
  4
  5const CURSOR_EXCERPT_TOKEN_BUDGET: usize = 8192;
  6
  7/// Computes a cursor excerpt as the largest linewise symmetric region around
  8/// the cursor that fits within an 8192-token budget. Returns the point range,
  9/// byte offset range, and the cursor offset relative to the excerpt start.
 10pub fn compute_cursor_excerpt(
 11    snapshot: &BufferSnapshot,
 12    cursor_offset: usize,
 13) -> (Range<Point>, Range<usize>, usize) {
 14    let cursor_point = cursor_offset.to_point(snapshot);
 15    let cursor_row = cursor_point.row;
 16    let (start_row, end_row, _) =
 17        expand_symmetric_from_cursor(snapshot, cursor_row, CURSOR_EXCERPT_TOKEN_BUDGET);
 18
 19    let excerpt_range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row));
 20    let excerpt_offset_range = excerpt_range.to_offset(snapshot);
 21    let cursor_offset_in_excerpt = cursor_offset - excerpt_offset_range.start;
 22
 23    (
 24        excerpt_range,
 25        excerpt_offset_range,
 26        cursor_offset_in_excerpt,
 27    )
 28}
 29
 30/// Expands symmetrically from cursor, one line at a time, alternating down then up.
 31/// Returns (start_row, end_row, remaining_tokens).
 32fn expand_symmetric_from_cursor(
 33    snapshot: &BufferSnapshot,
 34    cursor_row: u32,
 35    mut token_budget: usize,
 36) -> (u32, u32, usize) {
 37    let mut start_row = cursor_row;
 38    let mut end_row = cursor_row;
 39
 40    let cursor_line_tokens = line_token_count(snapshot, cursor_row);
 41    token_budget = token_budget.saturating_sub(cursor_line_tokens);
 42
 43    loop {
 44        let can_expand_up = start_row > 0;
 45        let can_expand_down = end_row < snapshot.max_point().row;
 46
 47        if token_budget == 0 || (!can_expand_up && !can_expand_down) {
 48            break;
 49        }
 50
 51        if can_expand_down {
 52            let next_row = end_row + 1;
 53            let line_tokens = line_token_count(snapshot, next_row);
 54            if line_tokens <= token_budget {
 55                end_row = next_row;
 56                token_budget = token_budget.saturating_sub(line_tokens);
 57            } else {
 58                break;
 59            }
 60        }
 61
 62        if can_expand_up && token_budget > 0 {
 63            let next_row = start_row - 1;
 64            let line_tokens = line_token_count(snapshot, next_row);
 65            if line_tokens <= token_budget {
 66                start_row = next_row;
 67                token_budget = token_budget.saturating_sub(line_tokens);
 68            } else {
 69                break;
 70            }
 71        }
 72    }
 73
 74    (start_row, end_row, token_budget)
 75}
 76
 77/// Typical number of string bytes per token for the purposes of limiting model input. This is
 78/// intentionally low to err on the side of underestimating limits.
 79pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
 80
 81pub fn guess_token_count(bytes: usize) -> usize {
 82    bytes / BYTES_PER_TOKEN_GUESS
 83}
 84
 85fn line_token_count(snapshot: &BufferSnapshot, row: u32) -> usize {
 86    guess_token_count(snapshot.line_len(row) as usize).max(1)
 87}
 88
 89/// Computes the byte offset ranges of all syntax nodes containing the cursor,
 90/// ordered from innermost to outermost. The offsets are relative to
 91/// `excerpt_offset_range.start`.
 92pub fn compute_syntax_ranges(
 93    snapshot: &BufferSnapshot,
 94    cursor_offset: usize,
 95    excerpt_offset_range: &Range<usize>,
 96) -> Vec<Range<usize>> {
 97    let cursor_point = cursor_offset.to_point(snapshot);
 98    let range = cursor_point..cursor_point;
 99    let mut current = snapshot.syntax_ancestor(range);
100    let mut ranges = Vec::new();
101    let mut last_range: Option<(usize, usize)> = None;
102
103    while let Some(node) = current.take() {
104        let node_start = node.start_byte();
105        let node_end = node.end_byte();
106        let key = (node_start, node_end);
107
108        current = node.parent();
109
110        if last_range == Some(key) {
111            continue;
112        }
113        last_range = Some(key);
114
115        let start = node_start.saturating_sub(excerpt_offset_range.start);
116        let end = node_end
117            .min(excerpt_offset_range.end)
118            .saturating_sub(excerpt_offset_range.start);
119        ranges.push(start..end);
120    }
121
122    ranges
123}
124
125/// Expands context by first trying to reach syntax boundaries,
126/// then expanding line-wise only if no syntax expansion occurred.
127pub fn expand_context_syntactically_then_linewise(
128    snapshot: &BufferSnapshot,
129    editable_range: Range<Point>,
130    context_token_limit: usize,
131) -> Range<Point> {
132    let mut start_row = editable_range.start.row;
133    let mut end_row = editable_range.end.row;
134    let mut remaining_tokens = context_token_limit;
135    let mut did_syntax_expand = false;
136
137    // Phase 1: Try to expand to containing syntax boundaries, picking the largest that fits.
138    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
139    {
140        let tokens_for_start = if boundary_start < start_row {
141            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
142        } else {
143            0
144        };
145        let tokens_for_end = if boundary_end > end_row {
146            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
147        } else {
148            0
149        };
150
151        let total_needed = tokens_for_start + tokens_for_end;
152
153        if total_needed <= remaining_tokens {
154            if boundary_start < start_row {
155                start_row = boundary_start;
156            }
157            if boundary_end > end_row {
158                end_row = boundary_end;
159            }
160            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
161            did_syntax_expand = true;
162        } else {
163            break;
164        }
165    }
166
167    // Phase 2: Only expand line-wise if no syntax expansion occurred.
168    if !did_syntax_expand {
169        (start_row, end_row, _) =
170            expand_linewise_biased(snapshot, start_row, end_row, remaining_tokens, true);
171    }
172
173    let start = Point::new(start_row, 0);
174    let end = Point::new(end_row, snapshot.line_len(end_row));
175    start..end
176}
177
178/// Returns an iterator of (start_row, end_row) for successively larger syntax nodes
179/// containing the given row range. Smallest containing node first.
180fn containing_syntax_boundaries(
181    snapshot: &BufferSnapshot,
182    start_row: u32,
183    end_row: u32,
184) -> impl Iterator<Item = (u32, u32)> {
185    let range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row));
186    let mut current = snapshot.syntax_ancestor(range);
187    let mut last_rows: Option<(u32, u32)> = None;
188
189    std::iter::from_fn(move || {
190        while let Some(node) = current.take() {
191            let node_start_row = node.start_position().row as u32;
192            let node_end_row = node.end_position().row as u32;
193            let rows = (node_start_row, node_end_row);
194
195            current = node.parent();
196
197            // Skip nodes that don't extend beyond our range.
198            if node_start_row >= start_row && node_end_row <= end_row {
199                continue;
200            }
201
202            // Skip if same as last returned (some nodes have same span).
203            if last_rows == Some(rows) {
204                continue;
205            }
206
207            last_rows = Some(rows);
208            return Some(rows);
209        }
210        None
211    })
212}
213
214/// Expands line-wise with a bias toward one direction.
215/// Returns (start_row, end_row, remaining_tokens).
216fn expand_linewise_biased(
217    snapshot: &BufferSnapshot,
218    mut start_row: u32,
219    mut end_row: u32,
220    mut remaining_tokens: usize,
221    prefer_up: bool,
222) -> (u32, u32, usize) {
223    loop {
224        let can_expand_up = start_row > 0;
225        let can_expand_down = end_row < snapshot.max_point().row;
226
227        if remaining_tokens == 0 || (!can_expand_up && !can_expand_down) {
228            break;
229        }
230
231        let mut expanded = false;
232
233        // Try preferred direction first.
234        if prefer_up {
235            if can_expand_up {
236                let next_row = start_row - 1;
237                let line_tokens = line_token_count(snapshot, next_row);
238                if line_tokens <= remaining_tokens {
239                    start_row = next_row;
240                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
241                    expanded = true;
242                }
243            }
244            if can_expand_down && remaining_tokens > 0 {
245                let next_row = end_row + 1;
246                let line_tokens = line_token_count(snapshot, next_row);
247                if line_tokens <= remaining_tokens {
248                    end_row = next_row;
249                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
250                    expanded = true;
251                }
252            }
253        } else {
254            if can_expand_down {
255                let next_row = end_row + 1;
256                let line_tokens = line_token_count(snapshot, next_row);
257                if line_tokens <= remaining_tokens {
258                    end_row = next_row;
259                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
260                    expanded = true;
261                }
262            }
263            if can_expand_up && remaining_tokens > 0 {
264                let next_row = start_row - 1;
265                let line_tokens = line_token_count(snapshot, next_row);
266                if line_tokens <= remaining_tokens {
267                    start_row = next_row;
268                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
269                    expanded = true;
270                }
271            }
272        }
273
274        if !expanded {
275            break;
276        }
277    }
278
279    (start_row, end_row, remaining_tokens)
280}
281
282/// Estimates token count for rows in range [start_row, end_row).
283fn estimate_tokens_for_rows(snapshot: &BufferSnapshot, start_row: u32, end_row: u32) -> usize {
284    let mut tokens = 0;
285    for row in start_row..end_row {
286        tokens += line_token_count(snapshot, row);
287    }
288    tokens
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use gpui::{App, AppContext as _};
295    use indoc::indoc;
296    use language::{Buffer, rust_lang};
297    use util::test::{TextRangeMarker, marked_text_ranges_by};
298    use zeta_prompt::compute_editable_and_context_ranges;
299
300    struct TestCase {
301        name: &'static str,
302        marked_text: &'static str,
303        editable_token_limit: usize,
304        context_token_limit: usize,
305    }
306
307    #[gpui::test]
308    fn test_editable_and_context_ranges(cx: &mut App) {
309        // Markers:
310        // ˇ = cursor position
311        // « » = expected editable range
312        // [ ] = expected context range
313        let test_cases = vec![
314            TestCase {
315                name: "small function fits entirely in editable and context",
316                marked_text: indoc! {r#"
317                    [«fn foo() {
318                        let x = 1;ˇ
319                        let y = 2;
320                    }»]
321                "#},
322                editable_token_limit: 30,
323                context_token_limit: 60,
324            },
325            TestCase {
326                name: "cursor near end of function - editable expands to syntax boundaries",
327                marked_text: indoc! {r#"
328                    [fn first() {
329                        let a = 1;
330                        let b = 2;
331                    }
332
333                    fn foo() {
334                    «    let x = 1;
335                        let y = 2;
336                        println!("{}", x + y);ˇ
337                    }»]
338                "#},
339                editable_token_limit: 18,
340                context_token_limit: 35,
341            },
342            TestCase {
343                name: "cursor at function start - editable expands to syntax boundaries",
344                marked_text: indoc! {r#"
345                    [fn before() {
346                    «    let a = 1;
347                    }
348
349                    fn foo() {ˇ
350                        let x = 1;
351                        let y = 2;
352                        let z = 3;
353                    }
354                    »
355                    fn after() {
356                        let b = 2;
357                    }]
358                "#},
359                editable_token_limit: 25,
360                context_token_limit: 50,
361            },
362            TestCase {
363                name: "tiny budget - just lines around cursor, no syntax expansion",
364                marked_text: indoc! {r#"
365                    fn outer() {
366                    [    let line1 = 1;
367                        let line2 = 2;
368                    «    let line3 = 3;
369                        let line4 = 4;ˇ»
370                        let line5 = 5;
371                        let line6 = 6;]
372                        let line7 = 7;
373                    }
374                "#},
375                editable_token_limit: 12,
376                context_token_limit: 24,
377            },
378            TestCase {
379                name: "context extends beyond editable",
380                marked_text: indoc! {r#"
381                    [fn first() { let a = 1; }
382                    «fn second() { let b = 2; }
383                    fn third() { let c = 3; }ˇ
384                    fn fourth() { let d = 4; }»
385                    fn fifth() { let e = 5; }]
386                "#},
387                editable_token_limit: 25,
388                context_token_limit: 45,
389            },
390            TestCase {
391                name: "cursor in first if-block - editable expands to syntax boundaries",
392                marked_text: indoc! {r#"
393                    [«fn before() { }
394
395                    fn process() {
396                        if condition1 {
397                            let a = 1;ˇ
398                            let b = 2;
399                        }
400                        if condition2 {»
401                            let c = 3;
402                            let d = 4;
403                        }
404                        if condition3 {
405                            let e = 5;
406                            let f = 6;
407                        }
408                    }
409
410                    fn after() { }]
411                "#},
412                editable_token_limit: 35,
413                context_token_limit: 60,
414            },
415            TestCase {
416                name: "cursor in middle if-block - editable spans surrounding blocks",
417                marked_text: indoc! {r#"
418                    [fn before() { }
419
420                    fn process() {
421                        if condition1 {
422                            let a = 1;
423                    «        let b = 2;
424                        }
425                        if condition2 {
426                            let c = 3;ˇ
427                            let d = 4;
428                        }
429                        if condition3 {
430                            let e = 5;»
431                            let f = 6;
432                        }
433                    }
434
435                    fn after() { }]
436                "#},
437                editable_token_limit: 40,
438                context_token_limit: 60,
439            },
440            TestCase {
441                name: "cursor near bottom of long function - context reaches function boundary",
442                marked_text: indoc! {r#"
443                    [fn other() { }
444
445                    fn long_function() {
446                        let line1 = 1;
447                        let line2 = 2;
448                        let line3 = 3;
449                        let line4 = 4;
450                        let line5 = 5;
451                        let line6 = 6;
452                    «    let line7 = 7;
453                        let line8 = 8;
454                        let line9 = 9;
455                        let line10 = 10;ˇ
456                        let line11 = 11;
457                    }
458
459                    fn another() { }»]
460                "#},
461                editable_token_limit: 40,
462                context_token_limit: 55,
463            },
464            TestCase {
465                name: "zero context budget - context equals editable",
466                marked_text: indoc! {r#"
467                    fn before() {
468                        let p = 1;
469                        let q = 2;
470                    [«}
471
472                    fn foo() {
473                        let x = 1;ˇ
474                        let y = 2;
475                    }
476                    »]
477                    fn after() {
478                        let r = 3;
479                        let s = 4;
480                    }
481                "#},
482                editable_token_limit: 15,
483                context_token_limit: 0,
484            },
485        ];
486
487        for test_case in test_cases {
488            let cursor_marker: TextRangeMarker = 'ˇ'.into();
489            let editable_marker: TextRangeMarker = ('«', '»').into();
490            let context_marker: TextRangeMarker = ('[', ']').into();
491
492            let (text, mut ranges) = marked_text_ranges_by(
493                test_case.marked_text,
494                vec![
495                    cursor_marker.clone(),
496                    editable_marker.clone(),
497                    context_marker.clone(),
498                ],
499            );
500
501            let cursor_ranges = ranges.remove(&cursor_marker).unwrap_or_default();
502            let expected_editable = ranges.remove(&editable_marker).unwrap_or_default();
503            let expected_context = ranges.remove(&context_marker).unwrap_or_default();
504            assert_eq!(expected_editable.len(), 1, "{}", test_case.name);
505            assert_eq!(expected_context.len(), 1, "{}", test_case.name);
506
507            cx.new(|cx: &mut gpui::Context<Buffer>| {
508                let text = text.trim_end_matches('\n');
509                let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
510                let snapshot = buffer.snapshot();
511
512                let cursor_offset = cursor_ranges[0].start;
513
514                let (_, excerpt_offset_range, cursor_offset_in_excerpt) =
515                    compute_cursor_excerpt(&snapshot, cursor_offset);
516                let excerpt_text: String = snapshot
517                    .text_for_range(excerpt_offset_range.clone())
518                    .collect();
519                let syntax_ranges =
520                    compute_syntax_ranges(&snapshot, cursor_offset, &excerpt_offset_range);
521
522                let (actual_editable, actual_context) = compute_editable_and_context_ranges(
523                    &excerpt_text,
524                    cursor_offset_in_excerpt,
525                    &syntax_ranges,
526                    test_case.editable_token_limit,
527                    test_case.context_token_limit,
528                );
529
530                let to_buffer_range = |range: Range<usize>| -> Range<usize> {
531                    (excerpt_offset_range.start + range.start)
532                        ..(excerpt_offset_range.start + range.end)
533                };
534
535                let actual_editable = to_buffer_range(actual_editable);
536                let actual_context = to_buffer_range(actual_context);
537
538                let expected_editable_range = expected_editable[0].clone();
539                let expected_context_range = expected_context[0].clone();
540
541                let editable_match = actual_editable == expected_editable_range;
542                let context_match = actual_context == expected_context_range;
543
544                if !editable_match || !context_match {
545                    let range_text = |range: &Range<usize>| {
546                        snapshot.text_for_range(range.clone()).collect::<String>()
547                    };
548
549                    println!("\n=== FAILED: {} ===", test_case.name);
550                    if !editable_match {
551                        println!("\nExpected editable ({:?}):", expected_editable_range);
552                        println!("---\n{}---", range_text(&expected_editable_range));
553                        println!("\nActual editable ({:?}):", actual_editable);
554                        println!("---\n{}---", range_text(&actual_editable));
555                    }
556                    if !context_match {
557                        println!("\nExpected context ({:?}):", expected_context_range);
558                        println!("---\n{}---", range_text(&expected_context_range));
559                        println!("\nActual context ({:?}):", actual_context);
560                        println!("---\n{}---", range_text(&actual_context));
561                    }
562                    panic!("Test '{}' failed - see output above", test_case.name);
563                }
564
565                buffer
566            });
567        }
568    }
569}