From b607077c08c562bb42023c7316ef57b84371f5b6 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 10 Nov 2025 14:44:54 -0800 Subject: [PATCH] Add old_text/new_text as a zeta2 prompt format (#42171) Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga Co-authored-by: Oleksiy Syvokon Co-authored-by: Ben Kunkle Co-authored-by: Michael Sloan --- .../cloud_llm_client/src/predict_edits_v3.rs | 2 + .../src/cloud_zeta2_prompt.rs | 57 ++++- crates/zeta2/src/xml_edits.rs | 197 ++++++++++++++++++ crates/zeta2/src/zeta2.rs | 64 ++++-- crates/zeta_cli/src/evaluate.rs | 17 +- crates/zeta_cli/src/main.rs | 2 + crates/zeta_cli/src/predict.rs | 134 +++++++++--- 7 files changed, 418 insertions(+), 55 deletions(-) create mode 100644 crates/zeta2/src/xml_edits.rs diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 2e884ae9fcb27530e5579b83767bde95b5df414c..98ca0748934d663d204c64544af8a3e83fcd704d 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -73,6 +73,7 @@ pub enum PromptFormat { MarkedExcerpt, LabeledSections, NumLinesUniDiff, + OldTextNewText, /// Prompt format intended for use via zeta_cli OnlySnippets, } @@ -100,6 +101,7 @@ impl std::fmt::Display for PromptFormat { PromptFormat::LabeledSections => write!(f, "Labeled Sections"), PromptFormat::OnlySnippets => write!(f, "Only Snippets"), PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"), + PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"), } } } diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 6055c39e16ea95b38754bb26fd7371250d1fc525..3f0bd476c50b9e6f92a9f457af15899fcb33b8ed 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -100,6 +100,54 @@ const UNIFIED_DIFF_REMINDER: &str = indoc! {" to uniquely identify it amongst all excerpts of code provided. "}; +const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#" + # Instructions + + You are an edit prediction agent in a code editor. + Your job is to predict the next edit that the user will make, + based on their last few edits and their current cursor location. + + # Output Format + + You must briefly explain your understanding of the user's goal, in one + or two sentences, and then specify their next edit, using the following + XML format: + + + + OLD TEXT 1 HERE + + + NEW TEXT 1 HERE + + + + OLD TEXT 1 HERE + + + NEW TEXT 1 HERE + + + + - Specify the file to edit using the `path` attribute. + - Use `` and `` tags to replace content + - `` must exactly match existing file content, including indentation + - `` cannot be empty + - Do not escape quotes, newlines, or other characters within tags + - Always close all tags properly + - Don't include the <|user_cursor|> marker in your output. + + # Edit History: + +"#}; + +const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#" + --- + + Remember that the edits in the edit history have already been deployed. + The files are currently as shown in the Code Excerpts section. +"#}; + pub fn build_prompt( request: &predict_edits_v3::PredictEditsRequest, ) -> Result<(String, SectionLabels)> { @@ -121,7 +169,9 @@ pub fn build_prompt( EDITABLE_REGION_END_MARKER_WITH_NEWLINE, ), ], - PromptFormat::LabeledSections | PromptFormat::NumLinesUniDiff => { + PromptFormat::LabeledSections + | PromptFormat::NumLinesUniDiff + | PromptFormat::OldTextNewText => { vec![(request.cursor_point, CURSOR_MARKER)] } PromptFormat::OnlySnippets => vec![], @@ -131,6 +181,7 @@ pub fn build_prompt( PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(), PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(), PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(), + PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(), PromptFormat::OnlySnippets => String::new(), }; @@ -186,6 +237,9 @@ pub fn build_prompt( PromptFormat::NumLinesUniDiff => { prompt.push_str(UNIFIED_DIFF_REMINDER); } + PromptFormat::OldTextNewText => { + prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER); + } _ => {} } @@ -611,6 +665,7 @@ impl<'a> SyntaxBasedPrompt<'a> { match self.request.prompt_format { PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets + | PromptFormat::OldTextNewText | PromptFormat::NumLinesUniDiff => { if range.start.0 > 0 && !skipped_last_snippet { output.push_str("…\n"); diff --git a/crates/zeta2/src/xml_edits.rs b/crates/zeta2/src/xml_edits.rs new file mode 100644 index 0000000000000000000000000000000000000000..e8bcc4b1ba7eb2d00cd73b0b2e8d1638a5b00e32 --- /dev/null +++ b/crates/zeta2/src/xml_edits.rs @@ -0,0 +1,197 @@ +use anyhow::{Context as _, Result, anyhow}; +use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot}; +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; + +pub async fn parse_xml_edits<'a>( + mut input: &'a str, + get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send, +) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> { + let edits_tag = parse_tag(&mut input, "edits")?.context("No edits tag")?; + + input = edits_tag.body; + + let file_path = edits_tag + .attributes + .trim_start() + .strip_prefix("path") + .context("no file attribute on edits tag")? + .trim_end() + .strip_prefix('=') + .context("no value for path attribute")? + .trim() + .trim_start_matches('"') + .trim_end_matches('"'); + + let (buffer, context_ranges) = get_buffer(file_path.as_ref()) + .with_context(|| format!("no buffer for file {file_path}"))?; + + let mut edits = vec![]; + while let Some(old_text_tag) = parse_tag(&mut input, "old_text")? { + let new_text_tag = + parse_tag(&mut input, "new_text")?.context("no new_text tag following old_text")?; + edits.extend(resolve_new_text_old_text_in_buffer( + new_text_tag.body, + old_text_tag.body, + buffer, + context_ranges, + )?); + } + + Ok((buffer, edits)) +} + +fn resolve_new_text_old_text_in_buffer( + new_text: &str, + old_text: &str, + buffer: &TextBufferSnapshot, + ranges: &[Range], +) -> Result, Arc)>, anyhow::Error> { + let context_offset = if old_text.is_empty() { + Ok(0) + } else { + let mut offset = None; + for range in ranges { + let range = range.to_offset(buffer); + let text = buffer.text_for_range(range.clone()).collect::(); + for (match_offset, _) in text.match_indices(old_text) { + if offset.is_some() { + anyhow::bail!("old_text is not unique enough:\n{}", old_text); + } + offset = Some(range.start + match_offset); + } + } + offset.ok_or_else(|| anyhow!("Failed to match old_text:\n{}", old_text)) + }?; + + let edits_within_hunk = language::text_diff(&old_text, &new_text); + Ok(edits_within_hunk + .into_iter() + .map(move |(inner_range, inner_text)| { + ( + buffer.anchor_after(context_offset + inner_range.start) + ..buffer.anchor_before(context_offset + inner_range.end), + inner_text, + ) + })) +} + +struct ParsedTag<'a> { + attributes: &'a str, + body: &'a str, +} + +fn parse_tag<'a>(input: &mut &'a str, tag: &str) -> Result>> { + let open_tag = format!("<{}", tag); + let close_tag = format!("", tag); + let Some(start_ix) = input.find(&open_tag) else { + return Ok(None); + }; + let start_ix = start_ix + open_tag.len(); + let closing_bracket_ix = start_ix + + input[start_ix..] + .find('>') + .with_context(|| format!("missing > after {tag}"))?; + let attributes = &input[start_ix..closing_bracket_ix].trim(); + let end_ix = closing_bracket_ix + + input[closing_bracket_ix..] + .find(&close_tag) + .with_context(|| format!("no `{close_tag}` tag"))?; + let body = &input[closing_bracket_ix + '>'.len_utf8()..end_ix]; + let body = body.strip_prefix('\n').unwrap_or(body); + *input = &input[end_ix + close_tag.len()..]; + Ok(Some(ParsedTag { attributes, body })) +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::TestAppContext; + use indoc::indoc; + use language::Point; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + #[test] + fn test_parse_tags() { + let mut input = indoc! {r#" + Prelude + + tag value + + "# }; + let parsed = parse_tag(&mut input, "tag").unwrap().unwrap(); + assert_eq!(parsed.attributes, "attr=\"foo\""); + assert_eq!(parsed.body, "tag value\n"); + assert_eq!(input, "\n"); + } + + #[gpui::test] + async fn test_parse_xml_edits(cx: &mut TestAppContext) { + let fs = init_test(cx); + + let buffer_1_text = indoc! {r#" + one two three four + five six seven eight + nine ten eleven twelve + "# }; + + fs.insert_tree( + path!("/root"), + json!({ + "file1": buffer_1_text, + }), + ) + .await; + + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/file1"), cx) + }) + .await + .unwrap(); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let edits = indoc! {r#" + + + five six seven eight + + + five SIX seven eight! + + + "#}; + + let (buffer, edits) = parse_xml_edits(edits, |_path| { + Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_])) + }) + .await + .unwrap(); + + let edits = edits + .into_iter() + .map(|(range, text)| (range.to_point(&buffer), text)) + .collect::>(); + assert_eq!( + edits, + &[ + (Point::new(1, 5)..Point::new(1, 8), "SIX".into()), + (Point::new(1, 20)..Point::new(1, 20), "!".into()) + ] + ); + } + + fn init_test(cx: &mut TestAppContext) -> Arc { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + }); + + FakeFs::new(cx.background_executor.clone()) + } +} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index c77c78b6f517bce085a26b2c60d04318b2f3cdae..6139c9c75e16f8805e6529dc1700eef1beacd713 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -47,6 +47,7 @@ mod prediction; mod provider; pub mod retrieval_search; pub mod udiff; +mod xml_edits; use crate::merge_excerpts::merge_excerpts; use crate::prediction::EditPrediction; @@ -948,8 +949,9 @@ impl Zeta { llm_token, app_version, #[cfg(feature = "llm-response-cache")] - llm_response_cache - ).await; + llm_response_cache, + ) + .await; let request_time = chrono::Utc::now() - before_request; log::trace!("Got edit prediction response"); @@ -969,7 +971,7 @@ impl Zeta { let (res, usage) = response?; let request_id = EditPredictionId(res.id.clone().into()); let Some(mut output_text) = text_from_response(res) else { - return Ok((None, usage)) + return Ok((None, usage)); }; if output_text.contains(CURSOR_MARKER) { @@ -977,20 +979,25 @@ impl Zeta { output_text = output_text.replace(CURSOR_MARKER, ""); } + let get_buffer_from_context = |path: &Path| { + included_files + .iter() + .find_map(|(_, buffer, probe_path, ranges)| { + if probe_path.as_ref() == path { + Some((buffer, ranges.as_slice())) + } else { + None + } + }) + }; + let (edited_buffer_snapshot, edits) = match options.prompt_format { PromptFormat::NumLinesUniDiff => { - crate::udiff::parse_diff(&output_text, |path| { - included_files - .iter() - .find_map(|(_, buffer, probe_path, ranges)| { - if probe_path.as_ref() == path { - Some((buffer, ranges.as_slice())) - } else { - None - } - }) - }) - .await? + crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? + } + PromptFormat::OldTextNewText => { + crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context) + .await? } _ => { bail!("unsupported prompt format {}", options.prompt_format) @@ -1006,9 +1013,17 @@ impl Zeta { None } }) - .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?; - - anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage)) + .context("Failed to find buffer in included_buffers")?; + + anyhow::Ok(( + Some(( + request_id, + edited_buffer, + edited_buffer_snapshot.clone(), + edits, + )), + usage, + )) } }); @@ -1387,7 +1402,8 @@ impl Zeta { continue; } - let input: SearchToolInput = serde_json::from_str(&function.arguments)?; + let input: SearchToolInput = serde_json::from_str(&function.arguments) + .with_context(|| format!("invalid search json {}", &function.arguments))?; queries.extend(input.queries); } @@ -1447,6 +1463,16 @@ impl Zeta { }) } + pub fn set_context( + &mut self, + project: Entity, + context: HashMap, Vec>>, + ) { + if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) { + zeta_project.context = Some(context); + } + } + fn gather_nearby_diagnostics( cursor_offset: usize, diagnostic_sets: &[(LanguageServerId, DiagnosticSet)], diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index 6d5b2da13a4301bfb52cb3cda7662843dea7cd12..b5c23af24845a90d153943f6ee2ccd29bbfaf6a7 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -24,6 +24,8 @@ pub struct EvaluateArguments { skip_cache: bool, #[arg(long, value_enum, default_value_t = PromptFormat::default())] prompt_format: PromptFormat, + #[arg(long)] + use_expected_context: bool, } pub async fn run_evaluate( @@ -39,6 +41,7 @@ pub async fn run_evaluate( &path, args.skip_cache, args.prompt_format, + args.use_expected_context, app_state.clone(), cx, ) @@ -63,13 +66,21 @@ pub async fn run_evaluate_one( example_path: &Path, skip_cache: bool, prompt_format: PromptFormat, + use_expected_context: bool, app_state: Arc, cx: &mut AsyncApp, ) -> Result { let example = NamedExample::load(&example_path).unwrap(); - let predictions = zeta2_predict(example.clone(), skip_cache, prompt_format, &app_state, cx) - .await - .unwrap(); + let predictions = zeta2_predict( + example.clone(), + skip_cache, + prompt_format, + use_expected_context, + &app_state, + cx, + ) + .await + .unwrap(); let evaluation_result = evaluate(&example.example, &predictions); diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 25fb920bab18f374e41b539bc21320faf6c75484..82760d6061d9b96a2da74bf5cb24e43d9ecdba60 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -171,6 +171,7 @@ enum PromptFormat { OnlySnippets, #[default] NumberedLines, + OldTextNewText, } impl Into for PromptFormat { @@ -180,6 +181,7 @@ impl Into for PromptFormat { Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections, Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets, Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff, + Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText, } } } diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index d85f009c9bacc0b6177683c064979740a0709115..4efc82fa8a7c5d5cf6773a7f771d12dd89b4e1ed 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -1,20 +1,23 @@ use crate::PromptFormat; -use crate::example::{ActualExcerpt, NamedExample}; +use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample}; use crate::headless::ZetaCliAppState; use crate::paths::{CACHE_DIR, LOGS_DIR}; use ::serde::Serialize; use anyhow::{Result, anyhow}; use clap::Args; +use collections::HashMap; use gpui::http_client::Url; +use language::{Anchor, Buffer, Point}; // use cloud_llm_client::predict_edits_v3::PromptFormat; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; use futures::StreamExt as _; -use gpui::{AppContext, AsyncApp}; +use gpui::{AppContext, AsyncApp, Entity}; use project::Project; use serde::Deserialize; use std::cell::Cell; use std::fs; use std::io::Write; +use std::ops::Range; use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; @@ -25,6 +28,8 @@ use zeta2::LlmResponseCache; pub struct PredictArguments { #[arg(long, value_enum, default_value_t = PromptFormat::default())] prompt_format: PromptFormat, + #[arg(long)] + use_expected_context: bool, #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)] format: PredictionsOutputFormat, example_path: PathBuf, @@ -38,15 +43,23 @@ pub enum PredictionsOutputFormat { Md, Diff, } + pub async fn run_zeta2_predict( args: PredictArguments, app_state: &Arc, cx: &mut AsyncApp, ) { let example = NamedExample::load(args.example_path).unwrap(); - let result = zeta2_predict(example, args.skip_cache, args.prompt_format, &app_state, cx) - .await - .unwrap(); + let result = zeta2_predict( + example, + args.skip_cache, + args.prompt_format, + args.use_expected_context, + &app_state, + cx, + ) + .await + .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); } @@ -58,6 +71,7 @@ pub async fn zeta2_predict( example: NamedExample, skip_cache: bool, prompt_format: PromptFormat, + use_expected_context: bool, app_state: &Arc, cx: &mut AsyncApp, ) -> Result { @@ -126,14 +140,13 @@ pub async fn zeta2_predict( let debug_task = cx.background_spawn({ let result = result.clone(); async move { - let mut context_retrieval_started_at = None; - let mut context_retrieval_finished_at = None; + let mut start_time = None; let mut search_queries_generated_at = None; let mut search_queries_executed_at = None; while let Some(event) = debug_rx.next().await { match event { zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { - context_retrieval_started_at = Some(info.timestamp); + start_time = Some(info.timestamp); fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?; } zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { @@ -146,11 +159,10 @@ pub async fn zeta2_predict( zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { search_queries_executed_at = Some(info.timestamp); } - zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { - context_retrieval_finished_at = Some(info.timestamp); - } + zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { let prediction_started_at = Instant::now(); + start_time.get_or_insert(prediction_started_at); fs::write( LOGS_DIR.join("prediction_prompt.md"), &request.local_prompt.unwrap_or_default(), @@ -190,15 +202,16 @@ pub async fn zeta2_predict( let mut result = result.lock().unwrap(); - result.planning_search_time = search_queries_generated_at.unwrap() - - context_retrieval_started_at.unwrap(); - result.running_search_time = search_queries_executed_at.unwrap() - - search_queries_generated_at.unwrap(); - result.filtering_search_time = context_retrieval_finished_at.unwrap() - - search_queries_executed_at.unwrap(); + if !use_expected_context { + result.planning_search_time = + Some(search_queries_generated_at.unwrap() - start_time.unwrap()); + result.running_search_time = Some( + search_queries_executed_at.unwrap() + - search_queries_generated_at.unwrap(), + ); + } result.prediction_time = prediction_finished_at - prediction_started_at; - result.total_time = - prediction_finished_at - context_retrieval_started_at.unwrap(); + result.total_time = prediction_finished_at - start_time.unwrap(); break; } @@ -208,13 +221,42 @@ pub async fn zeta2_predict( } }); - zeta.update(cx, |zeta, cx| { + zeta.update(cx, |zeta, _cx| { let mut options = zeta.options().clone(); options.prompt_format = prompt_format.into(); zeta.set_options(options); - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) - })? - .await?; + })?; + + if use_expected_context { + let context_excerpts_tasks = example + .example + .expected_context + .iter() + .flat_map(|section| { + section.alternatives[0].excerpts.iter().map(|excerpt| { + resolve_context_entry(project.clone(), excerpt.clone(), cx.clone()) + }) + }) + .collect::>(); + let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?; + + let mut context_excerpts = HashMap::default(); + for (buffer, mut excerpts) in context_excerpts_vec { + context_excerpts + .entry(buffer) + .or_insert(Vec::new()) + .append(&mut excerpts); + } + + zeta.update(cx, |zeta, _cx| { + zeta.set_context(project.clone(), context_excerpts) + })?; + } else { + zeta.update(cx, |zeta, cx| { + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + })? + .await?; + } let prediction = zeta .update(cx, |zeta, cx| { @@ -242,6 +284,38 @@ pub async fn zeta2_predict( anyhow::Ok(result) } +async fn resolve_context_entry( + project: Entity, + excerpt: ExpectedExcerpt, + mut cx: AsyncApp, +) -> Result<(Entity, Vec>)> { + let buffer = project + .update(&mut cx, |project, cx| { + let project_path = project.find_project_path(&excerpt.path, cx).unwrap(); + project.open_buffer(project_path, cx) + })? + .await?; + + let ranges = buffer.read_with(&mut cx, |buffer, _| { + let full_text = buffer.text(); + let offset = full_text + .find(&excerpt.text) + .expect("Expected context not found"); + let point = buffer.offset_to_point(offset); + excerpt + .required_lines + .iter() + .map(|line| { + let row = point.row + line.0; + let range = Point::new(row, 0)..Point::new(row + 1, 0); + buffer.anchor_after(range.start)..buffer.anchor_before(range.end) + }) + .collect() + })?; + + Ok((buffer, ranges)) +} + struct Cache { skip_cache: bool, } @@ -292,9 +366,8 @@ pub struct PredictionDetails { pub diff: String, pub excerpts: Vec, pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly - pub planning_search_time: Duration, - pub filtering_search_time: Duration, - pub running_search_time: Duration, + pub planning_search_time: Option, + pub running_search_time: Option, pub prediction_time: Duration, pub total_time: Duration, } @@ -311,8 +384,7 @@ impl PredictionDetails { } pub fn to_markdown(&self) -> String { - let inference_time = - self.planning_search_time + self.filtering_search_time + self.prediction_time; + let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time; format!( "## Excerpts\n\n\ @@ -322,16 +394,14 @@ impl PredictionDetails { ## Time\n\n\ Planning searches: {}ms\n\ Running searches: {}ms\n\ - Filtering context results: {}ms\n\ Making Prediction: {}ms\n\n\ -------------------\n\n\ Total: {}ms\n\ Inference: {}ms ({:.2}%)\n", self.excerpts_text, self.diff, - self.planning_search_time.as_millis(), - self.running_search_time.as_millis(), - self.filtering_search_time.as_millis(), + self.planning_search_time.unwrap_or_default().as_millis(), + self.running_search_time.unwrap_or_default().as_millis(), self.prediction_time.as_millis(), self.total_time.as_millis(), inference_time.as_millis(),