edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9use std::time::Instant;
 10
 11pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
 12pub use declaration_scoring::SnippetStyle;
 13pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 14
 15use gpui::{App, AppContext as _, Entity, Task};
 16use language::BufferSnapshot;
 17pub use reference::references_in_excerpt;
 18pub use syntax_index::SyntaxIndex;
 19use text::{Point, ToOffset as _};
 20
 21use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
 22
 23#[derive(Debug)]
 24pub struct EditPredictionContext {
 25    pub excerpt: EditPredictionExcerpt,
 26    pub excerpt_text: EditPredictionExcerptText,
 27    pub snippets: Vec<ScoredSnippet>,
 28    pub retrieval_duration: std::time::Duration,
 29}
 30
 31impl EditPredictionContext {
 32    pub fn gather(
 33        cursor_point: Point,
 34        buffer: BufferSnapshot,
 35        excerpt_options: EditPredictionExcerptOptions,
 36        syntax_index: Entity<SyntaxIndex>,
 37        cx: &mut App,
 38    ) -> Task<Option<Self>> {
 39        let start = Instant::now();
 40        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 41        cx.background_spawn(async move {
 42            let index_state = index_state.lock().await;
 43
 44            let excerpt =
 45                EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?;
 46            let excerpt_text = excerpt.text(&buffer);
 47            let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
 48            let cursor_offset = cursor_point.to_offset(&buffer);
 49
 50            let snippets = scored_snippets(
 51                &index_state,
 52                &excerpt,
 53                &excerpt_text,
 54                references,
 55                cursor_offset,
 56                &buffer,
 57            );
 58
 59            Some(Self {
 60                excerpt,
 61                excerpt_text,
 62                snippets,
 63                retrieval_duration: start.elapsed(),
 64            })
 65        })
 66    }
 67}
 68
 69#[cfg(test)]
 70mod tests {
 71    use super::*;
 72    use std::sync::Arc;
 73
 74    use gpui::{Entity, TestAppContext};
 75    use indoc::indoc;
 76    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
 77    use project::{FakeFs, Project};
 78    use serde_json::json;
 79    use settings::SettingsStore;
 80    use util::path;
 81
 82    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
 83
 84    #[gpui::test]
 85    async fn test_call_site(cx: &mut TestAppContext) {
 86        let (project, index, _rust_lang_id) = init_test(cx).await;
 87
 88        let buffer = project
 89            .update(cx, |project, cx| {
 90                let project_path = project.find_project_path("c.rs", cx).unwrap();
 91                project.open_buffer(project_path, cx)
 92            })
 93            .await
 94            .unwrap();
 95
 96        cx.run_until_parked();
 97
 98        // first process_data call site
 99        let cursor_point = language::Point::new(8, 21);
100        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
101
102        let context = cx
103            .update(|cx| {
104                EditPredictionContext::gather(
105                    cursor_point,
106                    buffer_snapshot,
107                    EditPredictionExcerptOptions {
108                        max_bytes: 40,
109                        min_bytes: 10,
110                        target_before_cursor_over_total_bytes: 0.5,
111                        include_parent_signatures: false,
112                    },
113                    index,
114                    cx,
115                )
116            })
117            .await
118            .unwrap();
119
120        assert_eq!(context.snippets.len(), 1);
121        assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
122        drop(buffer);
123    }
124
125    async fn init_test(
126        cx: &mut TestAppContext,
127    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
128        cx.update(|cx| {
129            let settings_store = SettingsStore::test(cx);
130            cx.set_global(settings_store);
131            language::init(cx);
132            Project::init_settings(cx);
133        });
134
135        let fs = FakeFs::new(cx.executor());
136        fs.insert_tree(
137            path!("/root"),
138            json!({
139                "a.rs": indoc! {r#"
140                    fn main() {
141                        let x = 1;
142                        let y = 2;
143                        let z = add(x, y);
144                        println!("Result: {}", z);
145                    }
146
147                    fn add(a: i32, b: i32) -> i32 {
148                        a + b
149                    }
150                "#},
151                "b.rs": indoc! {"
152                    pub struct Config {
153                        pub name: String,
154                        pub value: i32,
155                    }
156
157                    impl Config {
158                        pub fn new(name: String, value: i32) -> Self {
159                            Config { name, value }
160                        }
161                    }
162                "},
163                "c.rs": indoc! {r#"
164                    use std::collections::HashMap;
165
166                    fn main() {
167                        let args: Vec<String> = std::env::args().collect();
168                        let data: Vec<i32> = args[1..]
169                            .iter()
170                            .filter_map(|s| s.parse().ok())
171                            .collect();
172                        let result = process_data(data);
173                        println!("{:?}", result);
174                    }
175
176                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
177                        let mut counts = HashMap::new();
178                        for value in data {
179                            *counts.entry(value).or_insert(0) += 1;
180                        }
181                        counts
182                    }
183
184                    #[cfg(test)]
185                    mod tests {
186                        use super::*;
187
188                        #[test]
189                        fn test_process_data() {
190                            let data = vec![1, 2, 2, 3];
191                            let result = process_data(data);
192                            assert_eq!(result.get(&2), Some(&2));
193                        }
194                    }
195                "#}
196            }),
197        )
198        .await;
199        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
200        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
201        let lang = rust_lang();
202        let lang_id = lang.id();
203        language_registry.add(Arc::new(lang));
204
205        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
206        cx.run_until_parked();
207
208        (project, index, lang_id)
209    }
210
211    fn rust_lang() -> Language {
212        Language::new(
213            LanguageConfig {
214                name: "Rust".into(),
215                matcher: LanguageMatcher {
216                    path_suffixes: vec!["rs".to_string()],
217                    ..Default::default()
218                },
219                ..Default::default()
220            },
221            Some(tree_sitter_rust::LANGUAGE.into()),
222        )
223        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
224        .unwrap()
225        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
226        .unwrap()
227    }
228}