edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod imports;
  5mod outline;
  6mod reference;
  7mod syntax_index;
  8pub mod text_similarity;
  9
 10use std::{path::Path, sync::Arc};
 11
 12use cloud_llm_client::predict_edits_v3;
 13use collections::HashMap;
 14use gpui::{App, AppContext as _, Entity, Task};
 15use language::BufferSnapshot;
 16use text::{Point, ToOffset as _};
 17
 18pub use declaration::*;
 19pub use declaration_scoring::*;
 20pub use excerpt::*;
 21pub use imports::*;
 22pub use reference::*;
 23pub use syntax_index::*;
 24
 25pub use predict_edits_v3::Line;
 26
 27#[derive(Clone, Debug, PartialEq)]
 28pub struct EditPredictionContextOptions {
 29    pub use_imports: bool,
 30    pub excerpt: EditPredictionExcerptOptions,
 31    pub score: EditPredictionScoreOptions,
 32    pub max_retrieved_declarations: u8,
 33}
 34
 35#[derive(Clone, Debug)]
 36pub struct EditPredictionContext {
 37    pub excerpt: EditPredictionExcerpt,
 38    pub excerpt_text: EditPredictionExcerptText,
 39    pub cursor_point: Point,
 40    pub declarations: Vec<ScoredDeclaration>,
 41}
 42
 43impl EditPredictionContext {
 44    pub fn gather_context_in_background(
 45        cursor_point: Point,
 46        buffer: BufferSnapshot,
 47        options: EditPredictionContextOptions,
 48        syntax_index: Option<Entity<SyntaxIndex>>,
 49        cx: &mut App,
 50    ) -> Task<Option<Self>> {
 51        let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
 52            let mut path = f.worktree.read(cx).absolutize(&f.path);
 53            if path.pop() { Some(path) } else { None }
 54        });
 55
 56        if let Some(syntax_index) = syntax_index {
 57            let index_state =
 58                syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
 59            cx.background_spawn(async move {
 60                let parent_abs_path = parent_abs_path.as_deref();
 61                let index_state = index_state.upgrade()?;
 62                let index_state = index_state.lock().await;
 63                Self::gather_context(
 64                    cursor_point,
 65                    &buffer,
 66                    parent_abs_path,
 67                    &options,
 68                    Some(&index_state),
 69                )
 70            })
 71        } else {
 72            cx.background_spawn(async move {
 73                let parent_abs_path = parent_abs_path.as_deref();
 74                Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
 75            })
 76        }
 77    }
 78
 79    pub fn gather_context(
 80        cursor_point: Point,
 81        buffer: &BufferSnapshot,
 82        parent_abs_path: Option<&Path>,
 83        options: &EditPredictionContextOptions,
 84        index_state: Option<&SyntaxIndexState>,
 85    ) -> Option<Self> {
 86        let imports = if options.use_imports {
 87            Imports::gather(&buffer, parent_abs_path)
 88        } else {
 89            Imports::default()
 90        };
 91        Self::gather_context_with_references_fn(
 92            cursor_point,
 93            buffer,
 94            &imports,
 95            options,
 96            index_state,
 97            references_in_excerpt,
 98        )
 99    }
