edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
 10pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 11use gpui::{App, AppContext as _, Entity, Task};
 12use language::BufferSnapshot;
 13pub use reference::references_in_excerpt;
 14pub use syntax_index::SyntaxIndex;
 15use text::{Point, ToOffset as _};
 16
 17use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
 18
 19pub struct EditPredictionContext {
 20    pub excerpt: EditPredictionExcerpt,
 21    pub excerpt_text: EditPredictionExcerptText,
 22    pub snippets: Vec<ScoredSnippet>,
 23}
 24
 25impl EditPredictionContext {
 26    pub fn gather(
 27        cursor_point: Point,
 28        buffer: BufferSnapshot,
 29        excerpt_options: EditPredictionExcerptOptions,
 30        syntax_index: Entity<SyntaxIndex>,
 31        cx: &mut App,
 32    ) -> Task<Self> {
 33        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 34        cx.background_spawn(async move {
 35            let index_state = index_state.lock().await;
 36
 37            let excerpt =
 38                EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)
 39                    .unwrap();
 40            let excerpt_text = excerpt.text(&buffer);
 41            let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
 42            let cursor_offset = cursor_point.to_offset(&buffer);
 43
 44            let snippets = scored_snippets(
 45                &index_state,
 46                &excerpt,
 47                &excerpt_text,
 48                references,
 49                cursor_offset,
 50                &buffer,
 51            );
 52
 53            Self {
 54                excerpt,
 55                excerpt_text,
 56                snippets,
 57            }
 58        })
 59    }
 60}
 61
 62#[cfg(test)]
 63mod tests {
 64    use super::*;
 65    use std::sync::Arc;
 66
 67    use gpui::{Entity, TestAppContext};
 68    use indoc::indoc;
 69    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
 70    use project::{FakeFs, Project};
 71    use serde_json::json;
 72    use settings::SettingsStore;
 73    use util::path;
 74
 75    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
 76
 77    #[gpui::test]
 78    async fn test_call_site(cx: &mut TestAppContext) {
 79        let (project, index, _rust_lang_id) = init_test(cx).await;
 80
 81        let buffer = project
 82            .update(cx, |project, cx| {
 83                let project_path = project.find_project_path("c.rs", cx).unwrap();
 84                project.open_buffer(project_path, cx)
 85            })
 86            .await
 87            .unwrap();
 88
 89        cx.run_until_parked();
 90
 91        // first process_data call site
 92        let cursor_point = language::Point::new(8, 21);
 93        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
 94
 95        let context = cx
 96            .update(|cx| {
 97                EditPredictionContext::gather(
 98                    cursor_point,
 99                    buffer_snapshot,
100                    EditPredictionExcerptOptions {
101                        max_bytes: 40,
102                        min_bytes: 10,
103                        target_before_cursor_over_total_bytes: 0.5,
104                        include_parent_signatures: false,
105                    },
106                    index,
107                    cx,
108                )
109            })
110            .await;
111
112        assert_eq!(context.snippets.len(), 1);
113        assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
114        drop(buffer);
115    }
116
117    async fn init_test(
118        cx: &mut TestAppContext,
119    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
120        cx.update(|cx| {
121            let settings_store = SettingsStore::test(cx);
122            cx.set_global(settings_store);
123            language::init(cx);
124            Project::init_settings(cx);
125        });
126
127        let fs = FakeFs::new(cx.executor());
128        fs.insert_tree(
129            path!("/root"),
130            json!({
131                "a.rs": indoc! {r#"
132                    fn main() {
133                        let x = 1;
134                        let y = 2;
135                        let z = add(x, y);
136                        println!("Result: {}", z);
137                    }
138
139                    fn add(a: i32, b: i32) -> i32 {
140                        a + b
141                    }
142                "#},
143                "b.rs": indoc! {"
144                    pub struct Config {
145                        pub name: String,
146                        pub value: i32,
147                    }
148
149                    impl Config {
150                        pub fn new(name: String, value: i32) -> Self {
151                            Config { name, value }
152                        }
153                    }
154                "},
155                "c.rs": indoc! {r#"
156                    use std::collections::HashMap;
157
158                    fn main() {
159                        let args: Vec<String> = std::env::args().collect();
160                        let data: Vec<i32> = args[1..]
161                            .iter()
162                            .filter_map(|s| s.parse().ok())
163                            .collect();
164                        let result = process_data(data);
165                        println!("{:?}", result);
166                    }
167
168                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
169                        let mut counts = HashMap::new();
170                        for value in data {
171                            *counts.entry(value).or_insert(0) += 1;
172                        }
173                        counts
174                    }
175
176                    #[cfg(test)]
177                    mod tests {
178                        use super::*;
179
180                        #[test]
181                        fn test_process_data() {
182                            let data = vec![1, 2, 2, 3];
183                            let result = process_data(data);
184                            assert_eq!(result.get(&2), Some(&2));
185                        }
186                    }
187                "#}
188            }),
189        )
190        .await;
191        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
192        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
193        let lang = rust_lang();
194        let lang_id = lang.id();
195        language_registry.add(Arc::new(lang));
196
197        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
198        cx.run_until_parked();
199
200        (project, index, lang_id)
201    }
202
203    fn rust_lang() -> Language {
204        Language::new(
205            LanguageConfig {
206                name: "Rust".into(),
207                matcher: LanguageMatcher {
208                    path_suffixes: vec!["rs".to_string()],
209                    ..Default::default()
210                },
211                ..Default::default()
212            },
213            Some(tree_sitter_rust::LANGUAGE.into()),
214        )
215        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
216        .unwrap()
217        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
218        .unwrap()
219    }
220}