1use super::{
2 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
3 guess_token_count,
4};
5use language::{BufferSnapshot, Point};
6use std::{fmt::Write, ops::Range};
7
8#[derive(Debug)]
9pub struct InputExcerpt {
10 pub context_range: Range<Point>,
11 pub editable_range: Range<Point>,
12 pub prompt: String,
13}
14
15pub fn excerpt_for_cursor_position(
16 position: Point,
17 path: &str,
18 snapshot: &BufferSnapshot,
19 editable_region_token_limit: usize,
20 context_token_limit: usize,
21) -> InputExcerpt {
22 let mut scope_range = position..position;
23 let mut remaining_edit_tokens = editable_region_token_limit;
24
25 while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
26 let parent_tokens = guess_token_count(parent.byte_range().len());
27 let parent_point_range = Point::new(
28 parent.start_position().row as u32,
29 parent.start_position().column as u32,
30 )
31 ..Point::new(
32 parent.end_position().row as u32,
33 parent.end_position().column as u32,
34 );
35 if parent_point_range == scope_range {
36 break;
37 } else if parent_tokens <= editable_region_token_limit {
38 scope_range = parent_point_range;
39 remaining_edit_tokens = editable_region_token_limit - parent_tokens;
40 } else {
41 break;
42 }
43 }
44
45 let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
46 let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
47
48 let mut prompt = String::new();
49
50 writeln!(&mut prompt, "```{path}").unwrap();
51 if context_range.start == Point::zero() {
52 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
53 }
54
55 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
56 prompt.push_str(chunk.text);
57 }
58
59 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
60
61 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
62 prompt.push_str(chunk.text);
63 }
64 write!(prompt, "\n```").unwrap();
65
66 InputExcerpt {
67 context_range,
68 editable_range,
69 prompt,
70 }
71}
72
73fn push_editable_range(
74 cursor_position: Point,
75 snapshot: &BufferSnapshot,
76 editable_range: Range<Point>,
77 prompt: &mut String,
78) {
79 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
80 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
81 prompt.push_str(chunk.text);
82 }
83 prompt.push_str(CURSOR_MARKER);
84 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
85 prompt.push_str(chunk.text);
86 }
87 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
88}
89
90fn expand_range(
91 snapshot: &BufferSnapshot,
92 range: Range<Point>,
93 mut remaining_tokens: usize,
94) -> Range<Point> {
95 let mut expanded_range = range;
96 expanded_range.start.column = 0;
97 expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
98 loop {
99 let mut expanded = false;
100
101 if remaining_tokens > 0 && expanded_range.start.row > 0 {
102 expanded_range.start.row -= 1;
103 let line_tokens =
104 guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
105 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
106 expanded = true;
107 }
108
109 if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
110 expanded_range.end.row += 1;
111 expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
112 let line_tokens = guess_token_count(expanded_range.end.column as usize);
113 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
114 expanded = true;
115 }
116
117 if !expanded {
118 break;
119 }
120 }
121 expanded_range
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use gpui::{App, AppContext};
128 use indoc::indoc;
129 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
130 use std::sync::Arc;
131
132 #[gpui::test]
133 fn test_excerpt_for_cursor_position(cx: &mut App) {
134 let text = indoc! {r#"
135 fn foo() {
136 let x = 42;
137 println!("Hello, world!");
138 }
139
140 fn bar() {
141 let x = 42;
142 let mut sum = 0;
143 for i in 0..x {
144 sum += i;
145 }
146 println!("Sum: {}", sum);
147 return sum;
148 }
149
150 fn generate_random_numbers() -> Vec<i32> {
151 let mut rng = rand::thread_rng();
152 let mut numbers = Vec::new();
153 for _ in 0..5 {
154 numbers.push(rng.random_range(1..101));
155 }
156 numbers
157 }
158 "#};
159 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
160 let snapshot = buffer.read(cx).snapshot();
161
162 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
163 // when a larger scope doesn't fit the editable region.
164 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
165 assert_eq!(
166 excerpt.prompt,
167 indoc! {r#"
168 ```main.rs
169 let x = 42;
170 println!("Hello, world!");
171 <|editable_region_start|>
172 }
173
174 fn bar() {
175 let x = 42;
176 let mut sum = 0;
177 for i in 0..x {
178 sum += i;
179 }
180 println!("Sum: {}", sum);
181 r<|user_cursor_is_here|>eturn sum;
182 }
183
184 fn generate_random_numbers() -> Vec<i32> {
185 <|editable_region_end|>
186 let mut rng = rand::thread_rng();
187 let mut numbers = Vec::new();
188 ```"#}
189 );
190
191 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
192 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
193 assert_eq!(
194 excerpt.prompt,
195 indoc! {r#"
196 ```main.rs
197 fn bar() {
198 let x = 42;
199 let mut sum = 0;
200 <|editable_region_start|>
201 for i in 0..x {
202 sum += i;
203 }
204 println!("Sum: {}", sum);
205 r<|user_cursor_is_here|>eturn sum;
206 }
207
208 fn generate_random_numbers() -> Vec<i32> {
209 let mut rng = rand::thread_rng();
210 <|editable_region_end|>
211 let mut numbers = Vec::new();
212 for _ in 0..5 {
213 numbers.push(rng.random_range(1..101));
214 ```"#}
215 );
216 }
217
218 fn rust_lang() -> Language {
219 Language::new(
220 LanguageConfig {
221 name: "Rust".into(),
222 matcher: LanguageMatcher {
223 path_suffixes: vec!["rs".to_string()],
224 ..Default::default()
225 },
226 ..Default::default()
227 },
228 Some(tree_sitter_rust::LANGUAGE.into()),
229 )
230 }
231}