1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9use std::time::Instant;
10
11pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
12pub use declaration_scoring::SnippetStyle;
13pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
14
15use gpui::{App, AppContext as _, Entity, Task};
16use language::BufferSnapshot;
17pub use reference::references_in_excerpt;
18pub use syntax_index::SyntaxIndex;
19use text::{Point, ToOffset as _};
20
21use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
22
23#[derive(Debug)]
24pub struct EditPredictionContext {
25 pub excerpt: EditPredictionExcerpt,
26 pub excerpt_text: EditPredictionExcerptText,
27 pub snippets: Vec<ScoredSnippet>,
28 pub retrieval_duration: std::time::Duration,
29}
30
31impl EditPredictionContext {
32 pub fn gather(
33 cursor_point: Point,
34 buffer: BufferSnapshot,
35 excerpt_options: EditPredictionExcerptOptions,
36 syntax_index: Entity<SyntaxIndex>,
37 cx: &mut App,
38 ) -> Task<Option<Self>> {
39 let start = Instant::now();
40 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
41 cx.background_spawn(async move {
42 let index_state = index_state.lock().await;
43
44 let excerpt =
45 EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?;
46 let excerpt_text = excerpt.text(&buffer);
47 let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
48 let cursor_offset = cursor_point.to_offset(&buffer);
49
50 let snippets = scored_snippets(
51 &index_state,
52 &excerpt,
53 &excerpt_text,
54 references,
55 cursor_offset,
56 &buffer,
57 );
58
59 Some(Self {
60 excerpt,
61 excerpt_text,
62 snippets,
63 retrieval_duration: start.elapsed(),
64 })
65 })
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use std::sync::Arc;
73
74 use gpui::{Entity, TestAppContext};
75 use indoc::indoc;
76 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
77 use project::{FakeFs, Project};
78 use serde_json::json;
79 use settings::SettingsStore;
80 use util::path;
81
82 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
83
84 #[gpui::test]
85 async fn test_call_site(cx: &mut TestAppContext) {
86 let (project, index, _rust_lang_id) = init_test(cx).await;
87
88 let buffer = project
89 .update(cx, |project, cx| {
90 let project_path = project.find_project_path("c.rs", cx).unwrap();
91 project.open_buffer(project_path, cx)
92 })
93 .await
94 .unwrap();
95
96 cx.run_until_parked();
97
98 // first process_data call site
99 let cursor_point = language::Point::new(8, 21);
100 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
101
102 let context = cx
103 .update(|cx| {
104 EditPredictionContext::gather(
105 cursor_point,
106 buffer_snapshot,
107 EditPredictionExcerptOptions {
108 max_bytes: 40,
109 min_bytes: 10,
110 target_before_cursor_over_total_bytes: 0.5,
111 include_parent_signatures: false,
112 },
113 index,
114 cx,
115 )
116 })
117 .await
118 .unwrap();
119
120 assert_eq!(context.snippets.len(), 1);
121 assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
122 drop(buffer);
123 }
124
125 async fn init_test(
126 cx: &mut TestAppContext,
127 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
128 cx.update(|cx| {
129 let settings_store = SettingsStore::test(cx);
130 cx.set_global(settings_store);
131 SettingsStore::load_registered_settings(cx);
132
133 language::init(cx);
134 Project::init_settings(cx);
135 });
136
137 let fs = FakeFs::new(cx.executor());
138 fs.insert_tree(
139 path!("/root"),
140 json!({
141 "a.rs": indoc! {r#"
142 fn main() {
143 let x = 1;
144 let y = 2;
145 let z = add(x, y);
146 println!("Result: {}", z);
147 }
148
149 fn add(a: i32, b: i32) -> i32 {
150 a + b
151 }
152 "#},
153 "b.rs": indoc! {"
154 pub struct Config {
155 pub name: String,
156 pub value: i32,
157 }
158
159 impl Config {
160 pub fn new(name: String, value: i32) -> Self {
161 Config { name, value }
162 }
163 }
164 "},
165 "c.rs": indoc! {r#"
166 use std::collections::HashMap;
167
168 fn main() {
169 let args: Vec<String> = std::env::args().collect();
170 let data: Vec<i32> = args[1..]
171 .iter()
172 .filter_map(|s| s.parse().ok())
173 .collect();
174 let result = process_data(data);
175 println!("{:?}", result);
176 }
177
178 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
179 let mut counts = HashMap::new();
180 for value in data {
181 *counts.entry(value).or_insert(0) += 1;
182 }
183 counts
184 }
185
186 #[cfg(test)]
187 mod tests {
188 use super::*;
189
190 #[test]
191 fn test_process_data() {
192 let data = vec![1, 2, 2, 3];
193 let result = process_data(data);
194 assert_eq!(result.get(&2), Some(&2));
195 }
196 }
197 "#}
198 }),
199 )
200 .await;
201 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
202 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
203 let lang = rust_lang();
204 let lang_id = lang.id();
205 language_registry.add(Arc::new(lang));
206
207 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
208 cx.run_until_parked();
209
210 (project, index, lang_id)
211 }
212
213 fn rust_lang() -> Language {
214 Language::new(
215 LanguageConfig {
216 name: "Rust".into(),
217 matcher: LanguageMatcher {
218 path_suffixes: vec!["rs".to_string()],
219 ..Default::default()
220 },
221 ..Default::default()
222 },
223 Some(tree_sitter_rust::LANGUAGE.into()),
224 )
225 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
226 .unwrap()
227 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
228 .unwrap()
229 }
230}