zeta2: Targeted retrieval search (#42240)

Agus Zubiaga , Ben , Max , and Max Brunsfeld created

Since we removed the filtering step during context gathering, we want
the model to perform more targeted searches. This PR tweaks search tool
schema allowing the model to search within syntax nodes such as `impl`
blocks or methods.

This is what the query schema looks like now:

```rust
/// Search for relevant code by path, syntax hierarchy, and content.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SearchToolQuery {
    /// 1. A glob pattern to match file paths in the codebase to search in.
    pub glob: String,
    /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
    ///
    /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
    ///
    /// Example: Searching for a `User` class
    ///     ["class\s+User"]
    ///
    /// Example: Searching for a `get_full_name` method under a `User` class
    ///     ["class\s+User", "def\sget_full_name"]
    ///
    /// Skip this field to match on content alone.
    #[schemars(length(max = 3))]
    #[serde(default)]
    pub syntax_node: Vec<String>,
    /// 3. An optional regular expression to match the final content that should appear in the results.
    ///
    /// - Content will be matched within all lines of the matched syntax nodes.
    /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
    /// - If no syntax node regexes are provided, the content will be matched within the entire file.
    pub content: Option<String>,
}
```

We'll need to keep refining this, but the core implementation is ready.

Release Notes:

- N/A

---------

Co-authored-by: Ben <ben@zed.dev>
Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

Cargo.lock                                        |   5 
crates/agent/src/agent.rs                         |   1 
crates/agent/src/outline.rs                       |   9 
crates/agent/src/thread.rs                        |   4 
crates/agent/src/tools/context_server_registry.rs |   2 
crates/cloud_zeta2_prompt/Cargo.toml              |   1 
crates/cloud_zeta2_prompt/src/retrieval_prompt.rs |  64 +-
crates/language/src/buffer.rs                     |  12 
crates/language_model/Cargo.toml                  |   3 
crates/language_model/src/language_model.rs       |  11 
crates/language_model/src/tool_schema.rs          |  12 
crates/project/src/project.rs                     |   2 
crates/zeta2/Cargo.toml                           |   1 
crates/zeta2/src/retrieval_search.rs              | 483 ++++++++++++++--
crates/zeta2/src/udiff.rs                         |  28 
crates/zeta2/src/zeta2.rs                         |  53 +
crates/zeta2_tools/Cargo.toml                     |   2 
crates/zeta2_tools/src/zeta2_context_view.rs      |  45 
crates/zeta_cli/src/predict.rs                    | 191 +++---
19 files changed, 657 insertions(+), 272 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3202,7 +3202,6 @@ dependencies = [
  "rustc-hash 2.1.1",
  "schemars 1.0.4",
  "serde",
- "serde_json",
  "strum 0.27.2",
 ]
 
@@ -8867,6 +8866,7 @@ dependencies = [
  "open_router",
  "parking_lot",
  "proto",
+ "schemars 1.0.4",
  "serde",
  "serde_json",
  "settings",
@@ -21685,6 +21685,7 @@ dependencies = [
  "serde",
  "serde_json",
  "settings",
+ "smol",
  "thiserror 2.0.17",
  "util",
  "uuid",
@@ -21702,6 +21703,7 @@ dependencies = [
  "clap",
  "client",
  "cloud_llm_client",
+ "cloud_zeta2_prompt",
  "collections",
  "edit_prediction_context",
  "editor",
@@ -21715,7 +21717,6 @@ dependencies = [
  "ordered-float 2.10.1",
  "pretty_assertions",
  "project",
- "regex-syntax",
  "serde",
  "serde_json",
  "settings",

crates/agent/src/agent.rs 🔗

@@ -6,7 +6,6 @@ mod native_agent_server;
 pub mod outline;
 mod templates;
 mod thread;
-mod tool_schema;
 mod tools;
 
 #[cfg(test)]

crates/agent/src/outline.rs 🔗

@@ -1,6 +1,6 @@
 use anyhow::Result;
 use gpui::{AsyncApp, Entity};
-use language::{Buffer, OutlineItem, ParseStatus};
+use language::{Buffer, OutlineItem};
 use regex::Regex;
 use std::fmt::Write;
 use text::Point;
@@ -30,10 +30,9 @@ pub async fn get_buffer_content_or_outline(
     if file_size > AUTO_OUTLINE_SIZE {
         // For large files, use outline instead of full content
         // Wait until the buffer has been fully parsed, so we can read its outline
-        let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
-        while *parse_status.borrow() != ParseStatus::Idle {
-            parse_status.changed().await?;
-        }
+        buffer
+            .read_with(cx, |buffer, _| buffer.parsing_idle())?
+            .await;
 
         let outline_items = buffer.read_with(cx, |buffer, _| {
             let snapshot = buffer.snapshot();

crates/agent/src/thread.rs 🔗

@@ -2139,7 +2139,7 @@ where
 
     /// Returns the JSON schema that describes the tool's input.
     fn input_schema(format: LanguageModelToolSchemaFormat) -> Schema {
-        crate::tool_schema::root_schema_for::<Self::Input>(format)
+        language_model::tool_schema::root_schema_for::<Self::Input>(format)
     }
 
     /// Some tools rely on a provider for the underlying billing or other reasons.
@@ -2226,7 +2226,7 @@ where
 
     fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
         let mut json = serde_json::to_value(T::input_schema(format))?;
-        crate::tool_schema::adapt_schema_to_format(&mut json, format)?;
+        language_model::tool_schema::adapt_schema_to_format(&mut json, format)?;
         Ok(json)
     }
 

crates/agent/src/tools/context_server_registry.rs 🔗

@@ -165,7 +165,7 @@ impl AnyAgentTool for ContextServerTool {
         format: language_model::LanguageModelToolSchemaFormat,
     ) -> Result<serde_json::Value> {
         let mut schema = self.tool.input_schema.clone();
-        crate::tool_schema::adapt_schema_to_format(&mut schema, format)?;
+        language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
         Ok(match schema {
             serde_json::Value::Null => {
                 serde_json::json!({ "type": "object", "properties": [] })

crates/cloud_zeta2_prompt/Cargo.toml 🔗

@@ -19,5 +19,4 @@ ordered-float.workspace = true
 rustc-hash.workspace = true
 schemars.workspace = true
 serde.workspace = true
-serde_json.workspace = true
 strum.workspace = true

crates/cloud_zeta2_prompt/src/retrieval_prompt.rs 🔗

@@ -3,7 +3,7 @@ use cloud_llm_client::predict_edits_v3::{self, Excerpt};
 use indoc::indoc;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use std::{fmt::Write, sync::LazyLock};
+use std::fmt::Write;
 
 use crate::{push_events, write_codeblock};
 
@@ -15,7 +15,7 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
         push_events(&mut prompt, &request.events);
     }
 
-    writeln!(&mut prompt, "## Excerpt around the cursor\n")?;
+    writeln!(&mut prompt, "## Cursor context")?;
     write_codeblock(
         &request.excerpt_path,
         &[Excerpt {
@@ -39,54 +39,56 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
 #[derive(Clone, Deserialize, Serialize, JsonSchema)]
 pub struct SearchToolInput {
     /// An array of queries to run for gathering context relevant to the next prediction
-    #[schemars(length(max = 5))]
+    #[schemars(length(max = 3))]
     pub queries: Box<[SearchToolQuery]>,
 }
 
+/// Search for relevant code by path, syntax hierarchy, and content.
 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
 pub struct SearchToolQuery {
-    /// A glob pattern to match file paths in the codebase
+    /// 1. A glob pattern to match file paths in the codebase to search in.
     pub glob: String,
-    /// A regular expression to match content within the files matched by the glob pattern
-    pub regex: String,
+    /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
+    ///
+    /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
+    ///
+    /// Example: Searching for a `User` class
+    ///     ["class\s+User"]
+    ///
+    /// Example: Searching for a `get_full_name` method under a `User` class
+    ///     ["class\s+User", "def\sget_full_name"]
+    ///
+    /// Skip this field to match on content alone.
+    #[schemars(length(max = 3))]
+    #[serde(default)]
+    pub syntax_node: Vec<String>,
+    /// 3. An optional regular expression to match the final content that should appear in the results.
+    ///
+    /// - Content will be matched within all lines of the matched syntax nodes.
+    /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
+    /// - If no syntax node regexes are provided, the content will be matched within the entire file.
+    pub content: Option<String>,
 }
 
-pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
-    let schema = schemars::schema_for!(SearchToolInput);
-
-    let description = schema
-        .get("description")
-        .and_then(|description| description.as_str())
-        .unwrap()
-        .to_string();
-
-    (schema.into(), description)
-});
-
 pub const TOOL_NAME: &str = "search";
 
 const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
-    ## Task
+    You are part of an edit prediction system in a code editor.
+    Your role is to search for code that will serve as context for predicting the next edit.
 
-    You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
-    that will serve as context for predicting the next required edit.
-
-    **Your task:**
     - Analyze the user's recent edits and current cursor context
-    - Use the `search` tool to find code that may be relevant for predicting the next edit
+    - Use the `search` tool to find code that is relevant for predicting the next edit
     - Focus on finding:
        - Code patterns that might need similar changes based on the recent edits
        - Functions, variables, types, and constants referenced in the current cursor context
        - Related implementations, usages, or dependencies that may require consistent updates
-
-    **Important constraints:**
-    - This conversation has exactly 2 turns
-    - You must make ALL search queries in your first response via the `search` tool
-    - All queries will be executed in parallel and results returned together
-    - In the second turn, you will select the most relevant results via the `select` tool.
+       - How items defined in the cursor excerpt are used or altered
+    - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible
+    - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function
+    - Avoid using wildcard globs if you already know the file path of the content you're looking for
 "#};
 
 const TOOL_USE_REMINDER: &str = indoc! {"
     --
-    Use the `search` tool now
+    Analyze the user's intent in one to two sentences, then call the `search` tool.
 "};

crates/language/src/buffer.rs 🔗

@@ -1618,6 +1618,18 @@ impl Buffer {
         self.parse_status.1.clone()
     }
 
+    /// Wait until the buffer is no longer parsing
+    pub fn parsing_idle(&self) -> impl Future<Output = ()> + use<> {
+        let mut parse_status = self.parse_status();
+        async move {
+            while *parse_status.borrow() != ParseStatus::Idle {
+                if parse_status.changed().await.is_err() {
+                    break;
+                }
+            }
+        }
+    }
+
     /// Assign to the buffer a set of diagnostics created by a given language server.
     pub fn update_diagnostics(
         &mut self,

crates/language_model/Cargo.toml 🔗

@@ -17,7 +17,6 @@ test-support = []
 
 [dependencies]
 anthropic = { workspace = true, features = ["schemars"] }
-open_router.workspace = true
 anyhow.workspace = true
 base64.workspace = true
 client.workspace = true
@@ -30,8 +29,10 @@ http_client.workspace = true
 icons.workspace = true
 image.workspace = true
 log.workspace = true
+open_router.workspace = true
 parking_lot.workspace = true
 proto.workspace = true
+schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true

crates/language_model/src/language_model.rs 🔗

@@ -4,6 +4,7 @@ mod registry;
 mod request;
 mod role;
 mod telemetry;
+pub mod tool_schema;
 
 #[cfg(any(test, feature = "test-support"))]
 pub mod fake_provider;
@@ -35,6 +36,7 @@ pub use crate::registry::*;
 pub use crate::request::*;
 pub use crate::role::*;
 pub use crate::telemetry::*;
+pub use crate::tool_schema::LanguageModelToolSchemaFormat;
 
 pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
     LanguageModelProviderId::new("anthropic");
@@ -409,15 +411,6 @@ impl From<open_router::ApiError> for LanguageModelCompletionError {
     }
 }
 
-/// Indicates the format used to define the input schema for a language model tool.
-#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
-pub enum LanguageModelToolSchemaFormat {
-    /// A JSON schema, see https://json-schema.org
-    JsonSchema,
-    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
-    JsonSchemaSubset,
-}
-
 #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 #[serde(rename_all = "snake_case")]
 pub enum StopReason {

crates/agent/src/tool_schema.rs → crates/language_model/src/tool_schema.rs 🔗

@@ -1,5 +1,4 @@
 use anyhow::Result;
-use language_model::LanguageModelToolSchemaFormat;
 use schemars::{
     JsonSchema, Schema,
     generate::SchemaSettings,
@@ -7,7 +6,16 @@ use schemars::{
 };
 use serde_json::Value;
 
-pub(crate) fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
+/// Indicates the format used to define the input schema for a language model tool.
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
+pub enum LanguageModelToolSchemaFormat {
+    /// A JSON schema, see https://json-schema.org
+    JsonSchema,
+    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
+    JsonSchemaSubset,
+}
+
+pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
     let mut generator = match format {
         LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
         LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()

crates/project/src/project.rs 🔗

@@ -4060,7 +4060,7 @@ impl Project {
         result_rx
     }
 
-    fn find_search_candidate_buffers(
+    pub fn find_search_candidate_buffers(
         &mut self,
         query: &SearchQuery,
         limit: usize,

crates/zeta2/Cargo.toml 🔗

@@ -33,6 +33,7 @@ project.workspace = true
 release_channel.workspace = true
 serde.workspace = true
 serde_json.workspace = true
+smol.workspace = true
 thiserror.workspace = true
 util.workspace = true
 uuid.workspace = true

crates/zeta2/src/retrieval_search.rs 🔗

@@ -1,18 +1,19 @@
 use std::ops::Range;
 
 use anyhow::Result;
+use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
 use collections::HashMap;
-use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
 use futures::{
     StreamExt,
     channel::mpsc::{self, UnboundedSender},
 };
 use gpui::{AppContext, AsyncApp, Entity};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, ToPoint as _};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
 use project::{
     Project, WorktreeSettings,
     search::{SearchQuery, SearchResult},
 };
+use smol::channel;
 use util::{
     ResultExt as _,
     paths::{PathMatcher, PathStyle},
@@ -21,7 +22,7 @@ use workspace::item::Settings as _;
 
 pub async fn run_retrieval_searches(
     project: Entity<Project>,
-    regex_by_glob: HashMap<String, String>,
+    queries: Vec<SearchToolQuery>,
     cx: &mut AsyncApp,
 ) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
     let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
@@ -37,14 +38,13 @@ pub async fn run_retrieval_searches(
 
     let (results_tx, mut results_rx) = mpsc::unbounded();
 
-    for (glob, regex) in regex_by_glob {
+    for query in queries {
         let exclude_matcher = exclude_matcher.clone();
         let results_tx = results_tx.clone();
         let project = project.clone();
         cx.spawn(async move |cx| {
             run_query(
-                &glob,
-                &regex,
+                query,
                 results_tx.clone(),
                 path_style,
                 exclude_matcher,
@@ -108,87 +108,442 @@ pub async fn run_retrieval_searches(
     .await
 }
 
-const MIN_EXCERPT_LEN: usize = 16;
 const MAX_EXCERPT_LEN: usize = 768;
 const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
 
+struct SearchJob {
+    buffer: Entity<Buffer>,
+    snapshot: BufferSnapshot,
+    ranges: Vec<Range<usize>>,
+    query_ix: usize,
+    jobs_tx: channel::Sender<SearchJob>,
+}
+
 async fn run_query(
-    glob: &str,
-    regex: &str,
+    input_query: SearchToolQuery,
     results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
     path_style: PathStyle,
     exclude_matcher: PathMatcher,
     project: &Entity<Project>,
     cx: &mut AsyncApp,
 ) -> Result<()> {
-    let include_matcher = PathMatcher::new(vec![glob], path_style)?;
-
-    let query = SearchQuery::regex(
-        regex,
-        false,
-        true,
-        false,
-        true,
-        include_matcher,
-        exclude_matcher,
-        true,
-        None,
-    )?;
-
-    let results = project.update(cx, |project, cx| project.search(query, cx))?;
-    futures::pin_mut!(results);
-
-    while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
-        if results_tx.is_closed() {
-            break;
-        }
+    let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
 
-        if ranges.is_empty() {
-            continue;
-        }
+    let make_search = |regex: &str| -> Result<SearchQuery> {
+        SearchQuery::regex(
+            regex,
+            false,
+            true,
+            false,
+            true,
+            include_matcher.clone(),
+            exclude_matcher.clone(),
+            true,
+            None,
+        )
+    };
 
-        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-        let results_tx = results_tx.clone();
+    if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
+        let outer_syntax_query = make_search(outer_syntax_regex)?;
+        let nested_syntax_queries = input_query
+            .syntax_node
+            .into_iter()
+            .skip(1)
+            .map(|query| make_search(&query))
+            .collect::<Result<Vec<_>>>()?;
+        let content_query = input_query
+            .content
+            .map(|regex| make_search(&regex))
+            .transpose()?;
 
-        cx.background_spawn(async move {
-            let mut excerpts = Vec::with_capacity(ranges.len());
-
-            for range in ranges {
-                let offset_range = range.to_offset(&snapshot);
-                let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
-
-                let excerpt = EditPredictionExcerpt::select_from_buffer(
-                    query_point,
-                    &snapshot,
-                    &EditPredictionExcerptOptions {
-                        max_bytes: MAX_EXCERPT_LEN,
-                        min_bytes: MIN_EXCERPT_LEN,
-                        target_before_cursor_over_total_bytes: 0.5,
-                    },
-                    None,
-                );
+        let (jobs_tx, jobs_rx) = channel::unbounded();
 
-                if let Some(excerpt) = excerpt
-                    && !excerpt.line_range.is_empty()
-                {
-                    excerpts.push((
-                        snapshot.anchor_after(excerpt.range.start)
-                            ..snapshot.anchor_before(excerpt.range.end),
-                        excerpt.range.len(),
-                    ));
-                }
+        let outer_search_results_rx =
+            project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
+
+        let outer_search_task = cx.spawn(async move |cx| {
+            futures::pin_mut!(outer_search_results_rx);
+            while let Some(SearchResult::Buffer { buffer, ranges }) =
+                outer_search_results_rx.next().await
+            {
+                buffer
+                    .read_with(cx, |buffer, _| buffer.parsing_idle())?
+                    .await;
+                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+                let expanded_ranges: Vec<_> = ranges
+                    .into_iter()
+                    .filter_map(|range| expand_to_parent_range(&range, &snapshot))
+                    .collect();
+                jobs_tx
+                    .send(SearchJob {
+                        buffer,
+                        snapshot,
+                        ranges: expanded_ranges,
+                        query_ix: 0,
+                        jobs_tx: jobs_tx.clone(),
+                    })
+                    .await?;
             }
+            anyhow::Ok(())
+        });
 
-            let send_result = results_tx.unbounded_send((buffer, snapshot, excerpts));
+        let n_workers = cx.background_executor().num_cpus();
+        let search_job_task = cx.background_executor().scoped(|scope| {
+            for _ in 0..n_workers {
+                scope.spawn(async {
+                    while let Ok(job) = jobs_rx.recv().await {
+                        process_nested_search_job(
+                            &results_tx,
+                            &nested_syntax_queries,
+                            &content_query,
+                            job,
+                        )
+                        .await;
+                    }
+                });
+            }
+        });
+
+        search_job_task.await;
+        outer_search_task.await?;
+    } else if let Some(content_regex) = &input_query.content {
+        let search_query = make_search(&content_regex)?;
+
+        let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
+        futures::pin_mut!(results_rx);
+
+        while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
+            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+            let ranges = ranges
+                .into_iter()
+                .map(|range| {
+                    let range = range.to_offset(&snapshot);
+                    let range = expand_to_entire_lines(range, &snapshot);
+                    let size = range.len();
+                    let range =
+                        snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
+                    (range, size)
+                })
+                .collect();
+
+            let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
 
             if let Err(err) = send_result
                 && !err.is_disconnected()
             {
                 log::error!("{err}");
             }
-        })
-        .detach();
+        }
+    } else {
+        log::warn!("Context gathering model produced a glob-only search");
     }
 
     anyhow::Ok(())
 }
+
+async fn process_nested_search_job(
+    results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
+    queries: &Vec<SearchQuery>,
+    content_query: &Option<SearchQuery>,
+    job: SearchJob,
+) {
+    if let Some(search_query) = queries.get(job.query_ix) {
+        let mut subranges = Vec::new();
+        for range in job.ranges {
+            let start = range.start;
+            let search_results = search_query.search(&job.snapshot, Some(range)).await;
+            for subrange in search_results {
+                let subrange = start + subrange.start..start + subrange.end;
+                subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
+            }
+        }
+        job.jobs_tx
+            .send(SearchJob {
+                buffer: job.buffer,
+                snapshot: job.snapshot,
+                ranges: subranges,
+                query_ix: job.query_ix + 1,
+                jobs_tx: job.jobs_tx.clone(),
+            })
+            .await
+            .ok();
+    } else {
+        let ranges = if let Some(content_query) = content_query {
+            let mut subranges = Vec::new();
+            for range in job.ranges {
+                let start = range.start;
+                let search_results = content_query.search(&job.snapshot, Some(range)).await;
+                for subrange in search_results {
+                    let subrange = start + subrange.start..start + subrange.end;
+                    subranges.push(subrange);
+                }
+            }
+            subranges
+        } else {
+            job.ranges
+        };
+
+        let matches = ranges
+            .into_iter()
+            .map(|range| {
+                let snapshot = &job.snapshot;
+                let range = expand_to_entire_lines(range, snapshot);
+                let size = range.len();
+                let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
+                (range, size)
+            })
+            .collect();
+
+        let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
+
+        if let Err(err) = send_result
+            && !err.is_disconnected()
+        {
+            log::error!("{err}");
+        }
+    }
+}
+
+fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
+    let mut point_range = range.to_point(snapshot);
+    point_range.start.column = 0;
+    if point_range.end.column > 0 {
+        point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
+    }
+    point_range.to_offset(snapshot)
+}
+
+fn expand_to_parent_range<T: ToPoint + ToOffset>(
+    range: &Range<T>,
+    snapshot: &BufferSnapshot,
+) -> Option<Range<usize>> {
+    let mut line_range = range.to_point(&snapshot);
+    line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
+    line_range.end.column = snapshot.line_len(line_range.end.row);
+    // TODO skip result if matched line isn't the first node line?
+
+    let node = snapshot.syntax_ancestor(line_range)?;
+    Some(node.byte_range())
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::merge_excerpts::merge_excerpts;
+    use cloud_zeta2_prompt::write_codeblock;
+    use edit_prediction_context::Line;
+    use gpui::TestAppContext;
+    use indoc::indoc;
+    use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
+    use pretty_assertions::assert_eq;
+    use project::FakeFs;
+    use serde_json::json;
+    use settings::SettingsStore;
+    use std::path::Path;
+    use util::path;
+
+    #[gpui::test]
+    async fn test_retrieval(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            path!("/root"),
+            json!({
+                "user.rs": indoc!{"
+                    pub struct Organization {
+                        owner: Arc<User>,
+                    }
+
+                    pub struct User {
+                        first_name: String,
+                        last_name: String,
+                    }
+
+                    impl Organization {
+                        pub fn owner(&self) -> Arc<User> {
+                            self.owner.clone()
+                        }
+                    }
+
+                    impl User {
+                        pub fn new(first_name: String, last_name: String) -> Self {
+                            Self {
+                                first_name,
+                                last_name
+                            }
+                        }
+
+                        pub fn first_name(&self) -> String {
+                            self.first_name.clone()
+                        }
+
+                        pub fn last_name(&self) -> String {
+                            self.last_name.clone()
+                        }
+                    }
+                "},
+                "main.rs": indoc!{r#"
+                    fn main() {
+                        let user = User::new(FIRST_NAME.clone(), "doe".into());
+                        println!("user {:?}", user);
+                    }
+                "#},
+            }),
+        )
+        .await;
+
+        let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
+        project.update(cx, |project, _cx| {
+            project.languages().add(rust_lang().into())
+        });
+
+        assert_results(
+            &project,
+            SearchToolQuery {
+                glob: "user.rs".into(),
+                syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
+                content: None,
+            },
+            indoc! {r#"
+                `````root/user.rs
+                …
+                impl User {
+                …
+                    pub fn first_name(&self) -> String {
+                        self.first_name.clone()
+                    }
+                …
+                `````
+            "#},
+            cx,
+        )
+        .await;
+
+        assert_results(
+            &project,
+            SearchToolQuery {
+                glob: "user.rs".into(),
+                syntax_node: vec!["impl\\s+User".into()],
+                content: Some("\\.clone".into()),
+            },
+            indoc! {r#"
+                `````root/user.rs
+                …
+                impl User {
+                …
+                    pub fn first_name(&self) -> String {
+                        self.first_name.clone()
+                …
+                    pub fn last_name(&self) -> String {
+                        self.last_name.clone()
+                …
+                `````
+            "#},
+            cx,
+        )
+        .await;
+
+        assert_results(
+            &project,
+            SearchToolQuery {
+                glob: "*.rs".into(),
+                syntax_node: vec![],
+                content: Some("\\.clone".into()),
+            },
+            indoc! {r#"
+                `````root/main.rs
+                fn main() {
+                    let user = User::new(FIRST_NAME.clone(), "doe".into());
+                …
+                `````
+
+                `````root/user.rs
+                …
+                impl Organization {
+                    pub fn owner(&self) -> Arc<User> {
+                        self.owner.clone()
+                …
+                impl User {
+                …
+                    pub fn first_name(&self) -> String {
+                        self.first_name.clone()
+                …
+                    pub fn last_name(&self) -> String {
+                        self.last_name.clone()
+                …
+                `````
+            "#},
+            cx,
+        )
+        .await;
+    }
+
+    async fn assert_results(
+        project: &Entity<Project>,
+        query: SearchToolQuery,
+        expected_output: &str,
+        cx: &mut TestAppContext,
+    ) {
+        let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async())
+            .await
+            .unwrap();
+
+        let mut results = results.into_iter().collect::<Vec<_>>();
+        results.sort_by_key(|results| {
+            results
+                .0
+                .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
+        });
+
+        let mut output = String::new();
+        for (buffer, ranges) in results {
+            buffer.read_with(cx, |buffer, cx| {
+                let excerpts = ranges.into_iter().map(|range| {
+                    let point_range = range.to_point(buffer);
+                    if point_range.end.column > 0 {
+                        Line(point_range.start.row)..Line(point_range.end.row + 1)
+                    } else {
+                        Line(point_range.start.row)..Line(point_range.end.row)
+                    }
+                });
+
+                write_codeblock(
+                    &buffer.file().unwrap().full_path(cx),
+                    merge_excerpts(&buffer.snapshot(), excerpts).iter(),
+                    &[],
+                    Line(buffer.max_point().row),
+                    false,
+                    &mut output,
+                );
+            });
+        }
+        output.pop();
+
+        assert_eq!(output, expected_output);
+    }
+
+    fn rust_lang() -> Language {
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                matcher: LanguageMatcher {
+                    path_suffixes: vec!["rs".to_string()],
+                    ..Default::default()
+                },
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::LANGUAGE.into()),
+        )
+        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
+        .unwrap()
+    }
+
+    fn init_test(cx: &mut TestAppContext) {
+        cx.update(move |cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            zlog::init_test();
+        });
+    }
+}

crates/zeta2/src/udiff.rs 🔗

@@ -18,10 +18,10 @@ use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSn
 use project::Project;
 
 pub async fn parse_diff<'a>(
-    diff: &'a str,
+    diff_str: &'a str,
     get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
 ) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
-    let mut diff = DiffParser::new(diff);
+    let mut diff = DiffParser::new(diff_str);
     let mut edited_buffer = None;
     let mut edits = Vec::new();
 
@@ -41,7 +41,10 @@ pub async fn parse_diff<'a>(
                     Some(ref current) => current,
                 };
 
-                edits.extend(resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)?);
+                edits.extend(
+                    resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
+                        .with_context(|| format!("Diff:\n{diff_str}"))?,
+                );
             }
             DiffEvent::FileEnd { renamed_to } => {
                 let (buffer, _) = edited_buffer
@@ -69,13 +72,13 @@ pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffe
 
 #[must_use]
 pub async fn apply_diff<'a>(
-    diff: &'a str,
+    diff_str: &'a str,
     project: &Entity<Project>,
     cx: &mut AsyncApp,
 ) -> Result<OpenedBuffers<'a>> {
     let mut included_files = HashMap::default();
 
-    for line in diff.lines() {
+    for line in diff_str.lines() {
         let diff_line = DiffLine::parse(line);
 
         if let DiffLine::OldPath { path } = diff_line {
@@ -97,7 +100,7 @@ pub async fn apply_diff<'a>(
 
     let ranges = [Anchor::MIN..Anchor::MAX];
 
-    let mut diff = DiffParser::new(diff);
+    let mut diff = DiffParser::new(diff_str);
     let mut current_file = None;
     let mut edits = vec![];
 
@@ -120,7 +123,10 @@ pub async fn apply_diff<'a>(
                 };
 
                 buffer.read_with(cx, |buffer, _| {
-                    edits.extend(resolve_hunk_edits_in_buffer(hunk, buffer, ranges)?);
+                    edits.extend(
+                        resolve_hunk_edits_in_buffer(hunk, buffer, ranges)
+                            .with_context(|| format!("Diff:\n{diff_str}"))?,
+                    );
                     anyhow::Ok(())
                 })??;
             }
@@ -328,13 +334,7 @@ fn resolve_hunk_edits_in_buffer(
                 offset = Some(range.start + ix);
             }
         }
-        offset.ok_or_else(|| {
-            anyhow!(
-                "Failed to match context:\n{}\n\nBuffer:\n{}",
-                hunk.context,
-                buffer.text(),
-            )
-        })
+        offset.ok_or_else(|| anyhow!("Failed to match context:\n{}", hunk.context))
     }?;
     let iter = hunk.edits.into_iter().flat_map(move |edit| {
         let old_text = buffer

crates/zeta2/src/zeta2.rs 🔗

@@ -7,7 +7,7 @@ use cloud_llm_client::{
     ZED_VERSION_HEADER_NAME,
 };
 use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
-use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput;
+use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
 use collections::HashMap;
 use edit_prediction_context::{
     DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
@@ -35,7 +35,7 @@ use uuid::Uuid;
 use std::ops::Range;
 use std::path::Path;
 use std::str::FromStr as _;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
 use std::time::{Duration, Instant};
 use thiserror::Error;
 use util::rel_path::RelPathBuf;
@@ -88,6 +88,9 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
     buffer_change_grouping_interval: Duration::from_secs(1),
 };
 
+static MODEL_ID: LazyLock<String> =
+    LazyLock::new(|| std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()));
+
 pub struct Zeta2FeatureFlag;
 
 impl FeatureFlag for Zeta2FeatureFlag {
@@ -180,7 +183,7 @@ pub struct ZetaEditPredictionDebugInfo {
 pub struct ZetaSearchQueryDebugInfo {
     pub project: Entity<Project>,
     pub timestamp: Instant,
-    pub regex_by_glob: HashMap<String, String>,
+    pub search_queries: Vec<SearchToolQuery>,
 }
 
 pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
@@ -883,7 +886,7 @@ impl Zeta {
 
                 let (prompt, _) = prompt_result?;
                 let request = open_ai::Request {
-                    model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()),
+                    model: MODEL_ID.clone(),
                     messages: vec![open_ai::RequestMessage::User {
                         content: open_ai::MessageContent::Plain(prompt),
                     }],
@@ -1226,10 +1229,24 @@ impl Zeta {
                 .ok();
         }
 
-        let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA;
+        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
+            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
+                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
+            );
+
+            let description = schema
+                .get("description")
+                .and_then(|description| description.as_str())
+                .unwrap()
+                .to_string();
+
+            (schema.into(), description)
+        });
+
+        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
 
         let request = open_ai::Request {
-            model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()),
+            model: MODEL_ID.clone(),
             messages: vec![open_ai::RequestMessage::User {
                 content: open_ai::MessageContent::Plain(prompt),
             }],
@@ -1242,8 +1259,8 @@ impl Zeta {
             tools: vec![open_ai::ToolDefinition::Function {
                 function: FunctionDefinition {
                     name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
-                    description: Some(tool_description.clone()),
-                    parameters: Some(tool_schema.clone()),
+                    description: Some(tool_description),
+                    parameters: Some(tool_schema),
                 },
             }],
             prompt_cache_key: None,
@@ -1255,7 +1272,6 @@ impl Zeta {
             let response =
                 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
             let mut response = Self::handle_api_response(&this, response, cx)?;
-
             log::trace!("Got search planning response");
 
             let choice = response
@@ -1270,7 +1286,7 @@ impl Zeta {
                 anyhow::bail!("Retrieval response didn't include an assistant message");
             };
 
-            let mut regex_by_glob: HashMap<String, String> = HashMap::default();
+            let mut queries: Vec<SearchToolQuery> = Vec::new();
             for tool_call in tool_calls {
                 let open_ai::ToolCallContent::Function { function } = tool_call.content;
                 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
@@ -1283,13 +1299,7 @@ impl Zeta {
                 }
 
                 let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
-                for query in input.queries {
-                    let regex = regex_by_glob.entry(query.glob).or_default();
-                    if !regex.is_empty() {
-                        regex.push('|');
-                    }
-                    regex.push_str(&query.regex);
-                }
+                queries.extend(input.queries);
             }
 
             if let Some(debug_tx) = &debug_tx {
@@ -1298,16 +1308,16 @@ impl Zeta {
                         ZetaSearchQueryDebugInfo {
                             project: project.clone(),
                             timestamp: Instant::now(),
-                            regex_by_glob: regex_by_glob.clone(),
+                            search_queries: queries.clone(),
                         },
                     ))
                     .ok();
             }
 
-            log::trace!("Running retrieval search: {regex_by_glob:#?}");
+            log::trace!("Running retrieval search: {queries:#?}");
 
             let related_excerpts_result =
-                retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await;
+                retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
 
             log::trace!("Search queries executed");
 
@@ -1754,7 +1764,8 @@ mod tests {
                                     arguments: serde_json::to_string(&SearchToolInput {
                                         queries: Box::new([SearchToolQuery {
                                             glob: "root/2.txt".to_string(),
-                                            regex: ".".to_string(),
+                                            syntax_node: vec![],
+                                            content: Some(".".into()),
                                         }]),
                                     })
                                     .unwrap(),

crates/zeta2_tools/Cargo.toml 🔗

@@ -16,6 +16,7 @@ anyhow.workspace = true
 chrono.workspace = true
 client.workspace = true
 cloud_llm_client.workspace = true
+cloud_zeta2_prompt.workspace = true
 collections.workspace = true
 edit_prediction_context.workspace = true
 editor.workspace = true
@@ -27,7 +28,6 @@ log.workspace = true
 multi_buffer.workspace = true
 ordered-float.workspace = true
 project.workspace = true
-regex-syntax = "0.8.8"
 serde.workspace = true
 serde_json.workspace = true
 telemetry.workspace = true

crates/zeta2_tools/src/zeta2_context_view.rs 🔗

@@ -8,6 +8,7 @@ use std::{
 
 use anyhow::Result;
 use client::{Client, UserStore};
+use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
 use editor::{Editor, PathKey};
 use futures::StreamExt as _;
 use gpui::{
@@ -41,19 +42,13 @@ pub struct Zeta2ContextView {
 #[derive(Debug)]
 struct RetrievalRun {
     editor: Entity<Editor>,
-    search_queries: Vec<GlobQueries>,
+    search_queries: Vec<SearchToolQuery>,
     started_at: Instant,
     search_results_generated_at: Option<Instant>,
     search_results_executed_at: Option<Instant>,
     finished_at: Option<Instant>,
 }
 
-#[derive(Debug)]
-struct GlobQueries {
-    glob: String,
-    alternations: Vec<String>,
-}
-
 actions!(
     dev,
     [
@@ -210,23 +205,7 @@ impl Zeta2ContextView {
         };
 
         run.search_results_generated_at = Some(info.timestamp);
-        run.search_queries = info
-            .regex_by_glob
-            .into_iter()
-            .map(|(glob, regex)| {
-                let mut regex_parser = regex_syntax::ast::parse::Parser::new();
-
-                GlobQueries {
-                    glob,
-                    alternations: match regex_parser.parse(&regex) {
-                        Ok(regex_syntax::ast::Ast::Alternation(ref alt)) => {
-                            alt.asts.iter().map(|ast| ast.to_string()).collect()
-                        }
-                        _ => vec![regex],
-                    },
-                }
-            })
-            .collect();
+        run.search_queries = info.search_queries;
         cx.notify();
     }
 
@@ -292,18 +271,28 @@ impl Zeta2ContextView {
                         .enumerate()
                         .flat_map(|(ix, query)| {
                             std::iter::once(ListHeader::new(query.glob.clone()).into_any_element())
-                                .chain(query.alternations.iter().enumerate().map(
-                                    move |(alt_ix, alt)| {
-                                        ListItem::new(ix * 100 + alt_ix)
+                                .chain(query.syntax_node.iter().enumerate().map(
+                                    move |(regex_ix, regex)| {
+                                        ListItem::new(ix * 100 + regex_ix)
                                             .start_slot(
                                                 Icon::new(IconName::MagnifyingGlass)
                                                     .color(Color::Muted)
                                                     .size(IconSize::Small),
                                             )
-                                            .child(alt.clone())
+                                            .child(regex.clone())
                                             .into_any_element()
                                     },
                                 ))
+                                .chain(query.content.as_ref().map(move |regex| {
+                                    ListItem::new(ix * 100 + query.syntax_node.len())
+                                        .start_slot(
+                                            Icon::new(IconName::MagnifyingGlass)
+                                                .color(Color::Muted)
+                                                .size(IconSize::Small),
+                                        )
+                                        .child(regex.clone())
+                                        .into_any_element()
+                                }))
                         }),
                 ),
             )

crates/zeta_cli/src/predict.rs 🔗

@@ -2,11 +2,11 @@ use crate::example::{ActualExcerpt, NamedExample};
 use crate::headless::ZetaCliAppState;
 use crate::paths::LOGS_DIR;
 use ::serde::Serialize;
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
 use clap::Args;
 use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 use futures::StreamExt as _;
-use gpui::AsyncApp;
+use gpui::{AppContext, AsyncApp};
 use project::Project;
 use serde::Deserialize;
 use std::cell::Cell;
@@ -14,6 +14,7 @@ use std::fs;
 use std::io::Write;
 use std::path::PathBuf;
 use std::sync::Arc;
+use std::sync::Mutex;
 use std::time::{Duration, Instant};
 
 #[derive(Debug, Args)]
@@ -103,112 +104,126 @@ pub async fn zeta2_predict(
     let _edited_buffers = example.apply_edit_history(&project, cx).await?;
     let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
 
+    let result = Arc::new(Mutex::new(PredictionDetails::default()));
     let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
 
-    let refresh_task = zeta.update(cx, |zeta, cx| {
-        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
-    })?;
+    let debug_task = cx.background_spawn({
+        let result = result.clone();
+        async move {
+            let mut context_retrieval_started_at = None;
+            let mut context_retrieval_finished_at = None;
+            let mut search_queries_generated_at = None;
+            let mut search_queries_executed_at = None;
+            while let Some(event) = debug_rx.next().await {
+                match event {
+                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+                        context_retrieval_started_at = Some(info.timestamp);
+                        fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
+                    }
+                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+                        search_queries_generated_at = Some(info.timestamp);
+                        fs::write(
+                            LOGS_DIR.join("search_queries.json"),
+                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
+                        )?;
+                    }
+                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+                        search_queries_executed_at = Some(info.timestamp);
+                    }
+                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+                        context_retrieval_finished_at = Some(info.timestamp);
+                    }
+                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
+                        let prediction_started_at = Instant::now();
+                        fs::write(
+                            LOGS_DIR.join("prediction_prompt.md"),
+                            &request.local_prompt.unwrap_or_default(),
+                        )?;
 
-    let mut context_retrieval_started_at = None;
-    let mut context_retrieval_finished_at = None;
-    let mut search_queries_generated_at = None;
-    let mut search_queries_executed_at = None;
-    let mut prediction_started_at = None;
-    let mut prediction_finished_at = None;
-    let mut excerpts_text = String::new();
-    let mut prediction_task = None;
-    let mut result = PredictionDetails::default();
-    while let Some(event) = debug_rx.next().await {
-        match event {
-            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
-                context_retrieval_started_at = Some(info.timestamp);
-                fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
-            }
-            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
-                search_queries_generated_at = Some(info.timestamp);
-                fs::write(
-                    LOGS_DIR.join("search_queries.json"),
-                    serde_json::to_string_pretty(&info.regex_by_glob).unwrap(),
-                )?;
-            }
-            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
-                search_queries_executed_at = Some(info.timestamp);
-            }
-            zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
-                context_retrieval_finished_at = Some(info.timestamp);
+                        {
+                            let mut result = result.lock().unwrap();
 
-                prediction_task = Some(zeta.update(cx, |zeta, cx| {
-                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
-                })?);
-            }
-            zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
-                prediction_started_at = Some(Instant::now());
-                fs::write(
-                    LOGS_DIR.join("prediction_prompt.md"),
-                    &request.local_prompt.unwrap_or_default(),
-                )?;
+                            for included_file in request.request.included_files {
+                                let insertions =
+                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
+                                result.excerpts.extend(included_file.excerpts.iter().map(
+                                    |excerpt| ActualExcerpt {
+                                        path: included_file.path.components().skip(1).collect(),
+                                        text: String::from(excerpt.text.as_ref()),
+                                    },
+                                ));
+                                write_codeblock(
+                                    &included_file.path,
+                                    included_file.excerpts.iter(),
+                                    if included_file.path == request.request.excerpt_path {
+                                        &insertions
+                                    } else {
+                                        &[]
+                                    },
+                                    included_file.max_row,
+                                    false,
+                                    &mut result.excerpts_text,
+                                );
+                            }
+                        }
 
-                for included_file in request.request.included_files {
-                    let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
-                    result
-                        .excerpts
-                        .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt {
-                            path: included_file.path.components().skip(1).collect(),
-                            text: String::from(excerpt.text.as_ref()),
-                        }));
-                    write_codeblock(
-                        &included_file.path,
-                        included_file.excerpts.iter(),
-                        if included_file.path == request.request.excerpt_path {
-                            &insertions
-                        } else {
-                            &[]
-                        },
-                        included_file.max_row,
-                        false,
-                        &mut excerpts_text,
-                    );
-                }
+                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
+                        let response = zeta2::text_from_response(response).unwrap_or_default();
+                        let prediction_finished_at = Instant::now();
+                        fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
 
-                let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
-                let response = zeta2::text_from_response(response).unwrap_or_default();
-                prediction_finished_at = Some(Instant::now());
-                fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
+                        let mut result = result.lock().unwrap();
 
-                break;
+                        result.planning_search_time = search_queries_generated_at.unwrap()
+                            - context_retrieval_started_at.unwrap();
+                        result.running_search_time = search_queries_executed_at.unwrap()
+                            - search_queries_generated_at.unwrap();
+                        result.filtering_search_time = context_retrieval_finished_at.unwrap()
+                            - search_queries_executed_at.unwrap();
+                        result.prediction_time = prediction_finished_at - prediction_started_at;
+                        result.total_time =
+                            prediction_finished_at - context_retrieval_started_at.unwrap();
+
+                        break;
+                    }
+                }
             }
+            anyhow::Ok(())
         }
-    }
+    });
+
+    zeta.update(cx, |zeta, cx| {
+        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+    })?
+    .await?;
 
-    refresh_task.await.context("context retrieval failed")?;
-    let prediction = prediction_task.unwrap().await?;
+    let prediction = zeta
+        .update(cx, |zeta, cx| {
+            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+        })?
+        .await?;
 
+    debug_task.await?;
+
+    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
     result.diff = prediction
         .map(|prediction| {
             let old_text = prediction.snapshot.text();
-            let new_text = prediction.buffer.update(cx, |buffer, cx| {
-                buffer.edit(prediction.edits.iter().cloned(), None, cx);
-                buffer.text()
-            })?;
-            anyhow::Ok(language::unified_diff(&old_text, &new_text))
+            let new_text = prediction
+                .buffer
+                .update(cx, |buffer, cx| {
+                    buffer.edit(prediction.edits.iter().cloned(), None, cx);
+                    buffer.text()
+                })
+                .unwrap();
+            language::unified_diff(&old_text, &new_text)
         })
-        .transpose()?
         .unwrap_or_default();
-    result.excerpts_text = excerpts_text;
-
-    result.planning_search_time =
-        search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
-    result.running_search_time =
-        search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap();
-    result.filtering_search_time =
-        context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
-    result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
-    result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap();
 
     anyhow::Ok(result)
 }
 
-#[derive(Debug, Default, Serialize, Deserialize)]
+#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 pub struct PredictionDetails {
     pub diff: String,
     pub excerpts: Vec<ActualExcerpt>,