1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7pub mod text_similarity;
8
9use std::sync::Arc;
10
11use collections::HashMap;
12use gpui::{App, AppContext as _, Entity, Task};
13use language::BufferSnapshot;
14use text::{Point, ToOffset as _};
15
16pub use declaration::*;
17pub use declaration_scoring::*;
18pub use excerpt::*;
19pub use reference::*;
20pub use syntax_index::*;
21
22#[derive(Clone, Debug)]
23pub struct EditPredictionContext {
24 pub excerpt: EditPredictionExcerpt,
25 pub excerpt_text: EditPredictionExcerptText,
26 pub cursor_offset_in_excerpt: usize,
27 pub declarations: Vec<ScoredDeclaration>,
28}
29
30impl EditPredictionContext {
31 pub fn gather_context_in_background(
32 cursor_point: Point,
33 buffer: BufferSnapshot,
34 excerpt_options: EditPredictionExcerptOptions,
35 syntax_index: Option<Entity<SyntaxIndex>>,
36 cx: &mut App,
37 ) -> Task<Option<Self>> {
38 if let Some(syntax_index) = syntax_index {
39 let index_state =
40 syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
41 cx.background_spawn(async move {
42 let index_state = index_state.upgrade()?;
43 let index_state = index_state.lock().await;
44 Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
45 })
46 } else {
47 cx.background_spawn(async move {
48 Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
49 })
50 }
51 }
52
53 pub fn gather_context(
54 cursor_point: Point,
55 buffer: &BufferSnapshot,
56 excerpt_options: &EditPredictionExcerptOptions,
57 index_state: Option<&SyntaxIndexState>,
58 ) -> Option<Self> {
59 Self::gather_context_with_references_fn(
60 cursor_point,
61 buffer,
62 excerpt_options,
63 index_state,
64 references_in_excerpt,
65 )
66 }
67
68 pub fn gather_context_with_references_fn(
69 cursor_point: Point,
70 buffer: &BufferSnapshot,
71 excerpt_options: &EditPredictionExcerptOptions,
72 index_state: Option<&SyntaxIndexState>,
73 get_references: impl FnOnce(
74 &EditPredictionExcerpt,
75 &EditPredictionExcerptText,
76 &BufferSnapshot,
77 ) -> HashMap<Identifier, Vec<Reference>>,
78 ) -> Option<Self> {
79 let excerpt = EditPredictionExcerpt::select_from_buffer(
80 cursor_point,
81 buffer,
82 excerpt_options,
83 index_state,
84 )?;
85 let excerpt_text = excerpt.text(buffer);
86 let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
87
88 let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
89 let adjacent_end = Point::new(cursor_point.row + 1, 0);
90 let adjacent_occurrences = text_similarity::Occurrences::within_string(
91 &buffer
92 .text_for_range(adjacent_start..adjacent_end)
93 .collect::<String>(),
94 );
95
96 let cursor_offset_in_file = cursor_point.to_offset(buffer);
97 // TODO fix this to not need saturating_sub
98 let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
99
100 let declarations = if let Some(index_state) = index_state {
101 let references = get_references(&excerpt, &excerpt_text, buffer);
102
103 scored_declarations(
104 &index_state,
105 &excerpt,
106 &excerpt_occurrences,
107 &adjacent_occurrences,
108 references,
109 cursor_offset_in_file,
110 buffer,
111 )
112 } else {
113 vec![]
114 };
115
116 Some(Self {
117 excerpt,
118 excerpt_text,
119 cursor_offset_in_excerpt,
120 declarations,
121 })
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use std::sync::Arc;
129
130 use gpui::{Entity, TestAppContext};
131 use indoc::indoc;
132 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
133 use project::{FakeFs, Project};
134 use serde_json::json;
135 use settings::SettingsStore;
136 use util::path;
137
138 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
139
140 #[gpui::test]
141 async fn test_call_site(cx: &mut TestAppContext) {
142 let (project, index, _rust_lang_id) = init_test(cx).await;
143
144 let buffer = project
145 .update(cx, |project, cx| {
146 let project_path = project.find_project_path("c.rs", cx).unwrap();
147 project.open_buffer(project_path, cx)
148 })
149 .await
150 .unwrap();
151
152 cx.run_until_parked();
153
154 // first process_data call site
155 let cursor_point = language::Point::new(8, 21);
156 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
157
158 let context = cx
159 .update(|cx| {
160 EditPredictionContext::gather_context_in_background(
161 cursor_point,
162 buffer_snapshot,
163 EditPredictionExcerptOptions {
164 max_bytes: 60,
165 min_bytes: 10,
166 target_before_cursor_over_total_bytes: 0.5,
167 },
168 Some(index),
169 cx,
170 )
171 })
172 .await
173 .unwrap();
174
175 let mut snippet_identifiers = context
176 .declarations
177 .iter()
178 .map(|snippet| snippet.identifier.name.as_ref())
179 .collect::<Vec<_>>();
180 snippet_identifiers.sort();
181 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
182 drop(buffer);
183 }
184
185 async fn init_test(
186 cx: &mut TestAppContext,
187 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
188 cx.update(|cx| {
189 let settings_store = SettingsStore::test(cx);
190 cx.set_global(settings_store);
191 language::init(cx);
192 Project::init_settings(cx);
193 });
194
195 let fs = FakeFs::new(cx.executor());
196 fs.insert_tree(
197 path!("/root"),
198 json!({
199 "a.rs": indoc! {r#"
200 fn main() {
201 let x = 1;
202 let y = 2;
203 let z = add(x, y);
204 println!("Result: {}", z);
205 }
206
207 fn add(a: i32, b: i32) -> i32 {
208 a + b
209 }
210 "#},
211 "b.rs": indoc! {"
212 pub struct Config {
213 pub name: String,
214 pub value: i32,
215 }
216
217 impl Config {
218 pub fn new(name: String, value: i32) -> Self {
219 Config { name, value }
220 }
221 }
222 "},
223 "c.rs": indoc! {r#"
224 use std::collections::HashMap;
225
226 fn main() {
227 let args: Vec<String> = std::env::args().collect();
228 let data: Vec<i32> = args[1..]
229 .iter()
230 .filter_map(|s| s.parse().ok())
231 .collect();
232 let result = process_data(data);
233 println!("{:?}", result);
234 }
235
236 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
237 let mut counts = HashMap::new();
238 for value in data {
239 *counts.entry(value).or_insert(0) += 1;
240 }
241 counts
242 }
243
244 #[cfg(test)]
245 mod tests {
246 use super::*;
247
248 #[test]
249 fn test_process_data() {
250 let data = vec![1, 2, 2, 3];
251 let result = process_data(data);
252 assert_eq!(result.get(&2), Some(&2));
253 }
254 }
255 "#}
256 }),
257 )
258 .await;
259 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
260 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
261 let lang = rust_lang();
262 let lang_id = lang.id();
263 language_registry.add(Arc::new(lang));
264
265 let file_indexing_parallelism = 2;
266 let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
267 cx.run_until_parked();
268
269 (project, index, lang_id)
270 }
271
272 fn rust_lang() -> Language {
273 Language::new(
274 LanguageConfig {
275 name: "Rust".into(),
276 matcher: LanguageMatcher {
277 path_suffixes: vec!["rs".to_string()],
278 ..Default::default()
279 },
280 ..Default::default()
281 },
282 Some(tree_sitter_rust::LANGUAGE.into()),
283 )
284 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
285 .unwrap()
286 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
287 .unwrap()
288 }
289}