cursor_excerpt.rs

  1use language::{BufferSnapshot, Point};
  2use std::ops::Range;
  3use zeta_prompt::ExcerptRanges;
  4
  5/// Pre-computed Point ranges for all editable/context budget combinations.
  6pub struct ExcerptRangePoints {
  7    pub editable_150: Range<Point>,
  8    pub editable_180: Range<Point>,
  9    pub editable_350: Range<Point>,
 10    pub editable_150_context_350: Range<Point>,
 11    pub editable_180_context_350: Range<Point>,
 12    pub editable_350_context_150: Range<Point>,
 13}
 14
 15/// Computes all range variants for a cursor position: editable ranges at 150, 180, and 350
 16/// token budgets, plus their corresponding context expansions. Returns the full excerpt range
 17/// (union of all context ranges) and the individual sub-ranges as Points.
 18pub fn compute_excerpt_ranges(
 19    position: Point,
 20    snapshot: &BufferSnapshot,
 21) -> (Range<Point>, ExcerptRangePoints) {
 22    let editable_150 = compute_editable_range(snapshot, position, 150);
 23    let editable_180 = compute_editable_range(snapshot, position, 180);
 24    let editable_350 = compute_editable_range(snapshot, position, 350);
 25
 26    let editable_150_context_350 =
 27        expand_context_syntactically_then_linewise(snapshot, editable_150.clone(), 350);
 28    let editable_180_context_350 =
 29        expand_context_syntactically_then_linewise(snapshot, editable_180.clone(), 350);
 30    let editable_350_context_150 =
 31        expand_context_syntactically_then_linewise(snapshot, editable_350.clone(), 150);
 32
 33    let full_start_row = editable_150_context_350
 34        .start
 35        .row
 36        .min(editable_180_context_350.start.row)
 37        .min(editable_350_context_150.start.row);
 38    let full_end_row = editable_150_context_350
 39        .end
 40        .row
 41        .max(editable_180_context_350.end.row)
 42        .max(editable_350_context_150.end.row);
 43
 44    let full_context =
 45        Point::new(full_start_row, 0)..Point::new(full_end_row, snapshot.line_len(full_end_row));
 46
 47    let ranges = ExcerptRangePoints {
 48        editable_150,
 49        editable_180,
 50        editable_350,
 51        editable_150_context_350,
 52        editable_180_context_350,
 53        editable_350_context_150,
 54    };
 55
 56    (full_context, ranges)
 57}
 58
 59/// Converts `ExcerptRangePoints` to byte-offset `ExcerptRanges` relative to `excerpt_start`.
 60pub fn excerpt_ranges_to_byte_offsets(
 61    ranges: &ExcerptRangePoints,
 62    excerpt_start: usize,
 63    snapshot: &BufferSnapshot,
 64) -> ExcerptRanges {
 65    let to_offset = |range: &Range<Point>| -> Range<usize> {
 66        let start = range.start.to_offset(snapshot);
 67        let end = range.end.to_offset(snapshot);
 68        (start - excerpt_start)..(end - excerpt_start)
 69    };
 70    ExcerptRanges {
 71        editable_150: to_offset(&ranges.editable_150),
 72        editable_180: to_offset(&ranges.editable_180),
 73        editable_350: to_offset(&ranges.editable_350),
 74        editable_150_context_350: to_offset(&ranges.editable_150_context_350),
 75        editable_180_context_350: to_offset(&ranges.editable_180_context_350),
 76        editable_350_context_150: to_offset(&ranges.editable_350_context_150),
 77    }
 78}
 79
 80pub fn editable_and_context_ranges_for_cursor_position(
 81    position: Point,
 82    snapshot: &BufferSnapshot,
 83    editable_region_token_limit: usize,
 84    context_token_limit: usize,
 85) -> (Range<Point>, Range<Point>) {
 86    let editable_range = compute_editable_range(snapshot, position, editable_region_token_limit);
 87
 88    let context_range = expand_context_syntactically_then_linewise(
 89        snapshot,
 90        editable_range.clone(),
 91        context_token_limit,
 92    );
 93
 94    (editable_range, context_range)
 95}
 96
 97/// Computes the editable range using a three-phase approach:
 98/// 1. Expand symmetrically from cursor (75% of budget)
 99/// 2. Expand to syntax boundaries
