1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9use gpui::{App, AppContext as _, Entity, Task};
10use language::BufferSnapshot;
11use text::{Point, ToOffset as _};
12
13pub use declaration::*;
14pub use declaration_scoring::*;
15pub use excerpt::*;
16pub use reference::*;
17pub use syntax_index::*;
18
19#[derive(Clone, Debug)]
20pub struct EditPredictionContext {
21 pub excerpt: EditPredictionExcerpt,
22 pub excerpt_text: EditPredictionExcerptText,
23 pub cursor_offset_in_excerpt: usize,
24 pub declarations: Vec<ScoredDeclaration>,
25}
26
27impl EditPredictionContext {
28 pub fn gather_context_in_background(
29 cursor_point: Point,
30 buffer: BufferSnapshot,
31 excerpt_options: EditPredictionExcerptOptions,
32 syntax_index: Option<Entity<SyntaxIndex>>,
33 cx: &mut App,
34 ) -> Task<Option<Self>> {
35 if let Some(syntax_index) = syntax_index {
36 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
37 cx.background_spawn(async move {
38 let index_state = index_state.lock().await;
39 Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
40 })
41 } else {
42 cx.background_spawn(async move {
43 Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
44 })
45 }
46 }
47
48 pub fn gather_context(
49 cursor_point: Point,
50 buffer: &BufferSnapshot,
51 excerpt_options: &EditPredictionExcerptOptions,
52 index_state: Option<&SyntaxIndexState>,
53 ) -> Option<Self> {
54 let excerpt = EditPredictionExcerpt::select_from_buffer(
55 cursor_point,
56 buffer,
57 excerpt_options,
58 index_state,
59 )?;
60 let excerpt_text = excerpt.text(buffer);
61 let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
62
63 let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
64 let adjacent_end = Point::new(cursor_point.row + 1, 0);
65 let adjacent_occurrences = text_similarity::Occurrences::within_string(
66 &buffer
67 .text_for_range(adjacent_start..adjacent_end)
68 .collect::<String>(),
69 );
70
71 let cursor_offset_in_file = cursor_point.to_offset(buffer);
72 // TODO fix this to not need saturating_sub
73 let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
74
75 let declarations = if let Some(index_state) = index_state {
76 let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
77
78 scored_declarations(
79 &index_state,
80 &excerpt,
81 &excerpt_occurrences,
82 &adjacent_occurrences,
83 references,
84 cursor_offset_in_file,
85 buffer,
86 )
87 } else {
88 vec![]
89 };
90
91 Some(Self {
92 excerpt,
93 excerpt_text,
94 cursor_offset_in_excerpt,
95 declarations,
96 })
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use std::sync::Arc;
104
105 use gpui::{Entity, TestAppContext};
106 use indoc::indoc;
107 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
108 use project::{FakeFs, Project};
109 use serde_json::json;
110 use settings::SettingsStore;
111 use util::path;
112
113 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
114
115 #[gpui::test]
116 async fn test_call_site(cx: &mut TestAppContext) {
117 let (project, index, _rust_lang_id) = init_test(cx).await;
118
119 let buffer = project
120 .update(cx, |project, cx| {
121 let project_path = project.find_project_path("c.rs", cx).unwrap();
122 project.open_buffer(project_path, cx)
123 })
124 .await
125 .unwrap();
126
127 cx.run_until_parked();
128
129 // first process_data call site
130 let cursor_point = language::Point::new(8, 21);
131 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
132
133 let context = cx
134 .update(|cx| {
135 EditPredictionContext::gather_context_in_background(
136 cursor_point,
137 buffer_snapshot,
138 EditPredictionExcerptOptions {
139 max_bytes: 60,
140 min_bytes: 10,
141 target_before_cursor_over_total_bytes: 0.5,
142 },
143 Some(index),
144 cx,
145 )
146 })
147 .await
148 .unwrap();
149
150 let mut snippet_identifiers = context
151 .declarations
152 .iter()
153 .map(|snippet| snippet.identifier.name.as_ref())
154 .collect::<Vec<_>>();
155 snippet_identifiers.sort();
156 assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
157 drop(buffer);
158 }
159
160 async fn init_test(
161 cx: &mut TestAppContext,
162 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
163 cx.update(|cx| {
164 let settings_store = SettingsStore::test(cx);
165 cx.set_global(settings_store);
166 language::init(cx);
167 Project::init_settings(cx);
168 });
169
170 let fs = FakeFs::new(cx.executor());
171 fs.insert_tree(
172 path!("/root"),
173 json!({
174 "a.rs": indoc! {r#"
175 fn main() {
176 let x = 1;
177 let y = 2;
178 let z = add(x, y);
179 println!("Result: {}", z);
180 }
181
182 fn add(a: i32, b: i32) -> i32 {
183 a + b
184 }
185 "#},
186 "b.rs": indoc! {"
187 pub struct Config {
188 pub name: String,
189 pub value: i32,
190 }
191
192 impl Config {
193 pub fn new(name: String, value: i32) -> Self {
194 Config { name, value }
195 }
196 }
197 "},
198 "c.rs": indoc! {r#"
199 use std::collections::HashMap;
200
201 fn main() {
202 let args: Vec<String> = std::env::args().collect();
203 let data: Vec<i32> = args[1..]
204 .iter()
205 .filter_map(|s| s.parse().ok())
206 .collect();
207 let result = process_data(data);
208 println!("{:?}", result);
209 }
210
211 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
212 let mut counts = HashMap::new();
213 for value in data {
214 *counts.entry(value).or_insert(0) += 1;
215 }
216 counts
217 }
218
219 #[cfg(test)]
220 mod tests {
221 use super::*;
222
223 #[test]
224 fn test_process_data() {
225 let data = vec![1, 2, 2, 3];
226 let result = process_data(data);
227 assert_eq!(result.get(&2), Some(&2));
228 }
229 }
230 "#}
231 }),
232 )
233 .await;
234 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
235 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
236 let lang = rust_lang();
237 let lang_id = lang.id();
238 language_registry.add(Arc::new(lang));
239
240 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
241 cx.run_until_parked();
242
243 (project, index, lang_id)
244 }
245
246 fn rust_lang() -> Language {
247 Language::new(
248 LanguageConfig {
249 name: "Rust".into(),
250 matcher: LanguageMatcher {
251 path_suffixes: vec!["rs".to_string()],
252 ..Default::default()
253 },
254 ..Default::default()
255 },
256 Some(tree_sitter_rust::LANGUAGE.into()),
257 )
258 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
259 .unwrap()
260 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
261 .unwrap()
262 }
263}