cursor_excerpt.rs

  1use language::{BufferSnapshot, Point};
  2use std::ops::Range;
  3use text::OffsetRangeExt as _;
  4use zeta_prompt::ExcerptRanges;
  5
  6/// Computes all range variants for a cursor position: editable ranges at 150, 180, and 350
  7/// token budgets, plus their corresponding context expansions. Returns the full excerpt range
  8/// (union of all context ranges) and the individual sub-ranges as Points.
  9pub fn compute_excerpt_ranges(
 10    position: Point,
 11    snapshot: &BufferSnapshot,
 12) -> (Range<Point>, Range<usize>, ExcerptRanges) {
 13    let editable_150 = compute_editable_range(snapshot, position, 150);
 14    let editable_180 = compute_editable_range(snapshot, position, 180);
 15    let editable_350 = compute_editable_range(snapshot, position, 350);
 16
 17    let editable_150_context_350 =
 18        expand_context_syntactically_then_linewise(snapshot, editable_150.clone(), 350);
 19    let editable_180_context_350 =
 20        expand_context_syntactically_then_linewise(snapshot, editable_180.clone(), 350);
 21    let editable_350_context_150 =
 22        expand_context_syntactically_then_linewise(snapshot, editable_350.clone(), 150);
 23
 24    let full_start_row = editable_150_context_350
 25        .start
 26        .row
 27        .min(editable_180_context_350.start.row)
 28        .min(editable_350_context_150.start.row);
 29    let full_end_row = editable_150_context_350
 30        .end
 31        .row
 32        .max(editable_180_context_350.end.row)
 33        .max(editable_350_context_150.end.row);
 34
 35    let full_context =
 36        Point::new(full_start_row, 0)..Point::new(full_end_row, snapshot.line_len(full_end_row));
 37
 38    let full_context_offset_range = full_context.to_offset(snapshot);
 39
 40    let to_offset = |range: &Range<Point>| -> Range<usize> {
 41        let start = range.start.to_offset(snapshot);
 42        let end = range.end.to_offset(snapshot);
 43        (start - full_context_offset_range.start)..(end - full_context_offset_range.start)
 44    };
 45
 46    let ranges = ExcerptRanges {
 47        editable_150: to_offset(&editable_150),
 48        editable_180: to_offset(&editable_180),
 49        editable_350: to_offset(&editable_350),
 50        editable_150_context_350: to_offset(&editable_150_context_350),
 51        editable_180_context_350: to_offset(&editable_180_context_350),
 52        editable_350_context_150: to_offset(&editable_350_context_150),
 53    };
 54
 55    (full_context, full_context_offset_range, ranges)
 56}
 57
 58pub fn editable_and_context_ranges_for_cursor_position(
 59    position: Point,
 60    snapshot: &BufferSnapshot,
 61    editable_region_token_limit: usize,
 62    context_token_limit: usize,
 63) -> (Range<Point>, Range<Point>) {
 64    let editable_range = compute_editable_range(snapshot, position, editable_region_token_limit);
 65
 66    let context_range = expand_context_syntactically_then_linewise(
 67        snapshot,
 68        editable_range.clone(),
 69        context_token_limit,
 70    );
 71
 72    (editable_range, context_range)
 73}
 74
 75/// Computes the editable range using a three-phase approach:
 76/// 1. Expand symmetrically from cursor (75% of budget)
 77/// 2. Expand to syntax boundaries
 78/// 3. Continue line-wise in the least-expanded direction
 79fn compute_editable_range(
 80    snapshot: &BufferSnapshot,
 81    cursor: Point,
 82    token_limit: usize,
 83) -> Range<Point> {
 84    // Phase 1: Expand symmetrically from cursor using 75% of budget.
 85    let initial_budget = (token_limit * 3) / 4;
 86    let (mut start_row, mut end_row, mut remaining_tokens) =
 87        expand_symmetric_from_cursor(snapshot, cursor.row, initial_budget);
 88
 89    // Add remaining budget from phase 1.
 90    remaining_tokens += token_limit.saturating_sub(initial_budget);
 91
 92    let original_start = start_row;
 93    let original_end = end_row;
 94
 95    // Phase 2: Expand to syntax boundaries that fit within budget.
 96    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
 97    {
 98        let tokens_for_start = if boundary_start < start_row {
 99            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
100        } else {
101            0
102        };
103        let tokens_for_end = if boundary_end > end_row {
104            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
105        } else {
106            0
107        };
108
109        let total_needed = tokens_for_start + tokens_for_end;
110
111        if total_needed <= remaining_tokens {
112            if boundary_start < start_row {
113                start_row = boundary_start;
114            }
115            if boundary_end > end_row {
116                end_row = boundary_end;
117            }
118            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
119        } else {
120            break;
121        }
122    }
123
124    // Phase 3: Continue line-wise in the direction we expanded least during syntax phase.
125    let expanded_up = original_start.saturating_sub(start_row);
126    let expanded_down = end_row.saturating_sub(original_end);
127
128    (start_row, end_row, _) = expand_linewise_biased(
129        snapshot,
130        start_row,
131        end_row,
132        remaining_tokens,
133        expanded_up <= expanded_down, // prefer_up if we expanded less upward
134    );
135
136    let start = Point::new(start_row, 0);
137    let end = Point::new(end_row, snapshot.line_len(end_row));
138    start..end
139}
140
141/// Expands symmetrically from cursor, one line at a time, alternating down then up.
142/// Returns (start_row, end_row, remaining_tokens).
143fn expand_symmetric_from_cursor(
144    snapshot: &BufferSnapshot,
145    cursor_row: u32,
146    mut token_budget: usize,
147) -> (u32, u32, usize) {
148    let mut start_row = cursor_row;
149    let mut end_row = cursor_row;
150
151    // Account for the cursor's line.
152    let cursor_line_tokens = line_token_count(snapshot, cursor_row);
153    token_budget = token_budget.saturating_sub(cursor_line_tokens);
154
155    loop {
156        let can_expand_up = start_row > 0;
157        let can_expand_down = end_row < snapshot.max_point().row;
158
159        if token_budget == 0 || (!can_expand_up && !can_expand_down) {
160            break;
161        }
162
163        // Expand down first (slight forward bias for edit prediction).
164        if can_expand_down {
165            let next_row = end_row + 1;
166            let line_tokens = line_token_count(snapshot, next_row);
167            if line_tokens <= token_budget {
168                end_row = next_row;
169                token_budget = token_budget.saturating_sub(line_tokens);
170            } else {
171                break;
172            }
173        }
174
175        // Then expand up.
176        if can_expand_up && token_budget > 0 {
177            let next_row = start_row - 1;
178            let line_tokens = line_token_count(snapshot, next_row);
179            if line_tokens <= token_budget {
180                start_row = next_row;
181                token_budget = token_budget.saturating_sub(line_tokens);
182            } else {
183                break;
184            }
185        }
186    }
187
188    (start_row, end_row, token_budget)
189}
190
191/// Expands line-wise with a bias toward one direction.
192/// Returns (start_row, end_row, remaining_tokens).
193fn expand_linewise_biased(
194    snapshot: &BufferSnapshot,
195    mut start_row: u32,
196    mut end_row: u32,
197    mut remaining_tokens: usize,
198    prefer_up: bool,
199) -> (u32, u32, usize) {
200    loop {
201        let can_expand_up = start_row > 0;
202        let can_expand_down = end_row < snapshot.max_point().row;
203
204        if remaining_tokens == 0 || (!can_expand_up && !can_expand_down) {
205            break;
206        }
207
208        let mut expanded = false;
209
210        // Try preferred direction first.
211        if prefer_up {
212            if can_expand_up {
213                let next_row = start_row - 1;
214                let line_tokens = line_token_count(snapshot, next_row);
215                if line_tokens <= remaining_tokens {
216                    start_row = next_row;
217                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
218                    expanded = true;
219                }
220            }
221            if can_expand_down && remaining_tokens > 0 {
222                let next_row = end_row + 1;
223                let line_tokens = line_token_count(snapshot, next_row);
224                if line_tokens <= remaining_tokens {
225                    end_row = next_row;
226                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
227                    expanded = true;
228                }
229            }
230        } else {
231            if can_expand_down {
232                let next_row = end_row + 1;
233                let line_tokens = line_token_count(snapshot, next_row);
234                if line_tokens <= remaining_tokens {
235                    end_row = next_row;
236                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
237                    expanded = true;
238                }
239            }
240            if can_expand_up && remaining_tokens > 0 {
241                let next_row = start_row - 1;
242                let line_tokens = line_token_count(snapshot, next_row);
243                if line_tokens <= remaining_tokens {
244                    start_row = next_row;
245                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
246                    expanded = true;
247                }
248            }
249        }
250
251        if !expanded {
252            break;
253        }
254    }
255
256    (start_row, end_row, remaining_tokens)
257}
258
259/// Typical number of string bytes per token for the purposes of limiting model input. This is
260/// intentionally low to err on the side of underestimating limits.
261pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
262
263pub fn guess_token_count(bytes: usize) -> usize {
264    bytes / BYTES_PER_TOKEN_GUESS
265}
266
267fn line_token_count(snapshot: &BufferSnapshot, row: u32) -> usize {
268    guess_token_count(snapshot.line_len(row) as usize).max(1)
269}
270
271/// Estimates token count for rows in range [start_row, end_row).
272fn estimate_tokens_for_rows(snapshot: &BufferSnapshot, start_row: u32, end_row: u32) -> usize {
273    let mut tokens = 0;
274    for row in start_row..end_row {
275        tokens += line_token_count(snapshot, row);
276    }
277    tokens
278}
279
280/// Returns an iterator of (start_row, end_row) for successively larger syntax nodes
281/// containing the given row range. Smallest containing node first.
282fn containing_syntax_boundaries(
283    snapshot: &BufferSnapshot,
284    start_row: u32,
285    end_row: u32,
286) -> impl Iterator<Item = (u32, u32)> {
287    let range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row));
288    let mut current = snapshot.syntax_ancestor(range);
289    let mut last_rows: Option<(u32, u32)> = None;
290
291    std::iter::from_fn(move || {
292        while let Some(node) = current.take() {
293            let node_start_row = node.start_position().row as u32;
294            let node_end_row = node.end_position().row as u32;
295            let rows = (node_start_row, node_end_row);
296
297            current = node.parent();
298
299            // Skip nodes that don't extend beyond our range.
300            if node_start_row >= start_row && node_end_row <= end_row {
301                continue;
302            }
303
304            // Skip if same as last returned (some nodes have same span).
305            if last_rows == Some(rows) {
306                continue;
307            }
308
309            last_rows = Some(rows);
310            return Some(rows);
311        }
312        None
313    })
314}
315
316/// Expands context by first trying to reach syntax boundaries,
317/// then expanding line-wise only if no syntax expansion occurred.
318fn expand_context_syntactically_then_linewise(
319    snapshot: &BufferSnapshot,
320    editable_range: Range<Point>,
321    context_token_limit: usize,
322) -> Range<Point> {
323    let mut start_row = editable_range.start.row;
324    let mut end_row = editable_range.end.row;
325    let mut remaining_tokens = context_token_limit;
326    let mut did_syntax_expand = false;
327
328    // Phase 1: Try to expand to containing syntax boundaries, picking the largest that fits.
329    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
330    {
331        let tokens_for_start = if boundary_start < start_row {
332            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
333        } else {
334            0
335        };
336        let tokens_for_end = if boundary_end > end_row {
337            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
338        } else {
339            0
340        };
341
342        let total_needed = tokens_for_start + tokens_for_end;
343
344        if total_needed <= remaining_tokens {
345            if boundary_start < start_row {
346                start_row = boundary_start;
347            }
348            if boundary_end > end_row {
349                end_row = boundary_end;
350            }
351            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
352            did_syntax_expand = true;
353        } else {
354            break;
355        }
356    }
357
358    // Phase 2: Only expand line-wise if no syntax expansion occurred.
359    if !did_syntax_expand {
360        (start_row, end_row, _) =
361            expand_linewise_biased(snapshot, start_row, end_row, remaining_tokens, true);
362    }
363
364    let start = Point::new(start_row, 0);
365    let end = Point::new(end_row, snapshot.line_len(end_row));
366    start..end
367}
368
369use language::ToOffset as _;
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use gpui::{App, AppContext};
375    use indoc::indoc;
376    use language::{Buffer, rust_lang};
377    use util::test::{TextRangeMarker, marked_text_ranges_by};
378
379    struct TestCase {
380        name: &'static str,
381        marked_text: &'static str,
382        editable_token_limit: usize,
383        context_token_limit: usize,
384    }
385
386    #[gpui::test]
387    fn test_editable_and_context_ranges(cx: &mut App) {
388        // Markers:
389        // ˇ = cursor position
390        // « » = expected editable range
391        // [ ] = expected context range
392        let test_cases = vec![
393            TestCase {
394                name: "cursor near end of function - expands to syntax boundaries",
395                marked_text: indoc! {r#"
396                    [fn first() {
397                        let a = 1;
398                        let b = 2;
399                    }
400
401                    fn foo() {
402                    «    let x = 1;
403                        let y = 2;
404                        println!("{}", x + y);ˇ
405                    }»]
406                "#},
407                // 18 tokens - expands symmetrically then to syntax boundaries
408                editable_token_limit: 18,
409                context_token_limit: 35,
410            },
411            TestCase {
412                name: "cursor at function start - expands to syntax boundaries",
413                marked_text: indoc! {r#"
414                    [fn before() {
415                    «    let a = 1;
416                    }
417
418                    fn foo() {ˇ
419                        let x = 1;
420                        let y = 2;
421                        let z = 3;
422                    }
423                    »
424                    fn after() {
425                        let b = 2;
426                    }]
427                "#},
428                // 25 tokens - expands symmetrically then to syntax boundaries
429                editable_token_limit: 25,
430                context_token_limit: 50,
431            },
432            TestCase {
433                name: "tiny budget - just lines around cursor",
434                marked_text: indoc! {r#"
435                    fn outer() {
436                    [    let line1 = 1;
437                        let line2 = 2;
438                    «    let line3 = 3;
439                        let line4 = 4;ˇ»
440                        let line5 = 5;
441                        let line6 = 6;]
442                        let line7 = 7;
443                    }
444                "#},
445                // 12 tokens (~36 bytes) = just the cursor line with tiny budget
446                editable_token_limit: 12,
447                context_token_limit: 24,
448            },
449            TestCase {
450                name: "small function fits entirely",
451                marked_text: indoc! {r#"
452                    [«fn foo() {
453                        let x = 1;ˇ
454                        let y = 2;
455                    }»]
456                "#},
457                // Plenty of budget for this small function
458                editable_token_limit: 30,
459                context_token_limit: 60,
460            },
461            TestCase {
462                name: "context extends beyond editable",
463                marked_text: indoc! {r#"
464                    [fn first() { let a = 1; }
465                    «fn second() { let b = 2; }
466                    fn third() { let c = 3; }ˇ
467                    fn fourth() { let d = 4; }»
468                    fn fifth() { let e = 5; }]
469                "#},
470                // Small editable, larger context
471                editable_token_limit: 25,
472                context_token_limit: 45,
473            },
474            // Tests for syntax-aware editable and context expansion
475            TestCase {
476                name: "cursor in first if-statement - expands to syntax boundaries",
477                marked_text: indoc! {r#"
478                    [«fn before() { }
479
480                    fn process() {
481                        if condition1 {
482                            let a = 1;ˇ
483                            let b = 2;
484                        }
485                        if condition2 {»
486                            let c = 3;
487                            let d = 4;
488                        }
489                        if condition3 {
490                            let e = 5;
491                            let f = 6;
492                        }
493                    }
494
495                    fn after() { }]
496                "#},
497                // 35 tokens allows expansion to include function header and first two if blocks
498                editable_token_limit: 35,
499                // 60 tokens allows context to include the whole file
500                context_token_limit: 60,
501            },
502            TestCase {
503                name: "cursor in middle if-statement - expands to syntax boundaries",
504                marked_text: indoc! {r#"
505                    [fn before() { }
506
507                    fn process() {
508                        if condition1 {
509                            let a = 1;
510                    «        let b = 2;
511                        }
512                        if condition2 {
513                            let c = 3;ˇ
514                            let d = 4;
515                        }
516                        if condition3 {
517                            let e = 5;»
518                            let f = 6;
519                        }
520                    }
521
522                    fn after() { }]
523                "#},
524                // 40 tokens allows expansion to surrounding if blocks
525                editable_token_limit: 40,
526                // 60 tokens allows context to include the whole file
527                context_token_limit: 60,
528            },
529            TestCase {
530                name: "cursor near bottom of long function - editable expands toward syntax, context reaches function",
531                marked_text: indoc! {r#"
532                    [fn other() { }
533
534                    fn long_function() {
535                        let line1 = 1;
536                        let line2 = 2;
537                        let line3 = 3;
538                        let line4 = 4;
539                        let line5 = 5;
540                        let line6 = 6;
541                    «    let line7 = 7;
542                        let line8 = 8;
543                        let line9 = 9;
544                        let line10 = 10;ˇ
545                        let line11 = 11;
546                    }
547
548                    fn another() { }»]
549                "#},
550                // 40 tokens for editable - allows several lines plus syntax expansion
551                editable_token_limit: 40,
552                // 55 tokens - enough for function but not whole file
553                context_token_limit: 55,
554            },
555        ];
556
557        for test_case in test_cases {
558            let cursor_marker: TextRangeMarker = 'ˇ'.into();
559            let editable_marker: TextRangeMarker = ('«', '»').into();
560            let context_marker: TextRangeMarker = ('[', ']').into();
561
562            let (text, mut ranges) = marked_text_ranges_by(
563                test_case.marked_text,
564                vec![
565                    cursor_marker.clone(),
566                    editable_marker.clone(),
567                    context_marker.clone(),
568                ],
569            );
570
571            let cursor_ranges = ranges.remove(&cursor_marker).unwrap_or_default();
572            let expected_editable = ranges.remove(&editable_marker).unwrap_or_default();
573            let expected_context = ranges.remove(&context_marker).unwrap_or_default();
574            assert_eq!(expected_editable.len(), 1);
575            assert_eq!(expected_context.len(), 1);
576
577            cx.new(|cx| {
578                let text = text.trim_end_matches('\n');
579                let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
580                let snapshot = buffer.snapshot();
581
582                let cursor_offset = cursor_ranges[0].start;
583                let cursor_point = snapshot.offset_to_point(cursor_offset);
584                let expected_editable_start = snapshot.offset_to_point(expected_editable[0].start);
585                let expected_editable_end = snapshot.offset_to_point(expected_editable[0].end);
586                let expected_context_start = snapshot.offset_to_point(expected_context[0].start);
587                let expected_context_end = snapshot.offset_to_point(expected_context[0].end);
588
589                let (actual_editable, actual_context) =
590                    editable_and_context_ranges_for_cursor_position(
591                        cursor_point,
592                        &snapshot,
593                        test_case.editable_token_limit,
594                        test_case.context_token_limit,
595                    );
596
597                let range_text = |start: Point, end: Point| -> String {
598                    snapshot.text_for_range(start..end).collect()
599                };
600
601                let editable_match = actual_editable.start == expected_editable_start
602                    && actual_editable.end == expected_editable_end;
603                let context_match = actual_context.start == expected_context_start
604                    && actual_context.end == expected_context_end;
605
606                if !editable_match || !context_match {
607                    println!("\n=== FAILED: {} ===", test_case.name);
608                    if !editable_match {
609                        println!(
610                            "\nExpected editable ({:?}..{:?}):",
611                            expected_editable_start, expected_editable_end
612                        );
613                        println!(
614                            "---\n{}---",
615                            range_text(expected_editable_start, expected_editable_end)
616                        );
617                        println!(
618                            "\nActual editable ({:?}..{:?}):",
619                            actual_editable.start, actual_editable.end
620                        );
621                        println!(
622                            "---\n{}---",
623                            range_text(actual_editable.start, actual_editable.end)
624                        );
625                    }
626                    if !context_match {
627                        println!(
628                            "\nExpected context ({:?}..{:?}):",
629                            expected_context_start, expected_context_end
630                        );
631                        println!(
632                            "---\n{}---",
633                            range_text(expected_context_start, expected_context_end)
634                        );
635                        println!(
636                            "\nActual context ({:?}..{:?}):",
637                            actual_context.start, actual_context.end
638                        );
639                        println!(
640                            "---\n{}---",
641                            range_text(actual_context.start, actual_context.end)
642                        );
643                    }
644                    panic!("Test '{}' failed - see output above", test_case.name);
645                }
646
647                buffer
648            });
649        }
650    }
651}