diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 27fb7e836ee3ae9ec0197e4bcf0238b70167a73e..92e64f7f332accddbca46ee631f64e5b14be376d 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -35,8 +35,8 @@ use std::str::FromStr as _; use std::sync::Arc; use std::time::{Duration, Instant}; use thiserror::Error; -use util::ResultExt as _; use util::rel_path::RelPathBuf; +use util::{LogErrorFuture, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; pub mod merge_excerpts; @@ -183,7 +183,7 @@ struct ZetaProject { registered_buffers: HashMap, current_prediction: Option, context: Option, Vec>>>, - refresh_context_task: Option>>, + refresh_context_task: Option>>>, refresh_context_debounce_task: Option>>, refresh_context_timestamp: Option, } @@ -1080,7 +1080,11 @@ impl Zeta { log::debug!("refetching edit prediction context after pause"); } this.update(cx, |this, cx| { - this.refresh_context(project, buffer, cursor_position, cx); + let task = this.refresh_context(project.clone(), buffer, cursor_position, cx); + + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { + zeta_project.refresh_context_task = Some(task.log_err()); + }; }) .ok() } @@ -1089,73 +1093,68 @@ impl Zeta { // Refresh the related excerpts asynchronously. Ensure the task runs to completion, // and avoid spawning more than one concurrent task. - fn refresh_context( + pub fn refresh_context( &mut self, project: Entity, buffer: Entity, cursor_position: language::Anchor, cx: &mut Context, - ) { - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - - let debug_tx = self.debug_tx.clone(); - - zeta_project - .refresh_context_task - .get_or_insert(cx.spawn(async move |this, cx| { - let related_excerpts = this - .update(cx, |this, cx| { - let Some(zeta_project) = this.projects.get(&project.entity_id()) else { - return Task::ready(anyhow::Ok(HashMap::default())); - }; + ) -> Task> { + cx.spawn(async move |this, cx| { + let related_excerpts_result = this + .update(cx, |this, cx| { + let Some(zeta_project) = this.projects.get(&project.entity_id()) else { + return Task::ready(anyhow::Ok(HashMap::default())); + }; - let ContextMode::Llm(options) = &this.options().context else { - return Task::ready(anyhow::Ok(HashMap::default())); - }; + let ContextMode::Llm(options) = &this.options().context else { + return Task::ready(anyhow::Ok(HashMap::default())); + }; - let mut edit_history_unified_diff = String::new(); + let mut edit_history_unified_diff = String::new(); - for event in zeta_project.events.iter() { - if let Some(event) = event.to_request_event(cx) { - writeln!(&mut edit_history_unified_diff, "{event}").ok(); - } + for event in zeta_project.events.iter() { + if let Some(event) = event.to_request_event(cx) { + writeln!(&mut edit_history_unified_diff, "{event}").ok(); } + } - find_related_excerpts( - buffer.clone(), - cursor_position, - &project, - edit_history_unified_diff, - options, - debug_tx, - cx, - ) - }) - .ok()? - .await - .log_err() - .unwrap_or_default(); - this.update(cx, |this, _cx| { - let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { - return; - }; - zeta_project.context = Some(related_excerpts); - zeta_project.refresh_context_task.take(); - if let Some(debug_tx) = &this.debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( - ZetaContextRetrievalDebugInfo { - project, - timestamp: Instant::now(), - }, - )) - .ok(); + find_related_excerpts( + buffer.clone(), + cursor_position, + &project, + edit_history_unified_diff, + options, + this.debug_tx.clone(), + cx, + ) + })? + .await; + + this.update(cx, |this, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { + return Ok(()); + }; + zeta_project.refresh_context_task.take(); + if let Some(debug_tx) = &this.debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( + ZetaContextRetrievalDebugInfo { + project, + timestamp: Instant::now(), + }, + )) + .ok(); + } + match related_excerpts_result { + Ok(excerpts) => { + zeta_project.context = Some(excerpts); + Ok(()) } - }) - .ok() - })); + Err(error) => Err(error), + } + })? + }) } fn gather_nearby_diagnostics( diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index b8e7c9f5e1df79aaf9ad0485838d7db65dfabce0..1c6dbee7f8f3e900e3837495d81bd559b03a60a4 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -26,7 +26,7 @@ use project::{Project, ProjectPath, Worktree}; use reqwest_client::ReqwestClient; use serde_json::json; use std::io; -use std::time::Duration; +use std::time::{Duration, Instant}; use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc}; use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery}; @@ -411,30 +411,51 @@ async fn zeta2_predict( Ok(excerpt_offset) })??; - let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; + let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; let cursor_anchor = cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; - zeta.update(cx, |zeta, cx| { + let refresh_task = zeta.update(cx, |zeta, cx| { zeta.register_buffer(&cursor_buffer, &project, cx); + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) })?; - let (prediction_task, mut debug_rx) = zeta.update(cx, |zeta, cx| { - let receiver = zeta.debug_info(); - let prediction_task = zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx); - (prediction_task, receiver) - })?; - - let mut response = None; - + let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; + let mut context_retrieval_started_at = None; + let mut context_retrieval_finished_at = None; + let mut search_queries_generated_at = None; + let mut search_queries_executed_at = None; + let mut prediction_started_at = None; + let mut prediction_finished_at = None; let mut excerpts_text = String::new(); + let mut prediction_task = None; while let Some(event) = debug_rx.next().await { + dbg!(&event); match event { + zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + context_retrieval_started_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + search_queries_generated_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { + search_queries_executed_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { + context_retrieval_finished_at = Some(info.timestamp); + + prediction_task = Some(zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) + })?); + } zeta2::ZetaDebugInfo::EditPredicted(request) => { - response = Some(request.response_rx.await?); + prediction_started_at = Some(Instant::now()); + request.response_rx.await?.map_err(|err| anyhow!(err))?; + prediction_finished_at = Some(Instant::now()); + for included_file in request.request.included_files { let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)]; write_codeblock( @@ -456,17 +477,55 @@ async fn zeta2_predict( } } - prediction_task.await.context("No prediction")?; + refresh_task.await.context("context retrieval failed")?; + let prediction = prediction_task.unwrap().await?.context("No prediction")?; println!("## Excerpts\n"); println!("{excerpts_text}"); + let old_text = prediction.snapshot.text(); + let new_text = prediction.buffer.update(cx, |buffer, cx| { + buffer.edit(prediction.edits.iter().cloned(), None, cx); + buffer.text() + })?; + let diff = language::unified_diff(&old_text, &new_text); + println!("## Prediction\n"); - let response = response - .unwrap() - .map(|r| r.debug_info.unwrap().model_response) - .unwrap_or_else(|s| s); - println!("{response}"); + println!("{diff}"); + + println!("## Time\n"); + + let planning_search_time = + search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap(); + + println!("Planning searches: {}ms", planning_search_time.as_millis()); + println!( + "Running searches: {}ms", + (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis() + ); + + let filtering_search_time = + context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap(); + println!( + "Filtering context results: {}ms", + filtering_search_time.as_millis() + ); + + let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap(); + println!("Making Prediction: {}ms", prediction_time.as_millis()); + + println!("-------------------"); + let total_time = + (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis(); + println!("Total: {}ms", total_time); + + let inference_time = + (planning_search_time + filtering_search_time + prediction_time).as_millis(); + println!( + "Inference: {}ms ({:.2}%)", + inference_time, + (inference_time as f64 / total_time as f64) * 100. + ); anyhow::Ok(()) }