diff --git a/Cargo.lock b/Cargo.lock index 08ead914af69281a21c763b874bc57e8c84ac90d..faae1259d9d5c08559ec6ba02463367e84b3aa4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3202,7 +3202,6 @@ dependencies = [ "rustc-hash 2.1.1", "schemars 1.0.4", "serde", - "serde_json", "strum 0.27.2", ] @@ -8867,6 +8866,7 @@ dependencies = [ "open_router", "parking_lot", "proto", + "schemars 1.0.4", "serde", "serde_json", "settings", @@ -21685,6 +21685,7 @@ dependencies = [ "serde", "serde_json", "settings", + "smol", "thiserror 2.0.17", "util", "uuid", @@ -21702,6 +21703,7 @@ dependencies = [ "clap", "client", "cloud_llm_client", + "cloud_zeta2_prompt", "collections", "edit_prediction_context", "editor", @@ -21715,7 +21717,6 @@ dependencies = [ "ordered-float 2.10.1", "pretty_assertions", "project", - "regex-syntax", "serde", "serde_json", "settings", diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 0e9372373a65ac5fee9870cf58e2b0d9c11427d2..fc0b66f4073ea137f53b29286b0c17b53d11bf83 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -6,7 +6,6 @@ mod native_agent_server; pub mod outline; mod templates; mod thread; -mod tool_schema; mod tools; #[cfg(test)] diff --git a/crates/agent/src/outline.rs b/crates/agent/src/outline.rs index bc78290fb52ae208742b9dea0e6dbbe560022419..262fa8d3d139a5c8f5900d0dd55348f9dc716167 100644 --- a/crates/agent/src/outline.rs +++ b/crates/agent/src/outline.rs @@ -1,6 +1,6 @@ use anyhow::Result; use gpui::{AsyncApp, Entity}; -use language::{Buffer, OutlineItem, ParseStatus}; +use language::{Buffer, OutlineItem}; use regex::Regex; use std::fmt::Write; use text::Point; @@ -30,10 +30,9 @@ pub async fn get_buffer_content_or_outline( if file_size > AUTO_OUTLINE_SIZE { // For large files, use outline instead of full content // Wait until the buffer has been fully parsed, so we can read its outline - let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; - while *parse_status.borrow() != ParseStatus::Idle { - parse_status.changed().await?; - } + buffer + .read_with(cx, |buffer, _| buffer.parsing_idle())? + .await; let outline_items = buffer.read_with(cx, |buffer, _| { let snapshot = buffer.snapshot(); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 4c0fb00163744e66b5644a0fe76b1aa853fb8237..78f20152b4daf461de40cfa7746216092f82cf41 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2139,7 +2139,7 @@ where /// Returns the JSON schema that describes the tool's input. fn input_schema(format: LanguageModelToolSchemaFormat) -> Schema { - crate::tool_schema::root_schema_for::(format) + language_model::tool_schema::root_schema_for::(format) } /// Some tools rely on a provider for the underlying billing or other reasons. @@ -2226,7 +2226,7 @@ where fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { let mut json = serde_json::to_value(T::input_schema(format))?; - crate::tool_schema::adapt_schema_to_format(&mut json, format)?; + language_model::tool_schema::adapt_schema_to_format(&mut json, format)?; Ok(json) } diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 382d2ba9be74b4518de853037c858fd054366d5d..03a0ef84e73d4cbca83d61077d568ec58cd7ae2b 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -165,7 +165,7 @@ impl AnyAgentTool for ContextServerTool { format: language_model::LanguageModelToolSchemaFormat, ) -> Result { let mut schema = self.tool.input_schema.clone(); - crate::tool_schema::adapt_schema_to_format(&mut schema, format)?; + language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?; Ok(match schema { serde_json::Value::Null => { serde_json::json!({ "type": "object", "properties": [] }) diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml index fa8246950f8d03029388e0276954de946efc2346..8be10265cb23e7dd0983c52e7c2d6984b62c4be4 100644 --- a/crates/cloud_zeta2_prompt/Cargo.toml +++ b/crates/cloud_zeta2_prompt/Cargo.toml @@ -19,5 +19,4 @@ ordered-float.workspace = true rustc-hash.workspace = true schemars.workspace = true serde.workspace = true -serde_json.workspace = true strum.workspace = true diff --git a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs index 54ef1999729f6976bd77d280508f8c370d54488e..7fbc3834dfd0f4bbfc4085d696b7fbf755e6dd3d 100644 --- a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs @@ -3,7 +3,7 @@ use cloud_llm_client::predict_edits_v3::{self, Excerpt}; use indoc::indoc; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::{fmt::Write, sync::LazyLock}; +use std::fmt::Write; use crate::{push_events, write_codeblock}; @@ -15,7 +15,7 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R push_events(&mut prompt, &request.events); } - writeln!(&mut prompt, "## Excerpt around the cursor\n")?; + writeln!(&mut prompt, "## Cursor context")?; write_codeblock( &request.excerpt_path, &[Excerpt { @@ -39,54 +39,56 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R #[derive(Clone, Deserialize, Serialize, JsonSchema)] pub struct SearchToolInput { /// An array of queries to run for gathering context relevant to the next prediction - #[schemars(length(max = 5))] + #[schemars(length(max = 3))] pub queries: Box<[SearchToolQuery]>, } +/// Search for relevant code by path, syntax hierarchy, and content. #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SearchToolQuery { - /// A glob pattern to match file paths in the codebase + /// 1. A glob pattern to match file paths in the codebase to search in. pub glob: String, - /// A regular expression to match content within the files matched by the glob pattern - pub regex: String, + /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy. + /// + /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes. + /// + /// Example: Searching for a `User` class + /// ["class\s+User"] + /// + /// Example: Searching for a `get_full_name` method under a `User` class + /// ["class\s+User", "def\sget_full_name"] + /// + /// Skip this field to match on content alone. + #[schemars(length(max = 3))] + #[serde(default)] + pub syntax_node: Vec, + /// 3. An optional regular expression to match the final content that should appear in the results. + /// + /// - Content will be matched within all lines of the matched syntax nodes. + /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible. + /// - If no syntax node regexes are provided, the content will be matched within the entire file. + pub content: Option, } -pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| { - let schema = schemars::schema_for!(SearchToolInput); - - let description = schema - .get("description") - .and_then(|description| description.as_str()) - .unwrap() - .to_string(); - - (schema.into(), description) -}); - pub const TOOL_NAME: &str = "search"; const SEARCH_INSTRUCTIONS: &str = indoc! {r#" - ## Task + You are part of an edit prediction system in a code editor. + Your role is to search for code that will serve as context for predicting the next edit. - You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations - that will serve as context for predicting the next required edit. - - **Your task:** - Analyze the user's recent edits and current cursor context - - Use the `search` tool to find code that may be relevant for predicting the next edit + - Use the `search` tool to find code that is relevant for predicting the next edit - Focus on finding: - Code patterns that might need similar changes based on the recent edits - Functions, variables, types, and constants referenced in the current cursor context - Related implementations, usages, or dependencies that may require consistent updates - - **Important constraints:** - - This conversation has exactly 2 turns - - You must make ALL search queries in your first response via the `search` tool - - All queries will be executed in parallel and results returned together - - In the second turn, you will select the most relevant results via the `select` tool. + - How items defined in the cursor excerpt are used or altered + - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible + - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function + - Avoid using wildcard globs if you already know the file path of the content you're looking for "#}; const TOOL_USE_REMINDER: &str = indoc! {" -- - Use the `search` tool now + Analyze the user's intent in one to two sentences, then call the `search` tool. "}; diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 4dd90c15d9387327a75ece2d82385e406e5840d6..ea2405d04c32cba45963bc32747ee0b94292ffd9 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -1618,6 +1618,18 @@ impl Buffer { self.parse_status.1.clone() } + /// Wait until the buffer is no longer parsing + pub fn parsing_idle(&self) -> impl Future + use<> { + let mut parse_status = self.parse_status(); + async move { + while *parse_status.borrow() != ParseStatus::Idle { + if parse_status.changed().await.is_err() { + break; + } + } + } + } + /// Assign to the buffer a set of diagnostics created by a given language server. pub fn update_diagnostics( &mut self, diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index f572561f6a78b3cf2d9bfc2f7272895836f11614..4d40a063b604b405f7bcb29a3457956e1dd5541d 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -17,7 +17,6 @@ test-support = [] [dependencies] anthropic = { workspace = true, features = ["schemars"] } -open_router.workspace = true anyhow.workspace = true base64.workspace = true client.workspace = true @@ -30,8 +29,10 @@ http_client.workspace = true icons.workspace = true image.workspace = true log.workspace = true +open_router.workspace = true parking_lot.workspace = true proto.workspace = true +schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 24f9b84afcfa7b9a40b4a1b7684e9a9b036a5a85..94f6ec33f15062dd53b4122ca9d9dcac3fbff83d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -4,6 +4,7 @@ mod registry; mod request; mod role; mod telemetry; +pub mod tool_schema; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; @@ -35,6 +36,7 @@ pub use crate::registry::*; pub use crate::request::*; pub use crate::role::*; pub use crate::telemetry::*; +pub use crate::tool_schema::LanguageModelToolSchemaFormat; pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("anthropic"); @@ -409,15 +411,6 @@ impl From for LanguageModelCompletionError { } } -/// Indicates the format used to define the input schema for a language model tool. -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub enum LanguageModelToolSchemaFormat { - /// A JSON schema, see https://json-schema.org - JsonSchema, - /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema - JsonSchemaSubset, -} - #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum StopReason { diff --git a/crates/agent/src/tool_schema.rs b/crates/language_model/src/tool_schema.rs similarity index 95% rename from crates/agent/src/tool_schema.rs rename to crates/language_model/src/tool_schema.rs index 4b0de3e5c63fb0c5ccafbb89a22dad8a33072b35..f9402c28dc316f9ccdacc58afaa0eebd6699f92d 100644 --- a/crates/agent/src/tool_schema.rs +++ b/crates/language_model/src/tool_schema.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use language_model::LanguageModelToolSchemaFormat; use schemars::{ JsonSchema, Schema, generate::SchemaSettings, @@ -7,7 +6,16 @@ use schemars::{ }; use serde_json::Value; -pub(crate) fn root_schema_for(format: LanguageModelToolSchemaFormat) -> Schema { +/// Indicates the format used to define the input schema for a language model tool. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +pub enum LanguageModelToolSchemaFormat { + /// A JSON schema, see https://json-schema.org + JsonSchema, + /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema + JsonSchemaSubset, +} + +pub fn root_schema_for(format: LanguageModelToolSchemaFormat) -> Schema { let mut generator = match format { LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(), LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3() diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 9cc95da3b4c86daffd32af89d4f26509c97269fa..13ed42847d522c371226988d8ca133a1748d5fec 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -4060,7 +4060,7 @@ impl Project { result_rx } - fn find_search_candidate_buffers( + pub fn find_search_candidate_buffers( &mut self, query: &SearchQuery, limit: usize, diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index cbde212dd104bdc909dda19de403f815ff4f6386..3f394cd5ef2ab5d5bce05430a717312c9e3c0f5c 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -33,6 +33,7 @@ project.workspace = true release_channel.workspace = true serde.workspace = true serde_json.workspace = true +smol.workspace = true thiserror.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/zeta2/src/retrieval_search.rs b/crates/zeta2/src/retrieval_search.rs index e2e78c3e3b295549ca2c294818f935f1d7d8a9f9..f735f44cad9623711e5ed9a1293a74e34e084888 100644 --- a/crates/zeta2/src/retrieval_search.rs +++ b/crates/zeta2/src/retrieval_search.rs @@ -1,18 +1,19 @@ use std::ops::Range; use anyhow::Result; +use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use collections::HashMap; -use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions}; use futures::{ StreamExt, channel::mpsc::{self, UnboundedSender}, }; use gpui::{AppContext, AsyncApp, Entity}; -use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, ToPoint as _}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint}; use project::{ Project, WorktreeSettings, search::{SearchQuery, SearchResult}, }; +use smol::channel; use util::{ ResultExt as _, paths::{PathMatcher, PathStyle}, @@ -21,7 +22,7 @@ use workspace::item::Settings as _; pub async fn run_retrieval_searches( project: Entity, - regex_by_glob: HashMap, + queries: Vec, cx: &mut AsyncApp, ) -> Result, Vec>>> { let (exclude_matcher, path_style) = project.update(cx, |project, cx| { @@ -37,14 +38,13 @@ pub async fn run_retrieval_searches( let (results_tx, mut results_rx) = mpsc::unbounded(); - for (glob, regex) in regex_by_glob { + for query in queries { let exclude_matcher = exclude_matcher.clone(); let results_tx = results_tx.clone(); let project = project.clone(); cx.spawn(async move |cx| { run_query( - &glob, - ®ex, + query, results_tx.clone(), path_style, exclude_matcher, @@ -108,87 +108,442 @@ pub async fn run_retrieval_searches( .await } -const MIN_EXCERPT_LEN: usize = 16; const MAX_EXCERPT_LEN: usize = 768; const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5; +struct SearchJob { + buffer: Entity, + snapshot: BufferSnapshot, + ranges: Vec>, + query_ix: usize, + jobs_tx: channel::Sender, +} + async fn run_query( - glob: &str, - regex: &str, + input_query: SearchToolQuery, results_tx: UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, path_style: PathStyle, exclude_matcher: PathMatcher, project: &Entity, cx: &mut AsyncApp, ) -> Result<()> { - let include_matcher = PathMatcher::new(vec![glob], path_style)?; - - let query = SearchQuery::regex( - regex, - false, - true, - false, - true, - include_matcher, - exclude_matcher, - true, - None, - )?; - - let results = project.update(cx, |project, cx| project.search(query, cx))?; - futures::pin_mut!(results); - - while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { - if results_tx.is_closed() { - break; - } + let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?; - if ranges.is_empty() { - continue; - } + let make_search = |regex: &str| -> Result { + SearchQuery::regex( + regex, + false, + true, + false, + true, + include_matcher.clone(), + exclude_matcher.clone(), + true, + None, + ) + }; - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let results_tx = results_tx.clone(); + if let Some(outer_syntax_regex) = input_query.syntax_node.first() { + let outer_syntax_query = make_search(outer_syntax_regex)?; + let nested_syntax_queries = input_query + .syntax_node + .into_iter() + .skip(1) + .map(|query| make_search(&query)) + .collect::>>()?; + let content_query = input_query + .content + .map(|regex| make_search(®ex)) + .transpose()?; - cx.background_spawn(async move { - let mut excerpts = Vec::with_capacity(ranges.len()); - - for range in ranges { - let offset_range = range.to_offset(&snapshot); - let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot); - - let excerpt = EditPredictionExcerpt::select_from_buffer( - query_point, - &snapshot, - &EditPredictionExcerptOptions { - max_bytes: MAX_EXCERPT_LEN, - min_bytes: MIN_EXCERPT_LEN, - target_before_cursor_over_total_bytes: 0.5, - }, - None, - ); + let (jobs_tx, jobs_rx) = channel::unbounded(); - if let Some(excerpt) = excerpt - && !excerpt.line_range.is_empty() - { - excerpts.push(( - snapshot.anchor_after(excerpt.range.start) - ..snapshot.anchor_before(excerpt.range.end), - excerpt.range.len(), - )); - } + let outer_search_results_rx = + project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?; + + let outer_search_task = cx.spawn(async move |cx| { + futures::pin_mut!(outer_search_results_rx); + while let Some(SearchResult::Buffer { buffer, ranges }) = + outer_search_results_rx.next().await + { + buffer + .read_with(cx, |buffer, _| buffer.parsing_idle())? + .await; + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let expanded_ranges: Vec<_> = ranges + .into_iter() + .filter_map(|range| expand_to_parent_range(&range, &snapshot)) + .collect(); + jobs_tx + .send(SearchJob { + buffer, + snapshot, + ranges: expanded_ranges, + query_ix: 0, + jobs_tx: jobs_tx.clone(), + }) + .await?; } + anyhow::Ok(()) + }); - let send_result = results_tx.unbounded_send((buffer, snapshot, excerpts)); + let n_workers = cx.background_executor().num_cpus(); + let search_job_task = cx.background_executor().scoped(|scope| { + for _ in 0..n_workers { + scope.spawn(async { + while let Ok(job) = jobs_rx.recv().await { + process_nested_search_job( + &results_tx, + &nested_syntax_queries, + &content_query, + job, + ) + .await; + } + }); + } + }); + + search_job_task.await; + outer_search_task.await?; + } else if let Some(content_regex) = &input_query.content { + let search_query = make_search(&content_regex)?; + + let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?; + futures::pin_mut!(results_rx); + + while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let ranges = ranges + .into_iter() + .map(|range| { + let range = range.to_offset(&snapshot); + let range = expand_to_entire_lines(range, &snapshot); + let size = range.len(); + let range = + snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); + (range, size) + }) + .collect(); + + let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges)); if let Err(err) = send_result && !err.is_disconnected() { log::error!("{err}"); } - }) - .detach(); + } + } else { + log::warn!("Context gathering model produced a glob-only search"); } anyhow::Ok(()) } + +async fn process_nested_search_job( + results_tx: &UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, + queries: &Vec, + content_query: &Option, + job: SearchJob, +) { + if let Some(search_query) = queries.get(job.query_ix) { + let mut subranges = Vec::new(); + for range in job.ranges { + let start = range.start; + let search_results = search_query.search(&job.snapshot, Some(range)).await; + for subrange in search_results { + let subrange = start + subrange.start..start + subrange.end; + subranges.extend(expand_to_parent_range(&subrange, &job.snapshot)); + } + } + job.jobs_tx + .send(SearchJob { + buffer: job.buffer, + snapshot: job.snapshot, + ranges: subranges, + query_ix: job.query_ix + 1, + jobs_tx: job.jobs_tx.clone(), + }) + .await + .ok(); + } else { + let ranges = if let Some(content_query) = content_query { + let mut subranges = Vec::new(); + for range in job.ranges { + let start = range.start; + let search_results = content_query.search(&job.snapshot, Some(range)).await; + for subrange in search_results { + let subrange = start + subrange.start..start + subrange.end; + subranges.push(subrange); + } + } + subranges + } else { + job.ranges + }; + + let matches = ranges + .into_iter() + .map(|range| { + let snapshot = &job.snapshot; + let range = expand_to_entire_lines(range, snapshot); + let size = range.len(); + let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); + (range, size) + }) + .collect(); + + let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches)); + + if let Err(err) = send_result + && !err.is_disconnected() + { + log::error!("{err}"); + } + } +} + +fn expand_to_entire_lines(range: Range, snapshot: &BufferSnapshot) -> Range { + let mut point_range = range.to_point(snapshot); + point_range.start.column = 0; + if point_range.end.column > 0 { + point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0)); + } + point_range.to_offset(snapshot) +} + +fn expand_to_parent_range( + range: &Range, + snapshot: &BufferSnapshot, +) -> Option> { + let mut line_range = range.to_point(&snapshot); + line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len; + line_range.end.column = snapshot.line_len(line_range.end.row); + // TODO skip result if matched line isn't the first node line? + + let node = snapshot.syntax_ancestor(line_range)?; + Some(node.byte_range()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::merge_excerpts::merge_excerpts; + use cloud_zeta2_prompt::write_codeblock; + use edit_prediction_context::Line; + use gpui::TestAppContext; + use indoc::indoc; + use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; + use pretty_assertions::assert_eq; + use project::FakeFs; + use serde_json::json; + use settings::SettingsStore; + use std::path::Path; + use util::path; + + #[gpui::test] + async fn test_retrieval(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "user.rs": indoc!{" + pub struct Organization { + owner: Arc, + } + + pub struct User { + first_name: String, + last_name: String, + } + + impl Organization { + pub fn owner(&self) -> Arc { + self.owner.clone() + } + } + + impl User { + pub fn new(first_name: String, last_name: String) -> Self { + Self { + first_name, + last_name + } + } + + pub fn first_name(&self) -> String { + self.first_name.clone() + } + + pub fn last_name(&self) -> String { + self.last_name.clone() + } + } + "}, + "main.rs": indoc!{r#" + fn main() { + let user = User::new(FIRST_NAME.clone(), "doe".into()); + println!("user {:?}", user); + } + "#}, + }), + ) + .await; + + let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await; + project.update(cx, |project, _cx| { + project.languages().add(rust_lang().into()) + }); + + assert_results( + &project, + SearchToolQuery { + glob: "user.rs".into(), + syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()], + content: None, + }, + indoc! {r#" + `````root/user.rs + … + impl User { + … + pub fn first_name(&self) -> String { + self.first_name.clone() + } + … + ````` + "#}, + cx, + ) + .await; + + assert_results( + &project, + SearchToolQuery { + glob: "user.rs".into(), + syntax_node: vec!["impl\\s+User".into()], + content: Some("\\.clone".into()), + }, + indoc! {r#" + `````root/user.rs + … + impl User { + … + pub fn first_name(&self) -> String { + self.first_name.clone() + … + pub fn last_name(&self) -> String { + self.last_name.clone() + … + ````` + "#}, + cx, + ) + .await; + + assert_results( + &project, + SearchToolQuery { + glob: "*.rs".into(), + syntax_node: vec![], + content: Some("\\.clone".into()), + }, + indoc! {r#" + `````root/main.rs + fn main() { + let user = User::new(FIRST_NAME.clone(), "doe".into()); + … + ````` + + `````root/user.rs + … + impl Organization { + pub fn owner(&self) -> Arc { + self.owner.clone() + … + impl User { + … + pub fn first_name(&self) -> String { + self.first_name.clone() + … + pub fn last_name(&self) -> String { + self.last_name.clone() + … + ````` + "#}, + cx, + ) + .await; + } + + async fn assert_results( + project: &Entity, + query: SearchToolQuery, + expected_output: &str, + cx: &mut TestAppContext, + ) { + let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async()) + .await + .unwrap(); + + let mut results = results.into_iter().collect::>(); + results.sort_by_key(|results| { + results + .0 + .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone()) + }); + + let mut output = String::new(); + for (buffer, ranges) in results { + buffer.read_with(cx, |buffer, cx| { + let excerpts = ranges.into_iter().map(|range| { + let point_range = range.to_point(buffer); + if point_range.end.column > 0 { + Line(point_range.start.row)..Line(point_range.end.row + 1) + } else { + Line(point_range.start.row)..Line(point_range.end.row) + } + }); + + write_codeblock( + &buffer.file().unwrap().full_path(cx), + merge_excerpts(&buffer.snapshot(), excerpts).iter(), + &[], + Line(buffer.max_point().row), + false, + &mut output, + ); + }); + } + output.pop(); + + assert_eq!(output, expected_output); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(move |cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + zlog::init_test(); + }); + } +} diff --git a/crates/zeta2/src/udiff.rs b/crates/zeta2/src/udiff.rs index b30eb22741a1e701e2e744445f2a01c1f0ed0d03..d765a64345f839b9314632444d209fa79e9ca5ce 100644 --- a/crates/zeta2/src/udiff.rs +++ b/crates/zeta2/src/udiff.rs @@ -18,10 +18,10 @@ use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSn use project::Project; pub async fn parse_diff<'a>( - diff: &'a str, + diff_str: &'a str, get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send, ) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> { - let mut diff = DiffParser::new(diff); + let mut diff = DiffParser::new(diff_str); let mut edited_buffer = None; let mut edits = Vec::new(); @@ -41,7 +41,10 @@ pub async fn parse_diff<'a>( Some(ref current) => current, }; - edits.extend(resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)?); + edits.extend( + resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges) + .with_context(|| format!("Diff:\n{diff_str}"))?, + ); } DiffEvent::FileEnd { renamed_to } => { let (buffer, _) = edited_buffer @@ -69,13 +72,13 @@ pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap, Entity( - diff: &'a str, + diff_str: &'a str, project: &Entity, cx: &mut AsyncApp, ) -> Result> { let mut included_files = HashMap::default(); - for line in diff.lines() { + for line in diff_str.lines() { let diff_line = DiffLine::parse(line); if let DiffLine::OldPath { path } = diff_line { @@ -97,7 +100,7 @@ pub async fn apply_diff<'a>( let ranges = [Anchor::MIN..Anchor::MAX]; - let mut diff = DiffParser::new(diff); + let mut diff = DiffParser::new(diff_str); let mut current_file = None; let mut edits = vec![]; @@ -120,7 +123,10 @@ pub async fn apply_diff<'a>( }; buffer.read_with(cx, |buffer, _| { - edits.extend(resolve_hunk_edits_in_buffer(hunk, buffer, ranges)?); + edits.extend( + resolve_hunk_edits_in_buffer(hunk, buffer, ranges) + .with_context(|| format!("Diff:\n{diff_str}"))?, + ); anyhow::Ok(()) })??; } @@ -328,13 +334,7 @@ fn resolve_hunk_edits_in_buffer( offset = Some(range.start + ix); } } - offset.ok_or_else(|| { - anyhow!( - "Failed to match context:\n{}\n\nBuffer:\n{}", - hunk.context, - buffer.text(), - ) - }) + offset.ok_or_else(|| anyhow!("Failed to match context:\n{}", hunk.context)) }?; let iter = hunk.edits.into_iter().flat_map(move |edit| { let old_text = buffer diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 503964c88f18562dbf10197bfc330ffe49add8d5..ff0ff4f1ba2af59f32cddee96e4b9c0dd25af22d 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -7,7 +7,7 @@ use cloud_llm_client::{ ZED_VERSION_HEADER_NAME, }; use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; -use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput; +use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; use collections::HashMap; use edit_prediction_context::{ DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, @@ -35,7 +35,7 @@ use uuid::Uuid; use std::ops::Range; use std::path::Path; use std::str::FromStr as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant}; use thiserror::Error; use util::rel_path::RelPathBuf; @@ -88,6 +88,9 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { buffer_change_grouping_interval: Duration::from_secs(1), }; +static MODEL_ID: LazyLock = + LazyLock::new(|| std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string())); + pub struct Zeta2FeatureFlag; impl FeatureFlag for Zeta2FeatureFlag { @@ -180,7 +183,7 @@ pub struct ZetaEditPredictionDebugInfo { pub struct ZetaSearchQueryDebugInfo { pub project: Entity, pub timestamp: Instant, - pub regex_by_glob: HashMap, + pub search_queries: Vec, } pub type RequestDebugInfo = predict_edits_v3::DebugInfo; @@ -883,7 +886,7 @@ impl Zeta { let (prompt, _) = prompt_result?; let request = open_ai::Request { - model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()), + model: MODEL_ID.clone(), messages: vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt), }], @@ -1226,10 +1229,24 @@ impl Zeta { .ok(); } - let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA; + pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| { + let schema = language_model::tool_schema::root_schema_for::( + language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset, + ); + + let description = schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap() + .to_string(); + + (schema.into(), description) + }); + + let (tool_schema, tool_description) = TOOL_SCHEMA.clone(); let request = open_ai::Request { - model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()), + model: MODEL_ID.clone(), messages: vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt), }], @@ -1242,8 +1259,8 @@ impl Zeta { tools: vec![open_ai::ToolDefinition::Function { function: FunctionDefinition { name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(), - description: Some(tool_description.clone()), - parameters: Some(tool_schema.clone()), + description: Some(tool_description), + parameters: Some(tool_schema), }, }], prompt_cache_key: None, @@ -1255,7 +1272,6 @@ impl Zeta { let response = Self::send_raw_llm_request(client, llm_token, app_version, request).await; let mut response = Self::handle_api_response(&this, response, cx)?; - log::trace!("Got search planning response"); let choice = response @@ -1270,7 +1286,7 @@ impl Zeta { anyhow::bail!("Retrieval response didn't include an assistant message"); }; - let mut regex_by_glob: HashMap = HashMap::default(); + let mut queries: Vec = Vec::new(); for tool_call in tool_calls { let open_ai::ToolCallContent::Function { function } = tool_call.content; if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME { @@ -1283,13 +1299,7 @@ impl Zeta { } let input: SearchToolInput = serde_json::from_str(&function.arguments)?; - for query in input.queries { - let regex = regex_by_glob.entry(query.glob).or_default(); - if !regex.is_empty() { - regex.push('|'); - } - regex.push_str(&query.regex); - } + queries.extend(input.queries); } if let Some(debug_tx) = &debug_tx { @@ -1298,16 +1308,16 @@ impl Zeta { ZetaSearchQueryDebugInfo { project: project.clone(), timestamp: Instant::now(), - regex_by_glob: regex_by_glob.clone(), + search_queries: queries.clone(), }, )) .ok(); } - log::trace!("Running retrieval search: {regex_by_glob:#?}"); + log::trace!("Running retrieval search: {queries:#?}"); let related_excerpts_result = - retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await; + retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await; log::trace!("Search queries executed"); @@ -1754,7 +1764,8 @@ mod tests { arguments: serde_json::to_string(&SearchToolInput { queries: Box::new([SearchToolQuery { glob: "root/2.txt".to_string(), - regex: ".".to_string(), + syntax_node: vec![], + content: Some(".".into()), }]), }) .unwrap(), diff --git a/crates/zeta2_tools/Cargo.toml b/crates/zeta2_tools/Cargo.toml index 89d0ce8338624906d2262a7d8314700f6720cff1..3a9b1ccbf9340dfdaa06030e59c2112b9cda6307 100644 --- a/crates/zeta2_tools/Cargo.toml +++ b/crates/zeta2_tools/Cargo.toml @@ -16,6 +16,7 @@ anyhow.workspace = true chrono.workspace = true client.workspace = true cloud_llm_client.workspace = true +cloud_zeta2_prompt.workspace = true collections.workspace = true edit_prediction_context.workspace = true editor.workspace = true @@ -27,7 +28,6 @@ log.workspace = true multi_buffer.workspace = true ordered-float.workspace = true project.workspace = true -regex-syntax = "0.8.8" serde.workspace = true serde_json.workspace = true telemetry.workspace = true diff --git a/crates/zeta2_tools/src/zeta2_context_view.rs b/crates/zeta2_tools/src/zeta2_context_view.rs index 685029cc4a2993227725c17e283c660da5c1d5e1..1826bd22df6d08ce717ef9bdf0070f88ad63c433 100644 --- a/crates/zeta2_tools/src/zeta2_context_view.rs +++ b/crates/zeta2_tools/src/zeta2_context_view.rs @@ -8,6 +8,7 @@ use std::{ use anyhow::Result; use client::{Client, UserStore}; +use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use editor::{Editor, PathKey}; use futures::StreamExt as _; use gpui::{ @@ -41,19 +42,13 @@ pub struct Zeta2ContextView { #[derive(Debug)] struct RetrievalRun { editor: Entity, - search_queries: Vec, + search_queries: Vec, started_at: Instant, search_results_generated_at: Option, search_results_executed_at: Option, finished_at: Option, } -#[derive(Debug)] -struct GlobQueries { - glob: String, - alternations: Vec, -} - actions!( dev, [ @@ -210,23 +205,7 @@ impl Zeta2ContextView { }; run.search_results_generated_at = Some(info.timestamp); - run.search_queries = info - .regex_by_glob - .into_iter() - .map(|(glob, regex)| { - let mut regex_parser = regex_syntax::ast::parse::Parser::new(); - - GlobQueries { - glob, - alternations: match regex_parser.parse(®ex) { - Ok(regex_syntax::ast::Ast::Alternation(ref alt)) => { - alt.asts.iter().map(|ast| ast.to_string()).collect() - } - _ => vec![regex], - }, - } - }) - .collect(); + run.search_queries = info.search_queries; cx.notify(); } @@ -292,18 +271,28 @@ impl Zeta2ContextView { .enumerate() .flat_map(|(ix, query)| { std::iter::once(ListHeader::new(query.glob.clone()).into_any_element()) - .chain(query.alternations.iter().enumerate().map( - move |(alt_ix, alt)| { - ListItem::new(ix * 100 + alt_ix) + .chain(query.syntax_node.iter().enumerate().map( + move |(regex_ix, regex)| { + ListItem::new(ix * 100 + regex_ix) .start_slot( Icon::new(IconName::MagnifyingGlass) .color(Color::Muted) .size(IconSize::Small), ) - .child(alt.clone()) + .child(regex.clone()) .into_any_element() }, )) + .chain(query.content.as_ref().map(move |regex| { + ListItem::new(ix * 100 + query.syntax_node.len()) + .start_slot( + Icon::new(IconName::MagnifyingGlass) + .color(Color::Muted) + .size(IconSize::Small), + ) + .child(regex.clone()) + .into_any_element() + })) }), ), ) diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 1bc2411c825a2fa7147ff0c657804908b687d9ff..a593a1b12ceb2b72a316463076657f35ac2c4e9d 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -2,11 +2,11 @@ use crate::example::{ActualExcerpt, NamedExample}; use crate::headless::ZetaCliAppState; use crate::paths::LOGS_DIR; use ::serde::Serialize; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use clap::Args; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; use futures::StreamExt as _; -use gpui::AsyncApp; +use gpui::{AppContext, AsyncApp}; use project::Project; use serde::Deserialize; use std::cell::Cell; @@ -14,6 +14,7 @@ use std::fs; use std::io::Write; use std::path::PathBuf; use std::sync::Arc; +use std::sync::Mutex; use std::time::{Duration, Instant}; #[derive(Debug, Args)] @@ -103,112 +104,126 @@ pub async fn zeta2_predict( let _edited_buffers = example.apply_edit_history(&project, cx).await?; let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?; + let result = Arc::new(Mutex::new(PredictionDetails::default())); let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; - let refresh_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) - })?; + let debug_task = cx.background_spawn({ + let result = result.clone(); + async move { + let mut context_retrieval_started_at = None; + let mut context_retrieval_finished_at = None; + let mut search_queries_generated_at = None; + let mut search_queries_executed_at = None; + while let Some(event) = debug_rx.next().await { + match event { + zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + context_retrieval_started_at = Some(info.timestamp); + fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?; + } + zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + search_queries_generated_at = Some(info.timestamp); + fs::write( + LOGS_DIR.join("search_queries.json"), + serde_json::to_string_pretty(&info.search_queries).unwrap(), + )?; + } + zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { + search_queries_executed_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { + context_retrieval_finished_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { + let prediction_started_at = Instant::now(); + fs::write( + LOGS_DIR.join("prediction_prompt.md"), + &request.local_prompt.unwrap_or_default(), + )?; - let mut context_retrieval_started_at = None; - let mut context_retrieval_finished_at = None; - let mut search_queries_generated_at = None; - let mut search_queries_executed_at = None; - let mut prediction_started_at = None; - let mut prediction_finished_at = None; - let mut excerpts_text = String::new(); - let mut prediction_task = None; - let mut result = PredictionDetails::default(); - while let Some(event) = debug_rx.next().await { - match event { - zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { - context_retrieval_started_at = Some(info.timestamp); - fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?; - } - zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { - search_queries_generated_at = Some(info.timestamp); - fs::write( - LOGS_DIR.join("search_queries.json"), - serde_json::to_string_pretty(&info.regex_by_glob).unwrap(), - )?; - } - zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { - search_queries_executed_at = Some(info.timestamp); - } - zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { - context_retrieval_finished_at = Some(info.timestamp); + { + let mut result = result.lock().unwrap(); - prediction_task = Some(zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) - })?); - } - zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { - prediction_started_at = Some(Instant::now()); - fs::write( - LOGS_DIR.join("prediction_prompt.md"), - &request.local_prompt.unwrap_or_default(), - )?; + for included_file in request.request.included_files { + let insertions = + vec![(request.request.cursor_point, CURSOR_MARKER)]; + result.excerpts.extend(included_file.excerpts.iter().map( + |excerpt| ActualExcerpt { + path: included_file.path.components().skip(1).collect(), + text: String::from(excerpt.text.as_ref()), + }, + )); + write_codeblock( + &included_file.path, + included_file.excerpts.iter(), + if included_file.path == request.request.excerpt_path { + &insertions + } else { + &[] + }, + included_file.max_row, + false, + &mut result.excerpts_text, + ); + } + } - for included_file in request.request.included_files { - let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)]; - result - .excerpts - .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt { - path: included_file.path.components().skip(1).collect(), - text: String::from(excerpt.text.as_ref()), - })); - write_codeblock( - &included_file.path, - included_file.excerpts.iter(), - if included_file.path == request.request.excerpt_path { - &insertions - } else { - &[] - }, - included_file.max_row, - false, - &mut excerpts_text, - ); - } + let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; + let response = zeta2::text_from_response(response).unwrap_or_default(); + let prediction_finished_at = Instant::now(); + fs::write(LOGS_DIR.join("prediction_response.md"), &response)?; - let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = zeta2::text_from_response(response).unwrap_or_default(); - prediction_finished_at = Some(Instant::now()); - fs::write(LOGS_DIR.join("prediction_response.md"), &response)?; + let mut result = result.lock().unwrap(); - break; + result.planning_search_time = search_queries_generated_at.unwrap() + - context_retrieval_started_at.unwrap(); + result.running_search_time = search_queries_executed_at.unwrap() + - search_queries_generated_at.unwrap(); + result.filtering_search_time = context_retrieval_finished_at.unwrap() + - search_queries_executed_at.unwrap(); + result.prediction_time = prediction_finished_at - prediction_started_at; + result.total_time = + prediction_finished_at - context_retrieval_started_at.unwrap(); + + break; + } + } } + anyhow::Ok(()) } - } + }); + + zeta.update(cx, |zeta, cx| { + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + })? + .await?; - refresh_task.await.context("context retrieval failed")?; - let prediction = prediction_task.unwrap().await?; + let prediction = zeta + .update(cx, |zeta, cx| { + zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) + })? + .await?; + debug_task.await?; + + let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap(); result.diff = prediction .map(|prediction| { let old_text = prediction.snapshot.text(); - let new_text = prediction.buffer.update(cx, |buffer, cx| { - buffer.edit(prediction.edits.iter().cloned(), None, cx); - buffer.text() - })?; - anyhow::Ok(language::unified_diff(&old_text, &new_text)) + let new_text = prediction + .buffer + .update(cx, |buffer, cx| { + buffer.edit(prediction.edits.iter().cloned(), None, cx); + buffer.text() + }) + .unwrap(); + language::unified_diff(&old_text, &new_text) }) - .transpose()? .unwrap_or_default(); - result.excerpts_text = excerpts_text; - - result.planning_search_time = - search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap(); - result.running_search_time = - search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap(); - result.filtering_search_time = - context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap(); - result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap(); - result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap(); anyhow::Ok(result) } -#[derive(Debug, Default, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct PredictionDetails { pub diff: String, pub excerpts: Vec,