From c241eadbc3fd0d4036db266210b344203a3886bf Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 7 Nov 2025 22:06:12 -0300 Subject: [PATCH] zeta2: Targeted retrieval search (#42240) Since we removed the filtering step during context gathering, we want the model to perform more targeted searches. This PR tweaks search tool schema allowing the model to search within syntax nodes such as `impl` blocks or methods. This is what the query schema looks like now: ```rust /// Search for relevant code by path, syntax hierarchy, and content. #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SearchToolQuery { /// 1. A glob pattern to match file paths in the codebase to search in. pub glob: String, /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy. /// /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes. /// /// Example: Searching for a `User` class /// ["class\s+User"] /// /// Example: Searching for a `get_full_name` method under a `User` class /// ["class\s+User", "def\sget_full_name"] /// /// Skip this field to match on content alone. #[schemars(length(max = 3))] #[serde(default)] pub syntax_node: Vec, /// 3. An optional regular expression to match the final content that should appear in the results. /// /// - Content will be matched within all lines of the matched syntax nodes. /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible. /// - If no syntax node regexes are provided, the content will be matched within the entire file. pub content: Option, } ``` We'll need to keep refining this, but the core implementation is ready. Release Notes: - N/A --------- Co-authored-by: Ben Co-authored-by: Max Co-authored-by: Max Brunsfeld --- Cargo.lock | 5 +- crates/agent/src/agent.rs | 1 - crates/agent/src/outline.rs | 9 +- crates/agent/src/thread.rs | 4 +- .../src/tools/context_server_registry.rs | 2 +- crates/cloud_zeta2_prompt/Cargo.toml | 1 - .../src/retrieval_prompt.rs | 64 +-- crates/language/src/buffer.rs | 12 + crates/language_model/Cargo.toml | 3 +- crates/language_model/src/language_model.rs | 11 +- .../src/tool_schema.rs | 12 +- crates/project/src/project.rs | 2 +- crates/zeta2/Cargo.toml | 1 + crates/zeta2/src/retrieval_search.rs | 483 +++++++++++++++--- crates/zeta2/src/udiff.rs | 28 +- crates/zeta2/src/zeta2.rs | 53 +- crates/zeta2_tools/Cargo.toml | 2 +- crates/zeta2_tools/src/zeta2_context_view.rs | 45 +- crates/zeta_cli/src/predict.rs | 191 +++---- 19 files changed, 657 insertions(+), 272 deletions(-) rename crates/{agent => language_model}/src/tool_schema.rs (95%) diff --git a/Cargo.lock b/Cargo.lock index 08ead914af69281a21c763b874bc57e8c84ac90d..faae1259d9d5c08559ec6ba02463367e84b3aa4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3202,7 +3202,6 @@ dependencies = [ "rustc-hash 2.1.1", "schemars 1.0.4", "serde", - "serde_json", "strum 0.27.2", ] @@ -8867,6 +8866,7 @@ dependencies = [ "open_router", "parking_lot", "proto", + "schemars 1.0.4", "serde", "serde_json", "settings", @@ -21685,6 +21685,7 @@ dependencies = [ "serde", "serde_json", "settings", + "smol", "thiserror 2.0.17", "util", "uuid", @@ -21702,6 +21703,7 @@ dependencies = [ "clap", "client", "cloud_llm_client", + "cloud_zeta2_prompt", "collections", "edit_prediction_context", "editor", @@ -21715,7 +21717,6 @@ dependencies = [ "ordered-float 2.10.1", "pretty_assertions", "project", - "regex-syntax", "serde", "serde_json", "settings", diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 0e9372373a65ac5fee9870cf58e2b0d9c11427d2..fc0b66f4073ea137f53b29286b0c17b53d11bf83 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -6,7 +6,6 @@ mod native_agent_server; pub mod outline; mod templates; mod thread; -mod tool_schema; mod tools; #[cfg(test)] diff --git a/crates/agent/src/outline.rs b/crates/agent/src/outline.rs index bc78290fb52ae208742b9dea0e6dbbe560022419..262fa8d3d139a5c8f5900d0dd55348f9dc716167 100644 --- a/crates/agent/src/outline.rs +++ b/crates/agent/src/outline.rs @@ -1,6 +1,6 @@ use anyhow::Result; use gpui::{AsyncApp, Entity}; -use language::{Buffer, OutlineItem, ParseStatus}; +use language::{Buffer, OutlineItem}; use regex::Regex; use std::fmt::Write; use text::Point; @@ -30,10 +30,9 @@ pub async fn get_buffer_content_or_outline( if file_size > AUTO_OUTLINE_SIZE { // For large files, use outline instead of full content // Wait until the buffer has been fully parsed, so we can read its outline - let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; - while *parse_status.borrow() != ParseStatus::Idle { - parse_status.changed().await?; - } + buffer + .read_with(cx, |buffer, _| buffer.parsing_idle())? + .await; let outline_items = buffer.read_with(cx, |buffer, _| { let snapshot = buffer.snapshot(); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 4c0fb00163744e66b5644a0fe76b1aa853fb8237..78f20152b4daf461de40cfa7746216092f82cf41 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2139,7 +2139,7 @@ where /// Returns the JSON schema that describes the tool's input. fn input_schema(format: LanguageModelToolSchemaFormat) -> Schema { - crate::tool_schema::root_schema_for::(format) + language_model::tool_schema::root_schema_for::(format) } /// Some tools rely on a provider for the underlying billing or other reasons. @@ -2226,7 +2226,7 @@ where fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { let mut json = serde_json::to_value(T::input_schema(format))?; - crate::tool_schema::adapt_schema_to_format(&mut json, format)?; + language_model::tool_schema::adapt_schema_to_format(&mut json, format)?; Ok(json) } diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 382d2ba9be74b4518de853037c858fd054366d5d..03a0ef84e73d4cbca83d61077d568ec58cd7ae2b 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -165,7 +165,7 @@ impl AnyAgentTool for ContextServerTool { format: language_model::LanguageModelToolSchemaFormat, ) -> Result { let mut schema = self.tool.input_schema.clone(); - crate::tool_schema::adapt_schema_to_format(&mut schema, format)?; + language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?; Ok(match schema { serde_json::Value::Null => { serde_json::json!({ "type": "object", "properties": [] }) diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml index fa8246950f8d03029388e0276954de946efc2346..8be10265cb23e7dd0983c52e7c2d6984b62c4be4 100644 --- a/crates/cloud_zeta2_prompt/Cargo.toml +++ b/crates/cloud_zeta2_prompt/Cargo.toml @@ -19,5 +19,4 @@ ordered-float.workspace = true rustc-hash.workspace = true schemars.workspace = true serde.workspace = true -serde_json.workspace = true strum.workspace = true diff --git a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs index 54ef1999729f6976bd77d280508f8c370d54488e..7fbc3834dfd0f4bbfc4085d696b7fbf755e6dd3d 100644 --- a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs @@ -3,7 +3,7 @@ use cloud_llm_client::predict_edits_v3::{self, Excerpt}; use indoc::indoc; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::{fmt::Write, sync::LazyLock}; +use std::fmt::Write; use crate::{push_events, write_codeblock}; @@ -15,7 +15,7 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R push_events(&mut prompt, &request.events); } - writeln!(&mut prompt, "## Excerpt around the cursor\n")?; + writeln!(&mut prompt, "## Cursor context")?; write_codeblock( &request.excerpt_path, &[Excerpt { @@ -39,54 +39,56 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R #[derive(Clone, Deserialize, Serialize, JsonSchema)] pub struct SearchToolInput { /// An array of queries to run for gathering context relevant to the next prediction - #[schemars(length(max = 5))] + #[schemars(length(max = 3))] pub queries: Box<[SearchToolQuery]>, } +/// Search for relevant code by path, syntax hierarchy, and content. #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SearchToolQuery { - /// A glob pattern to match file paths in the codebase + /// 1. A glob pattern to match file paths in the codebase to search in. pub glob: String, - /// A regular expression to match content within the files matched by the glob pattern - pub regex: String, + /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy. + /// + /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes. + /// + /// Example: Searching for a `User` class + /// ["class\s+User"] + /// + /// Example: Searching for a `get_full_name` method under a `User` class + /// ["class\s+User", "def\sget_full_name"] + /// + /// Skip this field to match on content alone. + #[schemars(length(max = 3))] + #[serde(default)] + pub syntax_node: Vec, + /// 3. An optional regular expression to match the final content that should appear in the results. + /// + /// - Content will be matched within all lines of the matched syntax nodes. + /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible. + /// - If no syntax node regexes are provided, the content will be matched within the entire file. + pub content: Option, } -pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| { - let schema = schemars::schema_for!(SearchToolInput); - - let description = schema - .get("description") - .and_then(|description| description.as_str()) - .unwrap() - .to_string(); - - (schema.into(), description) -}); - pub const TOOL_NAME: &str = "search"; const SEARCH_INSTRUCTIONS: &str = indoc! {r#" - ## Task + You are part of an edit prediction system in a code editor. + Your role is to search for code that will serve as context for predicting the next edit. - You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations - that will serve as context for predicting the next required edit. - - **Your task:** - Analyze the user's recent edits and current cursor context - - Use the `search` tool to find code that may be relevant for predicting the next edit + - Use the `search` tool to find code that is relevant for predicting the next edit - Focus on finding: - Code patterns that might need similar changes based on the recent edits - Functions, variables, types, and constants referenced in the current cursor context - Related implementations, usages, or dependencies that may require consistent updates - - **Important constraints:** - - This conversation has exactly 2 turns - - You must make ALL search queries in your first response via the `search` tool - - All queries will be executed in parallel and results returned together - - In the second turn, you will select the most relevant results via the `select` tool. + - How items defined in the cursor excerpt are used or altered + - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible + - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function + - Avoid using wildcard globs if you already know the file path of the content you're looking for "#}; const TOOL_USE_REMINDER: &str = indoc! {" -- - Use the `search` tool now + Analyze the user's intent in one to two sentences, then call the `search` tool. "}; diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 4dd90c15d9387327a75ece2d82385e406e5840d6..ea2405d04c32cba45963bc32747ee0b94292ffd9 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -1618,6 +1618,18 @@ impl Buffer { self.parse_status.1.clone() } + /// Wait until the buffer is no longer parsing + pub fn parsing_idle(&self) -> impl Future + use<> { + let mut parse_status = self.parse_status(); + async move { + while *parse_status.borrow() != ParseStatus::Idle { + if parse_status.changed().await.is_err() { + break; + } + } + } + } + /// Assign to the buffer a set of diagnostics created by a given language server. pub fn update_diagnostics( &mut self, diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index f572561f6a78b3cf2d9bfc2f7272895836f11614..4d40a063b604b405f7bcb29a3457956e1dd5541d 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -17,7 +17,6 @@ test-support = [] [dependencies] anthropic = { workspace = true, features = ["schemars"] } -open_router.workspace = true anyhow.workspace = true base64.workspace = true client.workspace = true @@ -30,8 +29,10 @@ http_client.workspace = true icons.workspace = true image.workspace = true log.workspace = true +open_router.workspace = true parking_lot.workspace = true proto.workspace = true +schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 24f9b84afcfa7b9a40b4a1b7684e9a9b036a5a85..94f6ec33f15062dd53b4122ca9d9dcac3fbff83d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -4,6 +4,7 @@ mod registry; mod request; mod role; mod telemetry; +pub mod tool_schema; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; @@ -35,6 +36,7 @@ pub use crate::registry::*; pub use crate::request::*; pub use crate::role::*; pub use crate::telemetry::*; +pub use crate::tool_schema::LanguageModelToolSchemaFormat; pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("anthropic"); @@ -409,15 +411,6 @@ impl From for LanguageModelCompletionError { } } -/// Indicates the format used to define the input schema for a language model tool. -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub enum LanguageModelToolSchemaFormat { - /// A JSON schema, see https://json-schema.org - JsonSchema, - /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema - JsonSchemaSubset, -} - #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum StopReason { diff --git a/crates/agent/src/tool_schema.rs b/crates/language_model/src/tool_schema.rs similarity index 95% rename from crates/agent/src/tool_schema.rs rename to crates/language_model/src/tool_schema.rs index 4b0de3e5c63fb0c5ccafbb89a22dad8a33072b35..f9402c28dc316f9ccdacc58afaa0eebd6699f92d 100644 --- a/crates/agent/src/tool_schema.rs +++ b/crates/language_model/src/tool_schema.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use language_model::LanguageModelToolSchemaFormat; use schemars::{ JsonSchema, Schema, generate::SchemaSettings, @@ -7,7 +6,16 @@ use schemars::{ }; use serde_json::Value; -pub(crate) fn root_schema_for(format: LanguageModelToolSchemaFormat) -> Schema { +/// Indicates the format used to define the input schema for a language model tool. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +pub enum LanguageModelToolSchemaFormat { + /// A JSON schema, see https://json-schema.org + JsonSchema, + /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema + JsonSchemaSubset, +} + +pub fn root_schema_for(format: LanguageModelToolSchemaFormat) -> Schema { let mut generator = match format { LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(), LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3() diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 9cc95da3b4c86daffd32af89d4f26509c97269fa..13ed42847d522c371226988d8ca133a1748d5fec 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -4060,7 +4060,7 @@ impl Project { result_rx } - fn find_search_candidate_buffers( + pub fn find_search_candidate_buffers( &mut self, query: &SearchQuery, limit: usize, diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index cbde212dd104bdc909dda19de403f815ff4f6386..3f394cd5ef2ab5d5bce05430a717312c9e3c0f5c 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -33,6 +33,7 @@ project.workspace = true release_channel.workspace = true serde.workspace = true serde_json.workspace = true +smol.workspace = true thiserror.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/zeta2/src/retrieval_search.rs b/crates/zeta2/src/retrieval_search.rs index e2e78c3e3b295549ca2c294818f935f1d7d8a9f9..f735f44cad9623711e5ed9a1293a74e34e084888 100644 --- a/crates/zeta2/src/retrieval_search.rs +++ b/crates/zeta2/src/retrieval_search.rs @@ -1,18 +1,19 @@ use std::ops::Range; use anyhow::Result; +use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use collections::HashMap; -use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions}; use futures::{ StreamExt, channel::mpsc::{self, UnboundedSender}, }; use gpui::{AppContext, AsyncApp, Entity}; -use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, ToPoint as _}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint}; use project::{ Project, WorktreeSettings, search::{SearchQuery, SearchResult}, }; +use smol::channel; use util::{ ResultExt as _, paths::{PathMatcher, PathStyle}, @@ -21,7 +22,7 @@ use workspace::item::Settings as _; pub async fn run_retrieval_searches( project: Entity, - regex_by_glob: HashMap, + queries: Vec, cx: &mut AsyncApp, ) -> Result, Vec>>> { let (exclude_matcher, path_style) = project.update(cx, |project, cx| { @@ -37,14 +38,13 @@ pub async fn run_retrieval_searches( let (results_tx, mut results_rx) = mpsc::unbounded(); - for (glob, regex) in regex_by_glob { + for query in queries { let exclude_matcher = exclude_matcher.clone(); let results_tx = results_tx.clone(); let project = project.clone(); cx.spawn(async move |cx| { run_query( - &glob, - ®ex, + query, results_tx.clone(), path_style, exclude_matcher, @@ -108,87 +108,442 @@ pub async fn run_retrieval_searches( .await } -const MIN_EXCERPT_LEN: usize = 16; const MAX_EXCERPT_LEN: usize = 768; const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5; +struct SearchJob { + buffer: Entity, + snapshot: BufferSnapshot, + ranges: Vec>, + query_ix: usize, + jobs_tx: channel::Sender, +} + async fn run_query( - glob: &str, - regex: &str, + input_query: SearchToolQuery, results_tx: UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, path_style: PathStyle, exclude_matcher: PathMatcher, project: &Entity, cx: &mut AsyncApp, ) -> Result<()> { - let include_matcher = PathMatcher::new(vec![glob], path_style)?; - - let query = SearchQuery::regex( - regex, - false, - true, - false, - true, - include_matcher, - exclude_matcher, - true, - None, - )?; - - let results = project.update(cx, |project, cx| project.search(query, cx))?; - futures::pin_mut!(results); - - while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { - if results_tx.is_closed() { - break; - } + let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?; - if ranges.is_empty() { - continue; - } + let make_search = |regex: &str| -> Result { + SearchQuery::regex( + regex, + false, + true, + false, + true, + include_matcher.clone(), + exclude_matcher.clone(), + true, + None, + ) + }; - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let results_tx = results_tx.clone(); + if let Some(outer_syntax_regex) = input_query.syntax_node.first() { + let outer_syntax_query = make_search(outer_syntax_regex)?; + let nested_syntax_queries = input_query + .syntax_node + .into_iter() + .skip(1) + .map(|query| make_search(&query)) + .collect::>>()?; + let content_query = input_query + .content + .map(|regex| make_search(®ex)) + .transpose()?; - cx.background_spawn(async move { - let mut excerpts = Vec::with_capacity(ranges.len()); - - for range in ranges { - let offset_range = range.to_offset(&snapshot); - let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot); - - let excerpt = EditPredictionExcerpt::select_from_buffer( - query_point, - &snapshot, - &EditPredictionExcerptOptions { - max_bytes: MAX_EXCERPT_LEN, - min_bytes: MIN_EXCERPT_LEN, - target_before_cursor_over_total_bytes: 0.5, - }, - None, - ); + let (jobs_tx, jobs_rx) = channel::unbounded(); - if let Some(excerpt) = excerpt - && !excerpt.line_range.is_empty() - { - excerpts.push(( - snapshot.anchor_after(excerpt.range.start) - ..snapshot.anchor_before(excerpt.range.end), - excerpt.range.len(), - )); - } + let outer_search_results_rx = + project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?; + + let outer_search_task = cx.spawn(async move |cx| { + futures::pin_mut!(outer_search_results_rx); + while let Some(SearchResult::Buffer { buffer, ranges }) = + outer_search_results_rx.next().await + { + buffer + .read_with(cx, |buffer, _| buffer.parsing_idle())? + .await; + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let expanded_ranges: Vec<_> = ranges + .into_iter() + .filter_map(|range| expand_to_parent_range(&range, &snapshot)) + .collect(); + jobs_tx + .send(SearchJob { + buffer, + snapshot, + ranges: expanded_ranges, + query_ix: 0, + jobs_tx: jobs_tx.clone(), + }) + .await?; } + anyhow::Ok(()) + }); - let send_result = results_tx.unbounded_send((buffer, snapshot, excerpts)); + let n_workers = cx.background_executor().num_cpus(); + let search_job_task = cx.background_executor().scoped(|scope| { + for _ in 0..n_workers { + scope.spawn(async { + while let Ok(job) = jobs_rx.recv().await { + process_nested_search_job( + &results_tx, + &nested_syntax_queries, + &content_query, + job, + ) + .await; + } + }); + } + }); + + search_job_task.await; + outer_search_task.await?; + } else if let Some(content_regex) = &input_query.content { + let search_query = make_search(&content_regex)?; + + let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?; + futures::pin_mut!(results_rx); + + while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let ranges = ranges + .into_iter() + .map(|range| { + let range = range.to_offset(&snapshot); + let range = expand_to_entire_lines(range, &snapshot); + let size = range.len(); + let range = + snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); + (range, size) + }) + .collect(); + + let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges)); if let Err(err) = send_result && !err.is_disconnected() { log::error!("{err}"); } - }) - .detach(); + } + } else { + log::warn!("Context gathering model produced a glob-only search"); } anyhow::Ok(()) } + +async fn process_nested_search_job( + results_tx: &UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, + queries: &Vec, + content_query: &Option, + job: SearchJob, +) { + if let Some(search_query) = queries.get(job.query_ix) { + let mut subranges = Vec::new(); + for range in job.ranges { + let start = range.start; + let search_results = search_query.search(&job.snapshot, Some(range)).await; + for subrange in search_results { + let subrange = start + subrange.start..start + subrange.end; + subranges.extend(expand_to_parent_range(&subrange, &job.snapshot)); + } + } + job.jobs_tx + .send(SearchJob { + buffer: job.buffer, + snapshot: job.snapshot, + ranges: subranges, + query_ix: job.query_ix + 1, + jobs_tx: job.jobs_tx.clone(), + }) + .await + .ok(); + } else { + let ranges = if let Some(content_query) = content_query { + let mut subranges = Vec::new(); + for range in job.ranges { + let start = range.start; + let search_results = content_query.search(&job.snapshot, Some(range)).await; + for subrange in search_results { + let subrange = start + subrange.start..start + subrange.end; + subranges.push(subrange); + } + } + subranges + } else { + job.ranges + }; + + let matches = ranges + .into_iter() + .map(|range| { + let snapshot = &job.snapshot; + let range = expand_to_entire_lines(range, snapshot); + let size = range.len(); + let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); + (range, size) + }) + .collect(); + + let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches)); + + if let Err(err) = send_result + && !err.is_disconnected() + { + log::error!("{err}"); + } + } +} + +fn expand_to_entire_lines(range: Range, snapshot: &BufferSnapshot) -> Range { + let mut point_range = range.to_point(snapshot); + point_range.start.column = 0; + if point_range.end.column > 0 { + point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0)); + } + point_range.to_offset(snapshot) +} + +fn expand_to_parent_range( + range: &Range, + snapshot: &BufferSnapshot, +) -> Option> { + let mut line_range = range.to_point(&snapshot); + line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len; + line_range.end.column = snapshot.line_len(line_range.end.row); + // TODO skip result if matched line isn't the first node line? + + let node = snapshot.syntax_ancestor(line_range)?; + Some(node.byte_range()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::merge_excerpts::merge_excerpts; + use cloud_zeta2_prompt::write_codeblock; + use edit_prediction_context::Line; + use gpui::TestAppContext; + use indoc::indoc; + use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; + use pretty_assertions::assert_eq; + use project::FakeFs; + use serde_json::json; + use settings::SettingsStore; + use std::path::Path; + use util::path; + + #[gpui::test] + async fn test_retrieval(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "user.rs": indoc!{" + pub struct Organization { + owner: Arc, + } + + pub struct User { + first_name: String, + last_name: String, + } + + impl Organization { + pub fn owner(&self) -> Arc { + self.owner.clone() + } + } + + impl User { + pub fn new(first_name: String, last_name: String) -> Self { + Self { + first_name, + last_name + } + } + + pub fn first_name(&self) -> String { + self.first_name.clone() + } + + pub fn last_name(&self) -> String { + self.last_name.clone() + } + } + "}, + "main.rs": indoc!{r#" + fn main() { + let user = User::new(FIRST_NAME.clone(), "doe".into()); + println!("user {:?}", user); + } + "#}, + }), + ) + .await; + + let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await; + project.update(cx, |project, _cx| { + project.languages().add(rust_lang().into()) + }); + + assert_results( + &project, + SearchToolQuery { + glob: "user.rs".into(), + syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()], + content: None, + }, + indoc! {r#" + `````root/user.rs + … + impl User { + … + pub fn first_name(&self) -> String { + self.first_name.clone() + } + … + ````` + "#}, + cx, + ) + .await; + + assert_results( + &project, + SearchToolQuery { + glob: "user.rs".into(), + syntax_node: vec!["impl\\s+User".into()], + content: Some("\\.clone".into()), + }, + indoc! {r#" + `````root/user.rs + … + impl User { + … + pub fn first_name(&self) -> String { + self.first_name.clone() + … + pub fn last_name(&self) -> String { + self.last_name.clone() + … + ````` + "#}, + cx, + ) + .await; + + assert_results( + &project, + SearchToolQuery { + glob: "*.rs".into(), + syntax_node: vec![], + content: Some("\\.clone".into()), + }, + indoc! {r#" + `````root/main.rs + fn main() { + let user = User::new(FIRST_NAME.clone(), "doe".into()); + … + ````` + + `````root/user.rs + … + impl Organization { + pub fn owner(&self) -> Arc { + self.owner.clone() + … + impl User { + … + pub fn first_name(&self) -> String { + self.first_name.clone() + … + pub fn last_name(&self) -> String { + self.last_name.clone() + … + ````` + "#}, + cx, + ) + .await; + } + + async fn assert_results( + project: &Entity, + query: SearchToolQuery, + expected_output: &str, + cx: &mut TestAppContext, + ) { + let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async()) + .await + .unwrap(); + + let mut results = results.into_iter().collect::>(); + results.sort_by_key(|results| { + results + .0 + .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone()) + }); + + let mut output = String::new(); + for (buffer, ranges) in results { + buffer.read_with(cx, |buffer, cx| { + let excerpts = ranges.into_iter().map(|range| { + let point_range = range.to_point(buffer); + if point_range.end.column > 0 { + Line(point_range.start.row)..Line(point_range.end.row + 1) + } else { + Line(point_range.start.row)..Line(point_range.end.row) + } + }); + + write_codeblock( + &buffer.file().unwrap().full_path(cx), + merge_excerpts(&buffer.snapshot(), excerpts).iter(), + &[], + Line(buffer.max_point().row), + false, + &mut output, + ); + }); + } + output.pop(); + + assert_eq!(output, expected_output); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(move |cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + zlog::init_test(); + }); + } +} diff --git a/crates/zeta2/src/udiff.rs b/crates/zeta2/src/udiff.rs index b30eb22741a1e701e2e744445f2a01c1f0ed0d03..d765a64345f839b9314632444d209fa79e9ca5ce 100644 --- a/crates/zeta2/src/udiff.rs +++ b/crates/zeta2/src/udiff.rs @@ -18,10 +18,10 @@ use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSn use project::Project; pub async fn parse_diff<'a>( - diff: &'a str, + diff_str: &'a str, get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send, ) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> { - let mut diff = DiffParser::new(diff); + let mut diff = DiffParser::new(diff_str); let mut edited_buffer = None; let mut edits = Vec::new(); @@ -41,7 +41,10 @@ pub async fn parse_diff<'a>( Some(ref current) => current, }; - edits.extend(resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)?); + edits.extend( + resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges) + .with_context(|| format!("Diff:\n{diff_str}"))?, + ); } DiffEvent::FileEnd { renamed_to } => { let (buffer, _) = edited_buffer @@ -69,13 +72,13 @@ pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap, Entity( - diff: &'a str, + diff_str: &'a str, project: &Entity, cx: &mut AsyncApp, ) -> Result> { let mut included_files = HashMap::default(); - for line in diff.lines() { + for line in diff_str.lines() { let diff_line = DiffLine::parse(line); if let DiffLine::OldPath { path } = diff_line { @@ -97,7 +100,7 @@ pub async fn apply_diff<'a>( let ranges = [Anchor::MIN..Anchor::MAX]; - let mut diff = DiffParser::new(diff); + let mut diff = DiffParser::new(diff_str); let mut current_file = None; let mut edits = vec![]; @@ -120,7 +123,10 @@ pub async fn apply_diff<'a>( }; buffer.read_with(cx, |buffer, _| { - edits.extend(resolve_hunk_edits_in_buffer(hunk, buffer, ranges)?); + edits.extend( + resolve_hunk_edits_in_buffer(hunk, buffer, ranges) + .with_context(|| format!("Diff:\n{diff_str}"))?, + ); anyhow::Ok(()) })??; } @@ -328,13 +334,7 @@ fn resolve_hunk_edits_in_buffer( offset = Some(range.start + ix); } } - offset.ok_or_else(|| { - anyhow!( - "Failed to match context:\n{}\n\nBuffer:\n{}", - hunk.context, - buffer.text(), - ) - }) + offset.ok_or_else(|| anyhow!("Failed to match context:\n{}", hunk.context)) }?; let iter = hunk.edits.into_iter().flat_map(move |edit| { let old_text = buffer diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 503964c88f18562dbf10197bfc330ffe49add8d5..ff0ff4f1ba2af59f32cddee96e4b9c0dd25af22d 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -7,7 +7,7 @@ use cloud_llm_client::{ ZED_VERSION_HEADER_NAME, }; use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; -use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput; +use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; use collections::HashMap; use edit_prediction_context::{ DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, @@ -35,7 +35,7 @@ use uuid::Uuid; use std::ops::Range; use std::path::Path; use std::str::FromStr as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant}; use thiserror::Error; use util::rel_path::RelPathBuf; @@ -88,6 +88,9 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { buffer_change_grouping_interval: Duration::from_secs(1), }; +static MODEL_ID: LazyLock = + LazyLock::new(|| std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string())); + pub struct Zeta2FeatureFlag; impl FeatureFlag for Zeta2FeatureFlag { @@ -180,7 +183,7 @@ pub struct ZetaEditPredictionDebugInfo { pub struct ZetaSearchQueryDebugInfo { pub project: Entity, pub timestamp: Instant, - pub regex_by_glob: HashMap, + pub search_queries: Vec, } pub type RequestDebugInfo = predict_edits_v3::DebugInfo; @@ -883,7 +886,7 @@ impl Zeta { let (prompt, _) = prompt_result?; let request = open_ai::Request { - model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()), + model: MODEL_ID.clone(), messages: vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt), }], @@ -1226,10 +1229,24 @@ impl Zeta { .ok(); } - let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA; + pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| { + let schema = language_model::tool_schema::root_schema_for::( + language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset, + ); + + let description = schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap() + .to_string(); + + (schema.into(), description) + }); + + let (tool_schema, tool_description) = TOOL_SCHEMA.clone(); let request = open_ai::Request { - model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()), + model: MODEL_ID.clone(), messages: vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt), }], @@ -1242,8 +1259,8 @@ impl Zeta { tools: vec![open_ai::ToolDefinition::Function { function: FunctionDefinition { name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(), - description: Some(tool_description.clone()), - parameters: Some(tool_schema.clone()), + description: Some(tool_description), + parameters: Some(tool_schema), }, }], prompt_cache_key: None, @@ -1255,7 +1272,6 @@ impl Zeta { let response = Self::send_raw_llm_request(client, llm_token, app_version, request).await; let mut response = Self::handle_api_response(&this, response, cx)?; - log::trace!("Got search planning response"); let choice = response @@ -1270,7 +1286,7 @@ impl Zeta { anyhow::bail!("Retrieval response didn't include an assistant message"); }; - let mut regex_by_glob: HashMap = HashMap::default(); + let mut queries: Vec = Vec::new(); for tool_call in tool_calls { let open_ai::ToolCallContent::Function { function } = tool_call.content; if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME { @@ -1283,13 +1299,7 @@ impl Zeta { } let input: SearchToolInput = serde_json::from_str(&function.arguments)?; - for query in input.queries { - let regex = regex_by_glob.entry(query.glob).or_default(); - if !regex.is_empty() { - regex.push('|'); - } - regex.push_str(&query.regex); - } + queries.extend(input.queries); } if let Some(debug_tx) = &debug_tx { @@ -1298,16 +1308,16 @@ impl Zeta { ZetaSearchQueryDebugInfo { project: project.clone(), timestamp: Instant::now(), - regex_by_glob: regex_by_glob.clone(), + search_queries: queries.clone(), }, )) .ok(); } - log::trace!("Running retrieval search: {regex_by_glob:#?}"); + log::trace!("Running retrieval search: {queries:#?}"); let related_excerpts_result = - retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await; + retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await; log::trace!("Search queries executed"); @@ -1754,7 +1764,8 @@ mod tests { arguments: serde_json::to_string(&SearchToolInput { queries: Box::new([SearchToolQuery { glob: "root/2.txt".to_string(), - regex: ".".to_string(), + syntax_node: vec![], + content: Some(".".into()), }]), }) .unwrap(), diff --git a/crates/zeta2_tools/Cargo.toml b/crates/zeta2_tools/Cargo.toml index 89d0ce8338624906d2262a7d8314700f6720cff1..3a9b1ccbf9340dfdaa06030e59c2112b9cda6307 100644 --- a/crates/zeta2_tools/Cargo.toml +++ b/crates/zeta2_tools/Cargo.toml @@ -16,6 +16,7 @@ anyhow.workspace = true chrono.workspace = true client.workspace = true cloud_llm_client.workspace = true +cloud_zeta2_prompt.workspace = true collections.workspace = true edit_prediction_context.workspace = true editor.workspace = true @@ -27,7 +28,6 @@ log.workspace = true multi_buffer.workspace = true ordered-float.workspace = true project.workspace = true -regex-syntax = "0.8.8" serde.workspace = true serde_json.workspace = true telemetry.workspace = true diff --git a/crates/zeta2_tools/src/zeta2_context_view.rs b/crates/zeta2_tools/src/zeta2_context_view.rs index 685029cc4a2993227725c17e283c660da5c1d5e1..1826bd22df6d08ce717ef9bdf0070f88ad63c433 100644 --- a/crates/zeta2_tools/src/zeta2_context_view.rs +++ b/crates/zeta2_tools/src/zeta2_context_view.rs @@ -8,6 +8,7 @@ use std::{ use anyhow::Result; use client::{Client, UserStore}; +use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use editor::{Editor, PathKey}; use futures::StreamExt as _; use gpui::{ @@ -41,19 +42,13 @@ pub struct Zeta2ContextView { #[derive(Debug)] struct RetrievalRun { editor: Entity, - search_queries: Vec, + search_queries: Vec, started_at: Instant, search_results_generated_at: Option, search_results_executed_at: Option, finished_at: Option, } -#[derive(Debug)] -struct GlobQueries { - glob: String, - alternations: Vec, -} - actions!( dev, [ @@ -210,23 +205,7 @@ impl Zeta2ContextView { }; run.search_results_generated_at = Some(info.timestamp); - run.search_queries = info - .regex_by_glob - .into_iter() - .map(|(glob, regex)| { - let mut regex_parser = regex_syntax::ast::parse::Parser::new(); - - GlobQueries { - glob, - alternations: match regex_parser.parse(®ex) { - Ok(regex_syntax::ast::Ast::Alternation(ref alt)) => { - alt.asts.iter().map(|ast| ast.to_string()).collect() - } - _ => vec![regex], - }, - } - }) - .collect(); + run.search_queries = info.search_queries; cx.notify(); } @@ -292,18 +271,28 @@ impl Zeta2ContextView { .enumerate() .flat_map(|(ix, query)| { std::iter::once(ListHeader::new(query.glob.clone()).into_any_element()) - .chain(query.alternations.iter().enumerate().map( - move |(alt_ix, alt)| { - ListItem::new(ix * 100 + alt_ix) + .chain(query.syntax_node.iter().enumerate().map( + move |(regex_ix, regex)| { + ListItem::new(ix * 100 + regex_ix) .start_slot( Icon::new(IconName::MagnifyingGlass) .color(Color::Muted) .size(IconSize::Small), ) - .child(alt.clone()) + .child(regex.clone()) .into_any_element() }, )) + .chain(query.content.as_ref().map(move |regex| { + ListItem::new(ix * 100 + query.syntax_node.len()) + .start_slot( + Icon::new(IconName::MagnifyingGlass) + .color(Color::Muted) + .size(IconSize::Small), + ) + .child(regex.clone()) + .into_any_element() + })) }), ), ) diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 1bc2411c825a2fa7147ff0c657804908b687d9ff..a593a1b12ceb2b72a316463076657f35ac2c4e9d 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -2,11 +2,11 @@ use crate::example::{ActualExcerpt, NamedExample}; use crate::headless::ZetaCliAppState; use crate::paths::LOGS_DIR; use ::serde::Serialize; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use clap::Args; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; use futures::StreamExt as _; -use gpui::AsyncApp; +use gpui::{AppContext, AsyncApp}; use project::Project; use serde::Deserialize; use std::cell::Cell; @@ -14,6 +14,7 @@ use std::fs; use std::io::Write; use std::path::PathBuf; use std::sync::Arc; +use std::sync::Mutex; use std::time::{Duration, Instant}; #[derive(Debug, Args)] @@ -103,112 +104,126 @@ pub async fn zeta2_predict( let _edited_buffers = example.apply_edit_history(&project, cx).await?; let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?; + let result = Arc::new(Mutex::new(PredictionDetails::default())); let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; - let refresh_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) - })?; + let debug_task = cx.background_spawn({ + let result = result.clone(); + async move { + let mut context_retrieval_started_at = None; + let mut context_retrieval_finished_at = None; + let mut search_queries_generated_at = None; + let mut search_queries_executed_at = None; + while let Some(event) = debug_rx.next().await { + match event { + zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + context_retrieval_started_at = Some(info.timestamp); + fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?; + } + zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + search_queries_generated_at = Some(info.timestamp); + fs::write( + LOGS_DIR.join("search_queries.json"), + serde_json::to_string_pretty(&info.search_queries).unwrap(), + )?; + } + zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { + search_queries_executed_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { + context_retrieval_finished_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { + let prediction_started_at = Instant::now(); + fs::write( + LOGS_DIR.join("prediction_prompt.md"), + &request.local_prompt.unwrap_or_default(), + )?; - let mut context_retrieval_started_at = None; - let mut context_retrieval_finished_at = None; - let mut search_queries_generated_at = None; - let mut search_queries_executed_at = None; - let mut prediction_started_at = None; - let mut prediction_finished_at = None; - let mut excerpts_text = String::new(); - let mut prediction_task = None; - let mut result = PredictionDetails::default(); - while let Some(event) = debug_rx.next().await { - match event { - zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { - context_retrieval_started_at = Some(info.timestamp); - fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?; - } - zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { - search_queries_generated_at = Some(info.timestamp); - fs::write( - LOGS_DIR.join("search_queries.json"), - serde_json::to_string_pretty(&info.regex_by_glob).unwrap(), - )?; - } - zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { - search_queries_executed_at = Some(info.timestamp); - } - zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { - context_retrieval_finished_at = Some(info.timestamp); + { + let mut result = result.lock().unwrap(); - prediction_task = Some(zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) - })?); - } - zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { - prediction_started_at = Some(Instant::now()); - fs::write( - LOGS_DIR.join("prediction_prompt.md"), - &request.local_prompt.unwrap_or_default(), - )?; + for included_file in request.request.included_files { + let insertions = + vec![(request.request.cursor_point, CURSOR_MARKER)]; + result.excerpts.extend(included_file.excerpts.iter().map( + |excerpt| ActualExcerpt { + path: included_file.path.components().skip(1).collect(), + text: String::from(excerpt.text.as_ref()), + }, + )); + write_codeblock( + &included_file.path, + included_file.excerpts.iter(), + if included_file.path == request.request.excerpt_path { + &insertions + } else { + &[] + }, + included_file.max_row, + false, + &mut result.excerpts_text, + ); + } + } - for included_file in request.request.included_files { - let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)]; - result - .excerpts - .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt { - path: included_file.path.components().skip(1).collect(), - text: String::from(excerpt.text.as_ref()), - })); - write_codeblock( - &included_file.path, - included_file.excerpts.iter(), - if included_file.path == request.request.excerpt_path { - &insertions - } else { - &[] - }, - included_file.max_row, - false, - &mut excerpts_text, - ); - } + let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; + let response = zeta2::text_from_response(response).unwrap_or_default(); + let prediction_finished_at = Instant::now(); + fs::write(LOGS_DIR.join("prediction_response.md"), &response)?; - let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = zeta2::text_from_response(response).unwrap_or_default(); - prediction_finished_at = Some(Instant::now()); - fs::write(LOGS_DIR.join("prediction_response.md"), &response)?; + let mut result = result.lock().unwrap(); - break; + result.planning_search_time = search_queries_generated_at.unwrap() + - context_retrieval_started_at.unwrap(); + result.running_search_time = search_queries_executed_at.unwrap() + - search_queries_generated_at.unwrap(); + result.filtering_search_time = context_retrieval_finished_at.unwrap() + - search_queries_executed_at.unwrap(); + result.prediction_time = prediction_finished_at - prediction_started_at; + result.total_time = + prediction_finished_at - context_retrieval_started_at.unwrap(); + + break; + } + } } + anyhow::Ok(()) } - } + }); + + zeta.update(cx, |zeta, cx| { + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + })? + .await?; - refresh_task.await.context("context retrieval failed")?; - let prediction = prediction_task.unwrap().await?; + let prediction = zeta + .update(cx, |zeta, cx| { + zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) + })? + .await?; + debug_task.await?; + + let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap(); result.diff = prediction .map(|prediction| { let old_text = prediction.snapshot.text(); - let new_text = prediction.buffer.update(cx, |buffer, cx| { - buffer.edit(prediction.edits.iter().cloned(), None, cx); - buffer.text() - })?; - anyhow::Ok(language::unified_diff(&old_text, &new_text)) + let new_text = prediction + .buffer + .update(cx, |buffer, cx| { + buffer.edit(prediction.edits.iter().cloned(), None, cx); + buffer.text() + }) + .unwrap(); + language::unified_diff(&old_text, &new_text) }) - .transpose()? .unwrap_or_default(); - result.excerpts_text = excerpts_text; - - result.planning_search_time = - search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap(); - result.running_search_time = - search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap(); - result.filtering_search_time = - context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap(); - result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap(); - result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap(); anyhow::Ok(result) } -#[derive(Debug, Default, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct PredictionDetails { pub diff: String, pub excerpts: Vec,