1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9use std::sync::Arc;
10
11use gpui::{App, AppContext as _, Entity, Task};
12use language::BufferSnapshot;
13use text::{Point, ToOffset as _};
14
15pub use declaration::*;
16pub use declaration_scoring::*;
17pub use excerpt::*;
18pub use reference::*;
19pub use syntax_index::*;
20
21#[derive(Clone, Debug)]
22pub struct EditPredictionContext {
23 pub excerpt: EditPredictionExcerpt,
24 pub excerpt_text: EditPredictionExcerptText,
25 pub cursor_offset_in_excerpt: usize,
26 pub declarations: Vec<ScoredDeclaration>,
27}
28
29impl EditPredictionContext {
30 pub fn gather_context_in_background(
31 cursor_point: Point,
32 buffer: BufferSnapshot,
33 excerpt_options: EditPredictionExcerptOptions,
34 syntax_index: Option<Entity<SyntaxIndex>>,
35 cx: &mut App,
36 ) -> Task<Option<Self>> {
37 if let Some(syntax_index) = syntax_index {
38 let index_state =
39 syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
40 cx.background_spawn(async move {
41 let index_state = index_state.upgrade()?;
42 let index_state = index_state.lock().await;
43 Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
44 })
45 } else {
46 cx.background_spawn(async move {
47 Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
48 })
49 }
50 }
51
52 pub fn gather_context(
53 cursor_point: Point,
54 buffer: &BufferSnapshot,
55 excerpt_options: &EditPredictionExcerptOptions,
56 index_state: Option<&SyntaxIndexState>,
57 ) -> Option<Self> {
58 let excerpt = EditPredictionExcerpt::select_from_buffer(
59 cursor_point,
60 buffer,
61 excerpt_options,
62 index_state,
63 )?;
64 let excerpt_text = excerpt.text(buffer);
65 let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
66
67 let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
68 let adjacent_end = Point::new(cursor_point.row + 1, 0);
69 let adjacent_occurrences = text_similarity::Occurrences::within_string(
70 &buffer
71 .text_for_range(adjacent_start..adjacent_end)
72 .collect::<String>(),
73 );
74
75 let cursor_offset_in_file = cursor_point.to_offset(buffer);
76 // TODO fix this to not need saturating_sub
77 let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
78
79 let declarations = if let Some(index_state) = index_state {
80 let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
81
82 scored_declarations(
83 &index_state,
84 &excerpt,
85 &excerpt_occurrences,
86 &adjacent_occurrences,
87 references,
88 cursor_offset_in_file,
89 buffer,
90 )
91 } else {
92 vec![]
93 };
94
95 Some(Self {
96 excerpt,
97 excerpt_text,
98 cursor_offset_in_excerpt,
99 declarations,
100 })
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use std::sync::Arc;
108
109 use gpui::{Entity, TestAppContext};
110 use indoc::indoc;
111 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
112 use project::{FakeFs, Project};
113 use serde_json::json;
114 use settings::SettingsStore;
115 use util::path;
116
117 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
118
119 #[gpui::test]
120 async fn test_call_site(cx: &mut TestAppContext) {
121 let (project, index, _rust_lang_id) = init_test(cx).await;
122
123 let buffer = project
124 .update(cx, |project, cx| {
125 let project_path = project.find_project_path("c.rs", cx).unwrap();
126 project.open_buffer(project_path, cx)
127 })
128 .await
129 .unwrap();
130
131 cx.run_until_parked();
132
133 // first process_data call site
134 let cursor_point = language::Point::new(8, 21);
135 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
136
137 let context = cx
138 .update(|cx| {
139 EditPredictionContext::gather_context_in_background(
140 cursor_point,
141 buffer_snapshot,
142 EditPredictionExcerptOptions {
143 max_bytes: 60,
144 min_bytes: 10,
145 target_before_cursor_over_total_bytes: 0.5,
146 },
147 Some(index),
148 cx,
149 )
150 })
151 .await
152 .unwrap();
153
154 let mut snippet_identifiers = context
155 .declarations
156 .iter()
157 .map(|snippet| snippet.identifier.name.as_ref())
158 .collect::<Vec<_>>();
159 snippet_identifiers.sort();
160 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
161 drop(buffer);
162 }
163
164 async fn init_test(
165 cx: &mut TestAppContext,
166 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
167 cx.update(|cx| {
168 let settings_store = SettingsStore::test(cx);
169 cx.set_global(settings_store);
170 language::init(cx);
171 Project::init_settings(cx);
172 });
173
174 let fs = FakeFs::new(cx.executor());
175 fs.insert_tree(
176 path!("/root"),
177 json!({
178 "a.rs": indoc! {r#"
179 fn main() {
180 let x = 1;
181 let y = 2;
182 let z = add(x, y);
183 println!("Result: {}", z);
184 }
185
186 fn add(a: i32, b: i32) -> i32 {
187 a + b
188 }
189 "#},
190 "b.rs": indoc! {"
191 pub struct Config {
192 pub name: String,
193 pub value: i32,
194 }
195
196 impl Config {
197 pub fn new(name: String, value: i32) -> Self {
198 Config { name, value }
199 }
200 }
201 "},
202 "c.rs": indoc! {r#"
203 use std::collections::HashMap;
204
205 fn main() {
206 let args: Vec<String> = std::env::args().collect();
207 let data: Vec<i32> = args[1..]
208 .iter()
209 .filter_map(|s| s.parse().ok())
210 .collect();
211 let result = process_data(data);
212 println!("{:?}", result);
213 }
214
215 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
216 let mut counts = HashMap::new();
217 for value in data {
218 *counts.entry(value).or_insert(0) += 1;
219 }
220 counts
221 }
222
223 #[cfg(test)]
224 mod tests {
225 use super::*;
226
227 #[test]
228 fn test_process_data() {
229 let data = vec![1, 2, 2, 3];
230 let result = process_data(data);
231 assert_eq!(result.get(&2), Some(&2));
232 }
233 }
234 "#}
235 }),
236 )
237 .await;
238 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
239 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
240 let lang = rust_lang();
241 let lang_id = lang.id();
242 language_registry.add(Arc::new(lang));
243
244 let file_indexing_parallelism = 2;
245 let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
246 cx.run_until_parked();
247
248 (project, index, lang_id)
249 }
250
251 fn rust_lang() -> Language {
252 Language::new(
253 LanguageConfig {
254 name: "Rust".into(),
255 matcher: LanguageMatcher {
256 path_suffixes: vec!["rs".to_string()],
257 ..Default::default()
258 },
259 ..Default::default()
260 },
261 Some(tree_sitter_rust::LANGUAGE.into()),
262 )
263 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
264 .unwrap()
265 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
266 .unwrap()
267 }
268}