edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod imports;
  5mod outline;
  6mod reference;
  7mod similar_snippets;
  8mod syntax_index;
  9pub mod text_similarity;
 10
 11use cloud_llm_client::predict_edits_v3;
 12use collections::HashMap;
 13use gpui::{App, AppContext as _, Entity, Task};
 14use language::BufferSnapshot;
 15use std::{path::Path, sync::Arc, time::Instant};
 16use text::{Point, ToOffset as _};
 17
 18pub use declaration::*;
 19pub use declaration_scoring::*;
 20pub use excerpt::*;
 21pub use imports::*;
 22pub use reference::*;
 23pub use similar_snippets::*;
 24pub use syntax_index::*;
 25pub use text_similarity::*;
 26
 27pub use predict_edits_v3::Line;
 28
 29#[derive(Clone, Debug, PartialEq)]
 30pub struct EditPredictionContextOptions {
 31    pub use_imports: bool,
 32    pub excerpt: EditPredictionExcerptOptions,
 33    pub score: EditPredictionScoreOptions,
 34    pub similar_snippets: SimilarSnippetOptions,
 35    pub max_retrieved_declarations: u8,
 36}
 37
 38#[derive(Clone, Debug)]
 39pub struct EditPredictionContext {
 40    pub excerpt: EditPredictionExcerpt,
 41    pub excerpt_text: EditPredictionExcerptText,
 42    pub cursor_point: Point,
 43    pub declarations: Vec<ScoredDeclaration>,
 44    pub similar_snippets: Vec<SimilarSnippet>,
 45}
 46
 47impl EditPredictionContext {
 48    pub fn gather_context_in_background(
 49        cursor_point: Point,
 50        buffer: BufferSnapshot,
 51        options: EditPredictionContextOptions,
 52        syntax_index: Option<Entity<SyntaxIndex>>,
 53        cx: &mut App,
 54    ) -> Task<Option<Self>> {
 55        let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
 56            let mut path = f.worktree.read(cx).absolutize(&f.path);
 57            if path.pop() { Some(path) } else { None }
 58        });
 59
 60        if let Some(syntax_index) = syntax_index {
 61            let index_state =
 62                syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
 63            cx.background_spawn(async move {
 64                let parent_abs_path = parent_abs_path.as_deref();
 65                let index_state = index_state.upgrade()?;
 66                let index_state = index_state.lock().await;
 67                Self::gather_context(
 68                    cursor_point,
 69                    &buffer,
 70                    parent_abs_path,
 71                    &options,
 72                    Some(&index_state),
 73                )
 74            })
 75        } else {
 76            cx.background_spawn(async move {
 77                let parent_abs_path = parent_abs_path.as_deref();
 78                Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
 79            })
 80        }
 81    }
 82
 83    pub fn gather_context(
 84        cursor_point: Point,
 85        buffer: &BufferSnapshot,
 86        parent_abs_path: Option<&Path>,
 87        options: &EditPredictionContextOptions,
 88        index_state: Option<&SyntaxIndexState>,
 89    ) -> Option<Self> {
 90        let imports = if options.use_imports {
 91            Imports::gather(&buffer, parent_abs_path)
 92        } else {
 93            Imports::default()
 94        };
 95        Self::gather_context_with_references_fn(
 96            cursor_point,
 97            buffer,
 98            &imports,
 99            options,
100            index_state,
101            references_in_excerpt,
102        )
103    }
104
105    pub fn gather_context_with_references_fn(
106        cursor_point: Point,
107        buffer: &BufferSnapshot,
108        imports: &Imports,
109        options: &EditPredictionContextOptions,
110        index_state: Option<&SyntaxIndexState>,
111        get_references: impl FnOnce(
112            &EditPredictionExcerpt,
113            &EditPredictionExcerptText,
114            &BufferSnapshot,
115        ) -> HashMap<Identifier, Vec<Reference>>,
116    ) -> Option<Self> {
117        let excerpt = EditPredictionExcerpt::select_from_buffer(
118            cursor_point,
119            buffer,
120            &options.excerpt,
121            index_state,
122        )?;
123        let excerpt_text = excerpt.text(buffer);
124
125        let declarations = if options.max_retrieved_declarations > 0
126            && let Some(index_state) = index_state
127        {
128            let excerpt_occurrences =
129                Occurrences::new(IdentifierParts::occurrences_in_str(&excerpt_text.body));
130            let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
131            let adjacent_end = Point::new(cursor_point.row + 1, 0);
132            let adjacent_occurrences = Occurrences::new(IdentifierParts::occurrences_in_str(
133                &buffer
134                    .text_for_range(adjacent_start..adjacent_end)
135                    .collect::<String>(),
136            ));
137
138            let cursor_offset_in_file = cursor_point.to_offset(buffer);
139
140            let references = get_references(&excerpt, &excerpt_text, buffer);
141
142            let mut declarations = scored_declarations(
143                &options.score,
144                &index_state,
145                &excerpt,
146                &excerpt_occurrences,
147                &adjacent_occurrences,
148                &imports,
149                references,
150                cursor_offset_in_file,
151                buffer,
152            );
153            // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
154            declarations.truncate(options.max_retrieved_declarations as usize);
155            declarations
156        } else {
157            vec![]
158        };
159
160        let similar_snippets = if options.similar_snippets.max_result_count > 0 {
161            let before = Instant::now();
162            let excerpt_trigram_occurrences: Occurrences<NGram<3, CodeParts>> =
163                Occurrences::new(NGram::occurrences_in_str(&excerpt_text.body));
164            let similar_snippets = similar_snippets(
165                &excerpt_trigram_occurrences,
166                excerpt.range.clone(),
167                buffer,
168                &options.similar_snippets,
169            );
170            dbg!(before.elapsed());
171            similar_snippets
172        } else {
173            Vec::new()
174        };
175
176        // buffer.debug(&excerpt.range, "excerpt");
177
178        // buffer.debug(
179        //     &similar_snippets
180        //         .iter()
181        //         .map(|s| s.range.clone())
182        //         .collect::<Vec<_>>(),
183        //     "similar_snippets",
184        // );
185
186        Some(Self {
187            excerpt,
188            excerpt_text,
189            cursor_point,
190            declarations,
191            similar_snippets,
192        })
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use std::sync::Arc;
200
201    use gpui::{Entity, TestAppContext};
202    use indoc::indoc;
203    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
204    use project::{FakeFs, Project};
205    use serde_json::json;
206    use settings::SettingsStore;
207    use util::path;
208
209    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
210
211    #[gpui::test]
212    async fn test_call_site(cx: &mut TestAppContext) {
213        let (project, index, _rust_lang_id) = init_test(cx).await;
214
215        let buffer = project
216            .update(cx, |project, cx| {
217                let project_path = project.find_project_path("c.rs", cx).unwrap();
218                project.open_buffer(project_path, cx)
219            })
220            .await
221            .unwrap();
222
223        cx.run_until_parked();
224
225        // first process_data call site
226        let cursor_point = language::Point::new(8, 21);
227        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
228
229        let context = cx
230            .update(|cx| {
231                EditPredictionContext::gather_context_in_background(
232                    cursor_point,
233                    buffer_snapshot,
234                    EditPredictionContextOptions {
235                        use_imports: true,
236                        excerpt: EditPredictionExcerptOptions {
237                            max_bytes: 60,
238                            min_bytes: 10,
239                            target_before_cursor_over_total_bytes: 0.5,
240                        },
241                        score: EditPredictionScoreOptions {
242                            omit_excerpt_overlaps: true,
243                        },
244                        similar_snippets: SimilarSnippetOptions::default(),
245                        max_retrieved_declarations: u8::MAX,
246                    },
247                    Some(index.clone()),
248                    cx,
249                )
250            })
251            .await
252            .unwrap();
253
254        let mut snippet_identifiers = context
255            .declarations
256            .iter()
257            .map(|snippet| snippet.identifier.name.as_ref())
258            .collect::<Vec<_>>();
259        snippet_identifiers.sort();
260        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
261        drop(buffer);
262    }
263
264    async fn init_test(
265        cx: &mut TestAppContext,
266    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
267        cx.update(|cx| {
268            let settings_store = SettingsStore::test(cx);
269            cx.set_global(settings_store);
270            language::init(cx);
271            Project::init_settings(cx);
272        });
273
274        let fs = FakeFs::new(cx.executor());
275        fs.insert_tree(
276            path!("/root"),
277            json!({
278                "a.rs": indoc! {r#"
279                    fn main() {
280                        let x = 1;
281                        let y = 2;
282                        let z = add(x, y);
283                        println!("Result: {}", z);
284                    }
285
286                    fn add(a: i32, b: i32) -> i32 {
287                        a + b
288                    }
289                "#},
290                "b.rs": indoc! {"
291                    pub struct Config {
292                        pub name: String,
293                        pub value: i32,
294                    }
295
296                    impl Config {
297                        pub fn new(name: String, value: i32) -> Self {
298                            Config { name, value }
299                        }
300                    }
301                "},
302                "c.rs": indoc! {r#"
303                    use std::collections::HashMap;
304
305                    fn main() {
306                        let args: Vec<String> = std::env::args().collect();
307                        let data: Vec<i32> = args[1..]
308                            .iter()
309                            .filter_map(|s| s.parse().ok())
310                            .collect();
311                        let result = process_data(data);
312                        println!("{:?}", result);
313                    }
314
315                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
316                        let mut counts = HashMap::new();
317                        for value in data {
318                            *counts.entry(value).or_insert(0) += 1;
319                        }
320                        counts
321                    }
322
323                    #[cfg(test)]
324                    mod tests {
325                        use super::*;
326
327                        #[test]
328                        fn test_process_data() {
329                            let data = vec![1, 2, 2, 3];
330                            let result = process_data(data);
331                            assert_eq!(result.get(&2), Some(&2));
332                        }
333                    }
334                "#}
335            }),
336        )
337        .await;
338        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
339        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
340        let lang = rust_lang();
341        let lang_id = lang.id();
342        language_registry.add(Arc::new(lang));
343
344        let file_indexing_parallelism = 2;
345        let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
346        cx.run_until_parked();
347
348        (project, index, lang_id)
349    }
350
351    fn rust_lang() -> Language {
352        Language::new(
353            LanguageConfig {
354                name: "Rust".into(),
355                matcher: LanguageMatcher {
356                    path_suffixes: vec!["rs".to_string()],
357                    ..Default::default()
358                },
359                ..Default::default()
360            },
361            Some(tree_sitter_rust::LANGUAGE.into()),
362        )
363        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
364        .unwrap()
365        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
366        .unwrap()
367    }
368}