1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod imports;
5mod outline;
6mod reference;
7mod syntax_index;
8pub mod text_similarity;
9
10use std::{path::Path, sync::Arc};
11
12use cloud_llm_client::predict_edits_v3;
13use collections::HashMap;
14use gpui::{App, AppContext as _, Entity, Task};
15use language::BufferSnapshot;
16use text::{Point, ToOffset as _};
17
18pub use declaration::*;
19pub use declaration_scoring::*;
20pub use excerpt::*;
21pub use imports::*;
22pub use reference::*;
23pub use syntax_index::*;
24
25pub use predict_edits_v3::Line;
26
27#[derive(Clone, Debug, PartialEq)]
28pub struct EditPredictionContextOptions {
29 pub use_imports: bool,
30 pub excerpt: EditPredictionExcerptOptions,
31 pub score: EditPredictionScoreOptions,
32}
33
34#[derive(Clone, Debug)]
35pub struct EditPredictionContext {
36 pub excerpt: EditPredictionExcerpt,
37 pub excerpt_text: EditPredictionExcerptText,
38 pub cursor_point: Point,
39 pub declarations: Vec<ScoredDeclaration>,
40}
41
42impl EditPredictionContext {
43 pub fn gather_context_in_background(
44 cursor_point: Point,
45 buffer: BufferSnapshot,
46 options: EditPredictionContextOptions,
47 syntax_index: Option<Entity<SyntaxIndex>>,
48 cx: &mut App,
49 ) -> Task<Option<Self>> {
50 let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
51 let mut path = f.worktree.read(cx).absolutize(&f.path);
52 if path.pop() { Some(path) } else { None }
53 });
54
55 if let Some(syntax_index) = syntax_index {
56 let index_state =
57 syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
58 cx.background_spawn(async move {
59 let parent_abs_path = parent_abs_path.as_deref();
60 let index_state = index_state.upgrade()?;
61 let index_state = index_state.lock().await;
62 Self::gather_context(
63 cursor_point,
64 &buffer,
65 parent_abs_path,
66 &options,
67 Some(&index_state),
68 )
69 })
70 } else {
71 cx.background_spawn(async move {
72 let parent_abs_path = parent_abs_path.as_deref();
73 Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
74 })
75 }
76 }
77
78 pub fn gather_context(
79 cursor_point: Point,
80 buffer: &BufferSnapshot,
81 parent_abs_path: Option<&Path>,
82 options: &EditPredictionContextOptions,
83 index_state: Option<&SyntaxIndexState>,
84 ) -> Option<Self> {
85 let imports = if options.use_imports {
86 Imports::gather(&buffer, parent_abs_path)
87 } else {
88 Imports::default()
89 };
90 Self::gather_context_with_references_fn(
91 cursor_point,
92 buffer,
93 &imports,
94 options,
95 index_state,
96 references_in_excerpt,
97 )
98 }
99
100 pub fn gather_context_with_references_fn(
101 cursor_point: Point,
102 buffer: &BufferSnapshot,
103 imports: &Imports,
104 options: &EditPredictionContextOptions,
105 index_state: Option<&SyntaxIndexState>,
106 get_references: impl FnOnce(
107 &EditPredictionExcerpt,
108 &EditPredictionExcerptText,
109 &BufferSnapshot,
110 ) -> HashMap<Identifier, Vec<Reference>>,
111 ) -> Option<Self> {
112 let excerpt = EditPredictionExcerpt::select_from_buffer(
113 cursor_point,
114 buffer,
115 &options.excerpt,
116 index_state,
117 )?;
118 let excerpt_text = excerpt.text(buffer);
119 let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
120
121 let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
122 let adjacent_end = Point::new(cursor_point.row + 1, 0);
123 let adjacent_occurrences = text_similarity::Occurrences::within_string(
124 &buffer
125 .text_for_range(adjacent_start..adjacent_end)
126 .collect::<String>(),
127 );
128
129 let cursor_offset_in_file = cursor_point.to_offset(buffer);
130
131 let declarations = if let Some(index_state) = index_state {
132 let references = get_references(&excerpt, &excerpt_text, buffer);
133
134 scored_declarations(
135 &options.score,
136 &index_state,
137 &excerpt,
138 &excerpt_occurrences,
139 &adjacent_occurrences,
140 &imports,
141 references,
142 cursor_offset_in_file,
143 buffer,
144 )
145 } else {
146 vec![]
147 };
148
149 Some(Self {
150 excerpt,
151 excerpt_text,
152 cursor_point,
153 declarations,
154 })
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use std::sync::Arc;
162
163 use gpui::{Entity, TestAppContext};
164 use indoc::indoc;
165 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
166 use project::{FakeFs, Project};
167 use serde_json::json;
168 use settings::SettingsStore;
169 use util::path;
170
171 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
172
173 #[gpui::test]
174 async fn test_call_site(cx: &mut TestAppContext) {
175 let (project, index, _rust_lang_id) = init_test(cx).await;
176
177 let buffer = project
178 .update(cx, |project, cx| {
179 let project_path = project.find_project_path("c.rs", cx).unwrap();
180 project.open_buffer(project_path, cx)
181 })
182 .await
183 .unwrap();
184
185 cx.run_until_parked();
186
187 // first process_data call site
188 let cursor_point = language::Point::new(8, 21);
189 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
190
191 let context = cx
192 .update(|cx| {
193 EditPredictionContext::gather_context_in_background(
194 cursor_point,
195 buffer_snapshot,
196 EditPredictionContextOptions {
197 use_imports: true,
198 excerpt: EditPredictionExcerptOptions {
199 max_bytes: 60,
200 min_bytes: 10,
201 target_before_cursor_over_total_bytes: 0.5,
202 },
203 score: EditPredictionScoreOptions {
204 omit_excerpt_overlaps: true,
205 },
206 },
207 Some(index.clone()),
208 cx,
209 )
210 })
211 .await
212 .unwrap();
213
214 let mut snippet_identifiers = context
215 .declarations
216 .iter()
217 .map(|snippet| snippet.identifier.name.as_ref())
218 .collect::<Vec<_>>();
219 snippet_identifiers.sort();
220 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
221 drop(buffer);
222 }
223
224 async fn init_test(
225 cx: &mut TestAppContext,
226 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
227 cx.update(|cx| {
228 let settings_store = SettingsStore::test(cx);
229 cx.set_global(settings_store);
230 language::init(cx);
231 Project::init_settings(cx);
232 });
233
234 let fs = FakeFs::new(cx.executor());
235 fs.insert_tree(
236 path!("/root"),
237 json!({
238 "a.rs": indoc! {r#"
239 fn main() {
240 let x = 1;
241 let y = 2;
242 let z = add(x, y);
243 println!("Result: {}", z);
244 }
245
246 fn add(a: i32, b: i32) -> i32 {
247 a + b
248 }
249 "#},
250 "b.rs": indoc! {"
251 pub struct Config {
252 pub name: String,
253 pub value: i32,
254 }
255
256 impl Config {
257 pub fn new(name: String, value: i32) -> Self {
258 Config { name, value }
259 }
260 }
261 "},
262 "c.rs": indoc! {r#"
263 use std::collections::HashMap;
264
265 fn main() {
266 let args: Vec<String> = std::env::args().collect();
267 let data: Vec<i32> = args[1..]
268 .iter()
269 .filter_map(|s| s.parse().ok())
270 .collect();
271 let result = process_data(data);
272 println!("{:?}", result);
273 }
274
275 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
276 let mut counts = HashMap::new();
277 for value in data {
278 *counts.entry(value).or_insert(0) += 1;
279 }
280 counts
281 }
282
283 #[cfg(test)]
284 mod tests {
285 use super::*;
286
287 #[test]
288 fn test_process_data() {
289 let data = vec![1, 2, 2, 3];
290 let result = process_data(data);
291 assert_eq!(result.get(&2), Some(&2));
292 }
293 }
294 "#}
295 }),
296 )
297 .await;
298 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
299 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
300 let lang = rust_lang();
301 let lang_id = lang.id();
302 language_registry.add(Arc::new(lang));
303
304 let file_indexing_parallelism = 2;
305 let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
306 cx.run_until_parked();
307
308 (project, index, lang_id)
309 }
310
311 fn rust_lang() -> Language {
312 Language::new(
313 LanguageConfig {
314 name: "Rust".into(),
315 matcher: LanguageMatcher {
316 path_suffixes: vec!["rs".to_string()],
317 ..Default::default()
318 },
319 ..Default::default()
320 },
321 Some(tree_sitter_rust::LANGUAGE.into()),
322 )
323 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
324 .unwrap()
325 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
326 .unwrap()
327 }
328}