edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9use std::time::Instant;
 10
 11pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
 12pub use declaration_scoring::SnippetStyle;
 13pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 14
 15use gpui::{App, AppContext as _, Entity, Task};
 16use language::BufferSnapshot;
 17pub use reference::references_in_excerpt;
 18pub use syntax_index::SyntaxIndex;
 19use text::{Point, ToOffset as _};
 20
 21use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
 22
 23#[derive(Debug)]
 24pub struct EditPredictionContext {
 25    pub excerpt: EditPredictionExcerpt,
 26    pub excerpt_text: EditPredictionExcerptText,
 27    pub snippets: Vec<ScoredSnippet>,
 28    pub retrieval_duration: std::time::Duration,
 29}
 30
 31impl EditPredictionContext {
 32    pub fn gather(
 33        cursor_point: Point,
 34        buffer: BufferSnapshot,
 35        excerpt_options: EditPredictionExcerptOptions,
 36        syntax_index: Entity<SyntaxIndex>,
 37        cx: &mut App,
 38    ) -> Task<Option<Self>> {
 39        let start = Instant::now();
 40        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 41        cx.background_spawn(async move {
 42            let index_state = index_state.lock().await;
 43
 44            let excerpt =
 45                EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?;
 46            let excerpt_text = excerpt.text(&buffer);
 47            let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
 48            let cursor_offset = cursor_point.to_offset(&buffer);
 49
 50            let snippets = scored_snippets(
 51                &index_state,
 52                &excerpt,
 53                &excerpt_text,
 54                references,
 55                cursor_offset,
 56                &buffer,
 57            );
 58
 59            Some(Self {
 60                excerpt,
 61                excerpt_text,
 62                snippets,
 63                retrieval_duration: start.elapsed(),
 64            })
 65        })
 66    }
 67}
 68
 69#[cfg(test)]
 70mod tests {
 71    use super::*;
 72    use std::sync::Arc;
 73
 74    use gpui::{Entity, TestAppContext};
 75    use indoc::indoc;
 76    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
 77    use project::{FakeFs, Project};
 78    use serde_json::json;
 79    use settings::SettingsStore;
 80    use util::path;
 81
 82    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
 83
 84    #[gpui::test]
 85    async fn test_call_site(cx: &mut TestAppContext) {
 86        let (project, index, _rust_lang_id) = init_test(cx).await;
 87
 88        let buffer = project
 89            .update(cx, |project, cx| {
 90                let project_path = project.find_project_path("c.rs", cx).unwrap();
 91                project.open_buffer(project_path, cx)
 92            })
 93            .await
 94            .unwrap();
 95
 96        cx.run_until_parked();
 97
 98        // first process_data call site
 99        let cursor_point = language::Point::new(8, 21);
100        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
101
102        let context = cx
103            .update(|cx| {
104                EditPredictionContext::gather(
105                    cursor_point,
106                    buffer_snapshot,
107                    EditPredictionExcerptOptions {
108                        max_bytes: 40,
109                        min_bytes: 10,
110                        target_before_cursor_over_total_bytes: 0.5,
111                        include_parent_signatures: false,
112                    },
113                    index,
114                    cx,
115                )
116            })
117            .await
118            .unwrap();
119
120        assert_eq!(context.snippets.len(), 1);
121        assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
122        drop(buffer);
123    }
124
125    async fn init_test(
126        cx: &mut TestAppContext,
127    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
128        cx.update(|cx| {
129            let settings_store = SettingsStore::test(cx);
130            cx.set_global(settings_store);
131            SettingsStore::load_registered_settings(cx);
132
133            language::init(cx);
134            Project::init_settings(cx);
135        });
136
137        let fs = FakeFs::new(cx.executor());
138        fs.insert_tree(
139            path!("/root"),
140            json!({
141                "a.rs": indoc! {r#"
142                    fn main() {
143                        let x = 1;
144                        let y = 2;
145                        let z = add(x, y);
146                        println!("Result: {}", z);
147                    }
148
149                    fn add(a: i32, b: i32) -> i32 {
150                        a + b
151                    }
152                "#},
153                "b.rs": indoc! {"
154                    pub struct Config {
155                        pub name: String,
156                        pub value: i32,
157                    }
158
159                    impl Config {
160                        pub fn new(name: String, value: i32) -> Self {
161                            Config { name, value }
162                        }
163                    }
164                "},
165                "c.rs": indoc! {r#"
166                    use std::collections::HashMap;
167
168                    fn main() {
169                        let args: Vec<String> = std::env::args().collect();
170                        let data: Vec<i32> = args[1..]
171                            .iter()
172                            .filter_map(|s| s.parse().ok())
173                            .collect();
174                        let result = process_data(data);
175                        println!("{:?}", result);
176                    }
177
178                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
179                        let mut counts = HashMap::new();
180                        for value in data {
181                            *counts.entry(value).or_insert(0) += 1;
182                        }
183                        counts
184                    }
185
186                    #[cfg(test)]
187                    mod tests {
188                        use super::*;
189
190                        #[test]
191                        fn test_process_data() {
192                            let data = vec![1, 2, 2, 3];
193                            let result = process_data(data);
194                            assert_eq!(result.get(&2), Some(&2));
195                        }
196                    }
197                "#}
198            }),
199        )
200        .await;
201        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
202        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
203        let lang = rust_lang();
204        let lang_id = lang.id();
205        language_registry.add(Arc::new(lang));
206
207        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
208        cx.run_until_parked();
209
210        (project, index, lang_id)
211    }
212
213    fn rust_lang() -> Language {
214        Language::new(
215            LanguageConfig {
216                name: "Rust".into(),
217                matcher: LanguageMatcher {
218                    path_suffixes: vec!["rs".to_string()],
219                    ..Default::default()
220                },
221                ..Default::default()
222            },
223            Some(tree_sitter_rust::LANGUAGE.into()),
224        )
225        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
226        .unwrap()
227        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
228        .unwrap()
229    }
230}