1use language::{BufferSnapshot, Point, ToPoint as _};
2use std::ops::Range;
3use text::OffsetRangeExt as _;
4
5const CURSOR_EXCERPT_TOKEN_BUDGET: usize = 8192;
6
7/// Computes a cursor excerpt as the largest linewise symmetric region around
8/// the cursor that fits within an 8192-token budget. Returns the point range,
9/// byte offset range, and the cursor offset relative to the excerpt start.
10pub fn compute_cursor_excerpt(
11 snapshot: &BufferSnapshot,
12 cursor_offset: usize,
13) -> (Range<Point>, Range<usize>, usize) {
14 let cursor_point = cursor_offset.to_point(snapshot);
15 let cursor_row = cursor_point.row;
16 let (start_row, end_row, _) =
17 expand_symmetric_from_cursor(snapshot, cursor_row, CURSOR_EXCERPT_TOKEN_BUDGET);
18
19 let excerpt_range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row));
20 let excerpt_offset_range = excerpt_range.to_offset(snapshot);
21 let cursor_offset_in_excerpt = cursor_offset - excerpt_offset_range.start;
22
23 (
24 excerpt_range,
25 excerpt_offset_range,
26 cursor_offset_in_excerpt,
27 )
28}
29
30/// Expands symmetrically from cursor, one line at a time, alternating down then up.
31/// Returns (start_row, end_row, remaining_tokens).
32fn expand_symmetric_from_cursor(
33 snapshot: &BufferSnapshot,
34 cursor_row: u32,
35 mut token_budget: usize,
36) -> (u32, u32, usize) {
37 let mut start_row = cursor_row;
38 let mut end_row = cursor_row;
39
40 let cursor_line_tokens = line_token_count(snapshot, cursor_row);
41 token_budget = token_budget.saturating_sub(cursor_line_tokens);
42
43 loop {
44 let can_expand_up = start_row > 0;
45 let can_expand_down = end_row < snapshot.max_point().row;
46
47 if token_budget == 0 || (!can_expand_up && !can_expand_down) {
48 break;
49 }
50
51 if can_expand_down {
52 let next_row = end_row + 1;
53 let line_tokens = line_token_count(snapshot, next_row);
54 if line_tokens <= token_budget {
55 end_row = next_row;
56 token_budget = token_budget.saturating_sub(line_tokens);
57 } else {
58 break;
59 }
60 }
61
62 if can_expand_up && token_budget > 0 {
63 let next_row = start_row - 1;
64 let line_tokens = line_token_count(snapshot, next_row);
65 if line_tokens <= token_budget {
66 start_row = next_row;
67 token_budget = token_budget.saturating_sub(line_tokens);
68 } else {
69 break;
70 }
71 }
72 }
73
74 (start_row, end_row, token_budget)
75}
76
77/// Typical number of string bytes per token for the purposes of limiting model input. This is
78/// intentionally low to err on the side of underestimating limits.
79pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
80
81pub fn guess_token_count(bytes: usize) -> usize {
82 bytes / BYTES_PER_TOKEN_GUESS
83}
84
85fn line_token_count(snapshot: &BufferSnapshot, row: u32) -> usize {
86 guess_token_count(snapshot.line_len(row) as usize).max(1)
87}
88
89/// Computes the byte offset ranges of all syntax nodes containing the cursor,
90/// ordered from innermost to outermost. The offsets are relative to
91/// `excerpt_offset_range.start`.
92pub fn compute_syntax_ranges(
93 snapshot: &BufferSnapshot,
94 cursor_offset: usize,
95 excerpt_offset_range: &Range<usize>,
96) -> Vec<Range<usize>> {
97 let cursor_point = cursor_offset.to_point(snapshot);
98 let range = cursor_point..cursor_point;
99 let mut current = snapshot.syntax_ancestor(range);
100 let mut ranges = Vec::new();
101 let mut last_range: Option<(usize, usize)> = None;
102
103 while let Some(node) = current.take() {
104 let node_start = node.start_byte();
105 let node_end = node.end_byte();
106 let key = (node_start, node_end);
107
108 current = node.parent();
109
110 if last_range == Some(key) {
111 continue;
112 }
113 last_range = Some(key);
114
115 let start = node_start.saturating_sub(excerpt_offset_range.start);
116 let end = node_end
117 .min(excerpt_offset_range.end)
118 .saturating_sub(excerpt_offset_range.start);
119 ranges.push(start..end);
120 }
121
122 ranges
123}
124
125/// Expands context by first trying to reach syntax boundaries,
126/// then expanding line-wise only if no syntax expansion occurred.
127pub fn expand_context_syntactically_then_linewise(
128 snapshot: &BufferSnapshot,
129 editable_range: Range<Point>,
130 context_token_limit: usize,
131) -> Range<Point> {
132 let mut start_row = editable_range.start.row;
133 let mut end_row = editable_range.end.row;
134 let mut remaining_tokens = context_token_limit;
135 let mut did_syntax_expand = false;
136
137 // Phase 1: Try to expand to containing syntax boundaries, picking the largest that fits.
138 for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
139 {
140 let tokens_for_start = if boundary_start < start_row {
141 estimate_tokens_for_rows(snapshot, boundary_start, start_row)
142 } else {
143 0
144 };
145 let tokens_for_end = if boundary_end > end_row {
146 estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
147 } else {
148 0
149 };
150
151 let total_needed = tokens_for_start + tokens_for_end;
152
153 if total_needed <= remaining_tokens {
154 if boundary_start < start_row {
155 start_row = boundary_start;
156 }
157 if boundary_end > end_row {
158 end_row = boundary_end;
159 }
160 remaining_tokens = remaining_tokens.saturating_sub(total_needed);
161 did_syntax_expand = true;
162 } else {
163 break;
164 }
165 }
166
167 // Phase 2: Only expand line-wise if no syntax expansion occurred.
168 if !did_syntax_expand {
169 (start_row, end_row, _) =
170 expand_linewise_biased(snapshot, start_row, end_row, remaining_tokens, true);
171 }
172
173 let start = Point::new(start_row, 0);
174 let end = Point::new(end_row, snapshot.line_len(end_row));
175 start..end
176}
177
178/// Returns an iterator of (start_row, end_row) for successively larger syntax nodes
179/// containing the given row range. Smallest containing node first.
180fn containing_syntax_boundaries(
181 snapshot: &BufferSnapshot,
182 start_row: u32,
183 end_row: u32,
184) -> impl Iterator<Item = (u32, u32)> {
185 let range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row));
186 let mut current = snapshot.syntax_ancestor(range);
187 let mut last_rows: Option<(u32, u32)> = None;
188
189 std::iter::from_fn(move || {
190 while let Some(node) = current.take() {
191 let node_start_row = node.start_position().row as u32;
192 let node_end_row = node.end_position().row as u32;
193 let rows = (node_start_row, node_end_row);
194
195 current = node.parent();
196
197 // Skip nodes that don't extend beyond our range.
198 if node_start_row >= start_row && node_end_row <= end_row {
199 continue;
200 }
201
202 // Skip if same as last returned (some nodes have same span).
203 if last_rows == Some(rows) {
204 continue;
205 }
206
207 last_rows = Some(rows);
208 return Some(rows);
209 }
210 None
211 })
212}
213
214/// Expands line-wise with a bias toward one direction.
215/// Returns (start_row, end_row, remaining_tokens).
216fn expand_linewise_biased(
217 snapshot: &BufferSnapshot,
218 mut start_row: u32,
219 mut end_row: u32,
220 mut remaining_tokens: usize,
221 prefer_up: bool,
222) -> (u32, u32, usize) {
223 loop {
224 let can_expand_up = start_row > 0;
225 let can_expand_down = end_row < snapshot.max_point().row;
226
227 if remaining_tokens == 0 || (!can_expand_up && !can_expand_down) {
228 break;
229 }
230
231 let mut expanded = false;
232
233 // Try preferred direction first.
234 if prefer_up {
235 if can_expand_up {
236 let next_row = start_row - 1;
237 let line_tokens = line_token_count(snapshot, next_row);
238 if line_tokens <= remaining_tokens {
239 start_row = next_row;
240 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
241 expanded = true;
242 }
243 }
244 if can_expand_down && remaining_tokens > 0 {
245 let next_row = end_row + 1;
246 let line_tokens = line_token_count(snapshot, next_row);
247 if line_tokens <= remaining_tokens {
248 end_row = next_row;
249 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
250 expanded = true;
251 }
252 }
253 } else {
254 if can_expand_down {
255 let next_row = end_row + 1;
256 let line_tokens = line_token_count(snapshot, next_row);
257 if line_tokens <= remaining_tokens {
258 end_row = next_row;
259 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
260 expanded = true;
261 }
262 }
263 if can_expand_up && remaining_tokens > 0 {
264 let next_row = start_row - 1;
265 let line_tokens = line_token_count(snapshot, next_row);
266 if line_tokens <= remaining_tokens {
267 start_row = next_row;
268 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
269 expanded = true;
270 }
271 }
272 }
273
274 if !expanded {
275 break;
276 }
277 }
278
279 (start_row, end_row, remaining_tokens)
280}
281
282/// Estimates token count for rows in range [start_row, end_row).
283fn estimate_tokens_for_rows(snapshot: &BufferSnapshot, start_row: u32, end_row: u32) -> usize {
284 let mut tokens = 0;
285 for row in start_row..end_row {
286 tokens += line_token_count(snapshot, row);
287 }
288 tokens
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use gpui::{App, AppContext as _};
295 use indoc::indoc;
296 use language::{Buffer, rust_lang};
297 use util::test::{TextRangeMarker, marked_text_ranges_by};
298 use zeta_prompt::compute_editable_and_context_ranges;
299
300 struct TestCase {
301 name: &'static str,
302 marked_text: &'static str,
303 editable_token_limit: usize,
304 context_token_limit: usize,
305 }
306
307 #[gpui::test]
308 fn test_editable_and_context_ranges(cx: &mut App) {
309 // Markers:
310 // ˇ = cursor position
311 // « » = expected editable range
312 // [ ] = expected context range
313 let test_cases = vec![
314 TestCase {
315 name: "small function fits entirely in editable and context",
316 marked_text: indoc! {r#"
317 [«fn foo() {
318 let x = 1;ˇ
319 let y = 2;
320 }»]
321 "#},
322 editable_token_limit: 30,
323 context_token_limit: 60,
324 },
325 TestCase {
326 name: "cursor near end of function - editable expands to syntax boundaries",
327 marked_text: indoc! {r#"
328 [fn first() {
329 let a = 1;
330 let b = 2;
331 }
332
333 fn foo() {
334 « let x = 1;
335 let y = 2;
336 println!("{}", x + y);ˇ
337 }»]
338 "#},
339 editable_token_limit: 18,
340 context_token_limit: 35,
341 },
342 TestCase {
343 name: "cursor at function start - editable expands to syntax boundaries",
344 marked_text: indoc! {r#"
345 [fn before() {
346 « let a = 1;
347 }
348
349 fn foo() {ˇ
350 let x = 1;
351 let y = 2;
352 let z = 3;
353 }
354 »
355 fn after() {
356 let b = 2;
357 }]
358 "#},
359 editable_token_limit: 25,
360 context_token_limit: 50,
361 },
362 TestCase {
363 name: "tiny budget - just lines around cursor, no syntax expansion",
364 marked_text: indoc! {r#"
365 fn outer() {
366 [ let line1 = 1;
367 let line2 = 2;
368 « let line3 = 3;
369 let line4 = 4;ˇ»
370 let line5 = 5;
371 let line6 = 6;]
372 let line7 = 7;
373 }
374 "#},
375 editable_token_limit: 12,
376 context_token_limit: 24,
377 },
378 TestCase {
379 name: "context extends beyond editable",
380 marked_text: indoc! {r#"
381 [fn first() { let a = 1; }
382 «fn second() { let b = 2; }
383 fn third() { let c = 3; }ˇ
384 fn fourth() { let d = 4; }»
385 fn fifth() { let e = 5; }]
386 "#},
387 editable_token_limit: 25,
388 context_token_limit: 45,
389 },
390 TestCase {
391 name: "cursor in first if-block - editable expands to syntax boundaries",
392 marked_text: indoc! {r#"
393 [«fn before() { }
394
395 fn process() {
396 if condition1 {
397 let a = 1;ˇ
398 let b = 2;
399 }
400 if condition2 {»
401 let c = 3;
402 let d = 4;
403 }
404 if condition3 {
405 let e = 5;
406 let f = 6;
407 }
408 }
409
410 fn after() { }]
411 "#},
412 editable_token_limit: 35,
413 context_token_limit: 60,
414 },
415 TestCase {
416 name: "cursor in middle if-block - editable spans surrounding blocks",
417 marked_text: indoc! {r#"
418 [fn before() { }
419
420 fn process() {
421 if condition1 {
422 let a = 1;
423 « let b = 2;
424 }
425 if condition2 {
426 let c = 3;ˇ
427 let d = 4;
428 }
429 if condition3 {
430 let e = 5;»
431 let f = 6;
432 }
433 }
434
435 fn after() { }]
436 "#},
437 editable_token_limit: 40,
438 context_token_limit: 60,
439 },
440 TestCase {
441 name: "cursor near bottom of long function - context reaches function boundary",
442 marked_text: indoc! {r#"
443 [fn other() { }
444
445 fn long_function() {
446 let line1 = 1;
447 let line2 = 2;
448 let line3 = 3;
449 let line4 = 4;
450 let line5 = 5;
451 let line6 = 6;
452 « let line7 = 7;
453 let line8 = 8;
454 let line9 = 9;
455 let line10 = 10;ˇ
456 let line11 = 11;
457 }
458
459 fn another() { }»]
460 "#},
461 editable_token_limit: 40,
462 context_token_limit: 55,
463 },
464 TestCase {
465 name: "zero context budget - context equals editable",
466 marked_text: indoc! {r#"
467 fn before() {
468 let p = 1;
469 let q = 2;
470 [«}
471
472 fn foo() {
473 let x = 1;ˇ
474 let y = 2;
475 }
476 »]
477 fn after() {
478 let r = 3;
479 let s = 4;
480 }
481 "#},
482 editable_token_limit: 15,
483 context_token_limit: 0,
484 },
485 ];
486
487 for test_case in test_cases {
488 let cursor_marker: TextRangeMarker = 'ˇ'.into();
489 let editable_marker: TextRangeMarker = ('«', '»').into();
490 let context_marker: TextRangeMarker = ('[', ']').into();
491
492 let (text, mut ranges) = marked_text_ranges_by(
493 test_case.marked_text,
494 vec![
495 cursor_marker.clone(),
496 editable_marker.clone(),
497 context_marker.clone(),
498 ],
499 );
500
501 let cursor_ranges = ranges.remove(&cursor_marker).unwrap_or_default();
502 let expected_editable = ranges.remove(&editable_marker).unwrap_or_default();
503 let expected_context = ranges.remove(&context_marker).unwrap_or_default();
504 assert_eq!(expected_editable.len(), 1, "{}", test_case.name);
505 assert_eq!(expected_context.len(), 1, "{}", test_case.name);
506
507 cx.new(|cx: &mut gpui::Context<Buffer>| {
508 let text = text.trim_end_matches('\n');
509 let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
510 let snapshot = buffer.snapshot();
511
512 let cursor_offset = cursor_ranges[0].start;
513
514 let (_, excerpt_offset_range, cursor_offset_in_excerpt) =
515 compute_cursor_excerpt(&snapshot, cursor_offset);
516 let excerpt_text: String = snapshot
517 .text_for_range(excerpt_offset_range.clone())
518 .collect();
519 let syntax_ranges =
520 compute_syntax_ranges(&snapshot, cursor_offset, &excerpt_offset_range);
521
522 let (actual_editable, actual_context) = compute_editable_and_context_ranges(
523 &excerpt_text,
524 cursor_offset_in_excerpt,
525 &syntax_ranges,
526 test_case.editable_token_limit,
527 test_case.context_token_limit,
528 );
529
530 let to_buffer_range = |range: Range<usize>| -> Range<usize> {
531 (excerpt_offset_range.start + range.start)
532 ..(excerpt_offset_range.start + range.end)
533 };
534
535 let actual_editable = to_buffer_range(actual_editable);
536 let actual_context = to_buffer_range(actual_context);
537
538 let expected_editable_range = expected_editable[0].clone();
539 let expected_context_range = expected_context[0].clone();
540
541 let editable_match = actual_editable == expected_editable_range;
542 let context_match = actual_context == expected_context_range;
543
544 if !editable_match || !context_match {
545 let range_text = |range: &Range<usize>| {
546 snapshot.text_for_range(range.clone()).collect::<String>()
547 };
548
549 println!("\n=== FAILED: {} ===", test_case.name);
550 if !editable_match {
551 println!("\nExpected editable ({:?}):", expected_editable_range);
552 println!("---\n{}---", range_text(&expected_editable_range));
553 println!("\nActual editable ({:?}):", actual_editable);
554 println!("---\n{}---", range_text(&actual_editable));
555 }
556 if !context_match {
557 println!("\nExpected context ({:?}):", expected_context_range);
558 println!("---\n{}---", range_text(&expected_context_range));
559 println!("\nActual context ({:?}):", actual_context);
560 println!("---\n{}---", range_text(&actual_context));
561 }
562 panic!("Test '{}' failed - see output above", test_case.name);
563 }
564
565 buffer
566 });
567 }
568 }
569}