1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9use gpui::{App, AppContext as _, Entity, Task};
10use language::BufferSnapshot;
11use text::{Point, ToOffset as _};
12
13pub use declaration::*;
14pub use declaration_scoring::*;
15pub use excerpt::*;
16pub use reference::*;
17pub use syntax_index::*;
18
19#[derive(Clone, Debug)]
20pub struct EditPredictionContext {
21 pub excerpt: EditPredictionExcerpt,
22 pub excerpt_text: EditPredictionExcerptText,
23 pub cursor_offset_in_excerpt: usize,
24 pub snippets: Vec<ScoredSnippet>,
25}
26
27impl EditPredictionContext {
28 pub fn gather_context_in_background(
29 cursor_point: Point,
30 buffer: BufferSnapshot,
31 excerpt_options: EditPredictionExcerptOptions,
32 syntax_index: Option<Entity<SyntaxIndex>>,
33 cx: &mut App,
34 ) -> Task<Option<Self>> {
35 if let Some(syntax_index) = syntax_index {
36 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
37 cx.background_spawn(async move {
38 let index_state = index_state.lock().await;
39 Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
40 })
41 } else {
42 cx.background_spawn(async move {
43 Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
44 })
45 }
46 }
47
48 pub fn gather_context(
49 cursor_point: Point,
50 buffer: &BufferSnapshot,
51 excerpt_options: &EditPredictionExcerptOptions,
52 index_state: Option<&SyntaxIndexState>,
53 ) -> Option<Self> {
54 let excerpt = EditPredictionExcerpt::select_from_buffer(
55 cursor_point,
56 buffer,
57 excerpt_options,
58 index_state,
59 )?;
60 let excerpt_text = excerpt.text(buffer);
61 let cursor_offset_in_file = cursor_point.to_offset(buffer);
62 // TODO fix this to not need saturating_sub
63 let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
64
65 let snippets = if let Some(index_state) = index_state {
66 let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
67
68 scored_snippets(
69 &index_state,
70 &excerpt,
71 &excerpt_text,
72 references,
73 cursor_offset_in_file,
74 buffer,
75 )
76 } else {
77 vec![]
78 };
79
80 Some(Self {
81 excerpt,
82 excerpt_text,
83 cursor_offset_in_excerpt,
84 snippets,
85 })
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use std::sync::Arc;
93
94 use gpui::{Entity, TestAppContext};
95 use indoc::indoc;
96 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
97 use project::{FakeFs, Project};
98 use serde_json::json;
99 use settings::SettingsStore;
100 use util::path;
101
102 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
103
104 #[gpui::test]
105 async fn test_call_site(cx: &mut TestAppContext) {
106 let (project, index, _rust_lang_id) = init_test(cx).await;
107
108 let buffer = project
109 .update(cx, |project, cx| {
110 let project_path = project.find_project_path("c.rs", cx).unwrap();
111 project.open_buffer(project_path, cx)
112 })
113 .await
114 .unwrap();
115
116 cx.run_until_parked();
117
118 // first process_data call site
119 let cursor_point = language::Point::new(8, 21);
120 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
121
122 let context = cx
123 .update(|cx| {
124 EditPredictionContext::gather_context_in_background(
125 cursor_point,
126 buffer_snapshot,
127 EditPredictionExcerptOptions {
128 max_bytes: 60,
129 min_bytes: 10,
130 target_before_cursor_over_total_bytes: 0.5,
131 },
132 Some(index),
133 cx,
134 )
135 })
136 .await
137 .unwrap();
138
139 let mut snippet_identifiers = context
140 .snippets
141 .iter()
142 .map(|snippet| snippet.identifier.name.as_ref())
143 .collect::<Vec<_>>();
144 snippet_identifiers.sort();
145 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
146 drop(buffer);
147 }
148
149 async fn init_test(
150 cx: &mut TestAppContext,
151 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
152 cx.update(|cx| {
153 let settings_store = SettingsStore::test(cx);
154 cx.set_global(settings_store);
155 language::init(cx);
156 Project::init_settings(cx);
157 });
158
159 let fs = FakeFs::new(cx.executor());
160 fs.insert_tree(
161 path!("/root"),
162 json!({
163 "a.rs": indoc! {r#"
164 fn main() {
165 let x = 1;
166 let y = 2;
167 let z = add(x, y);
168 println!("Result: {}", z);
169 }
170
171 fn add(a: i32, b: i32) -> i32 {
172 a + b
173 }
174 "#},
175 "b.rs": indoc! {"
176 pub struct Config {
177 pub name: String,
178 pub value: i32,
179 }
180
181 impl Config {
182 pub fn new(name: String, value: i32) -> Self {
183 Config { name, value }
184 }
185 }
186 "},
187 "c.rs": indoc! {r#"
188 use std::collections::HashMap;
189
190 fn main() {
191 let args: Vec<String> = std::env::args().collect();
192 let data: Vec<i32> = args[1..]
193 .iter()
194 .filter_map(|s| s.parse().ok())
195 .collect();
196 let result = process_data(data);
197 println!("{:?}", result);
198 }
199
200 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
201 let mut counts = HashMap::new();
202 for value in data {
203 *counts.entry(value).or_insert(0) += 1;
204 }
205 counts
206 }
207
208 #[cfg(test)]
209 mod tests {
210 use super::*;
211
212 #[test]
213 fn test_process_data() {
214 let data = vec![1, 2, 2, 3];
215 let result = process_data(data);
216 assert_eq!(result.get(&2), Some(&2));
217 }
218 }
219 "#}
220 }),
221 )
222 .await;
223 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
224 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
225 let lang = rust_lang();
226 let lang_id = lang.id();
227 language_registry.add(Arc::new(lang));
228
229 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
230 cx.run_until_parked();
231
232 (project, index, lang_id)
233 }
234
235 fn rust_lang() -> Language {
236 Language::new(
237 LanguageConfig {
238 name: "Rust".into(),
239 matcher: LanguageMatcher {
240 path_suffixes: vec!["rs".to_string()],
241 ..Default::default()
242 },
243 ..Default::default()
244 },
245 Some(tree_sitter_rust::LANGUAGE.into()),
246 )
247 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
248 .unwrap()
249 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
250 .unwrap()
251 }
252}