edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9use gpui::{App, AppContext as _, Entity, Task};
 10use language::BufferSnapshot;
 11use text::{Point, ToOffset as _};
 12
 13pub use declaration::*;
 14pub use declaration_scoring::*;
 15pub use excerpt::*;
 16pub use reference::*;
 17pub use syntax_index::*;
 18
 19#[derive(Debug)]
 20pub struct EditPredictionContext {
 21    pub excerpt: EditPredictionExcerpt,
 22    pub excerpt_text: EditPredictionExcerptText,
 23    pub snippets: Vec<ScoredSnippet>,
 24}
 25
 26impl EditPredictionContext {
 27    pub fn gather_context_in_background(
 28        cursor_point: Point,
 29        buffer: BufferSnapshot,
 30        excerpt_options: EditPredictionExcerptOptions,
 31        syntax_index: Entity<SyntaxIndex>,
 32        cx: &mut App,
 33    ) -> Task<Option<Self>> {
 34        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 35        cx.background_spawn(async move {
 36            let index_state = index_state.lock().await;
 37            Self::gather_context(cursor_point, &buffer, &excerpt_options, &index_state)
 38        })
 39    }
 40
 41    pub fn gather_context(
 42        cursor_point: Point,
 43        buffer: &BufferSnapshot,
 44        excerpt_options: &EditPredictionExcerptOptions,
 45        index_state: &SyntaxIndexState,
 46    ) -> Option<Self> {
 47        let excerpt = EditPredictionExcerpt::select_from_buffer(
 48            cursor_point,
 49            buffer,
 50            excerpt_options,
 51            Some(index_state),
 52        )?;
 53        let excerpt_text = excerpt.text(buffer);
 54        let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
 55        let cursor_offset = cursor_point.to_offset(buffer);
 56
 57        let snippets = scored_snippets(
 58            &index_state,
 59            &excerpt,
 60            &excerpt_text,
 61            references,
 62            cursor_offset,
 63            buffer,
 64        );
 65
 66        Some(Self {
 67            excerpt,
 68            excerpt_text,
 69            snippets,
 70        })
 71    }
 72}
 73
 74#[cfg(test)]
 75mod tests {
 76    use super::*;
 77    use std::sync::Arc;
 78
 79    use gpui::{Entity, TestAppContext};
 80    use indoc::indoc;
 81    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
 82    use project::{FakeFs, Project};
 83    use serde_json::json;
 84    use settings::SettingsStore;
 85    use util::path;
 86
 87    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
 88
 89    #[gpui::test]
 90    async fn test_call_site(cx: &mut TestAppContext) {
 91        let (project, index, _rust_lang_id) = init_test(cx).await;
 92
 93        let buffer = project
 94            .update(cx, |project, cx| {
 95                let project_path = project.find_project_path("c.rs", cx).unwrap();
 96                project.open_buffer(project_path, cx)
 97            })
 98            .await
 99            .unwrap();
100
101        cx.run_until_parked();
102
103        // first process_data call site
104        let cursor_point = language::Point::new(8, 21);
105        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
106
107        let context = cx
108            .update(|cx| {
109                EditPredictionContext::gather_context_in_background(
110                    cursor_point,
111                    buffer_snapshot,
112                    EditPredictionExcerptOptions {
113                        max_bytes: 60,
114                        min_bytes: 10,
115                        target_before_cursor_over_total_bytes: 0.5,
116                    },
117                    index,
118                    cx,
119                )
120            })
121            .await
122            .unwrap();
123
124        let mut snippet_identifiers = context
125            .snippets
126            .iter()
127            .map(|snippet| snippet.identifier.name.as_ref())
128            .collect::<Vec<_>>();
129        snippet_identifiers.sort();
130        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
131        drop(buffer);
132    }
133
134    async fn init_test(
135        cx: &mut TestAppContext,
136    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
137        cx.update(|cx| {
138            let settings_store = SettingsStore::test(cx);
139            cx.set_global(settings_store);
140            language::init(cx);
141            Project::init_settings(cx);
142        });
143
144        let fs = FakeFs::new(cx.executor());
145        fs.insert_tree(
146            path!("/root"),
147            json!({
148                "a.rs": indoc! {r#"
149                    fn main() {
150                        let x = 1;
151                        let y = 2;
152                        let z = add(x, y);
153                        println!("Result: {}", z);
154                    }
155
156                    fn add(a: i32, b: i32) -> i32 {
157                        a + b
158                    }
159                "#},
160                "b.rs": indoc! {"
161                    pub struct Config {
162                        pub name: String,
163                        pub value: i32,
164                    }
165
166                    impl Config {
167                        pub fn new(name: String, value: i32) -> Self {
168                            Config { name, value }
169                        }
170                    }
171                "},
172                "c.rs": indoc! {r#"
173                    use std::collections::HashMap;
174
175                    fn main() {
176                        let args: Vec<String> = std::env::args().collect();
177                        let data: Vec<i32> = args[1..]
178                            .iter()
179                            .filter_map(|s| s.parse().ok())
180                            .collect();
181                        let result = process_data(data);
182                        println!("{:?}", result);
183                    }
184
185                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
186                        let mut counts = HashMap::new();
187                        for value in data {
188                            *counts.entry(value).or_insert(0) += 1;
189                        }
190                        counts
191                    }
192
193                    #[cfg(test)]
194                    mod tests {
195                        use super::*;
196
197                        #[test]
198                        fn test_process_data() {
199                            let data = vec![1, 2, 2, 3];
200                            let result = process_data(data);
201                            assert_eq!(result.get(&2), Some(&2));
202                        }
203                    }
204                "#}
205            }),
206        )
207        .await;
208        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
209        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
210        let lang = rust_lang();
211        let lang_id = lang.id();
212        language_registry.add(Arc::new(lang));
213
214        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
215        cx.run_until_parked();
216
217        (project, index, lang_id)
218    }
219
220    fn rust_lang() -> Language {
221        Language::new(
222            LanguageConfig {
223                name: "Rust".into(),
224                matcher: LanguageMatcher {
225                    path_suffixes: vec!["rs".to_string()],
226                    ..Default::default()
227                },
228                ..Default::default()
229            },
230            Some(tree_sitter_rust::LANGUAGE.into()),
231        )
232        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
233        .unwrap()
234        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
235        .unwrap()
236    }
237}