1mod declaration;
2mod declaration_scoring;
3mod excerpt;
4mod outline;
5mod reference;
6mod syntax_index;
7mod text_similarity;
8
9use std::time::Instant;
10
11pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
12pub use declaration_scoring::SnippetStyle;
13pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
14
15use gpui::{App, AppContext as _, Entity, Task};
16use language::BufferSnapshot;
17pub use reference::references_in_excerpt;
18pub use syntax_index::SyntaxIndex;
19use text::{Point, ToOffset as _};
20
21use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
22
23#[derive(Debug)]
24pub struct EditPredictionContext {
25 pub excerpt: EditPredictionExcerpt,
26 pub excerpt_text: EditPredictionExcerptText,
27 pub snippets: Vec<ScoredSnippet>,
28 pub retrieval_duration: std::time::Duration,
29}
30
31impl EditPredictionContext {
32 pub fn gather(
33 cursor_point: Point,
34 buffer: BufferSnapshot,
35 excerpt_options: EditPredictionExcerptOptions,
36 syntax_index: Entity<SyntaxIndex>,
37 cx: &mut App,
38 ) -> Task<Option<Self>> {
39 let start = Instant::now();
40 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
41 cx.background_spawn(async move {
42 let index_state = index_state.lock().await;
43
44 let excerpt =
45 EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?;
46 let excerpt_text = excerpt.text(&buffer);
47 let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
48 let cursor_offset = cursor_point.to_offset(&buffer);
49
50 let snippets = scored_snippets(
51 &index_state,
52 &excerpt,
53 &excerpt_text,
54 references,
55 cursor_offset,
56 &buffer,
57 );
58
59 Some(Self {
60 excerpt,
61 excerpt_text,
62 snippets,
63 retrieval_duration: start.elapsed(),
64 })
65 })
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use std::sync::Arc;
73
74 use gpui::{Entity, TestAppContext};
75 use indoc::indoc;
76 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
77 use project::{FakeFs, Project};
78 use serde_json::json;
79 use settings::SettingsStore;
80 use util::path;
81
82 use crate::{EditPredictionExcerptOptions, SyntaxIndex};
83
84 #[gpui::test]
85 async fn test_call_site(cx: &mut TestAppContext) {
86 let (project, index, _rust_lang_id) = init_test(cx).await;
87
88 let buffer = project
89 .update(cx, |project, cx| {
90 let project_path = project.find_project_path("c.rs", cx).unwrap();
91 project.open_buffer(project_path, cx)
92 })
93 .await
94 .unwrap();
95
96 cx.run_until_parked();
97
98 // first process_data call site
99 let cursor_point = language::Point::new(8, 21);
100 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
101
102 let context = cx
103 .update(|cx| {
104 EditPredictionContext::gather(
105 cursor_point,
106 buffer_snapshot,
107 EditPredictionExcerptOptions {
108 max_bytes: 40,
109 min_bytes: 10,
110 target_before_cursor_over_total_bytes: 0.5,
111 include_parent_signatures: false,
112 },
113 index,
114 cx,
115 )
116 })
117 .await
118 .unwrap();
119
120 assert_eq!(context.snippets.len(), 1);
121 assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
122 drop(buffer);
123 }
124
125 async fn init_test(
126 cx: &mut TestAppContext,
127 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
128 cx.update(|cx| {
129 let settings_store = SettingsStore::test(cx);
130 cx.set_global(settings_store);
131 language::init(cx);
132 Project::init_settings(cx);
133 });
134
135 let fs = FakeFs::new(cx.executor());
136 fs.insert_tree(
137 path!("/root"),
138 json!({
139 "a.rs": indoc! {r#"
140 fn main() {
141 let x = 1;
142 let y = 2;
143 let z = add(x, y);
144 println!("Result: {}", z);
145 }
146
147 fn add(a: i32, b: i32) -> i32 {
148 a + b
149 }
150 "#},
151 "b.rs": indoc! {"
152 pub struct Config {
153 pub name: String,
154 pub value: i32,
155 }
156
157 impl Config {
158 pub fn new(name: String, value: i32) -> Self {
159 Config { name, value }
160 }
161 }
162 "},
163 "c.rs": indoc! {r#"
164 use std::collections::HashMap;
165
166 fn main() {
167 let args: Vec<String> = std::env::args().collect();
168 let data: Vec<i32> = args[1..]
169 .iter()
170 .filter_map(|s| s.parse().ok())
171 .collect();
172 let result = process_data(data);
173 println!("{:?}", result);
174 }
175
176 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
177 let mut counts = HashMap::new();
178 for value in data {
179 *counts.entry(value).or_insert(0) += 1;
180 }
181 counts
182 }
183
184 #[cfg(test)]
185 mod tests {
186 use super::*;
187
188 #[test]
189 fn test_process_data() {
190 let data = vec![1, 2, 2, 3];
191 let result = process_data(data);
192 assert_eq!(result.get(&2), Some(&2));
193 }
194 }
195 "#}
196 }),
197 )
198 .await;
199 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
200 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
201 let lang = rust_lang();
202 let lang_id = lang.id();
203 language_registry.add(Arc::new(lang));
204
205 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
206 cx.run_until_parked();
207
208 (project, index, lang_id)
209 }
210
211 fn rust_lang() -> Language {
212 Language::new(
213 LanguageConfig {
214 name: "Rust".into(),
215 matcher: LanguageMatcher {
216 path_suffixes: vec!["rs".to_string()],
217 ..Default::default()
218 },
219 ..Default::default()
220 },
221 Some(tree_sitter_rust::LANGUAGE.into()),
222 )
223 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
224 .unwrap()
225 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
226 .unwrap()
227 }
228}