1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9use gpui::{App, AppContext as _, Entity, Task};
10use language::BufferSnapshot;
11use text::{Point, ToOffset as _};
12
13pub use declaration::*;
14pub use declaration_scoring::*;
15pub use excerpt::*;
16pub use reference::*;
17pub use syntax_index::*;
18
19#[derive(Debug)]
20pub struct EditPredictionContext {
21 pub excerpt: EditPredictionExcerpt,
22 pub excerpt_text: EditPredictionExcerptText,
23 pub snippets: Vec<ScoredSnippet>,
24}
25
26impl EditPredictionContext {
27 pub fn gather_context_in_background(
28 cursor_point: Point,
29 buffer: BufferSnapshot,
30 excerpt_options: EditPredictionExcerptOptions,
31 syntax_index: Entity<SyntaxIndex>,
32 cx: &mut App,
33 ) -> Task<Option<Self>> {
34 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
35 cx.background_spawn(async move {
36 let index_state = index_state.lock().await;
37 Self::gather_context(cursor_point, &buffer, &excerpt_options, &index_state)
38 })
39 }
40
41 pub fn gather_context(
42 cursor_point: Point,
43 buffer: &BufferSnapshot,
44 excerpt_options: &EditPredictionExcerptOptions,
45 index_state: &SyntaxIndexState,
46 ) -> Option<Self> {
47 let excerpt = EditPredictionExcerpt::select_from_buffer(
48 cursor_point,
49 buffer,
50 excerpt_options,
51 Some(index_state),
52 )?;
53 let excerpt_text = excerpt.text(buffer);
54 let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
55 let cursor_offset = cursor_point.to_offset(buffer);
56
57 let snippets = scored_snippets(
58 &index_state,
59 &excerpt,
60 &excerpt_text,
61 references,
62 cursor_offset,
63 buffer,
64 );
65
66 Some(Self {
67 excerpt,
68 excerpt_text,
69 snippets,
70 })
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use std::sync::Arc;
78
79 use gpui::{Entity, TestAppContext};
80 use indoc::indoc;
81 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
82 use project::{FakeFs, Project};
83 use serde_json::json;
84 use settings::SettingsStore;
85 use util::path;
86
87 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
88
89 #[gpui::test]
90 async fn test_call_site(cx: &mut TestAppContext) {
91 let (project, index, _rust_lang_id) = init_test(cx).await;
92
93 let buffer = project
94 .update(cx, |project, cx| {
95 let project_path = project.find_project_path("c.rs", cx).unwrap();
96 project.open_buffer(project_path, cx)
97 })
98 .await
99 .unwrap();
100
101 cx.run_until_parked();
102
103 // first process_data call site
104 let cursor_point = language::Point::new(8, 21);
105 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
106
107 let context = cx
108 .update(|cx| {
109 EditPredictionContext::gather_context_in_background(
110 cursor_point,
111 buffer_snapshot,
112 EditPredictionExcerptOptions {
113 max_bytes: 60,
114 min_bytes: 10,
115 target_before_cursor_over_total_bytes: 0.5,
116 },
117 index,
118 cx,
119 )
120 })
121 .await
122 .unwrap();
123
124 let mut snippet_identifiers = context
125 .snippets
126 .iter()
127 .map(|snippet| snippet.identifier.name.as_ref())
128 .collect::<Vec<_>>();
129 snippet_identifiers.sort();
130 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
131 drop(buffer);
132 }
133
134 async fn init_test(
135 cx: &mut TestAppContext,
136 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
137 cx.update(|cx| {
138 let settings_store = SettingsStore::test(cx);
139 cx.set_global(settings_store);
140 language::init(cx);
141 Project::init_settings(cx);
142 });
143
144 let fs = FakeFs::new(cx.executor());
145 fs.insert_tree(
146 path!("/root"),
147 json!({
148 "a.rs": indoc! {r#"
149 fn main() {
150 let x = 1;
151 let y = 2;
152 let z = add(x, y);
153 println!("Result: {}", z);
154 }
155
156 fn add(a: i32, b: i32) -> i32 {
157 a + b
158 }
159 "#},
160 "b.rs": indoc! {"
161 pub struct Config {
162 pub name: String,
163 pub value: i32,
164 }
165
166 impl Config {
167 pub fn new(name: String, value: i32) -> Self {
168 Config { name, value }
169 }
170 }
171 "},
172 "c.rs": indoc! {r#"
173 use std::collections::HashMap;
174
175 fn main() {
176 let args: Vec<String> = std::env::args().collect();
177 let data: Vec<i32> = args[1..]
178 .iter()
179 .filter_map(|s| s.parse().ok())
180 .collect();
181 let result = process_data(data);
182 println!("{:?}", result);
183 }
184
185 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
186 let mut counts = HashMap::new();
187 for value in data {
188 *counts.entry(value).or_insert(0) += 1;
189 }
190 counts
191 }
192
193 #[cfg(test)]
194 mod tests {
195 use super::*;
196
197 #[test]
198 fn test_process_data() {
199 let data = vec![1, 2, 2, 3];
200 let result = process_data(data);
201 assert_eq!(result.get(&2), Some(&2));
202 }
203 }
204 "#}
205 }),
206 )
207 .await;
208 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
209 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
210 let lang = rust_lang();
211 let lang_id = lang.id();
212 language_registry.add(Arc::new(lang));
213
214 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
215 cx.run_until_parked();
216
217 (project, index, lang_id)
218 }
219
220 fn rust_lang() -> Language {
221 Language::new(
222 LanguageConfig {
223 name: "Rust".into(),
224 matcher: LanguageMatcher {
225 path_suffixes: vec!["rs".to_string()],
226 ..Default::default()
227 },
228 ..Default::default()
229 },
230 Some(tree_sitter_rust::LANGUAGE.into()),
231 )
232 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
233 .unwrap()
234 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
235 .unwrap()
236 }
237}