edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod imports;
  5mod outline;
  6mod reference;
  7mod syntax_index;
  8pub mod text_similarity;
  9
 10use std::{path::Path, sync::Arc};
 11
 12use cloud_llm_client::predict_edits_v3;
 13use collections::HashMap;
 14use gpui::{App, AppContext as _, Entity, Task};
 15use language::BufferSnapshot;
 16use text::{Point, ToOffset as _};
 17
 18pub use declaration::*;
 19pub use declaration_scoring::*;
 20pub use excerpt::*;
 21pub use imports::*;
 22pub use reference::*;
 23pub use syntax_index::*;
 24
 25pub use predict_edits_v3::Line;
 26
 27#[derive(Clone, Debug, PartialEq)]
 28pub struct EditPredictionContextOptions {
 29    pub use_imports: bool,
 30    pub excerpt: EditPredictionExcerptOptions,
 31    pub score: EditPredictionScoreOptions,
 32}
 33
 34#[derive(Clone, Debug)]
 35pub struct EditPredictionContext {
 36    pub excerpt: EditPredictionExcerpt,
 37    pub excerpt_text: EditPredictionExcerptText,
 38    pub cursor_point: Point,
 39    pub declarations: Vec<ScoredDeclaration>,
 40}
 41
 42impl EditPredictionContext {
 43    pub fn gather_context_in_background(
 44        cursor_point: Point,
 45        buffer: BufferSnapshot,
 46        options: EditPredictionContextOptions,
 47        syntax_index: Option<Entity<SyntaxIndex>>,
 48        cx: &mut App,
 49    ) -> Task<Option<Self>> {
 50        let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
 51            let mut path = f.worktree.read(cx).absolutize(&f.path);
 52            if path.pop() { Some(path) } else { None }
 53        });
 54
 55        if let Some(syntax_index) = syntax_index {
 56            let index_state =
 57                syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
 58            cx.background_spawn(async move {
 59                let parent_abs_path = parent_abs_path.as_deref();
 60                let index_state = index_state.upgrade()?;
 61                let index_state = index_state.lock().await;
 62                Self::gather_context(
 63                    cursor_point,
 64                    &buffer,
 65                    parent_abs_path,
 66                    &options,
 67                    Some(&index_state),
 68                )
 69            })
 70        } else {
 71            cx.background_spawn(async move {
 72                let parent_abs_path = parent_abs_path.as_deref();
 73                Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
 74            })
 75        }
 76    }
 77
 78    pub fn gather_context(
 79        cursor_point: Point,
 80        buffer: &BufferSnapshot,
 81        parent_abs_path: Option<&Path>,
 82        options: &EditPredictionContextOptions,
 83        index_state: Option<&SyntaxIndexState>,
 84    ) -> Option<Self> {
 85        let imports = if options.use_imports {
 86            Imports::gather(&buffer, parent_abs_path)
 87        } else {
 88            Imports::default()
 89        };
 90        Self::gather_context_with_references_fn(
 91            cursor_point,
 92            buffer,
 93            &imports,
 94            options,
 95            index_state,
 96            references_in_excerpt,
 97        )
 98    }
 99
