1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod imports;
5mod outline;
6mod reference;
7mod syntax_index;
8pub mod text_similarity;
9
10use std::{path::Path, sync::Arc};
11
12use collections::HashMap;
13use gpui::{App, AppContext as _, Entity, Task};
14use language::BufferSnapshot;
15use text::{Point, ToOffset as _};
16
17pub use declaration::*;
18pub use declaration_scoring::*;
19pub use excerpt::*;
20pub use imports::*;
21pub use reference::*;
22pub use syntax_index::*;
23
24#[derive(Clone, Debug, PartialEq)]
25pub struct EditPredictionContextOptions {
26 pub use_imports: bool,
27 pub excerpt: EditPredictionExcerptOptions,
28 pub score: EditPredictionScoreOptions,
29}
30
31#[derive(Clone, Debug)]
32pub struct EditPredictionContext {
33 pub excerpt: EditPredictionExcerpt,
34 pub excerpt_text: EditPredictionExcerptText,
35 pub cursor_offset_in_excerpt: usize,
36 pub declarations: Vec<ScoredDeclaration>,
37}
38
39impl EditPredictionContext {
40 pub fn gather_context_in_background(
41 cursor_point: Point,
42 buffer: BufferSnapshot,
43 options: EditPredictionContextOptions,
44 syntax_index: Option<Entity<SyntaxIndex>>,
45 cx: &mut App,
46 ) -> Task<Option<Self>> {
47 let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
48 let mut path = f.worktree.read(cx).absolutize(&f.path);
49 if path.pop() { Some(path) } else { None }
50 });
51
52 if let Some(syntax_index) = syntax_index {
53 let index_state =
54 syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
55 cx.background_spawn(async move {
56 let parent_abs_path = parent_abs_path.as_deref();
57 let index_state = index_state.upgrade()?;
58 let index_state = index_state.lock().await;
59 Self::gather_context(
60 cursor_point,
61 &buffer,
62 parent_abs_path,
63 &options,
64 Some(&index_state),
65 )
66 })
67 } else {
68 cx.background_spawn(async move {
69 let parent_abs_path = parent_abs_path.as_deref();
70 Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
71 })
72 }
73 }
74
75 pub fn gather_context(
76 cursor_point: Point,
77 buffer: &BufferSnapshot,
78 parent_abs_path: Option<&Path>,
79 options: &EditPredictionContextOptions,
80 index_state: Option<&SyntaxIndexState>,
81 ) -> Option<Self> {
82 let imports = if options.use_imports {
83 Imports::gather(&buffer, parent_abs_path)
84 } else {
85 Imports::default()
86 };
87 Self::gather_context_with_references_fn(
88 cursor_point,
89 buffer,
90 &imports,
91 options,
92 index_state,
93 references_in_excerpt,
94 )
95 }
96
97 pub fn gather_context_with_references_fn(
98 cursor_point: Point,
99 buffer: &BufferSnapshot,
100 imports: &Imports,
101 options: &EditPredictionContextOptions,
102 index_state: Option<&SyntaxIndexState>,
103 get_references: impl FnOnce(
104 &EditPredictionExcerpt,
105 &EditPredictionExcerptText,
106 &BufferSnapshot,
107 ) -> HashMap<Identifier, Vec<Reference>>,
108 ) -> Option<Self> {
109 let excerpt = EditPredictionExcerpt::select_from_buffer(
110 cursor_point,
111 buffer,
112 &options.excerpt,
113 index_state,
114 )?;
115 let excerpt_text = excerpt.text(buffer);
116 let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
117
118 let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
119 let adjacent_end = Point::new(cursor_point.row + 1, 0);
120 let adjacent_occurrences = text_similarity::Occurrences::within_string(
121 &buffer
122 .text_for_range(adjacent_start..adjacent_end)
123 .collect::<String>(),
124 );
125
126 let cursor_offset_in_file = cursor_point.to_offset(buffer);
127 // TODO fix this to not need saturating_sub
128 let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
129
130 let declarations = if let Some(index_state) = index_state {
131 let references = get_references(&excerpt, &excerpt_text, buffer);
132
133 scored_declarations(
134 &options.score,
135 &index_state,
136 &excerpt,
137 &excerpt_occurrences,
138 &adjacent_occurrences,
139 &imports,
140 references,
141 cursor_offset_in_file,
142 buffer,
143 )
144 } else {
145 vec![]
146 };
147
148 Some(Self {
149 excerpt,
150 excerpt_text,
151 cursor_offset_in_excerpt,
152 declarations,
153 })
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use std::sync::Arc;
161
162 use gpui::{Entity, TestAppContext};
163 use indoc::indoc;
164 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
165 use project::{FakeFs, Project};
166 use serde_json::json;
167 use settings::SettingsStore;
168 use util::path;
169
170 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
171
172 #[gpui::test]
173 async fn test_call_site(cx: &mut TestAppContext) {
174 let (project, index, _rust_lang_id) = init_test(cx).await;
175
176 let buffer = project
177 .update(cx, |project, cx| {
178 let project_path = project.find_project_path("c.rs", cx).unwrap();
179 project.open_buffer(project_path, cx)
180 })
181 .await
182 .unwrap();
183
184 cx.run_until_parked();
185
186 // first process_data call site
187 let cursor_point = language::Point::new(8, 21);
188 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
189
190 let context = cx
191 .update(|cx| {
192 EditPredictionContext::gather_context_in_background(
193 cursor_point,
194 buffer_snapshot,
195 EditPredictionContextOptions {
196 use_imports: true,
197 excerpt: EditPredictionExcerptOptions {
198 max_bytes: 60,
199 min_bytes: 10,
200 target_before_cursor_over_total_bytes: 0.5,
201 },
202 score: EditPredictionScoreOptions {
203 omit_excerpt_overlaps: true,
204 },
205 },
206 Some(index.clone()),
207 cx,
208 )
209 })
210 .await
211 .unwrap();
212
213 let mut snippet_identifiers = context
214 .declarations
215 .iter()
216 .map(|snippet| snippet.identifier.name.as_ref())
217 .collect::<Vec<_>>();
218 snippet_identifiers.sort();
219 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
220 drop(buffer);
221 }
222
223 async fn init_test(
224 cx: &mut TestAppContext,
225 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
226 cx.update(|cx| {
227 let settings_store = SettingsStore::test(cx);
228 cx.set_global(settings_store);
229 language::init(cx);
230 Project::init_settings(cx);
231 });
232
233 let fs = FakeFs::new(cx.executor());
234 fs.insert_tree(
235 path!("/root"),
236 json!({
237 "a.rs": indoc! {r#"
238 fn main() {
239 let x = 1;
240 let y = 2;
241 let z = add(x, y);
242 println!("Result: {}", z);
243 }
244
245 fn add(a: i32, b: i32) -> i32 {
246 a + b
247 }
248 "#},
249 "b.rs": indoc! {"
250 pub struct Config {
251 pub name: String,
252 pub value: i32,
253 }
254
255 impl Config {
256 pub fn new(name: String, value: i32) -> Self {
257 Config { name, value }
258 }
259 }
260 "},
261 "c.rs": indoc! {r#"
262 use std::collections::HashMap;
263
264 fn main() {
265 let args: Vec<String> = std::env::args().collect();
266 let data: Vec<i32> = args[1..]
267 .iter()
268 .filter_map(|s| s.parse().ok())
269 .collect();
270 let result = process_data(data);
271 println!("{:?}", result);
272 }
273
274 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
275 let mut counts = HashMap::new();
276 for value in data {
277 *counts.entry(value).or_insert(0) += 1;
278 }
279 counts
280 }
281
282 #[cfg(test)]
283 mod tests {
284 use super::*;
285
286 #[test]
287 fn test_process_data() {
288 let data = vec![1, 2, 2, 3];
289 let result = process_data(data);
290 assert_eq!(result.get(&2), Some(&2));
291 }
292 }
293 "#}
294 }),
295 )
296 .await;
297 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
298 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
299 let lang = rust_lang();
300 let lang_id = lang.id();
301 language_registry.add(Arc::new(lang));
302
303 let file_indexing_parallelism = 2;
304 let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
305 cx.run_until_parked();
306
307 (project, index, lang_id)
308 }
309
310 fn rust_lang() -> Language {
311 Language::new(
312 LanguageConfig {
313 name: "Rust".into(),
314 matcher: LanguageMatcher {
315 path_suffixes: vec!["rs".to_string()],
316 ..Default::default()
317 },
318 ..Default::default()
319 },
320 Some(tree_sitter_rust::LANGUAGE.into()),
321 )
322 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
323 .unwrap()
324 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
325 .unwrap()
326 }
327}