1use anyhow::Result;
2use cloud_llm_client::predict_edits_v3::{self, Excerpt};
3use indoc::indoc;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use std::fmt::Write;
7
8use crate::{push_events, write_codeblock};
9
10pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
11 let mut prompt = SEARCH_INSTRUCTIONS.to_string();
12
13 if !request.events.is_empty() {
14 writeln!(&mut prompt, "\n## User Edits\n\n")?;
15 push_events(&mut prompt, &request.events);
16 }
17
18 writeln!(&mut prompt, "## Cursor context\n")?;
19 write_codeblock(
20 &request.excerpt_path,
21 &[Excerpt {
22 start_line: request.excerpt_line_range.start,
23 text: request.excerpt.into(),
24 }],
25 &[],
26 request.cursor_file_max_row,
27 true,
28 &mut prompt,
29 );
30
31 writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
32
33 Ok(prompt)
34}
35
36/// Search for relevant code
37///
38/// For the best results, run multiple queries at once with a single invocation of this tool.
39#[derive(Clone, Deserialize, Serialize, JsonSchema)]
40pub struct SearchToolInput {
41 /// An array of queries to run for gathering context relevant to the next prediction
42 #[schemars(length(max = 3))]
43 #[serde(deserialize_with = "deserialize_queries")]
44 pub queries: Box<[SearchToolQuery]>,
45}
46
47fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
48where
49 D: serde::Deserializer<'de>,
50{
51 use serde::de::Error;
52
53 #[derive(Deserialize)]
54 #[serde(untagged)]
55 enum QueryCollection {
56 Array(Box<[SearchToolQuery]>),
57 DoubleArray(Box<[Box<[SearchToolQuery]>]>),
58 Single(SearchToolQuery),
59 }
60
61 #[derive(Deserialize)]
62 #[serde(untagged)]
63 enum MaybeDoubleEncoded {
64 SingleEncoded(QueryCollection),
65 DoubleEncoded(String),
66 }
67
68 let result = MaybeDoubleEncoded::deserialize(deserializer)?;
69
70 let normalized = match result {
71 MaybeDoubleEncoded::SingleEncoded(value) => value,
72 MaybeDoubleEncoded::DoubleEncoded(value) => {
73 serde_json::from_str(&value).map_err(D::Error::custom)?
74 }
75 };
76
77 Ok(match normalized {
78 QueryCollection::Array(items) => items,
79 QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
80 QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
81 })
82}
83
84/// Search for relevant code by path, syntax hierarchy, and content.
85#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
86pub struct SearchToolQuery {
87 /// 1. A glob pattern to match file paths in the codebase to search in.
88 pub glob: String,
89 /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
90 ///
91 /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
92 ///
93 /// Example: Searching for a `User` class
94 /// ["class\s+User"]
95 ///
96 /// Example: Searching for a `get_full_name` method under a `User` class
97 /// ["class\s+User", "def\sget_full_name"]
98 ///
99 /// Skip this field to match on content alone.
100 #[schemars(length(max = 3))]
101 #[serde(default)]
102 pub syntax_node: Vec<String>,
103 /// 3. An optional regular expression to match the final content that should appear in the results.
104 ///
105 /// - Content will be matched within all lines of the matched syntax nodes.
106 /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
107 /// - If no syntax node regexes are provided, the content will be matched within the entire file.
108 pub content: Option<String>,
109}
110
111pub const TOOL_NAME: &str = "search";
112
113const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
114 You are part of an edit prediction system in a code editor.
115 Your role is to search for code that will serve as context for predicting the next edit.
116
117 - Analyze the user's recent edits and current cursor context
118 - Use the `search` tool to find code that is relevant for predicting the next edit
119 - Focus on finding:
120 - Code patterns that might need similar changes based on the recent edits
121 - Functions, variables, types, and constants referenced in the current cursor context
122 - Related implementations, usages, or dependencies that may require consistent updates
123 - How items defined in the cursor excerpt are used or altered
124 - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible
125 - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function
126 - Avoid using wildcard globs if you already know the file path of the content you're looking for
127"#};
128
129const TOOL_USE_REMINDER: &str = indoc! {"
130 --
131 Analyze the user's intent in one to two sentences, then call the `search` tool.
132"};
133
134#[cfg(test)]
135mod tests {
136 use serde_json::json;
137
138 use super::*;
139
140 #[test]
141 fn test_deserialize_queries() {
142 let single_query_json = indoc! {r#"{
143 "queries": {
144 "glob": "**/*.rs",
145 "syntax_node": ["fn test"],
146 "content": "assert"
147 }
148 }"#};
149
150 let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
151 assert_eq!(flat_input.queries.len(), 1);
152 assert_eq!(flat_input.queries[0].glob, "**/*.rs");
153 assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
154 assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
155
156 let flat_json = indoc! {r#"{
157 "queries": [
158 {
159 "glob": "**/*.rs",
160 "syntax_node": ["fn test"],
161 "content": "assert"
162 },
163 {
164 "glob": "**/*.ts",
165 "syntax_node": [],
166 "content": null
167 }
168 ]
169 }"#};
170
171 let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
172 assert_eq!(flat_input.queries.len(), 2);
173 assert_eq!(flat_input.queries[0].glob, "**/*.rs");
174 assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
175 assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
176 assert_eq!(flat_input.queries[1].glob, "**/*.ts");
177 assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
178 assert_eq!(flat_input.queries[1].content, None);
179
180 let nested_json = indoc! {r#"{
181 "queries": [
182 [
183 {
184 "glob": "**/*.rs",
185 "syntax_node": ["fn test"],
186 "content": "assert"
187 }
188 ],
189 [
190 {
191 "glob": "**/*.ts",
192 "syntax_node": [],
193 "content": null
194 }
195 ]
196 ]
197 }"#};
198
199 let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
200
201 assert_eq!(nested_input.queries.len(), 2);
202
203 assert_eq!(nested_input.queries[0].glob, "**/*.rs");
204 assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
205 assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
206 assert_eq!(nested_input.queries[1].glob, "**/*.ts");
207 assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
208 assert_eq!(nested_input.queries[1].content, None);
209
210 let double_encoded_queries = serde_json::to_string(&json!({
211 "queries": serde_json::to_string(&json!([
212 {
213 "glob": "**/*.rs",
214 "syntax_node": ["fn test"],
215 "content": "assert"
216 },
217 {
218 "glob": "**/*.ts",
219 "syntax_node": [],
220 "content": null
221 }
222 ])).unwrap()
223 }))
224 .unwrap();
225
226 let double_encoded_input: SearchToolInput =
227 serde_json::from_str(&double_encoded_queries).unwrap();
228
229 assert_eq!(double_encoded_input.queries.len(), 2);
230
231 assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
232 assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
233 assert_eq!(
234 double_encoded_input.queries[0].content,
235 Some("assert".to_string())
236 );
237 assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
238 assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
239 assert_eq!(double_encoded_input.queries[1].content, None);
240
241 // ### ERROR Switching from var declarations to lexical declarations [RUN 073]
242 // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
243 }
244}