edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9use gpui::{App, AppContext as _, Entity, Task};
 10use language::BufferSnapshot;
 11use text::{Point, ToOffset as _};
 12
 13pub use declaration::*;
 14pub use declaration_scoring::*;
 15pub use excerpt::*;
 16pub use reference::*;
 17pub use syntax_index::*;
 18
 19#[derive(Clone, Debug)]
 20pub struct EditPredictionContext {
 21    pub excerpt: EditPredictionExcerpt,
 22    pub excerpt_text: EditPredictionExcerptText,
 23    pub cursor_offset_in_excerpt: usize,
 24    pub snippets: Vec<ScoredSnippet>,
 25}
 26
 27impl EditPredictionContext {
 28    pub fn gather_context_in_background(
 29        cursor_point: Point,
 30        buffer: BufferSnapshot,
 31        excerpt_options: EditPredictionExcerptOptions,
 32        syntax_index: Option<Entity<SyntaxIndex>>,
 33        cx: &mut App,
 34    ) -> Task<Option<Self>> {
 35        if let Some(syntax_index) = syntax_index {
 36            let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 37            cx.background_spawn(async move {
 38                let index_state = index_state.lock().await;
 39                Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
 40            })
 41        } else {
 42            cx.background_spawn(async move {
 43                Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
 44            })
 45        }
 46    }
 47
 48    pub fn gather_context(
 49        cursor_point: Point,
 50        buffer: &BufferSnapshot,
 51        excerpt_options: &EditPredictionExcerptOptions,
 52        index_state: Option<&SyntaxIndexState>,
 53    ) -> Option<Self> {
 54        let excerpt = EditPredictionExcerpt::select_from_buffer(
 55            cursor_point,
 56            buffer,
 57            excerpt_options,
 58            index_state,
 59        )?;
 60        let excerpt_text = excerpt.text(buffer);
 61        let cursor_offset_in_file = cursor_point.to_offset(buffer);
 62        // TODO fix this to not need saturating_sub
 63        let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
 64
 65        let snippets = if let Some(index_state) = index_state {
 66            let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
 67
 68            scored_snippets(
 69                &index_state,
 70                &excerpt,
 71                &excerpt_text,
 72                references,
 73                cursor_offset_in_file,
 74                buffer,
 75            )
 76        } else {
 77            vec![]
 78        };
 79
 80        Some(Self {
 81            excerpt,
 82            excerpt_text,
 83            cursor_offset_in_excerpt,
 84            snippets,
 85        })
 86    }
 87}
 88
 89#[cfg(test)]
 90mod tests {
 91    use super::*;
 92    use std::sync::Arc;
 93
 94    use gpui::{Entity, TestAppContext};
 95    use indoc::indoc;
 96    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
 97    use project::{FakeFs, Project};
 98    use serde_json::json;
 99    use settings::SettingsStore;
100    use util::path;
101
102    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
103
104    #[gpui::test]
105    async fn test_call_site(cx: &mut TestAppContext) {
106        let (project, index, _rust_lang_id) = init_test(cx).await;
107
108        let buffer = project
109            .update(cx, |project, cx| {
110                let project_path = project.find_project_path("c.rs", cx).unwrap();
111                project.open_buffer(project_path, cx)
112            })
113            .await
114            .unwrap();
115
116        cx.run_until_parked();
117
118        // first process_data call site
119        let cursor_point = language::Point::new(8, 21);
120        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
121
122        let context = cx
123            .update(|cx| {
124                EditPredictionContext::gather_context_in_background(
125                    cursor_point,
126                    buffer_snapshot,
127                    EditPredictionExcerptOptions {
128                        max_bytes: 60,
129                        min_bytes: 10,
130                        target_before_cursor_over_total_bytes: 0.5,
131                    },
132                    Some(index),
133                    cx,
134                )
135            })
136            .await
137            .unwrap();
138
139        let mut snippet_identifiers = context
140            .snippets
141            .iter()
142            .map(|snippet| snippet.identifier.name.as_ref())
143            .collect::<Vec<_>>();
144        snippet_identifiers.sort();
145        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
146        drop(buffer);
147    }
148
149    async fn init_test(
150        cx: &mut TestAppContext,
151    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
152        cx.update(|cx| {
153            let settings_store = SettingsStore::test(cx);
154            cx.set_global(settings_store);
155            language::init(cx);
156            Project::init_settings(cx);
157        });
158
159        let fs = FakeFs::new(cx.executor());
160        fs.insert_tree(
161            path!("/root"),
162            json!({
163                "a.rs": indoc! {r#"
164                    fn main() {
165                        let x = 1;
166                        let y = 2;
167                        let z = add(x, y);
168                        println!("Result: {}", z);
169                    }
170
171                    fn add(a: i32, b: i32) -> i32 {
172                        a + b
173                    }
174                "#},
175                "b.rs": indoc! {"
176                    pub struct Config {
177                        pub name: String,
178                        pub value: i32,
179                    }
180
181                    impl Config {
182                        pub fn new(name: String, value: i32) -> Self {
183                            Config { name, value }
184                        }
185                    }
186                "},
187                "c.rs": indoc! {r#"
188                    use std::collections::HashMap;
189
190                    fn main() {
191                        let args: Vec<String> = std::env::args().collect();
192                        let data: Vec<i32> = args[1..]
193                            .iter()
194                            .filter_map(|s| s.parse().ok())
195                            .collect();
196                        let result = process_data(data);
197                        println!("{:?}", result);
198                    }
199
200                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
201                        let mut counts = HashMap::new();
202                        for value in data {
203                            *counts.entry(value).or_insert(0) += 1;
204                        }
205                        counts
206                    }
207
208                    #[cfg(test)]
209                    mod tests {
210                        use super::*;
211
212                        #[test]
213                        fn test_process_data() {
214                            let data = vec![1, 2, 2, 3];
215                            let result = process_data(data);
216                            assert_eq!(result.get(&2), Some(&2));
217                        }
218                    }
219                "#}
220            }),
221        )
222        .await;
223        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
224        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
225        let lang = rust_lang();
226        let lang_id = lang.id();
227        language_registry.add(Arc::new(lang));
228
229        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
230        cx.run_until_parked();
231
232        (project, index, lang_id)
233    }
234
235    fn rust_lang() -> Language {
236        Language::new(
237            LanguageConfig {
238                name: "Rust".into(),
239                matcher: LanguageMatcher {
240                    path_suffixes: vec!["rs".to_string()],
241                    ..Default::default()
242                },
243                ..Default::default()
244            },
245            Some(tree_sitter_rust::LANGUAGE.into()),
246        )
247        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
248        .unwrap()
249        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
250        .unwrap()
251    }
252}