1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9use cloud_llm_client::predict_edits_v3::{self, Signature};
10use collections::HashMap;
11pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
12pub use declaration_scoring::SnippetStyle;
13pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
14
15use gpui::{App, AppContext as _, Entity, Task};
16use language::BufferSnapshot;
17pub use reference::references_in_excerpt;
18pub use syntax_index::SyntaxIndex;
19use text::{Point, ToOffset as _};
20
21use crate::{
22 declaration::DeclarationId,
23 declaration_scoring::{ScoredSnippet, scored_snippets},
24 syntax_index::SyntaxIndexState,
25};
26
27#[derive(Debug)]
28pub struct EditPredictionContext {
29 pub excerpt: EditPredictionExcerpt,
30 pub excerpt_text: EditPredictionExcerptText,
31 pub snippets: Vec<ScoredSnippet>,
32}
33
34impl EditPredictionContext {
35 pub fn gather(
36 cursor_point: Point,
37 buffer: BufferSnapshot,
38 excerpt_options: EditPredictionExcerptOptions,
39 syntax_index: Entity<SyntaxIndex>,
40 cx: &mut App,
41 ) -> Task<Option<Self>> {
42 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
43 cx.background_spawn(async move {
44 let index_state = index_state.lock().await;
45 Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
46 })
47 }
48
49 fn gather_context(
50 cursor_point: Point,
51 buffer: BufferSnapshot,
52 excerpt_options: EditPredictionExcerptOptions,
53 index_state: &SyntaxIndexState,
54 ) -> Option<Self> {
55 let excerpt = EditPredictionExcerpt::select_from_buffer(
56 cursor_point,
57 &buffer,
58 &excerpt_options,
59 Some(index_state),
60 )?;
61 let excerpt_text = excerpt.text(&buffer);
62 let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
63 let cursor_offset = cursor_point.to_offset(&buffer);
64
65 let snippets = scored_snippets(
66 &index_state,
67 &excerpt,
68 &excerpt_text,
69 references,
70 cursor_offset,
71 &buffer,
72 );
73
74 Some(Self {
75 excerpt,
76 excerpt_text,
77 snippets,
78 })
79 }
80
81 pub fn cloud_request(
82 cursor_point: Point,
83 buffer: BufferSnapshot,
84 excerpt_options: EditPredictionExcerptOptions,
85 syntax_index: Entity<SyntaxIndex>,
86 cx: &mut App,
87 ) -> Task<Option<predict_edits_v3::Body>> {
88 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
89 cx.background_spawn(async move {
90 let index_state = index_state.lock().await;
91 Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
92 .map(|context| context.into_cloud_request(&index_state))
93 })
94 }
95
96 pub fn into_cloud_request(self, index: &SyntaxIndexState) -> predict_edits_v3::Body {
97 let mut signatures = Vec::new();
98 let mut declaration_to_signature_index = HashMap::default();
99 let mut referenced_declarations = Vec::new();
100 let excerpt_parent = self
101 .excerpt
102 .parent_declarations
103 .last()
104 .and_then(|(parent, _)| {
105 add_signature(
106 *parent,
107 &mut declaration_to_signature_index,
108 &mut signatures,
109 index,
110 )
111 });
112 for snippet in self.snippets {
113 let parent_index = snippet.declaration.parent().and_then(|parent| {
114 add_signature(
115 parent,
116 &mut declaration_to_signature_index,
117 &mut signatures,
118 index,
119 )
120 });
121 let (text, text_is_truncated) = snippet.declaration.item_text();
122 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
123 text: text.into(),
124 text_is_truncated,
125 signature_range: snippet.declaration.signature_range_in_item_text(),
126 parent_index,
127 score_components: snippet.score_components,
128 signature_score: snippet.scores.signature,
129 declaration_score: snippet.scores.declaration,
130 });
131 }
132 predict_edits_v3::Body {
133 excerpt: self.excerpt_text.body,
134 referenced_declarations,
135 signatures,
136 excerpt_parent,
137 // todo!
138 events: vec![],
139 can_collect_data: false,
140 diagnostic_groups: None,
141 git_info: None,
142 }
143 }
144}
145
146fn add_signature(
147 declaration_id: DeclarationId,
148 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
149 signatures: &mut Vec<Signature>,
150 index: &SyntaxIndexState,
151) -> Option<usize> {
152 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
153 return Some(*signature_index);
154 }
155 let Some(parent_declaration) = index.declaration(declaration_id) else {
156 log::error!("bug: missing parent declaration");
157 return None;
158 };
159 let parent_index = parent_declaration.parent().and_then(|parent| {
160 add_signature(parent, declaration_to_signature_index, signatures, index)
161 });
162 let (text, text_is_truncated) = parent_declaration.signature_text();
163 let signature_index = signatures.len();
164 signatures.push(Signature {
165 text: text.into(),
166 text_is_truncated,
167 parent_index,
168 });
169 declaration_to_signature_index.insert(declaration_id, signature_index);
170 Some(signature_index)
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use std::sync::Arc;
177
178 use gpui::{Entity, TestAppContext};
179 use indoc::indoc;
180 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
181 use project::{FakeFs, Project};
182 use serde_json::json;
183 use settings::SettingsStore;
184 use util::path;
185
186 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
187
188 #[gpui::test]
189 async fn test_call_site(cx: &mut TestAppContext) {
190 let (project, index, _rust_lang_id) = init_test(cx).await;
191
192 let buffer = project
193 .update(cx, |project, cx| {
194 let project_path = project.find_project_path("c.rs", cx).unwrap();
195 project.open_buffer(project_path, cx)
196 })
197 .await
198 .unwrap();
199
200 cx.run_until_parked();
201
202 // first process_data call site
203 let cursor_point = language::Point::new(8, 21);
204 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
205
206 let context = cx
207 .update(|cx| {
208 EditPredictionContext::gather(
209 cursor_point,
210 buffer_snapshot,
211 EditPredictionExcerptOptions {
212 max_bytes: 60,
213 min_bytes: 10,
214 target_before_cursor_over_total_bytes: 0.5,
215 },
216 index,
217 cx,
218 )
219 })
220 .await
221 .unwrap();
222
223 let mut snippet_identifiers = context
224 .snippets
225 .iter()
226 .map(|snippet| snippet.identifier.name.as_ref())
227 .collect::<Vec<_>>();
228 snippet_identifiers.sort();
229 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
230 drop(buffer);
231 }
232
233 async fn init_test(
234 cx: &mut TestAppContext,
235 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
236 cx.update(|cx| {
237 let settings_store = SettingsStore::test(cx);
238 cx.set_global(settings_store);
239 language::init(cx);
240 Project::init_settings(cx);
241 });
242
243 let fs = FakeFs::new(cx.executor());
244 fs.insert_tree(
245 path!("/root"),
246 json!({
247 "a.rs": indoc! {r#"
248 fn main() {
249 let x = 1;
250 let y = 2;
251 let z = add(x, y);
252 println!("Result: {}", z);
253 }
254
255 fn add(a: i32, b: i32) -> i32 {
256 a + b
257 }
258 "#},
259 "b.rs": indoc! {"
260 pub struct Config {
261 pub name: String,
262 pub value: i32,
263 }
264
265 impl Config {
266 pub fn new(name: String, value: i32) -> Self {
267 Config { name, value }
268 }
269 }
270 "},
271 "c.rs": indoc! {r#"
272 use std::collections::HashMap;
273
274 fn main() {
275 let args: Vec<String> = std::env::args().collect();
276 let data: Vec<i32> = args[1..]
277 .iter()
278 .filter_map(|s| s.parse().ok())
279 .collect();
280 let result = process_data(data);
281 println!("{:?}", result);
282 }
283
284 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
285 let mut counts = HashMap::new();
286 for value in data {
287 *counts.entry(value).or_insert(0) += 1;
288 }
289 counts
290 }
291
292 #[cfg(test)]
293 mod tests {
294 use super::*;
295
296 #[test]
297 fn test_process_data() {
298 let data = vec![1, 2, 2, 3];
299 let result = process_data(data);
300 assert_eq!(result.get(&2), Some(&2));
301 }
302 }
303 "#}
304 }),
305 )
306 .await;
307 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
308 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
309 let lang = rust_lang();
310 let lang_id = lang.id();
311 language_registry.add(Arc::new(lang));
312
313 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
314 cx.run_until_parked();
315
316 (project, index, lang_id)
317 }
318
319 fn rust_lang() -> Language {
320 Language::new(
321 LanguageConfig {
322 name: "Rust".into(),
323 matcher: LanguageMatcher {
324 path_suffixes: vec!["rs".to_string()],
325 ..Default::default()
326 },
327 ..Default::default()
328 },
329 Some(tree_sitter_rust::LANGUAGE.into()),
330 )
331 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
332 .unwrap()
333 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
334 .unwrap()
335 }
336}