@@ -35,8 +35,8 @@ use std::str::FromStr as _;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
-use util::ResultExt as _;
use util::rel_path::RelPathBuf;
+use util::{LogErrorFuture, TryFutureExt};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
pub mod merge_excerpts;
@@ -183,7 +183,7 @@ struct ZetaProject {
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
- refresh_context_task: Option<Task<Option<()>>>,
+ refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
refresh_context_debounce_task: Option<Task<Option<()>>>,
refresh_context_timestamp: Option<Instant>,
}
@@ -1080,7 +1080,11 @@ impl Zeta {
log::debug!("refetching edit prediction context after pause");
}
this.update(cx, |this, cx| {
- this.refresh_context(project, buffer, cursor_position, cx);
+ let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
+
+ if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+ zeta_project.refresh_context_task = Some(task.log_err());
+ };
})
.ok()
}
@@ -1089,73 +1093,68 @@ impl Zeta {
// Refresh the related excerpts asynchronously. Ensure the task runs to completion,
// and avoid spawning more than one concurrent task.
- fn refresh_context(
+ pub fn refresh_context(
&mut self,
project: Entity<Project>,
buffer: Entity<language::Buffer>,
cursor_position: language::Anchor,
cx: &mut Context<Self>,
- ) {
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- let debug_tx = self.debug_tx.clone();
-
- zeta_project
- .refresh_context_task
- .get_or_insert(cx.spawn(async move |this, cx| {
- let related_excerpts = this
- .update(cx, |this, cx| {
- let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
- return Task::ready(anyhow::Ok(HashMap::default()));
- };
+ ) -> Task<Result<()>> {
+ cx.spawn(async move |this, cx| {
+ let related_excerpts_result = this
+ .update(cx, |this, cx| {
+ let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
+ return Task::ready(anyhow::Ok(HashMap::default()));
+ };
- let ContextMode::Llm(options) = &this.options().context else {
- return Task::ready(anyhow::Ok(HashMap::default()));
- };
+ let ContextMode::Llm(options) = &this.options().context else {
+ return Task::ready(anyhow::Ok(HashMap::default()));
+ };
- let mut edit_history_unified_diff = String::new();
+ let mut edit_history_unified_diff = String::new();
- for event in zeta_project.events.iter() {
- if let Some(event) = event.to_request_event(cx) {
- writeln!(&mut edit_history_unified_diff, "{event}").ok();
- }
+ for event in zeta_project.events.iter() {
+ if let Some(event) = event.to_request_event(cx) {
+ writeln!(&mut edit_history_unified_diff, "{event}").ok();
}
+ }
- find_related_excerpts(
- buffer.clone(),
- cursor_position,
- &project,
- edit_history_unified_diff,
- options,
- debug_tx,
- cx,
- )
- })
- .ok()?
- .await
- .log_err()
- .unwrap_or_default();
- this.update(cx, |this, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
- return;
- };
- zeta_project.context = Some(related_excerpts);
- zeta_project.refresh_context_task.take();
- if let Some(debug_tx) = &this.debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
- ZetaContextRetrievalDebugInfo {
- project,
- timestamp: Instant::now(),
- },
- ))
- .ok();
+ find_related_excerpts(
+ buffer.clone(),
+ cursor_position,
+ &project,
+ edit_history_unified_diff,
+ options,
+ this.debug_tx.clone(),
+ cx,
+ )
+ })?
+ .await;
+
+ this.update(cx, |this, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
+ return Ok(());
+ };
+ zeta_project.refresh_context_task.take();
+ if let Some(debug_tx) = &this.debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
+ ZetaContextRetrievalDebugInfo {
+ project,
+ timestamp: Instant::now(),
+ },
+ ))
+ .ok();
+ }
+ match related_excerpts_result {
+ Ok(excerpts) => {
+ zeta_project.context = Some(excerpts);
+ Ok(())
}
- })
- .ok()
- }));
+ Err(error) => Err(error),
+ }
+ })?
+ })
}
fn gather_nearby_diagnostics(
@@ -26,7 +26,7 @@ use project::{Project, ProjectPath, Worktree};
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::io;
-use std::time::Duration;
+use std::time::{Duration, Instant};
use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
@@ -411,30 +411,51 @@ async fn zeta2_predict(
Ok(excerpt_offset)
})??;
- let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+ let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
- zeta.update(cx, |zeta, cx| {
+ let refresh_task = zeta.update(cx, |zeta, cx| {
zeta.register_buffer(&cursor_buffer, &project, cx);
+ zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
})?;
- let (prediction_task, mut debug_rx) = zeta.update(cx, |zeta, cx| {
- let receiver = zeta.debug_info();
- let prediction_task = zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx);
- (prediction_task, receiver)
- })?;
-
- let mut response = None;
-
+ let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+ let mut context_retrieval_started_at = None;
+ let mut context_retrieval_finished_at = None;
+ let mut search_queries_generated_at = None;
+ let mut search_queries_executed_at = None;
+ let mut prediction_started_at = None;
+ let mut prediction_finished_at = None;
let mut excerpts_text = String::new();
+ let mut prediction_task = None;
while let Some(event) = debug_rx.next().await {
+ dbg!(&event);
match event {
+ zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ context_retrieval_started_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+ search_queries_generated_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+ search_queries_executed_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+ context_retrieval_finished_at = Some(info.timestamp);
+
+ prediction_task = Some(zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+ })?);
+ }
zeta2::ZetaDebugInfo::EditPredicted(request) => {
- response = Some(request.response_rx.await?);
+ prediction_started_at = Some(Instant::now());
+ request.response_rx.await?.map_err(|err| anyhow!(err))?;
+ prediction_finished_at = Some(Instant::now());
+
for included_file in request.request.included_files {
let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
write_codeblock(
@@ -456,17 +477,55 @@ async fn zeta2_predict(
}
}
- prediction_task.await.context("No prediction")?;
+ refresh_task.await.context("context retrieval failed")?;
+ let prediction = prediction_task.unwrap().await?.context("No prediction")?;
println!("## Excerpts\n");
println!("{excerpts_text}");
+ let old_text = prediction.snapshot.text();
+ let new_text = prediction.buffer.update(cx, |buffer, cx| {
+ buffer.edit(prediction.edits.iter().cloned(), None, cx);
+ buffer.text()
+ })?;
+ let diff = language::unified_diff(&old_text, &new_text);
+
println!("## Prediction\n");
- let response = response
- .unwrap()
- .map(|r| r.debug_info.unwrap().model_response)
- .unwrap_or_else(|s| s);
- println!("{response}");
+ println!("{diff}");
+
+ println!("## Time\n");
+
+ let planning_search_time =
+ search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
+
+ println!("Planning searches: {}ms", planning_search_time.as_millis());
+ println!(
+ "Running searches: {}ms",
+ (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis()
+ );
+
+ let filtering_search_time =
+ context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
+ println!(
+ "Filtering context results: {}ms",
+ filtering_search_time.as_millis()
+ );
+
+ let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
+ println!("Making Prediction: {}ms", prediction_time.as_millis());
+
+ println!("-------------------");
+ let total_time =
+ (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis();
+ println!("Total: {}ms", total_time);
+
+ let inference_time =
+ (planning_search_time + filtering_search_time + prediction_time).as_millis();
+ println!(
+ "Inference: {}ms ({:.2}%)",
+ inference_time,
+ (inference_time as f64 / total_time as f64) * 100.
+ );
anyhow::Ok(())
}