100/// 3. Continue line-wise in the least-expanded direction
101fn compute_editable_range(
102    snapshot: &BufferSnapshot,
103    cursor: Point,
104    token_limit: usize,
105) -> Range<Point> {
106    // Phase 1: Expand symmetrically from cursor using 75% of budget.
107    let initial_budget = (token_limit * 3) / 4;
108    let (mut start_row, mut end_row, mut remaining_tokens) =
109        expand_symmetric_from_cursor(snapshot, cursor.row, initial_budget);
110
111    // Add remaining budget from phase 1.
112    remaining_tokens += token_limit.saturating_sub(initial_budget);
113
114    let original_start = start_row;
115    let original_end = end_row;
116
117    // Phase 2: Expand to syntax boundaries that fit within budget.
118    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
119    {
120        let tokens_for_start = if boundary_start < start_row {
121            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
122        } else {
123            0
124        };
125        let tokens_for_end = if boundary_end > end_row {
126            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
127        } else {
128            0
129        };
130
131        let total_needed = tokens_for_start + tokens_for_end;
132
133        if total_needed <= remaining_tokens {
134            if boundary_start < start_row {
135                start_row = boundary_start;
136            }
137            if boundary_end > end_row {
138                end_row = boundary_end;
139            }
140            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
141        } else {
142            break;
143        }
144    }
145
146    // Phase 3: Continue line-wise in the direction we expanded least during syntax phase.
147    let expanded_up = original_start.saturating_sub(start_row);
148    let expanded_down = end_row.saturating_sub(original_end);
149
150    (start_row, end_row, _) = expand_linewise_biased(
151        snapshot,
152        start_row,
153        end_row,
154        remaining_tokens,
155        expanded_up <= expanded_down, // prefer_up if we expanded less upward
156    );
157
158    let start = Point::new(start_row, 0);
159    let end = Point::new(end_row, snapshot.line_len(end_row));
160    start..end
161}
162
163/// Expands symmetrically from cursor, one line at a time, alternating down then up.
164/// Returns (start_row, end_row, remaining_tokens).
165fn expand_symmetric_from_cursor(
166    snapshot: &BufferSnapshot,
167    cursor_row: u32,
168    mut token_budget: usize,
169) -> (u32, u32, usize) {
170    let mut start_row = cursor_row;
171    let mut end_row = cursor_row;
172
173    // Account for the cursor's line.
174    let cursor_line_tokens = line_token_count(snapshot, cursor_row);
175    token_budget = token_budget.saturating_sub(cursor_line_tokens);
176
177    loop {
178        let can_expand_up = start_row > 0;
179        let can_expand_down = end_row < snapshot.max_point().row;
180
181        if token_budget == 0 || (!can_expand_up && !can_expand_down) {
182            break;
183        }
184
185        // Expand down first (slight forward bias for edit prediction).
186        if can_expand_down {
187            let next_row = end_row + 1;
188            let line_tokens = line_token_count(snapshot, next_row);
189            if line_tokens <= token_budget {
190                end_row = next_row;
191                token_budget = token_budget.saturating_sub(line_tokens);
192            } else {
193                break;
194            }
195        }
196
197        // Then expand up.
198        if can_expand_up && token_budget > 0 {
199            let next_row = start_row - 1;
200            let line_tokens = line_token_count(snapshot, next_row);
201            if line_tokens <= token_budget {
202                start_row = next_row;
203                token_budget = token_budget.saturating_sub(line_tokens);
204            } else {
205                break;
206            }
207        }
208    }
209
210    (start_row, end_row, token_budget)
211}
212
213/// Expands line-wise with a bias toward one direction.
214/// Returns (start_row, end_row, remaining_tokens).
215fn expand_linewise_biased(
216    snapshot: &BufferSnapshot,
217    mut start_row: u32,
218    mut end_row: u32,
219    mut remaining_tokens: usize,
220    prefer_up: bool,
221) -> (u32, u32, usize) {
222    loop {
223        let can_expand_up = start_row > 0;
224        let can_expand_down = end_row < snapshot.max_point().row;
225
226        if remaining_tokens == 0 || (!can_expand_up && !can_expand_down) {
227            break;
228        }
229
230        let mut expanded = false;
231
232        // Try preferred direction first.
233        if prefer_up {
234            if can_expand_up {
235                let next_row = start_row - 1;
236                let line_tokens = line_token_count(snapshot, next_row);
237                if line_tokens <= remaining_tokens {
238                    start_row = next_row;
239                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
240                    expanded = true;
241                }
242            }
243            if can_expand_down && remaining_tokens > 0 {
244                let next_row = end_row + 1;
245                let line_tokens = line_token_count(snapshot, next_row);
246                if line_tokens <= remaining_tokens {
247                    end_row = next_row;
248                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
249                    expanded = true;
250                }
251            }
252        } else {
253            if can_expand_down {
254                let next_row = end_row + 1;
255                let line_tokens = line_token_count(snapshot, next_row);
256                if line_tokens <= remaining_tokens {
257                    end_row = next_row;
258                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
259                    expanded = true;
260                }
261            }
262            if can_expand_up && remaining_tokens > 0 {
263                let next_row = start_row - 1;
264                let line_tokens = line_token_count(snapshot, next_row);
265                if line_tokens <= remaining_tokens {
266                    start_row = next_row;
267                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
268                    expanded = true;
269                }
270            }
271        }
272
273        if !expanded {
274            break;
275        }
276    }
277
278    (start_row, end_row, remaining_tokens)
279}
280
281/// Typical number of string bytes per token for the purposes of limiting model input. This is
282/// intentionally low to err on the side of underestimating limits.
283pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
284
285pub fn guess_token_count(bytes: usize) -> usize {
286    bytes / BYTES_PER_TOKEN_GUESS
287}
288
289fn line_token_count(snapshot: &BufferSnapshot, row: u32) -> usize {
290    guess_token_count(snapshot.line_len(row) as usize).max(1)
291}
292
293/// Estimates token count for rows in range [start_row, end_row).
294fn estimate_tokens_for_rows(snapshot: &BufferSnapshot, start_row: u32, end_row: u32) -> usize {
295    let mut tokens = 0;
296    for row in start_row..end_row {
297        tokens += line_token_count(snapshot, row);
298    }
299    tokens
300}
301
302/// Returns an iterator of (start_row, end_row) for successively larger syntax nodes
303/// containing the given row range. Smallest containing node first.
304fn containing_syntax_boundaries(
305    snapshot: &BufferSnapshot,
306    start_row: u32,
307    end_row: u32,
308) -> impl Iterator<Item = (u32, u32)> {
309    let range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row));
310    let mut current = snapshot.syntax_ancestor(range);
311    let mut last_rows: Option<(u32, u32)> = None;
312
313    std::iter::from_fn(move || {
314        while let Some(node) = current.take() {
315            let node_start_row = node.start_position().row as u32;
316            let node_end_row = node.end_position().row as u32;
317            let rows = (node_start_row, node_end_row);
318
319            current = node.parent();
320
321            // Skip nodes that don't extend beyond our range.
322            if node_start_row >= start_row && node_end_row <= end_row {
323                continue;
324            }
325
326            // Skip if same as last returned (some nodes have same span).
327            if last_rows == Some(rows) {
328                continue;
329            }
330
331            last_rows = Some(rows);
332            return Some(rows);
333        }
334        None
335    })
336}
337
338/// Expands context by first trying to reach syntax boundaries,
339/// then expanding line-wise only if no syntax expansion occurred.
340fn expand_context_syntactically_then_linewise(
341    snapshot: &BufferSnapshot,
342    editable_range: Range<Point>,
343    context_token_limit: usize,
344) -> Range<Point> {
345    let mut start_row = editable_range.start.row;
346    let mut end_row = editable_range.end.row;
347    let mut remaining_tokens = context_token_limit;
348    let mut did_syntax_expand = false;
349
350    // Phase 1: Try to expand to containing syntax boundaries, picking the largest that fits.
351    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
352    {
353        let tokens_for_start = if boundary_start < start_row {
354            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
355        } else {
356            0
357        };
358        let tokens_for_end = if boundary_end > end_row {
359            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
360        } else {
361            0
362        };
363
364        let total_needed = tokens_for_start + tokens_for_end;
365
366        if total_needed <= remaining_tokens {
367            if boundary_start < start_row {
368                start_row = boundary_start;
369            }
370            if boundary_end > end_row {
371                end_row = boundary_end;
372            }
373            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
374            did_syntax_expand = true;
375        } else {
376            break;
377        }
378    }
379
380    // Phase 2: Only expand line-wise if no syntax expansion occurred.
381    if !did_syntax_expand {
382        (start_row, end_row, _) =
383            expand_linewise_biased(snapshot, start_row, end_row, remaining_tokens, true);
384    }
385
386    let start = Point::new(start_row, 0);
387    let end = Point::new(end_row, snapshot.line_len(end_row));
388    start..end
389}
390
391use language::ToOffset as _;
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use gpui::{App, AppContext};
397    use indoc::indoc;
398    use language::{Buffer, rust_lang};
399    use util::test::{TextRangeMarker, marked_text_ranges_by};
400
401    struct TestCase {
402        name: &'static str,
403        marked_text: &'static str,
404        editable_token_limit: usize,
405        context_token_limit: usize,
406    }
407
408    #[gpui::test]
409    fn test_editable_and_context_ranges(cx: &mut App) {
410        // Markers:
411        // ˇ = cursor position
412        // « » = expected editable range
413        // [ ] = expected context range
414        let test_cases = vec![
415            TestCase {
416                name: "cursor near end of function - expands to syntax boundaries",
417                marked_text: indoc! {r#"
418                    [fn first() {
419                        let a = 1;
420                        let b = 2;
421                    }
422
423                    fn foo() {
424                    «    let x = 1;
425                        let y = 2;
426                        println!("{}", x + y);ˇ
427                    }»]
428                "#},
429                // 18 tokens - expands symmetrically then to syntax boundaries
430                editable_token_limit: 18,
431                context_token_limit: 35,
432            },
433            TestCase {
434                name: "cursor at function start - expands to syntax boundaries",
435                marked_text: indoc! {r#"
436                    [fn before() {
437                    «    let a = 1;
438                    }
439
440                    fn foo() {ˇ
441                        let x = 1;
442                        let y = 2;
443                        let z = 3;
444                    }
445                    »
446                    fn after() {
447                        let b = 2;
448                    }]
449                "#},
450                // 25 tokens - expands symmetrically then to syntax boundaries
451                editable_token_limit: 25,
452                context_token_limit: 50,
453            },
454            TestCase {
455                name: "tiny budget - just lines around cursor",
456                marked_text: indoc! {r#"
457                    fn outer() {
458                    [    let line1 = 1;
459                        let line2 = 2;
460                    «    let line3 = 3;
461                        let line4 = 4;ˇ»
462                        let line5 = 5;
463                        let line6 = 6;]
464                        let line7 = 7;
465                    }
466                "#},
467                // 12 tokens (~36 bytes) = just the cursor line with tiny budget
468                editable_token_limit: 12,
469                context_token_limit: 24,
470            },
471            TestCase {
472                name: "small function fits entirely",
473                marked_text: indoc! {r#"
474                    [«fn foo() {
475                        let x = 1;ˇ
476                        let y = 2;
477                    }»]
478                "#},
479                // Plenty of budget for this small function
480                editable_token_limit: 30,
481                context_token_limit: 60,
482            },
483            TestCase {
484                name: "context extends beyond editable",
485                marked_text: indoc! {r#"
486                    [fn first() { let a = 1; }
487                    «fn second() { let b = 2; }
488                    fn third() { let c = 3; }ˇ
489                    fn fourth() { let d = 4; }»
490                    fn fifth() { let e = 5; }]
491                "#},
492                // Small editable, larger context
493                editable_token_limit: 25,
494                context_token_limit: 45,
495            },
496            // Tests for syntax-aware editable and context expansion
497            TestCase {
498                name: "cursor in first if-statement - expands to syntax boundaries",
499                marked_text: indoc! {r#"
500                    [«fn before() { }
501
502                    fn process() {
503                        if condition1 {
504                            let a = 1;ˇ
505                            let b = 2;
506                        }
507                        if condition2 {»
508                            let c = 3;
509                            let d = 4;
510                        }
511                        if condition3 {
512                            let e = 5;
513                            let f = 6;
514                        }
515                    }
516
517                    fn after() { }]
518                "#},
519                // 35 tokens allows expansion to include function header and first two if blocks
520                editable_token_limit: 35,
521                // 60 tokens allows context to include the whole file
522                context_token_limit: 60,
523            },
524            TestCase {
525                name: "cursor in middle if-statement - expands to syntax boundaries",
526                marked_text: indoc! {r#"
527                    [fn before() { }
528
529                    fn process() {
530                        if condition1 {
531                            let a = 1;
532                    «        let b = 2;
533                        }
534                        if condition2 {
535                            let c = 3;ˇ
536                            let d = 4;
537                        }
538                        if condition3 {
539                            let e = 5;»
540                            let f = 6;
541                        }
542                    }
543
544                    fn after() { }]
545                "#},
546                // 40 tokens allows expansion to surrounding if blocks
547                editable_token_limit: 40,
548                // 60 tokens allows context to include the whole file
549                context_token_limit: 60,
550            },
551            TestCase {
552                name: "cursor near bottom of long function - editable expands toward syntax, context reaches function",
553                marked_text: indoc! {r#"
554                    [fn other() { }
555
556                    fn long_function() {
557                        let line1 = 1;
558                        let line2 = 2;
559                        let line3 = 3;
560                        let line4 = 4;
561                        let line5 = 5;
562                        let line6 = 6;
563                    «    let line7 = 7;
564                        let line8 = 8;
565                        let line9 = 9;
566                        let line10 = 10;ˇ
567                        let line11 = 11;
568                    }
569
570                    fn another() { }»]
571                "#},
572                // 40 tokens for editable - allows several lines plus syntax expansion
573                editable_token_limit: 40,
574                // 55 tokens - enough for function but not whole file
575                context_token_limit: 55,
576            },
577        ];
578
579        for test_case in test_cases {
580            let cursor_marker: TextRangeMarker = 'ˇ'.into();
581            let editable_marker: TextRangeMarker = ('«', '»').into();
582            let context_marker: TextRangeMarker = ('[', ']').into();
583
584            let (text, mut ranges) = marked_text_ranges_by(
585                test_case.marked_text,
586                vec![
587                    cursor_marker.clone(),
588                    editable_marker.clone(),
589                    context_marker.clone(),
590                ],
591            );
592
593            let cursor_ranges = ranges.remove(&cursor_marker).unwrap_or_default();
594            let expected_editable = ranges.remove(&editable_marker).unwrap_or_default();
595            let expected_context = ranges.remove(&context_marker).unwrap_or_default();
596            assert_eq!(expected_editable.len(), 1);
597            assert_eq!(expected_context.len(), 1);
598
599            cx.new(|cx| {
600                let text = text.trim_end_matches('\n');
601                let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
602                let snapshot = buffer.snapshot();
603
604                let cursor_offset = cursor_ranges[0].start;
605                let cursor_point = snapshot.offset_to_point(cursor_offset);
606                let expected_editable_start = snapshot.offset_to_point(expected_editable[0].start);
607                let expected_editable_end = snapshot.offset_to_point(expected_editable[0].end);
608                let expected_context_start = snapshot.offset_to_point(expected_context[0].start);
609                let expected_context_end = snapshot.offset_to_point(expected_context[0].end);
610
611                let (actual_editable, actual_context) =
612                    editable_and_context_ranges_for_cursor_position(
613                        cursor_point,
614                        &snapshot,
615                        test_case.editable_token_limit,
616                        test_case.context_token_limit,
617                    );
618
619                let range_text = |start: Point, end: Point| -> String {
620                    snapshot.text_for_range(start..end).collect()
621                };
622
623                let editable_match = actual_editable.start == expected_editable_start
624                    && actual_editable.end == expected_editable_end;
625                let context_match = actual_context.start == expected_context_start
626                    && actual_context.end == expected_context_end;
627
628                if !editable_match || !context_match {
629                    println!("\n=== FAILED: {} ===", test_case.name);
630                    if !editable_match {
631                        println!(
632                            "\nExpected editable ({:?}..{:?}):",
633                            expected_editable_start, expected_editable_end
634                        );
635                        println!(
636                            "---\n{}---",
637                            range_text(expected_editable_start, expected_editable_end)
638                        );
639                        println!(
640                            "\nActual editable ({:?}..{:?}):",
641                            actual_editable.start, actual_editable.end
642                        );
643                        println!(
644                            "---\n{}---",
645                            range_text(actual_editable.start, actual_editable.end)
646                        );
647                    }
648                    if !context_match {
649                        println!(
650                            "\nExpected context ({:?}..{:?}):",
651                            expected_context_start, expected_context_end
652                        );
653                        println!(
654                            "---\n{}---",
655                            range_text(expected_context_start, expected_context_end)
656                        );
657                        println!(
658                            "\nActual context ({:?}..{:?}):",
659                            actual_context.start, actual_context.end
660                        );
661                        println!(
662                            "---\n{}---",
663                            range_text(actual_context.start, actual_context.end)
664                        );
665                    }
666                    panic!("Test '{}' failed - see output above", test_case.name);
667                }
668
669                buffer
670            });
671        }
672    }
673}