input_excerpt.rs

  1use crate::{
  2    BYTES_PER_TOKEN_GUESS, CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
  3    START_OF_FILE_MARKER,
  4};
  5use language::{BufferSnapshot, Point};
  6use std::{fmt::Write, ops::Range};
  7
  8pub struct InputExcerpt {
  9    pub editable_range: Range<Point>,
 10    pub prompt: String,
 11    pub speculated_output: String,
 12}
 13
 14pub fn excerpt_for_cursor_position(
 15    position: Point,
 16    path: &str,
 17    snapshot: &BufferSnapshot,
 18    editable_region_token_limit: usize,
 19    context_token_limit: usize,
 20) -> InputExcerpt {
 21    let mut scope_range = position..position;
 22    let mut remaining_edit_tokens = editable_region_token_limit;
 23
 24    while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
 25        let parent_tokens = tokens_for_bytes(parent.byte_range().len());
 26        if parent_tokens <= editable_region_token_limit {
 27            scope_range = Point::new(
 28                parent.start_position().row as u32,
 29                parent.start_position().column as u32,
 30            )
 31                ..Point::new(
 32                    parent.end_position().row as u32,
 33                    parent.end_position().column as u32,
 34                );
 35            remaining_edit_tokens = editable_region_token_limit - parent_tokens;
 36        } else {
 37            break;
 38        }
 39    }
 40
 41    let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
 42    let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
 43
 44    let mut prompt = String::new();
 45    let mut speculated_output = String::new();
 46
 47    writeln!(&mut prompt, "```{path}").unwrap();
 48    if context_range.start == Point::zero() {
 49        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
 50    }
 51
 52    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
 53        prompt.push_str(chunk.text);
 54    }
 55
 56    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
 57    push_editable_range(
 58        position,
 59        snapshot,
 60        editable_range.clone(),
 61        &mut speculated_output,
 62    );
 63
 64    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
 65        prompt.push_str(chunk.text);
 66    }
 67    write!(prompt, "\n```").unwrap();
 68
 69    InputExcerpt {
 70        editable_range,
 71        prompt,
 72        speculated_output,
 73    }
 74}
 75
 76fn push_editable_range(
 77    cursor_position: Point,
 78    snapshot: &BufferSnapshot,
 79    editable_range: Range<Point>,
 80    prompt: &mut String,
 81) {
 82    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
 83    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
 84        prompt.push_str(chunk.text);
 85    }
 86    prompt.push_str(CURSOR_MARKER);
 87    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
 88        prompt.push_str(chunk.text);
 89    }
 90    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
 91}
 92
 93fn expand_range(
 94    snapshot: &BufferSnapshot,
 95    range: Range<Point>,
 96    mut remaining_tokens: usize,
 97) -> Range<Point> {
 98    let mut expanded_range = range.clone();
 99    expanded_range.start.column = 0;
100    expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
101    loop {
102        let mut expanded = false;
103
104        if remaining_tokens > 0 && expanded_range.start.row > 0 {
105            expanded_range.start.row -= 1;
106            let line_tokens =
107                tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize);
108            remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
109            expanded = true;
110        }
111
112        if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
113            expanded_range.end.row += 1;
114            expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
115            let line_tokens = tokens_for_bytes(expanded_range.end.column as usize);
116            remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
117            expanded = true;
118        }
119
120        if !expanded {
121            break;
122        }
123    }
124    expanded_range
125}
126
127fn tokens_for_bytes(bytes: usize) -> usize {
128    bytes / BYTES_PER_TOKEN_GUESS
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use gpui::{App, AppContext};
135    use indoc::indoc;
136    use language::{Buffer, Language, LanguageConfig, LanguageMatcher};
137    use std::sync::Arc;
138
139    #[gpui::test]
140    fn test_excerpt_for_cursor_position(cx: &mut App) {
141        let text = indoc! {r#"
142            fn foo() {
143                let x = 42;
144                println!("Hello, world!");
145            }
146
147            fn bar() {
148                let x = 42;
149                let mut sum = 0;
150                for i in 0..x {
151                    sum += i;
152                }
153                println!("Sum: {}", sum);
154                return sum;
155            }
156
157            fn generate_random_numbers() -> Vec<i32> {
158                let mut rng = rand::thread_rng();
159                let mut numbers = Vec::new();
160                for _ in 0..5 {
161                    numbers.push(rng.gen_range(1..101));
162                }
163                numbers
164            }
165        "#};
166        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
167        let snapshot = buffer.read(cx).snapshot();
168
169        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
170        // when a larger scope doesn't fit the editable region.
171        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
172        assert_eq!(
173            excerpt.prompt,
174            indoc! {r#"
175            ```main.rs
176                let x = 42;
177                println!("Hello, world!");
178            <|editable_region_start|>
179            }
180
181            fn bar() {
182                let x = 42;
183                let mut sum = 0;
184                for i in 0..x {
185                    sum += i;
186                }
187                println!("Sum: {}", sum);
188                r<|user_cursor_is_here|>eturn sum;
189            }
190
191            fn generate_random_numbers() -> Vec<i32> {
192            <|editable_region_end|>
193                let mut rng = rand::thread_rng();
194                let mut numbers = Vec::new();
195            ```"#}
196        );
197
198        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
199        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
200        assert_eq!(
201            excerpt.prompt,
202            indoc! {r#"
203            ```main.rs
204            fn bar() {
205                let x = 42;
206                let mut sum = 0;
207            <|editable_region_start|>
208                for i in 0..x {
209                    sum += i;
210                }
211                println!("Sum: {}", sum);
212                r<|user_cursor_is_here|>eturn sum;
213            }
214
215            fn generate_random_numbers() -> Vec<i32> {
216                let mut rng = rand::thread_rng();
217            <|editable_region_end|>
218                let mut numbers = Vec::new();
219                for _ in 0..5 {
220                    numbers.push(rng.gen_range(1..101));
221            ```"#}
222        );
223    }
224
225    fn rust_lang() -> Language {
226        Language::new(
227            LanguageConfig {
228                name: "Rust".into(),
229                matcher: LanguageMatcher {
230                    path_suffixes: vec!["rs".to_string()],
231                    ..Default::default()
232                },
233                ..Default::default()
234            },
235            Some(tree_sitter_rust::LANGUAGE.into()),
236        )
237    }
238}