zeta2: Improve queries parsing (#43012)

Ben Kunkle , Agus , and Max created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Agus <agus@zed.dev>
Co-authored-by: Max <max@zed.dev>

Change summary

Cargo.lock                                        |   1 
crates/cloud_zeta2_prompt/Cargo.toml              |   1 
crates/cloud_zeta2_prompt/src/retrieval_prompt.rs | 150 +++++++++++++++++
3 files changed, 152 insertions(+)

Detailed changes

Cargo.lock 🔗

@@ -3211,6 +3211,7 @@ dependencies = [
  "rustc-hash 2.1.1",
  "schemars 1.0.4",
  "serde",
+ "serde_json",
  "strum 0.27.2",
 ]
 

crates/cloud_zeta2_prompt/Cargo.toml 🔗

@@ -19,4 +19,5 @@ ordered-float.workspace = true
 rustc-hash.workspace = true
 schemars.workspace = true
 serde.workspace = true
+serde_json.workspace = true
 strum.workspace = true

crates/cloud_zeta2_prompt/src/retrieval_prompt.rs 🔗

@@ -40,9 +40,47 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
 pub struct SearchToolInput {
     /// An array of queries to run for gathering context relevant to the next prediction
     #[schemars(length(max = 3))]
+    #[serde(deserialize_with = "deserialize_queries")]
     pub queries: Box<[SearchToolQuery]>,
 }
 
+fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
+where
+    D: serde::Deserializer<'de>,
+{
+    use serde::de::Error;
+
+    #[derive(Deserialize)]
+    #[serde(untagged)]
+    enum QueryCollection {
+        Array(Box<[SearchToolQuery]>),
+        DoubleArray(Box<[Box<[SearchToolQuery]>]>),
+        Single(SearchToolQuery),
+    }
+
+    #[derive(Deserialize)]
+    #[serde(untagged)]
+    enum MaybeDoubleEncoded {
+        SingleEncoded(QueryCollection),
+        DoubleEncoded(String),
+    }
+
+    let result = MaybeDoubleEncoded::deserialize(deserializer)?;
+
+    let normalized = match result {
+        MaybeDoubleEncoded::SingleEncoded(value) => value,
+        MaybeDoubleEncoded::DoubleEncoded(value) => {
+            serde_json::from_str(&value).map_err(D::Error::custom)?
+        }
+    };
+
+    Ok(match normalized {
+        QueryCollection::Array(items) => items,
+        QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
+        QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
+    })
+}
+
 /// Search for relevant code by path, syntax hierarchy, and content.
 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
 pub struct SearchToolQuery {
@@ -92,3 +130,115 @@ const TOOL_USE_REMINDER: &str = indoc! {"
     --
     Analyze the user's intent in one to two sentences, then call the `search` tool.
 "};
+
+#[cfg(test)]
+mod tests {
+    use serde_json::json;
+
+    use super::*;
+
+    #[test]
+    fn test_deserialize_queries() {
+        let single_query_json = indoc! {r#"{
+            "queries": {
+                "glob": "**/*.rs",
+                "syntax_node": ["fn test"],
+                "content": "assert"
+            }
+        }"#};
+
+        let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
+        assert_eq!(flat_input.queries.len(), 1);
+        assert_eq!(flat_input.queries[0].glob, "**/*.rs");
+        assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
+        assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
+
+        let flat_json = indoc! {r#"{
+            "queries": [
+                {
+                    "glob": "**/*.rs",
+                    "syntax_node": ["fn test"],
+                    "content": "assert"
+                },
+                {
+                    "glob": "**/*.ts",
+                    "syntax_node": [],
+                    "content": null
+                }
+            ]
+        }"#};
+
+        let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
+        assert_eq!(flat_input.queries.len(), 2);
+        assert_eq!(flat_input.queries[0].glob, "**/*.rs");
+        assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
+        assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
+        assert_eq!(flat_input.queries[1].glob, "**/*.ts");
+        assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
+        assert_eq!(flat_input.queries[1].content, None);
+
+        let nested_json = indoc! {r#"{
+            "queries": [
+                [
+                    {
+                        "glob": "**/*.rs",
+                        "syntax_node": ["fn test"],
+                        "content": "assert"
+                    }
+                ],
+                [
+                    {
+                        "glob": "**/*.ts",
+                        "syntax_node": [],
+                        "content": null
+                    }
+                ]
+            ]
+        }"#};
+
+        let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
+
+        assert_eq!(nested_input.queries.len(), 2);
+
+        assert_eq!(nested_input.queries[0].glob, "**/*.rs");
+        assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
+        assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
+        assert_eq!(nested_input.queries[1].glob, "**/*.ts");
+        assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
+        assert_eq!(nested_input.queries[1].content, None);
+
+        let double_encoded_queries = serde_json::to_string(&json!({
+            "queries": serde_json::to_string(&json!([
+                {
+                    "glob": "**/*.rs",
+                    "syntax_node": ["fn test"],
+                    "content": "assert"
+                },
+                {
+                    "glob": "**/*.ts",
+                    "syntax_node": [],
+                    "content": null
+                }
+            ])).unwrap()
+        }))
+        .unwrap();
+
+        let double_encoded_input: SearchToolInput =
+            serde_json::from_str(&double_encoded_queries).unwrap();
+
+        assert_eq!(double_encoded_input.queries.len(), 2);
+
+        assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
+        assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
+        assert_eq!(
+            double_encoded_input.queries[0].content,
+            Some("assert".to_string())
+        );
+        assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
+        assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
+        assert_eq!(double_encoded_input.queries[1].content, None);
+
+        // ### ERROR Switching from var declarations to lexical declarations [RUN 073]
+        // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
+    }
+}