diff --git a/Cargo.lock b/Cargo.lock index a39ff712e3a5b9cefe42bcec8359fdf297d55f71..f076630a2e36c2fcca70db8cbdbf20c606b7e2c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3211,6 +3211,7 @@ dependencies = [ "rustc-hash 2.1.1", "schemars 1.0.4", "serde", + "serde_json", "strum 0.27.2", ] diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml index 8be10265cb23e7dd0983c52e7c2d6984b62c4be4..fa8246950f8d03029388e0276954de946efc2346 100644 --- a/crates/cloud_zeta2_prompt/Cargo.toml +++ b/crates/cloud_zeta2_prompt/Cargo.toml @@ -19,4 +19,5 @@ ordered-float.workspace = true rustc-hash.workspace = true schemars.workspace = true serde.workspace = true +serde_json.workspace = true strum.workspace = true diff --git a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs index e334674ef8004b485608e3864cf1e4e8d4c97cdb..fd35f63f03ff967491a28d817852f6622e4919ca 100644 --- a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs @@ -40,9 +40,47 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R pub struct SearchToolInput { /// An array of queries to run for gathering context relevant to the next prediction #[schemars(length(max = 3))] + #[serde(deserialize_with = "deserialize_queries")] pub queries: Box<[SearchToolQuery]>, } +fn deserialize_queries<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + use serde::de::Error; + + #[derive(Deserialize)] + #[serde(untagged)] + enum QueryCollection { + Array(Box<[SearchToolQuery]>), + DoubleArray(Box<[Box<[SearchToolQuery]>]>), + Single(SearchToolQuery), + } + + #[derive(Deserialize)] + #[serde(untagged)] + enum MaybeDoubleEncoded { + SingleEncoded(QueryCollection), + DoubleEncoded(String), + } + + let result = MaybeDoubleEncoded::deserialize(deserializer)?; + + let normalized = match result { + MaybeDoubleEncoded::SingleEncoded(value) => value, + MaybeDoubleEncoded::DoubleEncoded(value) => { + serde_json::from_str(&value).map_err(D::Error::custom)? + } + }; + + Ok(match normalized { + QueryCollection::Array(items) => items, + QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]), + QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(), + }) +} + /// Search for relevant code by path, syntax hierarchy, and content. #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)] pub struct SearchToolQuery { @@ -92,3 +130,115 @@ const TOOL_USE_REMINDER: &str = indoc! {" -- Analyze the user's intent in one to two sentences, then call the `search` tool. "}; + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + + #[test] + fn test_deserialize_queries() { + let single_query_json = indoc! {r#"{ + "queries": { + "glob": "**/*.rs", + "syntax_node": ["fn test"], + "content": "assert" + } + }"#}; + + let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap(); + assert_eq!(flat_input.queries.len(), 1); + assert_eq!(flat_input.queries[0].glob, "**/*.rs"); + assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]); + assert_eq!(flat_input.queries[0].content, Some("assert".to_string())); + + let flat_json = indoc! {r#"{ + "queries": [ + { + "glob": "**/*.rs", + "syntax_node": ["fn test"], + "content": "assert" + }, + { + "glob": "**/*.ts", + "syntax_node": [], + "content": null + } + ] + }"#}; + + let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap(); + assert_eq!(flat_input.queries.len(), 2); + assert_eq!(flat_input.queries[0].glob, "**/*.rs"); + assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]); + assert_eq!(flat_input.queries[0].content, Some("assert".to_string())); + assert_eq!(flat_input.queries[1].glob, "**/*.ts"); + assert_eq!(flat_input.queries[1].syntax_node.len(), 0); + assert_eq!(flat_input.queries[1].content, None); + + let nested_json = indoc! {r#"{ + "queries": [ + [ + { + "glob": "**/*.rs", + "syntax_node": ["fn test"], + "content": "assert" + } + ], + [ + { + "glob": "**/*.ts", + "syntax_node": [], + "content": null + } + ] + ] + }"#}; + + let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap(); + + assert_eq!(nested_input.queries.len(), 2); + + assert_eq!(nested_input.queries[0].glob, "**/*.rs"); + assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]); + assert_eq!(nested_input.queries[0].content, Some("assert".to_string())); + assert_eq!(nested_input.queries[1].glob, "**/*.ts"); + assert_eq!(nested_input.queries[1].syntax_node.len(), 0); + assert_eq!(nested_input.queries[1].content, None); + + let double_encoded_queries = serde_json::to_string(&json!({ + "queries": serde_json::to_string(&json!([ + { + "glob": "**/*.rs", + "syntax_node": ["fn test"], + "content": "assert" + }, + { + "glob": "**/*.ts", + "syntax_node": [], + "content": null + } + ])).unwrap() + })) + .unwrap(); + + let double_encoded_input: SearchToolInput = + serde_json::from_str(&double_encoded_queries).unwrap(); + + assert_eq!(double_encoded_input.queries.len(), 2); + + assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs"); + assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]); + assert_eq!( + double_encoded_input.queries[0].content, + Some("assert".to_string()) + ); + assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts"); + assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0); + assert_eq!(double_encoded_input.queries[1].content, None); + + // ### ERROR Switching from var declarations to lexical declarations [RUN 073] + // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]} + } +}