edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9use gpui::{App, AppContext as _, Entity, Task};
 10use language::BufferSnapshot;
 11use text::{Point, ToOffset as _};
 12
 13pub use declaration::*;
 14pub use declaration_scoring::*;
 15pub use excerpt::*;
 16pub use reference::*;
 17pub use syntax_index::*;
 18
 19#[derive(Clone, Debug)]
 20pub struct EditPredictionContext {
 21    pub excerpt: EditPredictionExcerpt,
 22    pub excerpt_text: EditPredictionExcerptText,
 23    pub cursor_offset_in_excerpt: usize,
 24    pub declarations: Vec<ScoredDeclaration>,
 25}
 26
 27impl EditPredictionContext {
 28    pub fn gather_context_in_background(
 29        cursor_point: Point,
 30        buffer: BufferSnapshot,
 31        excerpt_options: EditPredictionExcerptOptions,
 32        syntax_index: Option<Entity<SyntaxIndex>>,
 33        cx: &mut App,
 34    ) -> Task<Option<Self>> {
 35        if let Some(syntax_index) = syntax_index {
 36            let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 37            cx.background_spawn(async move {
 38                let index_state = index_state.lock().await;
 39                Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
 40            })
 41        } else {
 42            cx.background_spawn(async move {
 43                Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
 44            })
 45        }
 46    }
 47
 48    pub fn gather_context(
 49        cursor_point: Point,
 50        buffer: &BufferSnapshot,
 51        excerpt_options: &EditPredictionExcerptOptions,
 52        index_state: Option<&SyntaxIndexState>,
 53    ) -> Option<Self> {
 54        let excerpt = EditPredictionExcerpt::select_from_buffer(
 55            cursor_point,
 56            buffer,
 57            excerpt_options,
 58            index_state,
 59        )?;
 60        let excerpt_text = excerpt.text(buffer);
 61        let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
 62
 63        let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
 64        let adjacent_end = Point::new(cursor_point.row + 1, 0);
 65        let adjacent_occurrences = text_similarity::Occurrences::within_string(
 66            &buffer
 67                .text_for_range(adjacent_start..adjacent_end)
 68                .collect::<String>(),
 69        );
 70
 71        let cursor_offset_in_file = cursor_point.to_offset(buffer);
 72        // TODO fix this to not need saturating_sub
 73        let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
 74
 75        let declarations = if let Some(index_state) = index_state {
 76            let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
 77
 78            scored_declarations(
 79                &index_state,
 80                &excerpt,
 81                &excerpt_occurrences,
 82                &adjacent_occurrences,
 83                references,
 84                cursor_offset_in_file,
 85                buffer,
 86            )
 87        } else {
 88            vec![]
 89        };
 90
 91        Some(Self {
 92            excerpt,
 93            excerpt_text,
 94            cursor_offset_in_excerpt,
 95            declarations,
 96        })
 97    }
 98}
 99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use std::sync::Arc;
104
105    use gpui::{Entity, TestAppContext};
106    use indoc::indoc;
107    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
108    use project::{FakeFs, Project};
109    use serde_json::json;
110    use settings::SettingsStore;
111    use util::path;
112
113    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
114
115    #[gpui::test]
116    async fn test_call_site(cx: &mut TestAppContext) {
117        let (project, index, _rust_lang_id) = init_test(cx).await;
118
119        let buffer = project
120            .update(cx, |project, cx| {
121                let project_path = project.find_project_path("c.rs", cx).unwrap();
122                project.open_buffer(project_path, cx)
123            })
124            .await
125            .unwrap();
126
127        cx.run_until_parked();
128
129        // first process_data call site
130        let cursor_point = language::Point::new(8, 21);
131        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
132
133        let context = cx
134            .update(|cx| {
135                EditPredictionContext::gather_context_in_background(
136                    cursor_point,
137                    buffer_snapshot,
138                    EditPredictionExcerptOptions {
139                        max_bytes: 60,
140                        min_bytes: 10,
141                        target_before_cursor_over_total_bytes: 0.5,
142                    },
143                    Some(index),
144                    cx,
145                )
146            })
147            .await
148            .unwrap();
149
150        let mut snippet_identifiers = context
151            .declarations
152            .iter()
153            .map(|snippet| snippet.identifier.name.as_ref())
154            .collect::<Vec<_>>();
155        snippet_identifiers.sort();
156        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
157        drop(buffer);
158    }
159
160    async fn init_test(
161        cx: &mut TestAppContext,
162    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
163        cx.update(|cx| {
164            let settings_store = SettingsStore::test(cx);
165            cx.set_global(settings_store);
166            language::init(cx);
167            Project::init_settings(cx);
168        });
169
170        let fs = FakeFs::new(cx.executor());
171        fs.insert_tree(
172            path!("/root"),
173            json!({
174                "a.rs": indoc! {r#"
175                    fn main() {
176                        let x = 1;
177                        let y = 2;
178                        let z = add(x, y);
179                        println!("Result: {}", z);
180                    }
181
182                    fn add(a: i32, b: i32) -> i32 {
183                        a + b
184                    }
185                "#},
186                "b.rs": indoc! {"
187                    pub struct Config {
188                        pub name: String,
189                        pub value: i32,
190                    }
191
192                    impl Config {
193                        pub fn new(name: String, value: i32) -> Self {
194                            Config { name, value }
195                        }
196                    }
197                "},
198                "c.rs": indoc! {r#"
199                    use std::collections::HashMap;
200
201                    fn main() {
202                        let args: Vec<String> = std::env::args().collect();
203                        let data: Vec<i32> = args[1..]
204                            .iter()
205                            .filter_map(|s| s.parse().ok())
206                            .collect();
207                        let result = process_data(data);
208                        println!("{:?}", result);
209                    }
210
211                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
212                        let mut counts = HashMap::new();
213                        for value in data {
214                            *counts.entry(value).or_insert(0) += 1;
215                        }
216                        counts
217                    }
218
219                    #[cfg(test)]
220                    mod tests {
221                        use super::*;
222
223                        #[test]
224                        fn test_process_data() {
225                            let data = vec![1, 2, 2, 3];
226                            let result = process_data(data);
227                            assert_eq!(result.get(&2), Some(&2));
228                        }
229                    }
230                "#}
231            }),
232        )
233        .await;
234        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
235        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
236        let lang = rust_lang();
237        let lang_id = lang.id();
238        language_registry.add(Arc::new(lang));
239
240        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
241        cx.run_until_parked();
242
243        (project, index, lang_id)
244    }
245
246    fn rust_lang() -> Language {
247        Language::new(
248            LanguageConfig {
249                name: "Rust".into(),
250                matcher: LanguageMatcher {
251                    path_suffixes: vec!["rs".to_string()],
252                    ..Default::default()
253                },
254                ..Default::default()
255            },
256            Some(tree_sitter_rust::LANGUAGE.into()),
257        )
258        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
259        .unwrap()
260        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
261        .unwrap()
262    }
263}