1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod imports;
5mod outline;
6mod reference;
7mod similar_snippets;
8mod syntax_index;
9pub mod text_similarity;
10
11use cloud_llm_client::predict_edits_v3;
12use collections::HashMap;
13use gpui::{App, AppContext as _, Entity, Task};
14use language::BufferSnapshot;
15use std::{path::Path, sync::Arc, time::Instant};
16use text::{Point, ToOffset as _};
17
18pub use declaration::*;
19pub use declaration_scoring::*;
20pub use excerpt::*;
21pub use imports::*;
22pub use reference::*;
23pub use similar_snippets::*;
24pub use syntax_index::*;
25pub use text_similarity::*;
26
27pub use predict_edits_v3::Line;
28
29#[derive(Clone, Debug, PartialEq)]
30pub struct EditPredictionContextOptions {
31 pub use_imports: bool,
32 pub excerpt: EditPredictionExcerptOptions,
33 pub score: EditPredictionScoreOptions,
34 pub similar_snippets: SimilarSnippetOptions,
35 pub max_retrieved_declarations: u8,
36}
37
38#[derive(Clone, Debug)]
39pub struct EditPredictionContext {
40 pub excerpt: EditPredictionExcerpt,
41 pub excerpt_text: EditPredictionExcerptText,
42 pub cursor_point: Point,
43 pub declarations: Vec<ScoredDeclaration>,
44 pub similar_snippets: Vec<SimilarSnippet>,
45}
46
47impl EditPredictionContext {
48 pub fn gather_context_in_background(
49 cursor_point: Point,
50 buffer: BufferSnapshot,
51 options: EditPredictionContextOptions,
52 syntax_index: Option<Entity<SyntaxIndex>>,
53 cx: &mut App,
54 ) -> Task<Option<Self>> {
55 let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
56 let mut path = f.worktree.read(cx).absolutize(&f.path);
57 if path.pop() { Some(path) } else { None }
58 });
59
60 if let Some(syntax_index) = syntax_index {
61 let index_state =
62 syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
63 cx.background_spawn(async move {
64 let parent_abs_path = parent_abs_path.as_deref();
65 let index_state = index_state.upgrade()?;
66 let index_state = index_state.lock().await;
67 Self::gather_context(
68 cursor_point,
69 &buffer,
70 parent_abs_path,
71 &options,
72 Some(&index_state),
73 )
74 })
75 } else {
76 cx.background_spawn(async move {
77 let parent_abs_path = parent_abs_path.as_deref();
78 Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
79 })
80 }
81 }
82
83 pub fn gather_context(
84 cursor_point: Point,
85 buffer: &BufferSnapshot,
86 parent_abs_path: Option<&Path>,
87 options: &EditPredictionContextOptions,
88 index_state: Option<&SyntaxIndexState>,
89 ) -> Option<Self> {
90 let imports = if options.use_imports {
91 Imports::gather(&buffer, parent_abs_path)
92 } else {
93 Imports::default()
94 };
95 Self::gather_context_with_references_fn(
96 cursor_point,
97 buffer,
98 &imports,
99 options,
100 index_state,
101 references_in_excerpt,
102 )
103 }
104
105 pub fn gather_context_with_references_fn(
106 cursor_point: Point,
107 buffer: &BufferSnapshot,
108 imports: &Imports,
109 options: &EditPredictionContextOptions,
110 index_state: Option<&SyntaxIndexState>,
111 get_references: impl FnOnce(
112 &EditPredictionExcerpt,
113 &EditPredictionExcerptText,
114 &BufferSnapshot,
115 ) -> HashMap<Identifier, Vec<Reference>>,
116 ) -> Option<Self> {
117 let excerpt = EditPredictionExcerpt::select_from_buffer(
118 cursor_point,
119 buffer,
120 &options.excerpt,
121 index_state,
122 )?;
123 let excerpt_text = excerpt.text(buffer);
124
125 let declarations = if options.max_retrieved_declarations > 0
126 && let Some(index_state) = index_state
127 {
128 let excerpt_occurrences =
129 Occurrences::new(IdentifierParts::occurrences_in_str(&excerpt_text.body));
130 let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
131 let adjacent_end = Point::new(cursor_point.row + 1, 0);
132 let adjacent_occurrences = Occurrences::new(IdentifierParts::occurrences_in_str(
133 &buffer
134 .text_for_range(adjacent_start..adjacent_end)
135 .collect::<String>(),
136 ));
137
138 let cursor_offset_in_file = cursor_point.to_offset(buffer);
139
140 let references = get_references(&excerpt, &excerpt_text, buffer);
141
142 let mut declarations = scored_declarations(
143 &options.score,
144 &index_state,
145 &excerpt,
146 &excerpt_occurrences,
147 &adjacent_occurrences,
148 &imports,
149 references,
150 cursor_offset_in_file,
151 buffer,
152 );
153 // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
154 declarations.truncate(options.max_retrieved_declarations as usize);
155 declarations
156 } else {
157 vec![]
158 };
159
160 let similar_snippets = if options.similar_snippets.max_result_count > 0 {
161 let before = Instant::now();
162 let excerpt_trigram_occurrences: Occurrences<NGram<3, CodeParts>> =
163 Occurrences::new(NGram::occurrences_in_str(&excerpt_text.body));
164 let similar_snippets = similar_snippets(
165 &excerpt_trigram_occurrences,
166 excerpt.range.clone(),
167 buffer,
168 &options.similar_snippets,
169 );
170 dbg!(before.elapsed());
171 similar_snippets
172 } else {
173 Vec::new()
174 };
175
176 // buffer.debug(&excerpt.range, "excerpt");
177
178 // buffer.debug(
179 // &similar_snippets
180 // .iter()
181 // .map(|s| s.range.clone())
182 // .collect::<Vec<_>>(),
183 // "similar_snippets",
184 // );
185
186 Some(Self {
187 excerpt,
188 excerpt_text,
189 cursor_point,
190 declarations,
191 similar_snippets,
192 })
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use std::sync::Arc;
200
201 use gpui::{Entity, TestAppContext};
202 use indoc::indoc;
203 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
204 use project::{FakeFs, Project};
205 use serde_json::json;
206 use settings::SettingsStore;
207 use util::path;
208
209 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
210
211 #[gpui::test]
212 async fn test_call_site(cx: &mut TestAppContext) {
213 let (project, index, _rust_lang_id) = init_test(cx).await;
214
215 let buffer = project
216 .update(cx, |project, cx| {
217 let project_path = project.find_project_path("c.rs", cx).unwrap();
218 project.open_buffer(project_path, cx)
219 })
220 .await
221 .unwrap();
222
223 cx.run_until_parked();
224
225 // first process_data call site
226 let cursor_point = language::Point::new(8, 21);
227 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
228
229 let context = cx
230 .update(|cx| {
231 EditPredictionContext::gather_context_in_background(
232 cursor_point,
233 buffer_snapshot,
234 EditPredictionContextOptions {
235 use_imports: true,
236 excerpt: EditPredictionExcerptOptions {
237 max_bytes: 60,
238 min_bytes: 10,
239 target_before_cursor_over_total_bytes: 0.5,
240 },
241 score: EditPredictionScoreOptions {
242 omit_excerpt_overlaps: true,
243 },
244 similar_snippets: SimilarSnippetOptions::default(),
245 max_retrieved_declarations: u8::MAX,
246 },
247 Some(index.clone()),
248 cx,
249 )
250 })
251 .await
252 .unwrap();
253
254 let mut snippet_identifiers = context
255 .declarations
256 .iter()
257 .map(|snippet| snippet.identifier.name.as_ref())
258 .collect::<Vec<_>>();
259 snippet_identifiers.sort();
260 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
261 drop(buffer);
262 }
263
264 async fn init_test(
265 cx: &mut TestAppContext,
266 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
267 cx.update(|cx| {
268 let settings_store = SettingsStore::test(cx);
269 cx.set_global(settings_store);
270 language::init(cx);
271 Project::init_settings(cx);
272 });
273
274 let fs = FakeFs::new(cx.executor());
275 fs.insert_tree(
276 path!("/root"),
277 json!({
278 "a.rs": indoc! {r#"
279 fn main() {
280 let x = 1;
281 let y = 2;
282 let z = add(x, y);
283 println!("Result: {}", z);
284 }
285
286 fn add(a: i32, b: i32) -> i32 {
287 a + b
288 }
289 "#},
290 "b.rs": indoc! {"
291 pub struct Config {
292 pub name: String,
293 pub value: i32,
294 }
295
296 impl Config {
297 pub fn new(name: String, value: i32) -> Self {
298 Config { name, value }
299 }
300 }
301 "},
302 "c.rs": indoc! {r#"
303 use std::collections::HashMap;
304
305 fn main() {
306 let args: Vec<String> = std::env::args().collect();
307 let data: Vec<i32> = args[1..]
308 .iter()
309 .filter_map(|s| s.parse().ok())
310 .collect();
311 let result = process_data(data);
312 println!("{:?}", result);
313 }
314
315 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
316 let mut counts = HashMap::new();
317 for value in data {
318 *counts.entry(value).or_insert(0) += 1;
319 }
320 counts
321 }
322
323 #[cfg(test)]
324 mod tests {
325 use super::*;
326
327 #[test]
328 fn test_process_data() {
329 let data = vec![1, 2, 2, 3];
330 let result = process_data(data);
331 assert_eq!(result.get(&2), Some(&2));
332 }
333 }
334 "#}
335 }),
336 )
337 .await;
338 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
339 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
340 let lang = rust_lang();
341 let lang_id = lang.id();
342 language_registry.add(Arc::new(lang));
343
344 let file_indexing_parallelism = 2;
345 let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
346 cx.run_until_parked();
347
348 (project, index, lang_id)
349 }
350
351 fn rust_lang() -> Language {
352 Language::new(
353 LanguageConfig {
354 name: "Rust".into(),
355 matcher: LanguageMatcher {
356 path_suffixes: vec!["rs".to_string()],
357 ..Default::default()
358 },
359 ..Default::default()
360 },
361 Some(tree_sitter_rust::LANGUAGE.into()),
362 )
363 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
364 .unwrap()
365 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
366 .unwrap()
367 }
368}