1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod imports;
5mod outline;
6mod reference;
7mod syntax_index;
8pub mod text_similarity;
9
10use std::{path::Path, sync::Arc};
11
12use cloud_llm_client::predict_edits_v3;
13use collections::HashMap;
14use gpui::{App, AppContext as _, Entity, Task};
15use language::BufferSnapshot;
16use text::{Point, ToOffset as _};
17
18pub use declaration::*;
19pub use declaration_scoring::*;
20pub use excerpt::*;
21pub use imports::*;
22pub use reference::*;
23pub use syntax_index::*;
24
25pub use predict_edits_v3::Line;
26
27#[derive(Clone, Debug, PartialEq)]
28pub struct EditPredictionContextOptions {
29 pub use_imports: bool,
30 pub excerpt: EditPredictionExcerptOptions,
31 pub score: EditPredictionScoreOptions,
32 pub max_retrieved_declarations: u8,
33}
34
35#[derive(Clone, Debug)]
36pub struct EditPredictionContext {
37 pub excerpt: EditPredictionExcerpt,
38 pub excerpt_text: EditPredictionExcerptText,
39 pub cursor_point: Point,
40 pub declarations: Vec<ScoredDeclaration>,
41}
42
43impl EditPredictionContext {
44 pub fn gather_context_in_background(
45 cursor_point: Point,
46 buffer: BufferSnapshot,
47 options: EditPredictionContextOptions,
48 syntax_index: Option<Entity<SyntaxIndex>>,
49 cx: &mut App,
50 ) -> Task<Option<Self>> {
51 let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
52 let mut path = f.worktree.read(cx).absolutize(&f.path);
53 if path.pop() { Some(path) } else { None }
54 });
55
56 if let Some(syntax_index) = syntax_index {
57 let index_state =
58 syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
59 cx.background_spawn(async move {
60 let parent_abs_path = parent_abs_path.as_deref();
61 let index_state = index_state.upgrade()?;
62 let index_state = index_state.lock().await;
63 Self::gather_context(
64 cursor_point,
65 &buffer,
66 parent_abs_path,
67 &options,
68 Some(&index_state),
69 )
70 })
71 } else {
72 cx.background_spawn(async move {
73 let parent_abs_path = parent_abs_path.as_deref();
74 Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
75 })
76 }
77 }
78
79 pub fn gather_context(
80 cursor_point: Point,
81 buffer: &BufferSnapshot,
82 parent_abs_path: Option<&Path>,
83 options: &EditPredictionContextOptions,
84 index_state: Option<&SyntaxIndexState>,
85 ) -> Option<Self> {
86 let imports = if options.use_imports {
87 Imports::gather(&buffer, parent_abs_path)
88 } else {
89 Imports::default()
90 };
91 Self::gather_context_with_references_fn(
92 cursor_point,
93 buffer,
94 &imports,
95 options,
96 index_state,
97 references_in_excerpt,
98 )
99 }
100
101 pub fn gather_context_with_references_fn(
102 cursor_point: Point,
103 buffer: &BufferSnapshot,
104 imports: &Imports,
105 options: &EditPredictionContextOptions,
106 index_state: Option<&SyntaxIndexState>,
107 get_references: impl FnOnce(
108 &EditPredictionExcerpt,
109 &EditPredictionExcerptText,
110 &BufferSnapshot,
111 ) -> HashMap<Identifier, Vec<Reference>>,
112 ) -> Option<Self> {
113 let excerpt = EditPredictionExcerpt::select_from_buffer(
114 cursor_point,
115 buffer,
116 &options.excerpt,
117 index_state,
118 )?;
119 let excerpt_text = excerpt.text(buffer);
120
121 let declarations = if options.max_retrieved_declarations > 0
122 && let Some(index_state) = index_state
123 {
124 let excerpt_occurrences =
125 text_similarity::Occurrences::within_string(&excerpt_text.body);
126
127 let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
128 let adjacent_end = Point::new(cursor_point.row + 1, 0);
129 let adjacent_occurrences = text_similarity::Occurrences::within_string(
130 &buffer
131 .text_for_range(adjacent_start..adjacent_end)
132 .collect::<String>(),
133 );
134
135 let cursor_offset_in_file = cursor_point.to_offset(buffer);
136
137 let references = get_references(&excerpt, &excerpt_text, buffer);
138
139 let mut declarations = scored_declarations(
140 &options.score,
141 &index_state,
142 &excerpt,
143 &excerpt_occurrences,
144 &adjacent_occurrences,
145 &imports,
146 references,
147 cursor_offset_in_file,
148 buffer,
149 );
150 // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
151 declarations.truncate(options.max_retrieved_declarations as usize);
152 declarations
153 } else {
154 vec![]
155 };
156
157 Some(Self {
158 excerpt,
159 excerpt_text,
160 cursor_point,
161 declarations,
162 })
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use std::sync::Arc;
170
171 use gpui::{Entity, TestAppContext};
172 use indoc::indoc;
173 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
174 use project::{FakeFs, Project};
175 use serde_json::json;
176 use settings::SettingsStore;
177 use util::path;
178
179 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
180
181 #[gpui::test]
182 async fn test_call_site(cx: &mut TestAppContext) {
183 let (project, index, _rust_lang_id) = init_test(cx).await;
184
185 let buffer = project
186 .update(cx, |project, cx| {
187 let project_path = project.find_project_path("c.rs", cx).unwrap();
188 project.open_buffer(project_path, cx)
189 })
190 .await
191 .unwrap();
192
193 cx.run_until_parked();
194
195 // first process_data call site
196 let cursor_point = language::Point::new(8, 21);
197 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
198
199 let context = cx
200 .update(|cx| {
201 EditPredictionContext::gather_context_in_background(
202 cursor_point,
203 buffer_snapshot,
204 EditPredictionContextOptions {
205 use_imports: true,
206 excerpt: EditPredictionExcerptOptions {
207 max_bytes: 60,
208 min_bytes: 10,
209 target_before_cursor_over_total_bytes: 0.5,
210 },
211 score: EditPredictionScoreOptions {
212 omit_excerpt_overlaps: true,
213 },
214 max_retrieved_declarations: u8::MAX,
215 },
216 Some(index.clone()),
217 cx,
218 )
219 })
220 .await
221 .unwrap();
222
223 let mut snippet_identifiers = context
224 .declarations
225 .iter()
226 .map(|snippet| snippet.identifier.name.as_ref())
227 .collect::<Vec<_>>();
228 snippet_identifiers.sort();
229 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
230 drop(buffer);
231 }
232
233 async fn init_test(
234 cx: &mut TestAppContext,
235 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
236 cx.update(|cx| {
237 let settings_store = SettingsStore::test(cx);
238 cx.set_global(settings_store);
239 language::init(cx);
240 Project::init_settings(cx);
241 });
242
243 let fs = FakeFs::new(cx.executor());
244 fs.insert_tree(
245 path!("/root"),
246 json!({
247 "a.rs": indoc! {r#"
248 fn main() {
249 let x = 1;
250 let y = 2;
251 let z = add(x, y);
252 println!("Result: {}", z);
253 }
254
255 fn add(a: i32, b: i32) -> i32 {
256 a + b
257 }
258 "#},
259 "b.rs": indoc! {"
260 pub struct Config {
261 pub name: String,
262 pub value: i32,
263 }
264
265 impl Config {
266 pub fn new(name: String, value: i32) -> Self {
267 Config { name, value }
268 }
269 }
270 "},
271 "c.rs": indoc! {r#"
272 use std::collections::HashMap;
273
274 fn main() {
275 let args: Vec<String> = std::env::args().collect();
276 let data: Vec<i32> = args[1..]
277 .iter()
278 .filter_map(|s| s.parse().ok())
279 .collect();
280 let result = process_data(data);
281 println!("{:?}", result);
282 }
283
284 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
285 let mut counts = HashMap::new();
286 for value in data {
287 *counts.entry(value).or_insert(0) += 1;
288 }
289 counts
290 }
291
292 #[cfg(test)]
293 mod tests {
294 use super::*;
295
296 #[test]
297 fn test_process_data() {
298 let data = vec![1, 2, 2, 3];
299 let result = process_data(data);
300 assert_eq!(result.get(&2), Some(&2));
301 }
302 }
303 "#}
304 }),
305 )
306 .await;
307 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
308 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
309 let lang = rust_lang();
310 let lang_id = lang.id();
311 language_registry.add(Arc::new(lang));
312
313 let file_indexing_parallelism = 2;
314 let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
315 cx.run_until_parked();
316
317 (project, index, lang_id)
318 }
319
320 fn rust_lang() -> Language {
321 Language::new(
322 LanguageConfig {
323 name: "Rust".into(),
324 matcher: LanguageMatcher {
325 path_suffixes: vec!["rs".to_string()],
326 ..Default::default()
327 },
328 ..Default::default()
329 },
330 Some(tree_sitter_rust::LANGUAGE.into()),
331 )
332 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
333 .unwrap()
334 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
335 .unwrap()
336 }
337}