edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7pub mod text_similarity;
  8
  9use std::sync::Arc;
 10
 11use collections::HashMap;
 12use gpui::{App, AppContext as _, Entity, Task};
 13use language::BufferSnapshot;
 14use text::{Point, ToOffset as _};
 15
 16pub use declaration::*;
 17pub use declaration_scoring::*;
 18pub use excerpt::*;
 19pub use reference::*;
 20pub use syntax_index::*;
 21
 22#[derive(Clone, Debug)]
 23pub struct EditPredictionContext {
 24    pub excerpt: EditPredictionExcerpt,
 25    pub excerpt_text: EditPredictionExcerptText,
 26    pub cursor_offset_in_excerpt: usize,
 27    pub declarations: Vec<ScoredDeclaration>,
 28}
 29
 30impl EditPredictionContext {
 31    pub fn gather_context_in_background(
 32        cursor_point: Point,
 33        buffer: BufferSnapshot,
 34        excerpt_options: EditPredictionExcerptOptions,
 35        syntax_index: Option<Entity<SyntaxIndex>>,
 36        cx: &mut App,
 37    ) -> Task<Option<Self>> {
 38        if let Some(syntax_index) = syntax_index {
 39            let index_state =
 40                syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
 41            cx.background_spawn(async move {
 42                let index_state = index_state.upgrade()?;
 43                let index_state = index_state.lock().await;
 44                Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
 45            })
 46        } else {
 47            cx.background_spawn(async move {
 48                Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
 49            })
 50        }
 51    }
 52
 53    pub fn gather_context(
 54        cursor_point: Point,
 55        buffer: &BufferSnapshot,
 56        excerpt_options: &EditPredictionExcerptOptions,
 57        index_state: Option<&SyntaxIndexState>,
 58    ) -> Option<Self> {
 59        Self::gather_context_with_references_fn(
 60            cursor_point,
 61            buffer,
 62            excerpt_options,
 63            index_state,
 64            references_in_excerpt,
 65        )
 66    }
 67
 68    pub fn gather_context_with_references_fn(
 69        cursor_point: Point,
 70        buffer: &BufferSnapshot,
 71        excerpt_options: &EditPredictionExcerptOptions,
 72        index_state: Option<&SyntaxIndexState>,
 73        get_references: impl FnOnce(
 74            &EditPredictionExcerpt,
 75            &EditPredictionExcerptText,
 76            &BufferSnapshot,
 77        ) -> HashMap<Identifier, Vec<Reference>>,
 78    ) -> Option<Self> {
 79        let excerpt = EditPredictionExcerpt::select_from_buffer(
 80            cursor_point,
 81            buffer,
 82            excerpt_options,
 83            index_state,
 84        )?;
 85        let excerpt_text = excerpt.text(buffer);
 86        let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
 87
 88        let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
 89        let adjacent_end = Point::new(cursor_point.row + 1, 0);
 90        let adjacent_occurrences = text_similarity::Occurrences::within_string(
 91            &buffer
 92                .text_for_range(adjacent_start..adjacent_end)
 93                .collect::<String>(),
 94        );
 95
 96        let cursor_offset_in_file = cursor_point.to_offset(buffer);
 97        // TODO fix this to not need saturating_sub
 98        let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
 99
100        let declarations = if let Some(index_state) = index_state {
101            let references = get_references(&excerpt, &excerpt_text, buffer);
102
103            scored_declarations(
104                &index_state,
105                &excerpt,
106                &excerpt_occurrences,
107                &adjacent_occurrences,
108                references,
109                cursor_offset_in_file,
110                buffer,
111            )
112        } else {
113            vec![]
114        };
115
116        Some(Self {
117            excerpt,
118            excerpt_text,
119            cursor_offset_in_excerpt,
120            declarations,
121        })
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use std::sync::Arc;
129
130    use gpui::{Entity, TestAppContext};
131    use indoc::indoc;
132    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
133    use project::{FakeFs, Project};
134    use serde_json::json;
135    use settings::SettingsStore;
136    use util::path;
137
138    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
139
140    #[gpui::test]
141    async fn test_call_site(cx: &mut TestAppContext) {
142        let (project, index, _rust_lang_id) = init_test(cx).await;
143
144        let buffer = project
145            .update(cx, |project, cx| {
146                let project_path = project.find_project_path("c.rs", cx).unwrap();
147                project.open_buffer(project_path, cx)
148            })
149            .await
150            .unwrap();
151
152        cx.run_until_parked();
153
154        // first process_data call site
155        let cursor_point = language::Point::new(8, 21);
156        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
157
158        let context = cx
159            .update(|cx| {
160                EditPredictionContext::gather_context_in_background(
161                    cursor_point,
162                    buffer_snapshot,
163                    EditPredictionExcerptOptions {
164                        max_bytes: 60,
165                        min_bytes: 10,
166                        target_before_cursor_over_total_bytes: 0.5,
167                    },
168                    Some(index),
169                    cx,
170                )
171            })
172            .await
173            .unwrap();
174
175        let mut snippet_identifiers = context
176            .declarations
177            .iter()
178            .map(|snippet| snippet.identifier.name.as_ref())
179            .collect::<Vec<_>>();
180        snippet_identifiers.sort();
181        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
182        drop(buffer);
183    }
184
185    async fn init_test(
186        cx: &mut TestAppContext,
187    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
188        cx.update(|cx| {
189            let settings_store = SettingsStore::test(cx);
190            cx.set_global(settings_store);
191            language::init(cx);
192            Project::init_settings(cx);
193        });
194
195        let fs = FakeFs::new(cx.executor());
196        fs.insert_tree(
197            path!("/root"),
198            json!({
199                "a.rs": indoc! {r#"
200                    fn main() {
201                        let x = 1;
202                        let y = 2;
203                        let z = add(x, y);
204                        println!("Result: {}", z);
205                    }
206
207                    fn add(a: i32, b: i32) -> i32 {
208                        a + b
209                    }
210                "#},
211                "b.rs": indoc! {"
212                    pub struct Config {
213                        pub name: String,
214                        pub value: i32,
215                    }
216
217                    impl Config {
218                        pub fn new(name: String, value: i32) -> Self {
219                            Config { name, value }
220                        }
221                    }
222                "},
223                "c.rs": indoc! {r#"
224                    use std::collections::HashMap;
225
226                    fn main() {
227                        let args: Vec<String> = std::env::args().collect();
228                        let data: Vec<i32> = args[1..]
229                            .iter()
230                            .filter_map(|s| s.parse().ok())
231                            .collect();
232                        let result = process_data(data);
233                        println!("{:?}", result);
234                    }
235
236                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
237                        let mut counts = HashMap::new();
238                        for value in data {
239                            *counts.entry(value).or_insert(0) += 1;
240                        }
241                        counts
242                    }
243
244                    #[cfg(test)]
245                    mod tests {
246                        use super::*;
247
248                        #[test]
249                        fn test_process_data() {
250                            let data = vec![1, 2, 2, 3];
251                            let result = process_data(data);
252                            assert_eq!(result.get(&2), Some(&2));
253                        }
254                    }
255                "#}
256            }),
257        )
258        .await;
259        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
260        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
261        let lang = rust_lang();
262        let lang_id = lang.id();
263        language_registry.add(Arc::new(lang));
264
265        let file_indexing_parallelism = 2;
266        let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
267        cx.run_until_parked();
268
269        (project, index, lang_id)
270    }
271
272    fn rust_lang() -> Language {
273        Language::new(
274            LanguageConfig {
275                name: "Rust".into(),
276                matcher: LanguageMatcher {
277                    path_suffixes: vec!["rs".to_string()],
278                    ..Default::default()
279                },
280                ..Default::default()
281            },
282            Some(tree_sitter_rust::LANGUAGE.into()),
283        )
284        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
285        .unwrap()
286        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
287        .unwrap()
288    }
289}