1use super::{
2 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
3 guess_token_count,
4};
5use language::{BufferSnapshot, Point};
6use std::{fmt::Write, ops::Range};
7
8#[derive(Debug)]
9pub struct InputExcerpt {
10 pub context_range: Range<Point>,
11 pub editable_range: Range<Point>,
12 pub prompt: String,
13}
14
15pub fn excerpt_for_cursor_position(
16 position: Point,
17 path: &str,
18 snapshot: &BufferSnapshot,
19 editable_region_token_limit: usize,
20 context_token_limit: usize,
21) -> InputExcerpt {
22 let mut scope_range = position..position;
23 let mut remaining_edit_tokens = editable_region_token_limit;
24
25 while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
26 let parent_tokens = guess_token_count(parent.byte_range().len());
27 let parent_point_range = Point::new(
28 parent.start_position().row as u32,
29 parent.start_position().column as u32,
30 )
31 ..Point::new(
32 parent.end_position().row as u32,
33 parent.end_position().column as u32,
34 );
35 if parent_point_range == scope_range {
36 break;
37 } else if parent_tokens <= editable_region_token_limit {
38 scope_range = parent_point_range;
39 remaining_edit_tokens = editable_region_token_limit - parent_tokens;
40 } else {
41 break;
42 }
43 }
44
45 let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
46 let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
47
48 let mut prompt = String::new();
49
50 writeln!(&mut prompt, "```{path}").unwrap();
51 if context_range.start == Point::zero() {
52 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
53 }
54
55 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
56 prompt.push_str(chunk.text);
57 }
58
59 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
60
61 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
62 prompt.push_str(chunk.text);
63 }
64 write!(prompt, "\n```").unwrap();
65
66 InputExcerpt {
67 context_range,
68 editable_range,
69 prompt,
70 }
71}
72
73fn push_editable_range(
74 cursor_position: Point,
75 snapshot: &BufferSnapshot,
76 editable_range: Range<Point>,
77 prompt: &mut String,
78) {
79 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
80 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
81 prompt.push_str(chunk.text);
82 }
83 prompt.push_str(CURSOR_MARKER);
84 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
85 prompt.push_str(chunk.text);
86 }
87 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
88}
89
90fn expand_range(
91 snapshot: &BufferSnapshot,
92 range: Range<Point>,
93 mut remaining_tokens: usize,
94) -> Range<Point> {
95 let mut expanded_range = range;
96 expanded_range.start.column = 0;
97 expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
98 loop {
99 let mut expanded = false;
100
101 if remaining_tokens > 0 && expanded_range.start.row > 0 {
102 expanded_range.start.row -= 1;
103 let line_tokens =
104 guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
105 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
106 expanded = true;
107 }
108
109 if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
110 expanded_range.end.row += 1;
111 expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
112 let line_tokens = guess_token_count(expanded_range.end.column as usize);
113 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
114 expanded = true;
115 }
116
117 if !expanded {
118 break;
119 }
120 }
121 expanded_range
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use gpui::{App, AppContext};
128 use indoc::indoc;
129 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
130 use std::sync::Arc;
131
132 #[gpui::test]
133 fn test_excerpt_for_cursor_position(cx: &mut App) {
134 let text = indoc! {r#"
135 fn foo() {
136 let x = 42;
137 println!("Hello, world!");
138 }
139
140 fn bar() {
141 let x = 42;
142 let mut sum = 0;
143 for i in 0..x {
144 sum += i;
145 }
146 println!("Sum: {}", sum);
147 return sum;
148 }
149
150 fn generate_random_numbers() -> Vec<i32> {
151 let mut rng = rand::thread_rng();
152 let mut numbers = Vec::new();
153 for _ in 0..5 {
154 numbers.push(rng.random_range(1..101));
155 }
156 numbers
157 }
158 "#};
159 let buffer =
160 cx.new(|cx| Buffer::local(text, cx).with_language_immediate(Arc::new(rust_lang()), cx));
161 let snapshot = buffer.read(cx).snapshot();
162
163 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
164 // when a larger scope doesn't fit the editable region.
165 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
166 assert_eq!(
167 excerpt.prompt,
168 indoc! {r#"
169 ```main.rs
170 let x = 42;
171 println!("Hello, world!");
172 <|editable_region_start|>
173 }
174
175 fn bar() {
176 let x = 42;
177 let mut sum = 0;
178 for i in 0..x {
179 sum += i;
180 }
181 println!("Sum: {}", sum);
182 r<|user_cursor_is_here|>eturn sum;
183 }
184
185 fn generate_random_numbers() -> Vec<i32> {
186 <|editable_region_end|>
187 let mut rng = rand::thread_rng();
188 let mut numbers = Vec::new();
189 ```"#}
190 );
191
192 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
193 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
194 assert_eq!(
195 excerpt.prompt,
196 indoc! {r#"
197 ```main.rs
198 fn bar() {
199 let x = 42;
200 let mut sum = 0;
201 <|editable_region_start|>
202 for i in 0..x {
203 sum += i;
204 }
205 println!("Sum: {}", sum);
206 r<|user_cursor_is_here|>eturn sum;
207 }
208
209 fn generate_random_numbers() -> Vec<i32> {
210 let mut rng = rand::thread_rng();
211 <|editable_region_end|>
212 let mut numbers = Vec::new();
213 for _ in 0..5 {
214 numbers.push(rng.random_range(1..101));
215 ```"#}
216 );
217 }
218
219 fn rust_lang() -> Language {
220 Language::new(
221 LanguageConfig {
222 name: "Rust".into(),
223 matcher: LanguageMatcher {
224 path_suffixes: vec!["rs".to_string()],
225 ..Default::default()
226 },
227 ..Default::default()
228 },
229 Some(tree_sitter_rust::LANGUAGE.into()),
230 )
231 }
232}