Cargo.lock 🔗
@@ -3211,6 +3211,7 @@ dependencies = [
"rustc-hash 2.1.1",
"schemars 1.0.4",
"serde",
+ "serde_json",
"strum 0.27.2",
]
Ben Kunkle , Agus , and Max created
Closes #ISSUE
Release Notes:
- N/A *or* Added/Fixed/Improved ...
---------
Co-authored-by: Agus <agus@zed.dev>
Co-authored-by: Max <max@zed.dev>
Cargo.lock | 1
crates/cloud_zeta2_prompt/Cargo.toml | 1
crates/cloud_zeta2_prompt/src/retrieval_prompt.rs | 150 +++++++++++++++++
3 files changed, 152 insertions(+)
@@ -3211,6 +3211,7 @@ dependencies = [
"rustc-hash 2.1.1",
"schemars 1.0.4",
"serde",
+ "serde_json",
"strum 0.27.2",
]
@@ -19,4 +19,5 @@ ordered-float.workspace = true
rustc-hash.workspace = true
schemars.workspace = true
serde.workspace = true
+serde_json.workspace = true
strum.workspace = true
@@ -40,9 +40,47 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
pub struct SearchToolInput {
/// An array of queries to run for gathering context relevant to the next prediction
#[schemars(length(max = 3))]
+ #[serde(deserialize_with = "deserialize_queries")]
pub queries: Box<[SearchToolQuery]>,
}
+fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
+where
+ D: serde::Deserializer<'de>,
+{
+ use serde::de::Error;
+
+ #[derive(Deserialize)]
+ #[serde(untagged)]
+ enum QueryCollection {
+ Array(Box<[SearchToolQuery]>),
+ DoubleArray(Box<[Box<[SearchToolQuery]>]>),
+ Single(SearchToolQuery),
+ }
+
+ #[derive(Deserialize)]
+ #[serde(untagged)]
+ enum MaybeDoubleEncoded {
+ SingleEncoded(QueryCollection),
+ DoubleEncoded(String),
+ }
+
+ let result = MaybeDoubleEncoded::deserialize(deserializer)?;
+
+ let normalized = match result {
+ MaybeDoubleEncoded::SingleEncoded(value) => value,
+ MaybeDoubleEncoded::DoubleEncoded(value) => {
+ serde_json::from_str(&value).map_err(D::Error::custom)?
+ }
+ };
+
+ Ok(match normalized {
+ QueryCollection::Array(items) => items,
+ QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
+ QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
+ })
+}
+
/// Search for relevant code by path, syntax hierarchy, and content.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct SearchToolQuery {
@@ -92,3 +130,115 @@ const TOOL_USE_REMINDER: &str = indoc! {"
--
Analyze the user's intent in one to two sentences, then call the `search` tool.
"};
+
+#[cfg(test)]
+mod tests {
+ use serde_json::json;
+
+ use super::*;
+
+ #[test]
+ fn test_deserialize_queries() {
+ let single_query_json = indoc! {r#"{
+ "queries": {
+ "glob": "**/*.rs",
+ "syntax_node": ["fn test"],
+ "content": "assert"
+ }
+ }"#};
+
+ let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
+ assert_eq!(flat_input.queries.len(), 1);
+ assert_eq!(flat_input.queries[0].glob, "**/*.rs");
+ assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
+ assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
+
+ let flat_json = indoc! {r#"{
+ "queries": [
+ {
+ "glob": "**/*.rs",
+ "syntax_node": ["fn test"],
+ "content": "assert"
+ },
+ {
+ "glob": "**/*.ts",
+ "syntax_node": [],
+ "content": null
+ }
+ ]
+ }"#};
+
+ let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
+ assert_eq!(flat_input.queries.len(), 2);
+ assert_eq!(flat_input.queries[0].glob, "**/*.rs");
+ assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
+ assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
+ assert_eq!(flat_input.queries[1].glob, "**/*.ts");
+ assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
+ assert_eq!(flat_input.queries[1].content, None);
+
+ let nested_json = indoc! {r#"{
+ "queries": [
+ [
+ {
+ "glob": "**/*.rs",
+ "syntax_node": ["fn test"],
+ "content": "assert"
+ }
+ ],
+ [
+ {
+ "glob": "**/*.ts",
+ "syntax_node": [],
+ "content": null
+ }
+ ]
+ ]
+ }"#};
+
+ let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
+
+ assert_eq!(nested_input.queries.len(), 2);
+
+ assert_eq!(nested_input.queries[0].glob, "**/*.rs");
+ assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
+ assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
+ assert_eq!(nested_input.queries[1].glob, "**/*.ts");
+ assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
+ assert_eq!(nested_input.queries[1].content, None);
+
+ let double_encoded_queries = serde_json::to_string(&json!({
+ "queries": serde_json::to_string(&json!([
+ {
+ "glob": "**/*.rs",
+ "syntax_node": ["fn test"],
+ "content": "assert"
+ },
+ {
+ "glob": "**/*.ts",
+ "syntax_node": [],
+ "content": null
+ }
+ ])).unwrap()
+ }))
+ .unwrap();
+
+ let double_encoded_input: SearchToolInput =
+ serde_json::from_str(&double_encoded_queries).unwrap();
+
+ assert_eq!(double_encoded_input.queries.len(), 2);
+
+ assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
+ assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
+ assert_eq!(
+ double_encoded_input.queries[0].content,
+ Some("assert".to_string())
+ );
+ assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
+ assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
+ assert_eq!(double_encoded_input.queries[1].content, None);
+
+ // ### ERROR Switching from var declarations to lexical declarations [RUN 073]
+ // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
+ }
+}