100    pub fn gather_context_with_references_fn(
101        cursor_point: Point,
102        buffer: &BufferSnapshot,
103        imports: &Imports,
104        options: &EditPredictionContextOptions,
105        index_state: Option<&SyntaxIndexState>,
106        get_references: impl FnOnce(
107            &EditPredictionExcerpt,
108            &EditPredictionExcerptText,
109            &BufferSnapshot,
110        ) -> HashMap<Identifier, Vec<Reference>>,
111    ) -> Option<Self> {
112        let excerpt = EditPredictionExcerpt::select_from_buffer(
113            cursor_point,
114            buffer,
115            &options.excerpt,
116            index_state,
117        )?;
118        let excerpt_text = excerpt.text(buffer);
119        let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
120
121        let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
122        let adjacent_end = Point::new(cursor_point.row + 1, 0);
123        let adjacent_occurrences = text_similarity::Occurrences::within_string(
124            &buffer
125                .text_for_range(adjacent_start..adjacent_end)
126                .collect::<String>(),
127        );
128
129        let cursor_offset_in_file = cursor_point.to_offset(buffer);
130
131        let declarations = if let Some(index_state) = index_state {
132            let references = get_references(&excerpt, &excerpt_text, buffer);
133
134            scored_declarations(
135                &options.score,
136                &index_state,
137                &excerpt,
138                &excerpt_occurrences,
139                &adjacent_occurrences,
140                &imports,
141                references,
142                cursor_offset_in_file,
143                buffer,
144            )
145        } else {
146            vec![]
147        };
148
149        Some(Self {
150            excerpt,
151            excerpt_text,
152            cursor_point,
153            declarations,
154        })
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use std::sync::Arc;
162
163    use gpui::{Entity, TestAppContext};
164    use indoc::indoc;
165    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
166    use project::{FakeFs, Project};
167    use serde_json::json;
168    use settings::SettingsStore;
169    use util::path;
170
171    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
172
173    #[gpui::test]
174    async fn test_call_site(cx: &mut TestAppContext) {
175        let (project, index, _rust_lang_id) = init_test(cx).await;
176
177        let buffer = project
178            .update(cx, |project, cx| {
179                let project_path = project.find_project_path("c.rs", cx).unwrap();
180                project.open_buffer(project_path, cx)
181            })
182            .await
183            .unwrap();
184
185        cx.run_until_parked();
186
187        // first process_data call site
188        let cursor_point = language::Point::new(8, 21);
189        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
190
191        let context = cx
192            .update(|cx| {
193                EditPredictionContext::gather_context_in_background(
194                    cursor_point,
195                    buffer_snapshot,
196                    EditPredictionContextOptions {
197                        use_imports: true,
198                        excerpt: EditPredictionExcerptOptions {
199                            max_bytes: 60,
200                            min_bytes: 10,
201                            target_before_cursor_over_total_bytes: 0.5,
202                        },
203                        score: EditPredictionScoreOptions {
204                            omit_excerpt_overlaps: true,
205                        },
206                    },
207                    Some(index.clone()),
208                    cx,
209                )
210            })
211            .await
212            .unwrap();
213
214        let mut snippet_identifiers = context
215            .declarations
216            .iter()
217            .map(|snippet| snippet.identifier.name.as_ref())
218            .collect::<Vec<_>>();
219        snippet_identifiers.sort();
220        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
221        drop(buffer);
222    }
223
224    async fn init_test(
225        cx: &mut TestAppContext,
226    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
227        cx.update(|cx| {
228            let settings_store = SettingsStore::test(cx);
229            cx.set_global(settings_store);
230            language::init(cx);
231            Project::init_settings(cx);
232        });
233
234        let fs = FakeFs::new(cx.executor());
235        fs.insert_tree(
236            path!("/root"),
237            json!({
238                "a.rs": indoc! {r#"
239                    fn main() {
240                        let x = 1;
241                        let y = 2;
242                        let z = add(x, y);
243                        println!("Result: {}", z);
244                    }
245
246                    fn add(a: i32, b: i32) -> i32 {
247                        a + b
248                    }
249                "#},
250                "b.rs": indoc! {"
251                    pub struct Config {
252                        pub name: String,
253                        pub value: i32,
254                    }
255
256                    impl Config {
257                        pub fn new(name: String, value: i32) -> Self {
258                            Config { name, value }
259                        }
260                    }
261                "},
262                "c.rs": indoc! {r#"
263                    use std::collections::HashMap;
264
265                    fn main() {
266                        let args: Vec<String> = std::env::args().collect();
267                        let data: Vec<i32> = args[1..]
268                            .iter()
269                            .filter_map(|s| s.parse().ok())
270                            .collect();
271                        let result = process_data(data);
272                        println!("{:?}", result);
273                    }
274
275                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
276                        let mut counts = HashMap::new();
277                        for value in data {
278                            *counts.entry(value).or_insert(0) += 1;
279                        }
280                        counts
281                    }
282
283                    #[cfg(test)]
284                    mod tests {
285                        use super::*;
286
287                        #[test]
288                        fn test_process_data() {
289                            let data = vec![1, 2, 2, 3];
290                            let result = process_data(data);
291                            assert_eq!(result.get(&2), Some(&2));
292                        }
293                    }
294                "#}
295            }),
296        )
297        .await;
298        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
299        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
300        let lang = rust_lang();
301        let lang_id = lang.id();
302        language_registry.add(Arc::new(lang));
303
304        let file_indexing_parallelism = 2;
305        let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
306        cx.run_until_parked();
307
308        (project, index, lang_id)
309    }
310
311    fn rust_lang() -> Language {
312        Language::new(
313            LanguageConfig {
314                name: "Rust".into(),
315                matcher: LanguageMatcher {
316                    path_suffixes: vec!["rs".to_string()],
317                    ..Default::default()
318                },
319                ..Default::default()
320            },
321            Some(tree_sitter_rust::LANGUAGE.into()),
322        )
323        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
324        .unwrap()
325        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
326        .unwrap()
327    }
328}