1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9use gpui::{App, AppContext as _, Entity, Task};
10use language::BufferSnapshot;
11use text::{Point, ToOffset as _};
12
13pub use declaration::*;
14pub use declaration_scoring::*;
15pub use excerpt::*;
16pub use reference::*;
17pub use syntax_index::*;
18
19#[derive(Debug)]
20pub struct EditPredictionContext {
21 pub excerpt: EditPredictionExcerpt,
22 pub excerpt_text: EditPredictionExcerptText,
23 pub cursor_offset_in_excerpt: usize,
24 pub snippets: Vec<ScoredSnippet>,
25}
26
27impl EditPredictionContext {
28 pub fn gather_context_in_background(
29 cursor_point: Point,
30 buffer: BufferSnapshot,
31 excerpt_options: EditPredictionExcerptOptions,
32 syntax_index: Option<Entity<SyntaxIndex>>,
33 cx: &mut App,
34 ) -> Task<Option<Self>> {
35 if let Some(syntax_index) = syntax_index {
36 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
37 cx.background_spawn(async move {
38 let index_state = index_state.lock().await;
39 Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
40 })
41 } else {
42 cx.background_spawn(async move {
43 Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
44 })
45 }
46 }
47
48 pub fn gather_context(
49 cursor_point: Point,
50 buffer: &BufferSnapshot,
51 excerpt_options: &EditPredictionExcerptOptions,
52 index_state: Option<&SyntaxIndexState>,
53 ) -> Option<Self> {
54 let excerpt = EditPredictionExcerpt::select_from_buffer(
55 cursor_point,
56 buffer,
57 excerpt_options,
58 index_state,
59 )?;
60 let excerpt_text = excerpt.text(buffer);
61 let cursor_offset_in_file = cursor_point.to_offset(buffer);
62 let cursor_offset_in_excerpt = cursor_offset_in_file - excerpt.range.start;
63
64 let snippets = if let Some(index_state) = index_state {
65 let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
66
67 scored_snippets(
68 &index_state,
69 &excerpt,
70 &excerpt_text,
71 references,
72 cursor_offset_in_file,
73 buffer,
74 )
75 } else {
76 vec![]
77 };
78
79 Some(Self {
80 excerpt,
81 excerpt_text,
82 cursor_offset_in_excerpt,
83 snippets,
84 })
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use std::sync::Arc;
92
93 use gpui::{Entity, TestAppContext};
94 use indoc::indoc;
95 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
96 use project::{FakeFs, Project};
97 use serde_json::json;
98 use settings::SettingsStore;
99 use util::path;
100
101 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
102
103 #[gpui::test]
104 async fn test_call_site(cx: &mut TestAppContext) {
105 let (project, index, _rust_lang_id) = init_test(cx).await;
106
107 let buffer = project
108 .update(cx, |project, cx| {
109 let project_path = project.find_project_path("c.rs", cx).unwrap();
110 project.open_buffer(project_path, cx)
111 })
112 .await
113 .unwrap();
114
115 cx.run_until_parked();
116
117 // first process_data call site
118 let cursor_point = language::Point::new(8, 21);
119 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
120
121 let context = cx
122 .update(|cx| {
123 EditPredictionContext::gather_context_in_background(
124 cursor_point,
125 buffer_snapshot,
126 EditPredictionExcerptOptions {
127 max_bytes: 60,
128 min_bytes: 10,
129 target_before_cursor_over_total_bytes: 0.5,
130 },
131 Some(index),
132 cx,
133 )
134 })
135 .await
136 .unwrap();
137
138 let mut snippet_identifiers = context
139 .snippets
140 .iter()
141 .map(|snippet| snippet.identifier.name.as_ref())
142 .collect::<Vec<_>>();
143 snippet_identifiers.sort();
144 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
145 drop(buffer);
146 }
147
148 async fn init_test(
149 cx: &mut TestAppContext,
150 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
151 cx.update(|cx| {
152 let settings_store = SettingsStore::test(cx);
153 cx.set_global(settings_store);
154 language::init(cx);
155 Project::init_settings(cx);
156 });
157
158 let fs = FakeFs::new(cx.executor());
159 fs.insert_tree(
160 path!("/root"),
161 json!({
162 "a.rs": indoc! {r#"
163 fn main() {
164 let x = 1;
165 let y = 2;
166 let z = add(x, y);
167 println!("Result: {}", z);
168 }
169
170 fn add(a: i32, b: i32) -> i32 {
171 a + b
172 }
173 "#},
174 "b.rs": indoc! {"
175 pub struct Config {
176 pub name: String,
177 pub value: i32,
178 }
179
180 impl Config {
181 pub fn new(name: String, value: i32) -> Self {
182 Config { name, value }
183 }
184 }
185 "},
186 "c.rs": indoc! {r#"
187 use std::collections::HashMap;
188
189 fn main() {
190 let args: Vec<String> = std::env::args().collect();
191 let data: Vec<i32> = args[1..]
192 .iter()
193 .filter_map(|s| s.parse().ok())
194 .collect();
195 let result = process_data(data);
196 println!("{:?}", result);
197 }
198
199 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
200 let mut counts = HashMap::new();
201 for value in data {
202 *counts.entry(value).or_insert(0) += 1;
203 }
204 counts
205 }
206
207 #[cfg(test)]
208 mod tests {
209 use super::*;
210
211 #[test]
212 fn test_process_data() {
213 let data = vec![1, 2, 2, 3];
214 let result = process_data(data);
215 assert_eq!(result.get(&2), Some(&2));
216 }
217 }
218 "#}
219 }),
220 )
221 .await;
222 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
223 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
224 let lang = rust_lang();
225 let lang_id = lang.id();
226 language_registry.add(Arc::new(lang));
227
228 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
229 cx.run_until_parked();
230
231 (project, index, lang_id)
232 }
233
234 fn rust_lang() -> Language {
235 Language::new(
236 LanguageConfig {
237 name: "Rust".into(),
238 matcher: LanguageMatcher {
239 path_suffixes: vec!["rs".to_string()],
240 ..Default::default()
241 },
242 ..Default::default()
243 },
244 Some(tree_sitter_rust::LANGUAGE.into()),
245 )
246 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
247 .unwrap()
248 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
249 .unwrap()
250 }
251}