edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod imports;
  5mod outline;
  6mod reference;
  7mod syntax_index;
  8pub mod text_similarity;
  9
 10use std::{path::Path, sync::Arc};
 11
 12use collections::HashMap;
 13use gpui::{App, AppContext as _, Entity, Task};
 14use language::BufferSnapshot;
 15use text::{Point, ToOffset as _};
 16
 17pub use declaration::*;
 18pub use declaration_scoring::*;
 19pub use excerpt::*;
 20pub use imports::*;
 21pub use reference::*;
 22pub use syntax_index::*;
 23
 24#[derive(Clone, Debug, PartialEq)]
 25pub struct EditPredictionContextOptions {
 26    pub use_imports: bool,
 27    pub excerpt: EditPredictionExcerptOptions,
 28    pub score: EditPredictionScoreOptions,
 29}
 30
 31#[derive(Clone, Debug)]
 32pub struct EditPredictionContext {
 33    pub excerpt: EditPredictionExcerpt,
 34    pub excerpt_text: EditPredictionExcerptText,
 35    pub cursor_offset_in_excerpt: usize,
 36    pub declarations: Vec<ScoredDeclaration>,
 37}
 38
 39impl EditPredictionContext {
 40    pub fn gather_context_in_background(
 41        cursor_point: Point,
 42        buffer: BufferSnapshot,
 43        options: EditPredictionContextOptions,
 44        syntax_index: Option<Entity<SyntaxIndex>>,
 45        cx: &mut App,
 46    ) -> Task<Option<Self>> {
 47        let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
 48            let mut path = f.worktree.read(cx).absolutize(&f.path);
 49            if path.pop() { Some(path) } else { None }
 50        });
 51
 52        if let Some(syntax_index) = syntax_index {
 53            let index_state =
 54                syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
 55            cx.background_spawn(async move {
 56                let parent_abs_path = parent_abs_path.as_deref();
 57                let index_state = index_state.upgrade()?;
 58                let index_state = index_state.lock().await;
 59                Self::gather_context(
 60                    cursor_point,
 61                    &buffer,
 62                    parent_abs_path,
 63                    &options,
 64                    Some(&index_state),
 65                )
 66            })
 67        } else {
 68            cx.background_spawn(async move {
 69                let parent_abs_path = parent_abs_path.as_deref();
 70                Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
 71            })
 72        }
 73    }
 74
 75    pub fn gather_context(
 76        cursor_point: Point,
 77        buffer: &BufferSnapshot,
 78        parent_abs_path: Option<&Path>,
 79        options: &EditPredictionContextOptions,
 80        index_state: Option<&SyntaxIndexState>,
 81    ) -> Option<Self> {
 82        let imports = if options.use_imports {
 83            Imports::gather(&buffer, parent_abs_path)
 84        } else {
 85            Imports::default()
 86        };
 87        Self::gather_context_with_references_fn(
 88            cursor_point,
 89            buffer,
 90            &imports,
 91            options,
 92            index_state,
 93            references_in_excerpt,
 94        )
 95    }
 96
 97    pub fn gather_context_with_references_fn(
 98        cursor_point: Point,
 99        buffer: &BufferSnapshot,
100        imports: &Imports,
101        options: &EditPredictionContextOptions,
102        index_state: Option<&SyntaxIndexState>,
103        get_references: impl FnOnce(
104            &EditPredictionExcerpt,
105            &EditPredictionExcerptText,
106            &BufferSnapshot,
107        ) -> HashMap<Identifier, Vec<Reference>>,
108    ) -> Option<Self> {
109        let excerpt = EditPredictionExcerpt::select_from_buffer(
110            cursor_point,
111            buffer,
112            &options.excerpt,
113            index_state,
114        )?;
115        let excerpt_text = excerpt.text(buffer);
116        let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
117
118        let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
119        let adjacent_end = Point::new(cursor_point.row + 1, 0);
120        let adjacent_occurrences = text_similarity::Occurrences::within_string(
121            &buffer
122                .text_for_range(adjacent_start..adjacent_end)
123                .collect::<String>(),
124        );
125
126        let cursor_offset_in_file = cursor_point.to_offset(buffer);
127        // TODO fix this to not need saturating_sub
128        let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
129
130        let declarations = if let Some(index_state) = index_state {
131            let references = get_references(&excerpt, &excerpt_text, buffer);
132
133            scored_declarations(
134                &options.score,
135                &index_state,
136                &excerpt,
137                &excerpt_occurrences,
138                &adjacent_occurrences,
139                &imports,
140                references,
141                cursor_offset_in_file,
142                buffer,
143            )
144        } else {
145            vec![]
146        };
147
148        Some(Self {
149            excerpt,
150            excerpt_text,
151            cursor_offset_in_excerpt,
152            declarations,
153        })
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use std::sync::Arc;
161
162    use gpui::{Entity, TestAppContext};
163    use indoc::indoc;
164    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
165    use project::{FakeFs, Project};
166    use serde_json::json;
167    use settings::SettingsStore;
168    use util::path;
169
170    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
171
172    #[gpui::test]
173    async fn test_call_site(cx: &mut TestAppContext) {
174        let (project, index, _rust_lang_id) = init_test(cx).await;
175
176        let buffer = project
177            .update(cx, |project, cx| {
178                let project_path = project.find_project_path("c.rs", cx).unwrap();
179                project.open_buffer(project_path, cx)
180            })
181            .await
182            .unwrap();
183
184        cx.run_until_parked();
185
186        // first process_data call site
187        let cursor_point = language::Point::new(8, 21);
188        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
189
190        let context = cx
191            .update(|cx| {
192                EditPredictionContext::gather_context_in_background(
193                    cursor_point,
194                    buffer_snapshot,
195                    EditPredictionContextOptions {
196                        use_imports: true,
197                        excerpt: EditPredictionExcerptOptions {
198                            max_bytes: 60,
199                            min_bytes: 10,
200                            target_before_cursor_over_total_bytes: 0.5,
201                        },
202                        score: EditPredictionScoreOptions {
203                            omit_excerpt_overlaps: true,
204                        },
205                    },
206                    Some(index.clone()),
207                    cx,
208                )
209            })
210            .await
211            .unwrap();
212
213        let mut snippet_identifiers = context
214            .declarations
215            .iter()
216            .map(|snippet| snippet.identifier.name.as_ref())
217            .collect::<Vec<_>>();
218        snippet_identifiers.sort();
219        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
220        drop(buffer);
221    }
222
223    async fn init_test(
224        cx: &mut TestAppContext,
225    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
226        cx.update(|cx| {
227            let settings_store = SettingsStore::test(cx);
228            cx.set_global(settings_store);
229            language::init(cx);
230            Project::init_settings(cx);
231        });
232
233        let fs = FakeFs::new(cx.executor());
234        fs.insert_tree(
235            path!("/root"),
236            json!({
237                "a.rs": indoc! {r#"
238                    fn main() {
239                        let x = 1;
240                        let y = 2;
241                        let z = add(x, y);
242                        println!("Result: {}", z);
243                    }
244
245                    fn add(a: i32, b: i32) -> i32 {
246                        a + b
247                    }
248                "#},
249                "b.rs": indoc! {"
250                    pub struct Config {
251                        pub name: String,
252                        pub value: i32,
253                    }
254
255                    impl Config {
256                        pub fn new(name: String, value: i32) -> Self {
257                            Config { name, value }
258                        }
259                    }
260                "},
261                "c.rs": indoc! {r#"
262                    use std::collections::HashMap;
263
264                    fn main() {
265                        let args: Vec<String> = std::env::args().collect();
266                        let data: Vec<i32> = args[1..]
267                            .iter()
268                            .filter_map(|s| s.parse().ok())
269                            .collect();
270                        let result = process_data(data);
271                        println!("{:?}", result);
272                    }
273
274                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
275                        let mut counts = HashMap::new();
276                        for value in data {
277                            *counts.entry(value).or_insert(0) += 1;
278                        }
279                        counts
280                    }
281
282                    #[cfg(test)]
283                    mod tests {
284                        use super::*;
285
286                        #[test]
287                        fn test_process_data() {
288                            let data = vec![1, 2, 2, 3];
289                            let result = process_data(data);
290                            assert_eq!(result.get(&2), Some(&2));
291                        }
292                    }
293                "#}
294            }),
295        )
296        .await;
297        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
298        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
299        let lang = rust_lang();
300        let lang_id = lang.id();
301        language_registry.add(Arc::new(lang));
302
303        let file_indexing_parallelism = 2;
304        let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
305        cx.run_until_parked();
306
307        (project, index, lang_id)
308    }
309
310    fn rust_lang() -> Language {
311        Language::new(
312            LanguageConfig {
313                name: "Rust".into(),
314                matcher: LanguageMatcher {
315                    path_suffixes: vec!["rs".to_string()],
316                    ..Default::default()
317                },
318                ..Default::default()
319            },
320            Some(tree_sitter_rust::LANGUAGE.into()),
321        )
322        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
323        .unwrap()
324        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
325        .unwrap()
326    }
327}