1use crate::{
2 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
3 guess_token_count,
4};
5use language::{BufferSnapshot, Point};
6use std::{fmt::Write, ops::Range};
7
8#[derive(Debug)]
9pub struct InputExcerpt {
10 pub editable_range: Range<Point>,
11 pub prompt: String,
12}
13
14pub fn excerpt_for_cursor_position(
15 position: Point,
16 path: &str,
17 snapshot: &BufferSnapshot,
18 editable_region_token_limit: usize,
19 context_token_limit: usize,
20) -> InputExcerpt {
21 let mut scope_range = position..position;
22 let mut remaining_edit_tokens = editable_region_token_limit;
23
24 while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
25 let parent_tokens = guess_token_count(parent.byte_range().len());
26 let parent_point_range = Point::new(
27 parent.start_position().row as u32,
28 parent.start_position().column as u32,
29 )
30 ..Point::new(
31 parent.end_position().row as u32,
32 parent.end_position().column as u32,
33 );
34 if parent_point_range == scope_range {
35 break;
36 } else if parent_tokens <= editable_region_token_limit {
37 scope_range = parent_point_range;
38 remaining_edit_tokens = editable_region_token_limit - parent_tokens;
39 } else {
40 break;
41 }
42 }
43
44 let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
45 let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
46
47 let mut prompt = String::new();
48
49 writeln!(&mut prompt, "```{path}").unwrap();
50 if context_range.start == Point::zero() {
51 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
52 }
53
54 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
55 prompt.push_str(chunk.text);
56 }
57
58 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
59
60 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
61 prompt.push_str(chunk.text);
62 }
63 write!(prompt, "\n```").unwrap();
64
65 InputExcerpt {
66 editable_range,
67 prompt,
68 }
69}
70
71fn push_editable_range(
72 cursor_position: Point,
73 snapshot: &BufferSnapshot,
74 editable_range: Range<Point>,
75 prompt: &mut String,
76) {
77 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
78 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
79 prompt.push_str(chunk.text);
80 }
81 prompt.push_str(CURSOR_MARKER);
82 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
83 prompt.push_str(chunk.text);
84 }
85 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
86}
87
88fn expand_range(
89 snapshot: &BufferSnapshot,
90 range: Range<Point>,
91 mut remaining_tokens: usize,
92) -> Range<Point> {
93 let mut expanded_range = range;
94 expanded_range.start.column = 0;
95 expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
96 loop {
97 let mut expanded = false;
98
99 if remaining_tokens > 0 && expanded_range.start.row > 0 {
100 expanded_range.start.row -= 1;
101 let line_tokens =
102 guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
103 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
104 expanded = true;
105 }
106
107 if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
108 expanded_range.end.row += 1;
109 expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
110 let line_tokens = guess_token_count(expanded_range.end.column as usize);
111 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
112 expanded = true;
113 }
114
115 if !expanded {
116 break;
117 }
118 }
119 expanded_range
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use gpui::{App, AppContext};
126 use indoc::indoc;
127 use language::{Buffer, Language, LanguageConfig, LanguageMatcher};
128 use std::sync::Arc;
129
130 #[gpui::test]
131 fn test_excerpt_for_cursor_position(cx: &mut App) {
132 let text = indoc! {r#"
133 fn foo() {
134 let x = 42;
135 println!("Hello, world!");
136 }
137
138 fn bar() {
139 let x = 42;
140 let mut sum = 0;
141 for i in 0..x {
142 sum += i;
143 }
144 println!("Sum: {}", sum);
145 return sum;
146 }
147
148 fn generate_random_numbers() -> Vec<i32> {
149 let mut rng = rand::thread_rng();
150 let mut numbers = Vec::new();
151 for _ in 0..5 {
152 numbers.push(rng.random_range(1..101));
153 }
154 numbers
155 }
156 "#};
157 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
158 let snapshot = buffer.read(cx).snapshot();
159
160 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
161 // when a larger scope doesn't fit the editable region.
162 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
163 assert_eq!(
164 excerpt.prompt,
165 indoc! {r#"
166 ```main.rs
167 let x = 42;
168 println!("Hello, world!");
169 <|editable_region_start|>
170 }
171
172 fn bar() {
173 let x = 42;
174 let mut sum = 0;
175 for i in 0..x {
176 sum += i;
177 }
178 println!("Sum: {}", sum);
179 r<|user_cursor_is_here|>eturn sum;
180 }
181
182 fn generate_random_numbers() -> Vec<i32> {
183 <|editable_region_end|>
184 let mut rng = rand::thread_rng();
185 let mut numbers = Vec::new();
186 ```"#}
187 );
188
189 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
190 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
191 assert_eq!(
192 excerpt.prompt,
193 indoc! {r#"
194 ```main.rs
195 fn bar() {
196 let x = 42;
197 let mut sum = 0;
198 <|editable_region_start|>
199 for i in 0..x {
200 sum += i;
201 }
202 println!("Sum: {}", sum);
203 r<|user_cursor_is_here|>eturn sum;
204 }
205
206 fn generate_random_numbers() -> Vec<i32> {
207 let mut rng = rand::thread_rng();
208 <|editable_region_end|>
209 let mut numbers = Vec::new();
210 for _ in 0..5 {
211 numbers.push(rng.random_range(1..101));
212 ```"#}
213 );
214 }
215
216 fn rust_lang() -> Language {
217 Language::new(
218 LanguageConfig {
219 name: "Rust".into(),
220 matcher: LanguageMatcher {
221 path_suffixes: vec!["rs".to_string()],
222 ..Default::default()
223 },
224 ..Default::default()
225 },
226 Some(tree_sitter_rust::LANGUAGE.into()),
227 )
228 }
229}