1use crate::{
2 CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
3 tokens_for_bytes,
4};
5use language::{BufferSnapshot, Point};
6use std::{fmt::Write, ops::Range};
7
8#[derive(Debug)]
9pub struct InputExcerpt {
10 pub editable_range: Range<Point>,
11 pub prompt: String,
12 pub speculated_output: String,
13}
14
15pub fn excerpt_for_cursor_position(
16 position: Point,
17 path: &str,
18 snapshot: &BufferSnapshot,
19 editable_region_token_limit: usize,
20 context_token_limit: usize,
21) -> InputExcerpt {
22 let mut scope_range = position..position;
23 let mut remaining_edit_tokens = editable_region_token_limit;
24
25 while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
26 let parent_tokens = tokens_for_bytes(parent.byte_range().len());
27 let parent_point_range = Point::new(
28 parent.start_position().row as u32,
29 parent.start_position().column as u32,
30 )
31 ..Point::new(
32 parent.end_position().row as u32,
33 parent.end_position().column as u32,
34 );
35 if parent_point_range == scope_range {
36 break;
37 } else if parent_tokens <= editable_region_token_limit {
38 scope_range = parent_point_range;
39 remaining_edit_tokens = editable_region_token_limit - parent_tokens;
40 } else {
41 break;
42 }
43 }
44
45 let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
46 let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
47
48 let mut prompt = String::new();
49 let mut speculated_output = String::new();
50
51 writeln!(&mut prompt, "```{path}").unwrap();
52 if context_range.start == Point::zero() {
53 writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
54 }
55
56 for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
57 prompt.push_str(chunk.text);
58 }
59
60 push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
61 push_editable_range(
62 position,
63 snapshot,
64 editable_range.clone(),
65 &mut speculated_output,
66 );
67
68 for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
69 prompt.push_str(chunk.text);
70 }
71 write!(prompt, "\n```").unwrap();
72
73 InputExcerpt {
74 editable_range,
75 prompt,
76 speculated_output,
77 }
78}
79
80fn push_editable_range(
81 cursor_position: Point,
82 snapshot: &BufferSnapshot,
83 editable_range: Range<Point>,
84 prompt: &mut String,
85) {
86 writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
87 for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
88 prompt.push_str(chunk.text);
89 }
90 prompt.push_str(CURSOR_MARKER);
91 for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
92 prompt.push_str(chunk.text);
93 }
94 write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
95}
96
97fn expand_range(
98 snapshot: &BufferSnapshot,
99 range: Range<Point>,
100 mut remaining_tokens: usize,
101) -> Range<Point> {
102 let mut expanded_range = range.clone();
103 expanded_range.start.column = 0;
104 expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
105 loop {
106 let mut expanded = false;
107
108 if remaining_tokens > 0 && expanded_range.start.row > 0 {
109 expanded_range.start.row -= 1;
110 let line_tokens =
111 tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize);
112 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
113 expanded = true;
114 }
115
116 if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
117 expanded_range.end.row += 1;
118 expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
119 let line_tokens = tokens_for_bytes(expanded_range.end.column as usize);
120 remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
121 expanded = true;
122 }
123
124 if !expanded {
125 break;
126 }
127 }
128 expanded_range
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use gpui::{App, AppContext};
135 use indoc::indoc;
136 use language::{Buffer, Language, LanguageConfig, LanguageMatcher};
137 use std::sync::Arc;
138
139 #[gpui::test]
140 fn test_excerpt_for_cursor_position(cx: &mut App) {
141 let text = indoc! {r#"
142 fn foo() {
143 let x = 42;
144 println!("Hello, world!");
145 }
146
147 fn bar() {
148 let x = 42;
149 let mut sum = 0;
150 for i in 0..x {
151 sum += i;
152 }
153 println!("Sum: {}", sum);
154 return sum;
155 }
156
157 fn generate_random_numbers() -> Vec<i32> {
158 let mut rng = rand::thread_rng();
159 let mut numbers = Vec::new();
160 for _ in 0..5 {
161 numbers.push(rng.gen_range(1..101));
162 }
163 numbers
164 }
165 "#};
166 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
167 let snapshot = buffer.read(cx).snapshot();
168
169 // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
170 // when a larger scope doesn't fit the editable region.
171 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
172 assert_eq!(
173 excerpt.prompt,
174 indoc! {r#"
175 ```main.rs
176 let x = 42;
177 println!("Hello, world!");
178 <|editable_region_start|>
179 }
180
181 fn bar() {
182 let x = 42;
183 let mut sum = 0;
184 for i in 0..x {
185 sum += i;
186 }
187 println!("Sum: {}", sum);
188 r<|user_cursor_is_here|>eturn sum;
189 }
190
191 fn generate_random_numbers() -> Vec<i32> {
192 <|editable_region_end|>
193 let mut rng = rand::thread_rng();
194 let mut numbers = Vec::new();
195 ```"#}
196 );
197
198 // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
199 let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
200 assert_eq!(
201 excerpt.prompt,
202 indoc! {r#"
203 ```main.rs
204 fn bar() {
205 let x = 42;
206 let mut sum = 0;
207 <|editable_region_start|>
208 for i in 0..x {
209 sum += i;
210 }
211 println!("Sum: {}", sum);
212 r<|user_cursor_is_here|>eturn sum;
213 }
214
215 fn generate_random_numbers() -> Vec<i32> {
216 let mut rng = rand::thread_rng();
217 <|editable_region_end|>
218 let mut numbers = Vec::new();
219 for _ in 0..5 {
220 numbers.push(rng.gen_range(1..101));
221 ```"#}
222 );
223 }
224
225 fn rust_lang() -> Language {
226 Language::new(
227 LanguageConfig {
228 name: "Rust".into(),
229 matcher: LanguageMatcher {
230 path_suffixes: vec!["rs".to_string()],
231 ..Default::default()
232 },
233 ..Default::default()
234 },
235 Some(tree_sitter_rust::LANGUAGE.into()),
236 )
237 }
238}