retrieval_prompt.rs

  1use anyhow::Result;
  2use cloud_llm_client::predict_edits_v3::{self, Excerpt};
  3use indoc::indoc;
  4use schemars::JsonSchema;
  5use serde::{Deserialize, Serialize};
  6use std::fmt::Write;
  7
  8use crate::{push_events, write_codeblock};
  9
 10pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
 11    let mut prompt = SEARCH_INSTRUCTIONS.to_string();
 12
 13    if !request.events.is_empty() {
 14        writeln!(&mut prompt, "\n## User Edits\n\n")?;
 15        push_events(&mut prompt, &request.events);
 16    }
 17
 18    writeln!(&mut prompt, "## Cursor context\n")?;
 19    write_codeblock(
 20        &request.excerpt_path,
 21        &[Excerpt {
 22            start_line: request.excerpt_line_range.start,
 23            text: request.excerpt.into(),
 24        }],
 25        &[],
 26        request.cursor_file_max_row,
 27        true,
 28        &mut prompt,
 29    );
 30
 31    writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
 32
 33    Ok(prompt)
 34}
 35
 36/// Search for relevant code
 37///
 38/// For the best results, run multiple queries at once with a single invocation of this tool.
 39#[derive(Clone, Deserialize, Serialize, JsonSchema)]
 40pub struct SearchToolInput {
 41    /// An array of queries to run for gathering context relevant to the next prediction
 42    #[schemars(length(max = 3))]
 43    #[serde(deserialize_with = "deserialize_queries")]
 44    pub queries: Box<[SearchToolQuery]>,
 45}
 46
 47fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
 48where
 49    D: serde::Deserializer<'de>,
 50{
 51    use serde::de::Error;
 52
 53    #[derive(Deserialize)]
 54    #[serde(untagged)]
 55    enum QueryCollection {
 56        Array(Box<[SearchToolQuery]>),
 57        DoubleArray(Box<[Box<[SearchToolQuery]>]>),
 58        Single(SearchToolQuery),
 59    }
 60
 61    #[derive(Deserialize)]
 62    #[serde(untagged)]
 63    enum MaybeDoubleEncoded {
 64        SingleEncoded(QueryCollection),
 65        DoubleEncoded(String),
 66    }
 67
 68    let result = MaybeDoubleEncoded::deserialize(deserializer)?;
 69
 70    let normalized = match result {
 71        MaybeDoubleEncoded::SingleEncoded(value) => value,
 72        MaybeDoubleEncoded::DoubleEncoded(value) => {
 73            serde_json::from_str(&value).map_err(D::Error::custom)?
 74        }
 75    };
 76
 77    Ok(match normalized {
 78        QueryCollection::Array(items) => items,
 79        QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
 80        QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
 81    })
 82}
 83
 84/// Search for relevant code by path, syntax hierarchy, and content.
 85#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
 86pub struct SearchToolQuery {
 87    /// 1. A glob pattern to match file paths in the codebase to search in.
 88    pub glob: String,
 89    /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
 90    ///
 91    /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
 92    ///
 93    /// Example: Searching for a `User` class
 94    ///     ["class\s+User"]
 95    ///
 96    /// Example: Searching for a `get_full_name` method under a `User` class
 97    ///     ["class\s+User", "def\sget_full_name"]
 98    ///
 99    /// Skip this field to match on content alone.
100    #[schemars(length(max = 3))]
101    #[serde(default)]
102    pub syntax_node: Vec<String>,
103    /// 3. An optional regular expression to match the final content that should appear in the results.
104    ///
105    /// - Content will be matched within all lines of the matched syntax nodes.
106    /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
107    /// - If no syntax node regexes are provided, the content will be matched within the entire file.
108    pub content: Option<String>,
109}
110
111pub const TOOL_NAME: &str = "search";
112
113const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
114    You are part of an edit prediction system in a code editor.
115    Your role is to search for code that will serve as context for predicting the next edit.
116
117    - Analyze the user's recent edits and current cursor context
118    - Use the `search` tool to find code that is relevant for predicting the next edit
119    - Focus on finding:
120       - Code patterns that might need similar changes based on the recent edits
121       - Functions, variables, types, and constants referenced in the current cursor context
122       - Related implementations, usages, or dependencies that may require consistent updates
123       - How items defined in the cursor excerpt are used or altered
124    - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible
125    - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function
126    - Avoid using wildcard globs if you already know the file path of the content you're looking for
127"#};
128
129const TOOL_USE_REMINDER: &str = indoc! {"
130    --
131    Analyze the user's intent in one to two sentences, then call the `search` tool.
132"};
133
134#[cfg(test)]
135mod tests {
136    use serde_json::json;
137
138    use super::*;
139
140    #[test]
141    fn test_deserialize_queries() {
142        let single_query_json = indoc! {r#"{
143            "queries": {
144                "glob": "**/*.rs",
145                "syntax_node": ["fn test"],
146                "content": "assert"
147            }
148        }"#};
149
150        let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
151        assert_eq!(flat_input.queries.len(), 1);
152        assert_eq!(flat_input.queries[0].glob, "**/*.rs");
153        assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
154        assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
155
156        let flat_json = indoc! {r#"{
157            "queries": [
158                {
159                    "glob": "**/*.rs",
160                    "syntax_node": ["fn test"],
161                    "content": "assert"
162                },
163                {
164                    "glob": "**/*.ts",
165                    "syntax_node": [],
166                    "content": null
167                }
168            ]
169        }"#};
170
171        let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
172        assert_eq!(flat_input.queries.len(), 2);
173        assert_eq!(flat_input.queries[0].glob, "**/*.rs");
174        assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
175        assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
176        assert_eq!(flat_input.queries[1].glob, "**/*.ts");
177        assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
178        assert_eq!(flat_input.queries[1].content, None);
179
180        let nested_json = indoc! {r#"{
181            "queries": [
182                [
183                    {
184                        "glob": "**/*.rs",
185                        "syntax_node": ["fn test"],
186                        "content": "assert"
187                    }
188                ],
189                [
190                    {
191                        "glob": "**/*.ts",
192                        "syntax_node": [],
193                        "content": null
194                    }
195                ]
196            ]
197        }"#};
198
199        let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
200
201        assert_eq!(nested_input.queries.len(), 2);
202
203        assert_eq!(nested_input.queries[0].glob, "**/*.rs");
204        assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
205        assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
206        assert_eq!(nested_input.queries[1].glob, "**/*.ts");
207        assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
208        assert_eq!(nested_input.queries[1].content, None);
209
210        let double_encoded_queries = serde_json::to_string(&json!({
211            "queries": serde_json::to_string(&json!([
212                {
213                    "glob": "**/*.rs",
214                    "syntax_node": ["fn test"],
215                    "content": "assert"
216                },
217                {
218                    "glob": "**/*.ts",
219                    "syntax_node": [],
220                    "content": null
221                }
222            ])).unwrap()
223        }))
224        .unwrap();
225
226        let double_encoded_input: SearchToolInput =
227            serde_json::from_str(&double_encoded_queries).unwrap();
228
229        assert_eq!(double_encoded_input.queries.len(), 2);
230
231        assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
232        assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
233        assert_eq!(
234            double_encoded_input.queries[0].content,
235            Some("assert".to_string())
236        );
237        assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
238        assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
239        assert_eq!(double_encoded_input.queries[1].content, None);
240
241        // ### ERROR Switching from var declarations to lexical declarations [RUN 073]
242        // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
243    }
244}