input_excerpt.rs

  1use super::{
  2    CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
  3    guess_token_count,
  4};
  5use language::{BufferSnapshot, Point};
  6use std::{fmt::Write, ops::Range};
  7
  8#[derive(Debug)]
  9pub struct InputExcerpt {
 10    pub context_range: Range<Point>,
 11    pub editable_range: Range<Point>,
 12    pub prompt: String,
 13}
 14
 15pub fn excerpt_for_cursor_position(
 16    position: Point,
 17    path: &str,
 18    snapshot: &BufferSnapshot,
 19    editable_region_token_limit: usize,
 20    context_token_limit: usize,
 21) -> InputExcerpt {
 22    let mut scope_range = position..position;
 23    let mut remaining_edit_tokens = editable_region_token_limit;
 24
 25    while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
 26        let parent_tokens = guess_token_count(parent.byte_range().len());
 27        let parent_point_range = Point::new(
 28            parent.start_position().row as u32,
 29            parent.start_position().column as u32,
 30        )
 31            ..Point::new(
 32                parent.end_position().row as u32,
 33                parent.end_position().column as u32,
 34            );
 35        if parent_point_range == scope_range {
 36            break;
 37        } else if parent_tokens <= editable_region_token_limit {
 38            scope_range = parent_point_range;
 39            remaining_edit_tokens = editable_region_token_limit - parent_tokens;
 40        } else {
 41            break;
 42        }
 43    }
 44
 45    let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
 46    let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
 47
 48    let mut prompt = String::new();
 49
 50    writeln!(&mut prompt, "```{path}").unwrap();
 51    if context_range.start == Point::zero() {
 52        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
 53    }
 54
 55    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
 56        prompt.push_str(chunk.text);
 57    }
 58
 59    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
 60
 61    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
 62        prompt.push_str(chunk.text);
 63    }
 64    write!(prompt, "\n```").unwrap();
 65
 66    InputExcerpt {
 67        context_range,
 68        editable_range,
 69        prompt,
 70    }
 71}
 72
 73fn push_editable_range(
 74    cursor_position: Point,
 75    snapshot: &BufferSnapshot,
 76    editable_range: Range<Point>,
 77    prompt: &mut String,
 78) {
 79    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
 80    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
 81        prompt.push_str(chunk.text);
 82    }
 83    prompt.push_str(CURSOR_MARKER);
 84    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
 85        prompt.push_str(chunk.text);
 86    }
 87    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
 88}
 89
 90fn expand_range(
 91    snapshot: &BufferSnapshot,
 92    range: Range<Point>,
 93    mut remaining_tokens: usize,
 94) -> Range<Point> {
 95    let mut expanded_range = range;
 96    expanded_range.start.column = 0;
 97    expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
 98    loop {
 99        let mut expanded = false;
100
101        if remaining_tokens > 0 && expanded_range.start.row > 0 {
102            expanded_range.start.row -= 1;
103            let line_tokens =
104                guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
105            remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
106            expanded = true;
107        }
108
109        if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
110            expanded_range.end.row += 1;
111            expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
112            let line_tokens = guess_token_count(expanded_range.end.column as usize);
113            remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
114            expanded = true;
115        }
116
117        if !expanded {
118            break;
119        }
120    }
121    expanded_range
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use gpui::{App, AppContext};
128    use indoc::indoc;
129    use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
130    use std::sync::Arc;
131
132    #[gpui::test]
133    fn test_excerpt_for_cursor_position(cx: &mut App) {
134        let text = indoc! {r#"
135            fn foo() {
136                let x = 42;
137                println!("Hello, world!");
138            }
139
140            fn bar() {
141                let x = 42;
142                let mut sum = 0;
143                for i in 0..x {
144                    sum += i;
145                }
146                println!("Sum: {}", sum);
147                return sum;
148            }
149
150            fn generate_random_numbers() -> Vec<i32> {
151                let mut rng = rand::thread_rng();
152                let mut numbers = Vec::new();
153                for _ in 0..5 {
154                    numbers.push(rng.random_range(1..101));
155                }
156                numbers
157            }
158        "#};
159        let buffer =
160            cx.new(|cx| Buffer::local(text, cx).with_language_immediate(Arc::new(rust_lang()), cx));
161        let snapshot = buffer.read(cx).snapshot();
162
163        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
164        // when a larger scope doesn't fit the editable region.
165        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
166        assert_eq!(
167            excerpt.prompt,
168            indoc! {r#"
169            ```main.rs
170                let x = 42;
171                println!("Hello, world!");
172            <|editable_region_start|>
173            }
174
175            fn bar() {
176                let x = 42;
177                let mut sum = 0;
178                for i in 0..x {
179                    sum += i;
180                }
181                println!("Sum: {}", sum);
182                r<|user_cursor_is_here|>eturn sum;
183            }
184
185            fn generate_random_numbers() -> Vec<i32> {
186            <|editable_region_end|>
187                let mut rng = rand::thread_rng();
188                let mut numbers = Vec::new();
189            ```"#}
190        );
191
192        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
193        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
194        assert_eq!(
195            excerpt.prompt,
196            indoc! {r#"
197            ```main.rs
198            fn bar() {
199                let x = 42;
200                let mut sum = 0;
201            <|editable_region_start|>
202                for i in 0..x {
203                    sum += i;
204                }
205                println!("Sum: {}", sum);
206                r<|user_cursor_is_here|>eturn sum;
207            }
208
209            fn generate_random_numbers() -> Vec<i32> {
210                let mut rng = rand::thread_rng();
211            <|editable_region_end|>
212                let mut numbers = Vec::new();
213                for _ in 0..5 {
214                    numbers.push(rng.random_range(1..101));
215            ```"#}
216        );
217    }
218
219    fn rust_lang() -> Language {
220        Language::new(
221            LanguageConfig {
222                name: "Rust".into(),
223                matcher: LanguageMatcher {
224                    path_suffixes: vec!["rs".to_string()],
225                    ..Default::default()
226                },
227                ..Default::default()
228            },
229            Some(tree_sitter_rust::LANGUAGE.into()),
230        )
231    }
232}