cursor_excerpt.rs

  1use language::{BufferSnapshot, Point};
  2use std::ops::Range;
  3use text::OffsetRangeExt as _;
  4use zeta_prompt::ExcerptRanges;
  5
  6/// Computes all range variants for a cursor position: editable ranges at 150, 180, and 350
  7/// token budgets, plus their corresponding context expansions. Returns the full excerpt range
  8/// (union of all context ranges) and the individual sub-ranges as Points.
  9pub fn compute_excerpt_ranges(
 10    position: Point,
 11    snapshot: &BufferSnapshot,
 12) -> (Range<Point>, Range<usize>, ExcerptRanges) {
 13    let editable_150 = compute_editable_range(snapshot, position, 150);
 14    let editable_180 = compute_editable_range(snapshot, position, 180);
 15    let editable_350 = compute_editable_range(snapshot, position, 350);
 16    let editable_512 = compute_editable_range(snapshot, position, 512);
 17
 18    let editable_150_context_350 =
 19        expand_context_syntactically_then_linewise(snapshot, editable_150.clone(), 350);
 20    let editable_180_context_350 =
 21        expand_context_syntactically_then_linewise(snapshot, editable_180.clone(), 350);
 22    let editable_350_context_150 =
 23        expand_context_syntactically_then_linewise(snapshot, editable_350.clone(), 150);
 24    let editable_350_context_512 =
 25        expand_context_syntactically_then_linewise(snapshot, editable_350.clone(), 512);
 26    let editable_350_context_1024 =
 27        expand_context_syntactically_then_linewise(snapshot, editable_350.clone(), 1024);
 28    let context_4096 = expand_context_syntactically_then_linewise(
 29        snapshot,
 30        editable_350_context_1024.clone(),
 31        4096 - 1024,
 32    );
 33    let context_8192 =
 34        expand_context_syntactically_then_linewise(snapshot, context_4096.clone(), 8192 - 4096);
 35
 36    let full_start_row = context_8192.start.row;
 37    let full_end_row = context_8192.end.row;
 38
 39    let full_context =
 40        Point::new(full_start_row, 0)..Point::new(full_end_row, snapshot.line_len(full_end_row));
 41
 42    let full_context_offset_range = full_context.to_offset(snapshot);
 43
 44    let to_offset = |range: &Range<Point>| -> Range<usize> {
 45        let start = range.start.to_offset(snapshot);
 46        let end = range.end.to_offset(snapshot);
 47        (start - full_context_offset_range.start)..(end - full_context_offset_range.start)
 48    };
 49
 50    let ranges = ExcerptRanges {
 51        editable_150: to_offset(&editable_150),
 52        editable_180: to_offset(&editable_180),
 53        editable_350: to_offset(&editable_350),
 54        editable_512: Some(to_offset(&editable_512)),
 55        editable_150_context_350: to_offset(&editable_150_context_350),
 56        editable_180_context_350: to_offset(&editable_180_context_350),
 57        editable_350_context_150: to_offset(&editable_350_context_150),
 58        editable_350_context_512: Some(to_offset(&editable_350_context_512)),
 59        editable_350_context_1024: Some(to_offset(&editable_350_context_1024)),
 60        context_4096: Some(to_offset(&context_4096)),
 61        context_8192: Some(to_offset(&context_8192)),
 62    };
 63
 64    (full_context, full_context_offset_range, ranges)
 65}
 66
 67pub fn editable_and_context_ranges_for_cursor_position(
 68    position: Point,
 69    snapshot: &BufferSnapshot,
 70    editable_region_token_limit: usize,
 71    context_token_limit: usize,
 72) -> (Range<Point>, Range<Point>) {
 73    let editable_range = compute_editable_range(snapshot, position, editable_region_token_limit);
 74
 75    let context_range = expand_context_syntactically_then_linewise(
 76        snapshot,
 77        editable_range.clone(),
 78        context_token_limit,
 79    );
 80
 81    (editable_range, context_range)
 82}
 83
 84/// Computes the editable range using a three-phase approach:
 85/// 1. Expand symmetrically from cursor (75% of budget)
 86/// 2. Expand to syntax boundaries
 87/// 3. Continue line-wise in the least-expanded direction
 88fn compute_editable_range(
 89    snapshot: &BufferSnapshot,
 90    cursor: Point,
 91    token_limit: usize,
 92) -> Range<Point> {
 93    // Phase 1: Expand symmetrically from cursor using 75% of budget.
 94    let initial_budget = (token_limit * 3) / 4;
 95    let (mut start_row, mut end_row, mut remaining_tokens) =
 96        expand_symmetric_from_cursor(snapshot, cursor.row, initial_budget);
 97
 98    // Add remaining budget from phase 1.
 99    remaining_tokens += token_limit.saturating_sub(initial_budget);
100
101    let original_start = start_row;
102    let original_end = end_row;
103
104    // Phase 2: Expand to syntax boundaries that fit within budget.
105    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
106    {
107        let tokens_for_start = if boundary_start < start_row {
108            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
109        } else {
110            0
111        };
112        let tokens_for_end = if boundary_end > end_row {
113            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
114        } else {
115            0
116        };
117
118        let total_needed = tokens_for_start + tokens_for_end;
119
120        if total_needed <= remaining_tokens {
121            if boundary_start < start_row {
122                start_row = boundary_start;
123            }
124            if boundary_end > end_row {
125                end_row = boundary_end;
126            }
127            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
128        } else {
129            break;
130        }
131    }
132
133    // Phase 3: Continue line-wise in the direction we expanded least during syntax phase.
134    let expanded_up = original_start.saturating_sub(start_row);
135    let expanded_down = end_row.saturating_sub(original_end);
136
137    (start_row, end_row, _) = expand_linewise_biased(
138        snapshot,
139        start_row,
140        end_row,
141        remaining_tokens,
142        expanded_up <= expanded_down, // prefer_up if we expanded less upward
143    );
144
145    let start = Point::new(start_row, 0);
146    let end = Point::new(end_row, snapshot.line_len(end_row));
147    start..end
148}
149
150/// Expands symmetrically from cursor, one line at a time, alternating down then up.
151/// Returns (start_row, end_row, remaining_tokens).
152fn expand_symmetric_from_cursor(
153    snapshot: &BufferSnapshot,
154    cursor_row: u32,
155    mut token_budget: usize,
156) -> (u32, u32, usize) {
157    let mut start_row = cursor_row;
158    let mut end_row = cursor_row;
159
160    // Account for the cursor's line.
161    let cursor_line_tokens = line_token_count(snapshot, cursor_row);
162    token_budget = token_budget.saturating_sub(cursor_line_tokens);
163
164    loop {
165        let can_expand_up = start_row > 0;
166        let can_expand_down = end_row < snapshot.max_point().row;
167
168        if token_budget == 0 || (!can_expand_up && !can_expand_down) {
169            break;
170        }
171
172        // Expand down first (slight forward bias for edit prediction).
173        if can_expand_down {
174            let next_row = end_row + 1;
175            let line_tokens = line_token_count(snapshot, next_row);
176            if line_tokens <= token_budget {
177                end_row = next_row;
178                token_budget = token_budget.saturating_sub(line_tokens);
179            } else {
180                break;
181            }
182        }
183
184        // Then expand up.
185        if can_expand_up && token_budget > 0 {
186            let next_row = start_row - 1;
187            let line_tokens = line_token_count(snapshot, next_row);
188            if line_tokens <= token_budget {
189                start_row = next_row;
190                token_budget = token_budget.saturating_sub(line_tokens);
191            } else {
192                break;
193            }
194        }
195    }
196
197    (start_row, end_row, token_budget)
198}
199
200/// Expands line-wise with a bias toward one direction.
201/// Returns (start_row, end_row, remaining_tokens).
202fn expand_linewise_biased(
203    snapshot: &BufferSnapshot,
204    mut start_row: u32,
205    mut end_row: u32,
206    mut remaining_tokens: usize,
207    prefer_up: bool,
208) -> (u32, u32, usize) {
209    loop {
210        let can_expand_up = start_row > 0;
211        let can_expand_down = end_row < snapshot.max_point().row;
212
213        if remaining_tokens == 0 || (!can_expand_up && !can_expand_down) {
214            break;
215        }
216
217        let mut expanded = false;
218
219        // Try preferred direction first.
220        if prefer_up {
221            if can_expand_up {
222                let next_row = start_row - 1;
223                let line_tokens = line_token_count(snapshot, next_row);
224                if line_tokens <= remaining_tokens {
225                    start_row = next_row;
226                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
227                    expanded = true;
228                }
229            }
230            if can_expand_down && remaining_tokens > 0 {
231                let next_row = end_row + 1;
232                let line_tokens = line_token_count(snapshot, next_row);
233                if line_tokens <= remaining_tokens {
234                    end_row = next_row;
235                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
236                    expanded = true;
237                }
238            }
239        } else {
240            if can_expand_down {
241                let next_row = end_row + 1;
242                let line_tokens = line_token_count(snapshot, next_row);
243                if line_tokens <= remaining_tokens {
244                    end_row = next_row;
245                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
246                    expanded = true;
247                }
248            }
249            if can_expand_up && remaining_tokens > 0 {
250                let next_row = start_row - 1;
251                let line_tokens = line_token_count(snapshot, next_row);
252                if line_tokens <= remaining_tokens {
253                    start_row = next_row;
254                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
255                    expanded = true;
256                }
257            }
258        }
259
260        if !expanded {
261            break;
262        }
263    }
264
265    (start_row, end_row, remaining_tokens)
266}
267
268/// Typical number of string bytes per token for the purposes of limiting model input. This is
269/// intentionally low to err on the side of underestimating limits.
270pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
271
272pub fn guess_token_count(bytes: usize) -> usize {
273    bytes / BYTES_PER_TOKEN_GUESS
274}
275
276fn line_token_count(snapshot: &BufferSnapshot, row: u32) -> usize {
277    guess_token_count(snapshot.line_len(row) as usize).max(1)
278}
279
280/// Estimates token count for rows in range [start_row, end_row).
281fn estimate_tokens_for_rows(snapshot: &BufferSnapshot, start_row: u32, end_row: u32) -> usize {
282    let mut tokens = 0;
283    for row in start_row..end_row {
284        tokens += line_token_count(snapshot, row);
285    }
286    tokens
287}
288
289/// Returns an iterator of (start_row, end_row) for successively larger syntax nodes
290/// containing the given row range. Smallest containing node first.
291fn containing_syntax_boundaries(
292    snapshot: &BufferSnapshot,
293    start_row: u32,
294    end_row: u32,
295) -> impl Iterator<Item = (u32, u32)> {
296    let range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row));
297    let mut current = snapshot.syntax_ancestor(range);
298    let mut last_rows: Option<(u32, u32)> = None;
299
300    std::iter::from_fn(move || {
301        while let Some(node) = current.take() {
302            let node_start_row = node.start_position().row as u32;
303            let node_end_row = node.end_position().row as u32;
304            let rows = (node_start_row, node_end_row);
305
306            current = node.parent();
307
308            // Skip nodes that don't extend beyond our range.
309            if node_start_row >= start_row && node_end_row <= end_row {
310                continue;
311            }
312
313            // Skip if same as last returned (some nodes have same span).
314            if last_rows == Some(rows) {
315                continue;
316            }
317
318            last_rows = Some(rows);
319            return Some(rows);
320        }
321        None
322    })
323}
324
325/// Expands context by first trying to reach syntax boundaries,
326/// then expanding line-wise only if no syntax expansion occurred.
327fn expand_context_syntactically_then_linewise(
328    snapshot: &BufferSnapshot,
329    editable_range: Range<Point>,
330    context_token_limit: usize,
331) -> Range<Point> {
332    let mut start_row = editable_range.start.row;
333    let mut end_row = editable_range.end.row;
334    let mut remaining_tokens = context_token_limit;
335    let mut did_syntax_expand = false;
336
337    // Phase 1: Try to expand to containing syntax boundaries, picking the largest that fits.
338    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
339    {
340        let tokens_for_start = if boundary_start < start_row {
341            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
342        } else {
343            0
344        };
345        let tokens_for_end = if boundary_end > end_row {
346            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
347        } else {
348            0
349        };
350
351        let total_needed = tokens_for_start + tokens_for_end;
352
353        if total_needed <= remaining_tokens {
354            if boundary_start < start_row {
355                start_row = boundary_start;
356            }
357            if boundary_end > end_row {
358                end_row = boundary_end;
359            }
360            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
361            did_syntax_expand = true;
362        } else {
363            break;
364        }
365    }
366
367    // Phase 2: Only expand line-wise if no syntax expansion occurred.
368    if !did_syntax_expand {
369        (start_row, end_row, _) =
370            expand_linewise_biased(snapshot, start_row, end_row, remaining_tokens, true);
371    }
372
373    let start = Point::new(start_row, 0);
374    let end = Point::new(end_row, snapshot.line_len(end_row));
375    start..end
376}
377
378use language::ToOffset as _;
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use gpui::{App, AppContext};
384    use indoc::indoc;
385    use language::{Buffer, rust_lang};
386    use util::test::{TextRangeMarker, marked_text_ranges_by};
387
388    struct TestCase {
389        name: &'static str,
390        marked_text: &'static str,
391        editable_token_limit: usize,
392        context_token_limit: usize,
393    }
394
395    #[gpui::test]
396    fn test_editable_and_context_ranges(cx: &mut App) {
397        // Markers:
398        // ˇ = cursor position
399        // « » = expected editable range
400        // [ ] = expected context range
401        let test_cases = vec![
402            TestCase {
403                name: "cursor near end of function - expands to syntax boundaries",
404                marked_text: indoc! {r#"
405                    [fn first() {
406                        let a = 1;
407                        let b = 2;
408                    }
409
410                    fn foo() {
411                    «    let x = 1;
412                        let y = 2;
413                        println!("{}", x + y);ˇ
414                    }»]
415                "#},
416                // 18 tokens - expands symmetrically then to syntax boundaries
417                editable_token_limit: 18,
418                context_token_limit: 35,
419            },
420            TestCase {
421                name: "cursor at function start - expands to syntax boundaries",
422                marked_text: indoc! {r#"
423                    [fn before() {
424                    «    let a = 1;
425                    }
426
427                    fn foo() {ˇ
428                        let x = 1;
429                        let y = 2;
430                        let z = 3;
431                    }
432                    »
433                    fn after() {
434                        let b = 2;
435                    }]
436                "#},
437                // 25 tokens - expands symmetrically then to syntax boundaries
438                editable_token_limit: 25,
439                context_token_limit: 50,
440            },
441            TestCase {
442                name: "tiny budget - just lines around cursor",
443                marked_text: indoc! {r#"
444                    fn outer() {
445                    [    let line1 = 1;
446                        let line2 = 2;
447                    «    let line3 = 3;
448                        let line4 = 4;ˇ»
449                        let line5 = 5;
450                        let line6 = 6;]
451                        let line7 = 7;
452                    }
453                "#},
454                // 12 tokens (~36 bytes) = just the cursor line with tiny budget
455                editable_token_limit: 12,
456                context_token_limit: 24,
457            },
458            TestCase {
459                name: "small function fits entirely",
460                marked_text: indoc! {r#"
461                    [«fn foo() {
462                        let x = 1;ˇ
463                        let y = 2;
464                    }»]
465                "#},
466                // Plenty of budget for this small function
467                editable_token_limit: 30,
468                context_token_limit: 60,
469            },
470            TestCase {
471                name: "context extends beyond editable",
472                marked_text: indoc! {r#"
473                    [fn first() { let a = 1; }
474                    «fn second() { let b = 2; }
475                    fn third() { let c = 3; }ˇ
476                    fn fourth() { let d = 4; }»
477                    fn fifth() { let e = 5; }]
478                "#},
479                // Small editable, larger context
480                editable_token_limit: 25,
481                context_token_limit: 45,
482            },
483            // Tests for syntax-aware editable and context expansion
484            TestCase {
485                name: "cursor in first if-statement - expands to syntax boundaries",
486                marked_text: indoc! {r#"
487                    [«fn before() { }
488
489                    fn process() {
490                        if condition1 {
491                            let a = 1;ˇ
492                            let b = 2;
493                        }
494                        if condition2 {»
495                            let c = 3;
496                            let d = 4;
497                        }
498                        if condition3 {
499                            let e = 5;
500                            let f = 6;
501                        }
502                    }
503
504                    fn after() { }]
505                "#},
506                // 35 tokens allows expansion to include function header and first two if blocks
507                editable_token_limit: 35,
508                // 60 tokens allows context to include the whole file
509                context_token_limit: 60,
510            },
511            TestCase {
512                name: "cursor in middle if-statement - expands to syntax boundaries",
513                marked_text: indoc! {r#"
514                    [fn before() { }
515
516                    fn process() {
517                        if condition1 {
518                            let a = 1;
519                    «        let b = 2;
520                        }
521                        if condition2 {
522                            let c = 3;ˇ
523                            let d = 4;
524                        }
525                        if condition3 {
526                            let e = 5;»
527                            let f = 6;
528                        }
529                    }
530
531                    fn after() { }]
532                "#},
533                // 40 tokens allows expansion to surrounding if blocks
534                editable_token_limit: 40,
535                // 60 tokens allows context to include the whole file
536                context_token_limit: 60,
537            },
538            TestCase {
539                name: "cursor near bottom of long function - editable expands toward syntax, context reaches function",
540                marked_text: indoc! {r#"
541                    [fn other() { }
542
543                    fn long_function() {
544                        let line1 = 1;
545                        let line2 = 2;
546                        let line3 = 3;
547                        let line4 = 4;
548                        let line5 = 5;
549                        let line6 = 6;
550                    «    let line7 = 7;
551                        let line8 = 8;
552                        let line9 = 9;
553                        let line10 = 10;ˇ
554                        let line11 = 11;
555                    }
556
557                    fn another() { }»]
558                "#},
559                // 40 tokens for editable - allows several lines plus syntax expansion
560                editable_token_limit: 40,
561                // 55 tokens - enough for function but not whole file
562                context_token_limit: 55,
563            },
564        ];
565
566        for test_case in test_cases {
567            let cursor_marker: TextRangeMarker = 'ˇ'.into();
568            let editable_marker: TextRangeMarker = ('«', '»').into();
569            let context_marker: TextRangeMarker = ('[', ']').into();
570
571            let (text, mut ranges) = marked_text_ranges_by(
572                test_case.marked_text,
573                vec![
574                    cursor_marker.clone(),
575                    editable_marker.clone(),
576                    context_marker.clone(),
577                ],
578            );
579
580            let cursor_ranges = ranges.remove(&cursor_marker).unwrap_or_default();
581            let expected_editable = ranges.remove(&editable_marker).unwrap_or_default();
582            let expected_context = ranges.remove(&context_marker).unwrap_or_default();
583            assert_eq!(expected_editable.len(), 1);
584            assert_eq!(expected_context.len(), 1);
585
586            cx.new(|cx| {
587                let text = text.trim_end_matches('\n');
588                let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
589                let snapshot = buffer.snapshot();
590
591                let cursor_offset = cursor_ranges[0].start;
592                let cursor_point = snapshot.offset_to_point(cursor_offset);
593                let expected_editable_start = snapshot.offset_to_point(expected_editable[0].start);
594                let expected_editable_end = snapshot.offset_to_point(expected_editable[0].end);
595                let expected_context_start = snapshot.offset_to_point(expected_context[0].start);
596                let expected_context_end = snapshot.offset_to_point(expected_context[0].end);
597
598                let (actual_editable, actual_context) =
599                    editable_and_context_ranges_for_cursor_position(
600                        cursor_point,
601                        &snapshot,
602                        test_case.editable_token_limit,
603                        test_case.context_token_limit,
604                    );
605
606                let range_text = |start: Point, end: Point| -> String {
607                    snapshot.text_for_range(start..end).collect()
608                };
609
610                let editable_match = actual_editable.start == expected_editable_start
611                    && actual_editable.end == expected_editable_end;
612                let context_match = actual_context.start == expected_context_start
613                    && actual_context.end == expected_context_end;
614
615                if !editable_match || !context_match {
616                    println!("\n=== FAILED: {} ===", test_case.name);
617                    if !editable_match {
618                        println!(
619                            "\nExpected editable ({:?}..{:?}):",
620                            expected_editable_start, expected_editable_end
621                        );
622                        println!(
623                            "---\n{}---",
624                            range_text(expected_editable_start, expected_editable_end)
625                        );
626                        println!(
627                            "\nActual editable ({:?}..{:?}):",
628                            actual_editable.start, actual_editable.end
629                        );
630                        println!(
631                            "---\n{}---",
632                            range_text(actual_editable.start, actual_editable.end)
633                        );
634                    }
635                    if !context_match {
636                        println!(
637                            "\nExpected context ({:?}..{:?}):",
638                            expected_context_start, expected_context_end
639                        );
640                        println!(
641                            "---\n{}---",
642                            range_text(expected_context_start, expected_context_end)
643                        );
644                        println!(
645                            "\nActual context ({:?}..{:?}):",
646                            actual_context.start, actual_context.end
647                        );
648                        println!(
649                            "---\n{}---",
650                            range_text(actual_context.start, actual_context.end)
651                        );
652                    }
653                    panic!("Test '{}' failed - see output above", test_case.name);
654                }
655
656                buffer
657            });
658        }
659    }
660}