edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9use std::sync::Arc;
 10
 11use gpui::{App, AppContext as _, Entity, Task};
 12use language::BufferSnapshot;
 13use text::{Point, ToOffset as _};
 14
 15pub use declaration::*;
 16pub use declaration_scoring::*;
 17pub use excerpt::*;
 18pub use reference::*;
 19pub use syntax_index::*;
 20
 21#[derive(Clone, Debug)]
 22pub struct EditPredictionContext {
 23    pub excerpt: EditPredictionExcerpt,
 24    pub excerpt_text: EditPredictionExcerptText,
 25    pub cursor_offset_in_excerpt: usize,
 26    pub declarations: Vec<ScoredDeclaration>,
 27}
 28
 29impl EditPredictionContext {
 30    pub fn gather_context_in_background(
 31        cursor_point: Point,
 32        buffer: BufferSnapshot,
 33        excerpt_options: EditPredictionExcerptOptions,
 34        syntax_index: Option<Entity<SyntaxIndex>>,
 35        cx: &mut App,
 36    ) -> Task<Option<Self>> {
 37        if let Some(syntax_index) = syntax_index {
 38            let index_state =
 39                syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
 40            cx.background_spawn(async move {
 41                let index_state = index_state.upgrade()?;
 42                let index_state = index_state.lock().await;
 43                Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
 44            })
 45        } else {
 46            cx.background_spawn(async move {
 47                Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
 48            })
 49        }
 50    }
 51
 52    pub fn gather_context(
 53        cursor_point: Point,
 54        buffer: &BufferSnapshot,
 55        excerpt_options: &EditPredictionExcerptOptions,
 56        index_state: Option<&SyntaxIndexState>,
 57    ) -> Option<Self> {
 58        let excerpt = EditPredictionExcerpt::select_from_buffer(
 59            cursor_point,
 60            buffer,
 61            excerpt_options,
 62            index_state,
 63        )?;
 64        let excerpt_text = excerpt.text(buffer);
 65        let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
 66
 67        let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
 68        let adjacent_end = Point::new(cursor_point.row + 1, 0);
 69        let adjacent_occurrences = text_similarity::Occurrences::within_string(
 70            &buffer
 71                .text_for_range(adjacent_start..adjacent_end)
 72                .collect::<String>(),
 73        );
 74
 75        let cursor_offset_in_file = cursor_point.to_offset(buffer);
 76        // TODO fix this to not need saturating_sub
 77        let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
 78
 79        let declarations = if let Some(index_state) = index_state {
 80            let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
 81
 82            scored_declarations(
 83                &index_state,
 84                &excerpt,
 85                &excerpt_occurrences,
 86                &adjacent_occurrences,
 87                references,
 88                cursor_offset_in_file,
 89                buffer,
 90            )
 91        } else {
 92            vec![]
 93        };
 94
 95        Some(Self {
 96            excerpt,
 97            excerpt_text,
 98            cursor_offset_in_excerpt,
 99            declarations,
100        })
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use std::sync::Arc;
108
109    use gpui::{Entity, TestAppContext};
110    use indoc::indoc;
111    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
112    use project::{FakeFs, Project};
113    use serde_json::json;
114    use settings::SettingsStore;
115    use util::path;
116
117    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
118
119    #[gpui::test]
120    async fn test_call_site(cx: &mut TestAppContext) {
121        let (project, index, _rust_lang_id) = init_test(cx).await;
122
123        let buffer = project
124            .update(cx, |project, cx| {
125                let project_path = project.find_project_path("c.rs", cx).unwrap();
126                project.open_buffer(project_path, cx)
127            })
128            .await
129            .unwrap();
130
131        cx.run_until_parked();
132
133        // first process_data call site
134        let cursor_point = language::Point::new(8, 21);
135        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
136
137        let context = cx
138            .update(|cx| {
139                EditPredictionContext::gather_context_in_background(
140                    cursor_point,
141                    buffer_snapshot,
142                    EditPredictionExcerptOptions {
143                        max_bytes: 60,
144                        min_bytes: 10,
145                        target_before_cursor_over_total_bytes: 0.5,
146                    },
147                    Some(index),
148                    cx,
149                )
150            })
151            .await
152            .unwrap();
153
154        let mut snippet_identifiers = context
155            .declarations
156            .iter()
157            .map(|snippet| snippet.identifier.name.as_ref())
158            .collect::<Vec<_>>();
159        snippet_identifiers.sort();
160        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
161        drop(buffer);
162    }
163
164    async fn init_test(
165        cx: &mut TestAppContext,
166    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
167        cx.update(|cx| {
168            let settings_store = SettingsStore::test(cx);
169            cx.set_global(settings_store);
170            language::init(cx);
171            Project::init_settings(cx);
172        });
173
174        let fs = FakeFs::new(cx.executor());
175        fs.insert_tree(
176            path!("/root"),
177            json!({
178                "a.rs": indoc! {r#"
179                    fn main() {
180                        let x = 1;
181                        let y = 2;
182                        let z = add(x, y);
183                        println!("Result: {}", z);
184                    }
185
186                    fn add(a: i32, b: i32) -> i32 {
187                        a + b
188                    }
189                "#},
190                "b.rs": indoc! {"
191                    pub struct Config {
192                        pub name: String,
193                        pub value: i32,
194                    }
195
196                    impl Config {
197                        pub fn new(name: String, value: i32) -> Self {
198                            Config { name, value }
199                        }
200                    }
201                "},
202                "c.rs": indoc! {r#"
203                    use std::collections::HashMap;
204
205                    fn main() {
206                        let args: Vec<String> = std::env::args().collect();
207                        let data: Vec<i32> = args[1..]
208                            .iter()
209                            .filter_map(|s| s.parse().ok())
210                            .collect();
211                        let result = process_data(data);
212                        println!("{:?}", result);
213                    }
214
215                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
216                        let mut counts = HashMap::new();
217                        for value in data {
218                            *counts.entry(value).or_insert(0) += 1;
219                        }
220                        counts
221                    }
222
223                    #[cfg(test)]
224                    mod tests {
225                        use super::*;
226
227                        #[test]
228                        fn test_process_data() {
229                            let data = vec![1, 2, 2, 3];
230                            let result = process_data(data);
231                            assert_eq!(result.get(&2), Some(&2));
232                        }
233                    }
234                "#}
235            }),
236        )
237        .await;
238        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
239        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
240        let lang = rust_lang();
241        let lang_id = lang.id();
242        language_registry.add(Arc::new(lang));
243
244        let file_indexing_parallelism = 2;
245        let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
246        cx.run_until_parked();
247
248        (project, index, lang_id)
249    }
250
251    fn rust_lang() -> Language {
252        Language::new(
253            LanguageConfig {
254                name: "Rust".into(),
255                matcher: LanguageMatcher {
256                    path_suffixes: vec!["rs".to_string()],
257                    ..Default::default()
258                },
259                ..Default::default()
260            },
261            Some(tree_sitter_rust::LANGUAGE.into()),
262        )
263        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
264        .unwrap()
265        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
266        .unwrap()
267    }
268}