1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
10pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
11use gpui::{App, AppContext as _, Entity, Task};
12use language::BufferSnapshot;
13pub use reference::references_in_excerpt;
14pub use syntax_index::SyntaxIndex;
15use text::{Point, ToOffset as _};
16
17use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
18
19pub struct EditPredictionContext {
20 pub excerpt: EditPredictionExcerpt,
21 pub excerpt_text: EditPredictionExcerptText,
22 pub snippets: Vec<ScoredSnippet>,
23}
24
25impl EditPredictionContext {
26 pub fn gather(
27 cursor_point: Point,
28 buffer: BufferSnapshot,
29 excerpt_options: EditPredictionExcerptOptions,
30 syntax_index: Entity<SyntaxIndex>,
31 cx: &mut App,
32 ) -> Task<Self> {
33 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
34 cx.background_spawn(async move {
35 let index_state = index_state.lock().await;
36
37 let excerpt =
38 EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)
39 .unwrap();
40 let excerpt_text = excerpt.text(&buffer);
41 let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
42 let cursor_offset = cursor_point.to_offset(&buffer);
43
44 let snippets = scored_snippets(
45 &index_state,
46 &excerpt,
47 &excerpt_text,
48 references,
49 cursor_offset,
50 &buffer,
51 );
52
53 Self {
54 excerpt,
55 excerpt_text,
56 snippets,
57 }
58 })
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65 use std::sync::Arc;
66
67 use gpui::{Entity, TestAppContext};
68 use indoc::indoc;
69 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
70 use project::{FakeFs, Project};
71 use serde_json::json;
72 use settings::SettingsStore;
73 use util::path;
74
75 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
76
77 #[gpui::test]
78 async fn test_call_site(cx: &mut TestAppContext) {
79 let (project, index, _rust_lang_id) = init_test(cx).await;
80
81 let buffer = project
82 .update(cx, |project, cx| {
83 let project_path = project.find_project_path("c.rs", cx).unwrap();
84 project.open_buffer(project_path, cx)
85 })
86 .await
87 .unwrap();
88
89 cx.run_until_parked();
90
91 // first process_data call site
92 let cursor_point = language::Point::new(8, 21);
93 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
94
95 let context = cx
96 .update(|cx| {
97 EditPredictionContext::gather(
98 cursor_point,
99 buffer_snapshot,
100 EditPredictionExcerptOptions {
101 max_bytes: 40,
102 min_bytes: 10,
103 target_before_cursor_over_total_bytes: 0.5,
104 include_parent_signatures: false,
105 },
106 index,
107 cx,
108 )
109 })
110 .await;
111
112 assert_eq!(context.snippets.len(), 1);
113 assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
114 drop(buffer);
115 }
116
117 async fn init_test(
118 cx: &mut TestAppContext,
119 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
120 cx.update(|cx| {
121 let settings_store = SettingsStore::test(cx);
122 cx.set_global(settings_store);
123 language::init(cx);
124 Project::init_settings(cx);
125 });
126
127 let fs = FakeFs::new(cx.executor());
128 fs.insert_tree(
129 path!("/root"),
130 json!({
131 "a.rs": indoc! {r#"
132 fn main() {
133 let x = 1;
134 let y = 2;
135 let z = add(x, y);
136 println!("Result: {}", z);
137 }
138
139 fn add(a: i32, b: i32) -> i32 {
140 a + b
141 }
142 "#},
143 "b.rs": indoc! {"
144 pub struct Config {
145 pub name: String,
146 pub value: i32,
147 }
148
149 impl Config {
150 pub fn new(name: String, value: i32) -> Self {
151 Config { name, value }
152 }
153 }
154 "},
155 "c.rs": indoc! {r#"
156 use std::collections::HashMap;
157
158 fn main() {
159 let args: Vec<String> = std::env::args().collect();
160 let data: Vec<i32> = args[1..]
161 .iter()
162 .filter_map(|s| s.parse().ok())
163 .collect();
164 let result = process_data(data);
165 println!("{:?}", result);
166 }
167
168 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
169 let mut counts = HashMap::new();
170 for value in data {
171 *counts.entry(value).or_insert(0) += 1;
172 }
173 counts
174 }
175
176 #[cfg(test)]
177 mod tests {
178 use super::*;
179
180 #[test]
181 fn test_process_data() {
182 let data = vec![1, 2, 2, 3];
183 let result = process_data(data);
184 assert_eq!(result.get(&2), Some(&2));
185 }
186 }
187 "#}
188 }),
189 )
190 .await;
191 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
192 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
193 let lang = rust_lang();
194 let lang_id = lang.id();
195 language_registry.add(Arc::new(lang));
196
197 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
198 cx.run_until_parked();
199
200 (project, index, lang_id)
201 }
202
203 fn rust_lang() -> Language {
204 Language::new(
205 LanguageConfig {
206 name: "Rust".into(),
207 matcher: LanguageMatcher {
208 path_suffixes: vec!["rs".to_string()],
209 ..Default::default()
210 },
211 ..Default::default()
212 },
213 Some(tree_sitter_rust::LANGUAGE.into()),
214 )
215 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
216 .unwrap()
217 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
218 .unwrap()
219 }
220}