edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9use gpui::{App, AppContext as _, Entity, Task};
 10use language::BufferSnapshot;
 11use text::{Point, ToOffset as _};
 12
 13pub use declaration::*;
 14pub use declaration_scoring::*;
 15pub use excerpt::*;
 16pub use reference::*;
 17pub use syntax_index::*;
 18
 19#[derive(Debug)]
 20pub struct EditPredictionContext {
 21    pub excerpt: EditPredictionExcerpt,
 22    pub excerpt_text: EditPredictionExcerptText,
 23    pub cursor_offset_in_excerpt: usize,
 24    pub snippets: Vec<ScoredSnippet>,
 25}
 26
 27impl EditPredictionContext {
 28    pub fn gather_context_in_background(
 29        cursor_point: Point,
 30        buffer: BufferSnapshot,
 31        excerpt_options: EditPredictionExcerptOptions,
 32        syntax_index: Option<Entity<SyntaxIndex>>,
 33        cx: &mut App,
 34    ) -> Task<Option<Self>> {
 35        if let Some(syntax_index) = syntax_index {
 36            let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 37            cx.background_spawn(async move {
 38                let index_state = index_state.lock().await;
 39                Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
 40            })
 41        } else {
 42            cx.background_spawn(async move {
 43                Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
 44            })
 45        }
 46    }
 47
 48    pub fn gather_context(
 49        cursor_point: Point,
 50        buffer: &BufferSnapshot,
 51        excerpt_options: &EditPredictionExcerptOptions,
 52        index_state: Option<&SyntaxIndexState>,
 53    ) -> Option<Self> {
 54        let excerpt = EditPredictionExcerpt::select_from_buffer(
 55            cursor_point,
 56            buffer,
 57            excerpt_options,
 58            index_state,
 59        )?;
 60        let excerpt_text = excerpt.text(buffer);
 61        let cursor_offset_in_file = cursor_point.to_offset(buffer);
 62        let cursor_offset_in_excerpt = cursor_offset_in_file - excerpt.range.start;
 63
 64        let snippets = if let Some(index_state) = index_state {
 65            let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
 66
 67            scored_snippets(
 68                &index_state,
 69                &excerpt,
 70                &excerpt_text,
 71                references,
 72                cursor_offset_in_file,
 73                buffer,
 74            )
 75        } else {
 76            vec![]
 77        };
 78
 79        Some(Self {
 80            excerpt,
 81            excerpt_text,
 82            cursor_offset_in_excerpt,
 83            snippets,
 84        })
 85    }
 86}
 87
 88#[cfg(test)]
 89mod tests {
 90    use super::*;
 91    use std::sync::Arc;
 92
 93    use gpui::{Entity, TestAppContext};
 94    use indoc::indoc;
 95    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
 96    use project::{FakeFs, Project};
 97    use serde_json::json;
 98    use settings::SettingsStore;
 99    use util::path;
100
101    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
102
103    #[gpui::test]
104    async fn test_call_site(cx: &mut TestAppContext) {
105        let (project, index, _rust_lang_id) = init_test(cx).await;
106
107        let buffer = project
108            .update(cx, |project, cx| {
109                let project_path = project.find_project_path("c.rs", cx).unwrap();
110                project.open_buffer(project_path, cx)
111            })
112            .await
113            .unwrap();
114
115        cx.run_until_parked();
116
117        // first process_data call site
118        let cursor_point = language::Point::new(8, 21);
119        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
120
121        let context = cx
122            .update(|cx| {
123                EditPredictionContext::gather_context_in_background(
124                    cursor_point,
125                    buffer_snapshot,
126                    EditPredictionExcerptOptions {
127                        max_bytes: 60,
128                        min_bytes: 10,
129                        target_before_cursor_over_total_bytes: 0.5,
130                    },
131                    Some(index),
132                    cx,
133                )
134            })
135            .await
136            .unwrap();
137
138        let mut snippet_identifiers = context
139            .snippets
140            .iter()
141            .map(|snippet| snippet.identifier.name.as_ref())
142            .collect::<Vec<_>>();
143        snippet_identifiers.sort();
144        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
145        drop(buffer);
146    }
147
148    async fn init_test(
149        cx: &mut TestAppContext,
150    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
151        cx.update(|cx| {
152            let settings_store = SettingsStore::test(cx);
153            cx.set_global(settings_store);
154            language::init(cx);
155            Project::init_settings(cx);
156        });
157
158        let fs = FakeFs::new(cx.executor());
159        fs.insert_tree(
160            path!("/root"),
161            json!({
162                "a.rs": indoc! {r#"
163                    fn main() {
164                        let x = 1;
165                        let y = 2;
166                        let z = add(x, y);
167                        println!("Result: {}", z);
168                    }
169
170                    fn add(a: i32, b: i32) -> i32 {
171                        a + b
172                    }
173                "#},
174                "b.rs": indoc! {"
175                    pub struct Config {
176                        pub name: String,
177                        pub value: i32,
178                    }
179
180                    impl Config {
181                        pub fn new(name: String, value: i32) -> Self {
182                            Config { name, value }
183                        }
184                    }
185                "},
186                "c.rs": indoc! {r#"
187                    use std::collections::HashMap;
188
189                    fn main() {
190                        let args: Vec<String> = std::env::args().collect();
191                        let data: Vec<i32> = args[1..]
192                            .iter()
193                            .filter_map(|s| s.parse().ok())
194                            .collect();
195                        let result = process_data(data);
196                        println!("{:?}", result);
197                    }
198
199                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
200                        let mut counts = HashMap::new();
201                        for value in data {
202                            *counts.entry(value).or_insert(0) += 1;
203                        }
204                        counts
205                    }
206
207                    #[cfg(test)]
208                    mod tests {
209                        use super::*;
210
211                        #[test]
212                        fn test_process_data() {
213                            let data = vec![1, 2, 2, 3];
214                            let result = process_data(data);
215                            assert_eq!(result.get(&2), Some(&2));
216                        }
217                    }
218                "#}
219            }),
220        )
221        .await;
222        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
223        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
224        let lang = rust_lang();
225        let lang_id = lang.id();
226        language_registry.add(Arc::new(lang));
227
228        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
229        cx.run_until_parked();
230
231        (project, index, lang_id)
232    }
233
234    fn rust_lang() -> Language {
235        Language::new(
236            LanguageConfig {
237                name: "Rust".into(),
238                matcher: LanguageMatcher {
239                    path_suffixes: vec!["rs".to_string()],
240                    ..Default::default()
241                },
242                ..Default::default()
243            },
244            Some(tree_sitter_rust::LANGUAGE.into()),
245        )
246        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
247        .unwrap()
248        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
249        .unwrap()
250    }
251}