100
101    pub fn gather_context_with_references_fn(
102        cursor_point: Point,
103        buffer: &BufferSnapshot,
104        imports: &Imports,
105        options: &EditPredictionContextOptions,
106        index_state: Option<&SyntaxIndexState>,
107        get_references: impl FnOnce(
108            &EditPredictionExcerpt,
109            &EditPredictionExcerptText,
110            &BufferSnapshot,
111        ) -> HashMap<Identifier, Vec<Reference>>,
112    ) -> Option<Self> {
113        let excerpt = EditPredictionExcerpt::select_from_buffer(
114            cursor_point,
115            buffer,
116            &options.excerpt,
117            index_state,
118        )?;
119        let excerpt_text = excerpt.text(buffer);
120
121        let declarations = if options.max_retrieved_declarations > 0
122            && let Some(index_state) = index_state
123        {
124            let excerpt_occurrences =
125                text_similarity::Occurrences::within_string(&excerpt_text.body);
126
127            let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
128            let adjacent_end = Point::new(cursor_point.row + 1, 0);
129            let adjacent_occurrences = text_similarity::Occurrences::within_string(
130                &buffer
131                    .text_for_range(adjacent_start..adjacent_end)
132                    .collect::<String>(),
133            );
134
135            let cursor_offset_in_file = cursor_point.to_offset(buffer);
136
137            let references = get_references(&excerpt, &excerpt_text, buffer);
138
139            let mut declarations = scored_declarations(
140                &options.score,
141                &index_state,
142                &excerpt,
143                &excerpt_occurrences,
144                &adjacent_occurrences,
145                &imports,
146                references,
147                cursor_offset_in_file,
148                buffer,
149            );
150            // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
151            declarations.truncate(options.max_retrieved_declarations as usize);
152            declarations
153        } else {
154            vec![]
155        };
156
157        Some(Self {
158            excerpt,
159            excerpt_text,
160            cursor_point,
161            declarations,
162        })
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use std::sync::Arc;
170
171    use gpui::{Entity, TestAppContext};
172    use indoc::indoc;
173    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
174    use project::{FakeFs, Project};
175    use serde_json::json;
176    use settings::SettingsStore;
177    use util::path;
178
179    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
180
181    #[gpui::test]
182    async fn test_call_site(cx: &mut TestAppContext) {
183        let (project, index, _rust_lang_id) = init_test(cx).await;
184
185        let buffer = project
186            .update(cx, |project, cx| {
187                let project_path = project.find_project_path("c.rs", cx).unwrap();
188                project.open_buffer(project_path, cx)
189            })
190            .await
191            .unwrap();
192
193        cx.run_until_parked();
194
195        // first process_data call site
196        let cursor_point = language::Point::new(8, 21);
197        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
198
199        let context = cx
200            .update(|cx| {
201                EditPredictionContext::gather_context_in_background(
202                    cursor_point,
203                    buffer_snapshot,
204                    EditPredictionContextOptions {
205                        use_imports: true,
206                        excerpt: EditPredictionExcerptOptions {
207                            max_bytes: 60,
208                            min_bytes: 10,
209                            target_before_cursor_over_total_bytes: 0.5,
210                        },
211                        score: EditPredictionScoreOptions {
212                            omit_excerpt_overlaps: true,
213                        },
214                        max_retrieved_declarations: u8::MAX,
215                    },
216                    Some(index.clone()),
217                    cx,
218                )
219            })
220            .await
221            .unwrap();
222
223        let mut snippet_identifiers = context
224            .declarations
225            .iter()
226            .map(|snippet| snippet.identifier.name.as_ref())
227            .collect::<Vec<_>>();
228        snippet_identifiers.sort();
229        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
230        drop(buffer);
231    }
232
233    async fn init_test(
234        cx: &mut TestAppContext,
235    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
236        cx.update(|cx| {
237            let settings_store = SettingsStore::test(cx);
238            cx.set_global(settings_store);
239            language::init(cx);
240            Project::init_settings(cx);
241        });
242
243        let fs = FakeFs::new(cx.executor());
244        fs.insert_tree(
245            path!("/root"),
246            json!({
247                "a.rs": indoc! {r#"
248                    fn main() {
249                        let x = 1;
250                        let y = 2;
251                        let z = add(x, y);
252                        println!("Result: {}", z);
253                    }
254
255                    fn add(a: i32, b: i32) -> i32 {
256                        a + b
257                    }
258                "#},
259                "b.rs": indoc! {"
260                    pub struct Config {
261                        pub name: String,
262                        pub value: i32,
263                    }
264
265                    impl Config {
266                        pub fn new(name: String, value: i32) -> Self {
267                            Config { name, value }
268                        }
269                    }
270                "},
271                "c.rs": indoc! {r#"
272                    use std::collections::HashMap;
273
274                    fn main() {
275                        let args: Vec<String> = std::env::args().collect();
276                        let data: Vec<i32> = args[1..]
277                            .iter()
278                            .filter_map(|s| s.parse().ok())
279                            .collect();
280                        let result = process_data(data);
281                        println!("{:?}", result);
282                    }
283
284                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
285                        let mut counts = HashMap::new();
286                        for value in data {
287                            *counts.entry(value).or_insert(0) += 1;
288                        }
289                        counts
290                    }
291
292                    #[cfg(test)]
293                    mod tests {
294                        use super::*;
295
296                        #[test]
297                        fn test_process_data() {
298                            let data = vec![1, 2, 2, 3];
299                            let result = process_data(data);
300                            assert_eq!(result.get(&2), Some(&2));
301                        }
302                    }
303                "#}
304            }),
305        )
306        .await;
307        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
308        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
309        let lang = rust_lang();
310        let lang_id = lang.id();
311        language_registry.add(Arc::new(lang));
312
313        let file_indexing_parallelism = 2;
314        let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
315        cx.run_until_parked();
316
317        (project, index, lang_id)
318    }
319
320    fn rust_lang() -> Language {
321        Language::new(
322            LanguageConfig {
323                name: "Rust".into(),
324                matcher: LanguageMatcher {
325                    path_suffixes: vec!["rs".to_string()],
326                    ..Default::default()
327                },
328                ..Default::default()
329            },
330            Some(tree_sitter_rust::LANGUAGE.into()),
331        )
332        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
333        .unwrap()
334        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
335        .unwrap()
336    }
337}