edit_prediction_context.rs

  1mod declaration;
  2mod declaration_scoring;
  3mod excerpt;
  4mod outline;
  5mod reference;
  6mod syntax_index;
  7mod text_similarity;
  8
  9use cloud_llm_client::predict_edits_v3::{self, Signature};
 10use collections::HashMap;
 11pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
 12pub use declaration_scoring::SnippetStyle;
 13pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 14
 15use gpui::{App, AppContext as _, Entity, Task};
 16use language::BufferSnapshot;
 17pub use reference::references_in_excerpt;
 18pub use syntax_index::SyntaxIndex;
 19use text::{Point, ToOffset as _};
 20
 21use crate::{
 22    declaration::DeclarationId,
 23    declaration_scoring::{ScoredSnippet, scored_snippets},
 24    syntax_index::SyntaxIndexState,
 25};
 26
 27#[derive(Debug)]
 28pub struct EditPredictionContext {
 29    pub excerpt: EditPredictionExcerpt,
 30    pub excerpt_text: EditPredictionExcerptText,
 31    pub snippets: Vec<ScoredSnippet>,
 32}
 33
 34impl EditPredictionContext {
 35    pub fn gather(
 36        cursor_point: Point,
 37        buffer: BufferSnapshot,
 38        excerpt_options: EditPredictionExcerptOptions,
 39        syntax_index: Entity<SyntaxIndex>,
 40        cx: &mut App,
 41    ) -> Task<Option<Self>> {
 42        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 43        cx.background_spawn(async move {
 44            let index_state = index_state.lock().await;
 45            Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
 46        })
 47    }
 48
 49    fn gather_context(
 50        cursor_point: Point,
 51        buffer: BufferSnapshot,
 52        excerpt_options: EditPredictionExcerptOptions,
 53        index_state: &SyntaxIndexState,
 54    ) -> Option<Self> {
 55        let excerpt = EditPredictionExcerpt::select_from_buffer(
 56            cursor_point,
 57            &buffer,
 58            &excerpt_options,
 59            Some(index_state),
 60        )?;
 61        let excerpt_text = excerpt.text(&buffer);
 62        let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
 63        let cursor_offset = cursor_point.to_offset(&buffer);
 64
 65        let snippets = scored_snippets(
 66            &index_state,
 67            &excerpt,
 68            &excerpt_text,
 69            references,
 70            cursor_offset,
 71            &buffer,
 72        );
 73
 74        Some(Self {
 75            excerpt,
 76            excerpt_text,
 77            snippets,
 78        })
 79    }
 80
 81    pub fn cloud_request(
 82        cursor_point: Point,
 83        buffer: BufferSnapshot,
 84        excerpt_options: EditPredictionExcerptOptions,
 85        syntax_index: Entity<SyntaxIndex>,
 86        cx: &mut App,
 87    ) -> Task<Option<predict_edits_v3::Body>> {
 88        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
 89        cx.background_spawn(async move {
 90            let index_state = index_state.lock().await;
 91            Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
 92                .map(|context| context.into_cloud_request(&index_state))
 93        })
 94    }
 95
 96    pub fn into_cloud_request(self, index: &SyntaxIndexState) -> predict_edits_v3::Body {
 97        let mut signatures = Vec::new();
 98        let mut declaration_to_signature_index = HashMap::default();
 99        let mut referenced_declarations = Vec::new();
100        let excerpt_parent = self
101            .excerpt
102            .parent_declarations
103            .last()
104            .and_then(|(parent, _)| {
105                add_signature(
106                    *parent,
107                    &mut declaration_to_signature_index,
108                    &mut signatures,
109                    index,
110                )
111            });
112        for snippet in self.snippets {
113            let parent_index = snippet.declaration.parent().and_then(|parent| {
114                add_signature(
115                    parent,
116                    &mut declaration_to_signature_index,
117                    &mut signatures,
118                    index,
119                )
120            });
121            let (text, text_is_truncated) = snippet.declaration.item_text();
122            referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
123                text: text.into(),
124                text_is_truncated,
125                signature_range: snippet.declaration.signature_range_in_item_text(),
126                parent_index,
127                score_components: snippet.score_components,
128                signature_score: snippet.scores.signature,
129                declaration_score: snippet.scores.declaration,
130            });
131        }
132        predict_edits_v3::Body {
133            excerpt: self.excerpt_text.body,
134            referenced_declarations,
135            signatures,
136            excerpt_parent,
137            // todo!
138            events: vec![],
139            can_collect_data: false,
140            diagnostic_groups: None,
141            git_info: None,
142        }
143    }
144}
145
146fn add_signature(
147    declaration_id: DeclarationId,
148    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
149    signatures: &mut Vec<Signature>,
150    index: &SyntaxIndexState,
151) -> Option<usize> {
152    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
153        return Some(*signature_index);
154    }
155    let Some(parent_declaration) = index.declaration(declaration_id) else {
156        log::error!("bug: missing parent declaration");
157        return None;
158    };
159    let parent_index = parent_declaration.parent().and_then(|parent| {
160        add_signature(parent, declaration_to_signature_index, signatures, index)
161    });
162    let (text, text_is_truncated) = parent_declaration.signature_text();
163    let signature_index = signatures.len();
164    signatures.push(Signature {
165        text: text.into(),
166        text_is_truncated,
167        parent_index,
168    });
169    declaration_to_signature_index.insert(declaration_id, signature_index);
170    Some(signature_index)
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use std::sync::Arc;
177
178    use gpui::{Entity, TestAppContext};
179    use indoc::indoc;
180    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
181    use project::{FakeFs, Project};
182    use serde_json::json;
183    use settings::SettingsStore;
184    use util::path;
185
186    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
187
188    #[gpui::test]
189    async fn test_call_site(cx: &mut TestAppContext) {
190        let (project, index, _rust_lang_id) = init_test(cx).await;
191
192        let buffer = project
193            .update(cx, |project, cx| {
194                let project_path = project.find_project_path("c.rs", cx).unwrap();
195                project.open_buffer(project_path, cx)
196            })
197            .await
198            .unwrap();
199
200        cx.run_until_parked();
201
202        // first process_data call site
203        let cursor_point = language::Point::new(8, 21);
204        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
205
206        let context = cx
207            .update(|cx| {
208                EditPredictionContext::gather(
209                    cursor_point,
210                    buffer_snapshot,
211                    EditPredictionExcerptOptions {
212                        max_bytes: 60,
213                        min_bytes: 10,
214                        target_before_cursor_over_total_bytes: 0.5,
215                    },
216                    index,
217                    cx,
218                )
219            })
220            .await
221            .unwrap();
222
223        let mut snippet_identifiers = context
224            .snippets
225            .iter()
226            .map(|snippet| snippet.identifier.name.as_ref())
227            .collect::<Vec<_>>();
228        snippet_identifiers.sort();
229        assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
230        drop(buffer);
231    }
232
233    async fn init_test(
234        cx: &mut TestAppContext,
235    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
236        cx.update(|cx| {
237            let settings_store = SettingsStore::test(cx);
238            cx.set_global(settings_store);
239            language::init(cx);
240            Project::init_settings(cx);
241        });
242
243        let fs = FakeFs::new(cx.executor());
244        fs.insert_tree(
245            path!("/root"),
246            json!({
247                "a.rs": indoc! {r#"
248                    fn main() {
249                        let x = 1;
250                        let y = 2;
251                        let z = add(x, y);
252                        println!("Result: {}", z);
253                    }
254
255                    fn add(a: i32, b: i32) -> i32 {
256                        a + b
257                    }
258                "#},
259                "b.rs": indoc! {"
260                    pub struct Config {
261                        pub name: String,
262                        pub value: i32,
263                    }
264
265                    impl Config {
266                        pub fn new(name: String, value: i32) -> Self {
267                            Config { name, value }
268                        }
269                    }
270                "},
271                "c.rs": indoc! {r#"
272                    use std::collections::HashMap;
273
274                    fn main() {
275                        let args: Vec<String> = std::env::args().collect();
276                        let data: Vec<i32> = args[1..]
277                            .iter()
278                            .filter_map(|s| s.parse().ok())
279                            .collect();
280                        let result = process_data(data);
281                        println!("{:?}", result);
282                    }
283
284                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
285                        let mut counts = HashMap::new();
286                        for value in data {
287                            *counts.entry(value).or_insert(0) += 1;
288                        }
289                        counts
290                    }
291
292                    #[cfg(test)]
293                    mod tests {
294                        use super::*;
295
296                        #[test]
297                        fn test_process_data() {
298                            let data = vec![1, 2, 2, 3];
299                            let result = process_data(data);
300                            assert_eq!(result.get(&2), Some(&2));
301                        }
302                    }
303                "#}
304            }),
305        )
306        .await;
307        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
308        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
309        let lang = rust_lang();
310        let lang_id = lang.id();
311        language_registry.add(Arc::new(lang));
312
313        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
314        cx.run_until_parked();
315
316        (project, index, lang_id)
317    }
318
319    fn rust_lang() -> Language {
320        Language::new(
321            LanguageConfig {
322                name: "Rust".into(),
323                matcher: LanguageMatcher {
324                    path_suffixes: vec!["rs".to_string()],
325                    ..Default::default()
326                },
327                ..Default::default()
328            },
329            Some(tree_sitter_rust::LANGUAGE.into()),
330        )
331        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
332        .unwrap()
333        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
334        .unwrap()
335    }
336}