Calculate prediction diff and display timings

Agus Zubiaga , Ben Kunkle , and Max Brunsfeld created

Co-authored-by: Ben Kunkle <ben.kunkle@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/zeta2/src/zeta2.rs   | 117 +++++++++++++++++++-------------------
crates/zeta_cli/src/main.rs |  95 +++++++++++++++++++++++++------
2 files changed, 135 insertions(+), 77 deletions(-)

Detailed changes

crates/zeta2/src/zeta2.rs 🔗

@@ -35,8 +35,8 @@ use std::str::FromStr as _;
 use std::sync::Arc;
 use std::time::{Duration, Instant};
 use thiserror::Error;
-use util::ResultExt as _;
 use util::rel_path::RelPathBuf;
+use util::{LogErrorFuture, TryFutureExt};
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 
 pub mod merge_excerpts;
@@ -183,7 +183,7 @@ struct ZetaProject {
     registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
     current_prediction: Option<CurrentEditPrediction>,
     context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
-    refresh_context_task: Option<Task<Option<()>>>,
+    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
     refresh_context_debounce_task: Option<Task<Option<()>>>,
     refresh_context_timestamp: Option<Instant>,
 }
@@ -1080,7 +1080,11 @@ impl Zeta {
                     log::debug!("refetching edit prediction context after pause");
                 }
                 this.update(cx, |this, cx| {
-                    this.refresh_context(project, buffer, cursor_position, cx);
+                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
+
+                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+                        zeta_project.refresh_context_task = Some(task.log_err());
+                    };
                 })
                 .ok()
             }
@@ -1089,73 +1093,68 @@ impl Zeta {
 
     // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
     // and avoid spawning more than one concurrent task.
-    fn refresh_context(
+    pub fn refresh_context(
         &mut self,
         project: Entity<Project>,
         buffer: Entity<language::Buffer>,
         cursor_position: language::Anchor,
         cx: &mut Context<Self>,
-    ) {
-        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
-            return;
-        };
-
-        let debug_tx = self.debug_tx.clone();
-
-        zeta_project
-            .refresh_context_task
-            .get_or_insert(cx.spawn(async move |this, cx| {
-                let related_excerpts = this
-                    .update(cx, |this, cx| {
-                        let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
-                            return Task::ready(anyhow::Ok(HashMap::default()));
-                        };
+    ) -> Task<Result<()>> {
+        cx.spawn(async move |this, cx| {
+            let related_excerpts_result = this
+                .update(cx, |this, cx| {
+                    let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
+                        return Task::ready(anyhow::Ok(HashMap::default()));
+                    };
 
-                        let ContextMode::Llm(options) = &this.options().context else {
-                            return Task::ready(anyhow::Ok(HashMap::default()));
-                        };
+                    let ContextMode::Llm(options) = &this.options().context else {
+                        return Task::ready(anyhow::Ok(HashMap::default()));
+                    };
 
-                        let mut edit_history_unified_diff = String::new();
+                    let mut edit_history_unified_diff = String::new();
 
-                        for event in zeta_project.events.iter() {
-                            if let Some(event) = event.to_request_event(cx) {
-                                writeln!(&mut edit_history_unified_diff, "{event}").ok();
-                            }
+                    for event in zeta_project.events.iter() {
+                        if let Some(event) = event.to_request_event(cx) {
+                            writeln!(&mut edit_history_unified_diff, "{event}").ok();
                         }
+                    }
 
-                        find_related_excerpts(
-                            buffer.clone(),
-                            cursor_position,
-                            &project,
-                            edit_history_unified_diff,
-                            options,
-                            debug_tx,
-                            cx,
-                        )
-                    })
-                    .ok()?
-                    .await
-                    .log_err()
-                    .unwrap_or_default();
-                this.update(cx, |this, _cx| {
-                    let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
-                        return;
-                    };
-                    zeta_project.context = Some(related_excerpts);
-                    zeta_project.refresh_context_task.take();
-                    if let Some(debug_tx) = &this.debug_tx {
-                        debug_tx
-                            .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
-                                ZetaContextRetrievalDebugInfo {
-                                    project,
-                                    timestamp: Instant::now(),
-                                },
-                            ))
-                            .ok();
+                    find_related_excerpts(
+                        buffer.clone(),
+                        cursor_position,
+                        &project,
+                        edit_history_unified_diff,
+                        options,
+                        this.debug_tx.clone(),
+                        cx,
+                    )
+                })?
+                .await;
+
+            this.update(cx, |this, _cx| {
+                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
+                    return Ok(());
+                };
+                zeta_project.refresh_context_task.take();
+                if let Some(debug_tx) = &this.debug_tx {
+                    debug_tx
+                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
+                            ZetaContextRetrievalDebugInfo {
+                                project,
+                                timestamp: Instant::now(),
+                            },
+                        ))
+                        .ok();
+                }
+                match related_excerpts_result {
+                    Ok(excerpts) => {
+                        zeta_project.context = Some(excerpts);
+                        Ok(())
                     }
-                })
-                .ok()
-            }));
+                    Err(error) => Err(error),
+                }
+            })?
+        })
     }
 
     fn gather_nearby_diagnostics(

crates/zeta_cli/src/main.rs 🔗

@@ -26,7 +26,7 @@ use project::{Project, ProjectPath, Worktree};
 use reqwest_client::ReqwestClient;
 use serde_json::json;
 use std::io;
-use std::time::Duration;
+use std::time::{Duration, Instant};
 use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
 use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
 
@@ -411,30 +411,51 @@ async fn zeta2_predict(
 
         Ok(excerpt_offset)
     })??;
-    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
 
+    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
     let cursor_anchor =
         cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
 
     let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
 
-    zeta.update(cx, |zeta, cx| {
+    let refresh_task = zeta.update(cx, |zeta, cx| {
         zeta.register_buffer(&cursor_buffer, &project, cx);
+        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
     })?;
 
-    let (prediction_task, mut debug_rx) = zeta.update(cx, |zeta, cx| {
-        let receiver = zeta.debug_info();
-        let prediction_task = zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx);
-        (prediction_task, receiver)
-    })?;
-
-    let mut response = None;
-
+    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+    let mut context_retrieval_started_at = None;
+    let mut context_retrieval_finished_at = None;
+    let mut search_queries_generated_at = None;
+    let mut search_queries_executed_at = None;
+    let mut prediction_started_at = None;
+    let mut prediction_finished_at = None;
     let mut excerpts_text = String::new();
+    let mut prediction_task = None;
     while let Some(event) = debug_rx.next().await {
+        dbg!(&event);
         match event {
+            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+                context_retrieval_started_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+                search_queries_generated_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+                search_queries_executed_at = Some(info.timestamp);
+            }
+            zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+                context_retrieval_finished_at = Some(info.timestamp);
+
+                prediction_task = Some(zeta.update(cx, |zeta, cx| {
+                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+                })?);
+            }
             zeta2::ZetaDebugInfo::EditPredicted(request) => {
-                response = Some(request.response_rx.await?);
+                prediction_started_at = Some(Instant::now());
+                request.response_rx.await?.map_err(|err| anyhow!(err))?;
+                prediction_finished_at = Some(Instant::now());
+
                 for included_file in request.request.included_files {
                     let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
                     write_codeblock(
@@ -456,17 +477,55 @@ async fn zeta2_predict(
         }
     }
 
-    prediction_task.await.context("No prediction")?;
+    refresh_task.await.context("context retrieval failed")?;
+    let prediction = prediction_task.unwrap().await?.context("No prediction")?;
 
     println!("## Excerpts\n");
     println!("{excerpts_text}");
 
+    let old_text = prediction.snapshot.text();
+    let new_text = prediction.buffer.update(cx, |buffer, cx| {
+        buffer.edit(prediction.edits.iter().cloned(), None, cx);
+        buffer.text()
+    })?;
+    let diff = language::unified_diff(&old_text, &new_text);
+
     println!("## Prediction\n");
-    let response = response
-        .unwrap()
-        .map(|r| r.debug_info.unwrap().model_response)
-        .unwrap_or_else(|s| s);
-    println!("{response}");
+    println!("{diff}");
+
+    println!("## Time\n");
+
+    let planning_search_time =
+        search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
+
+    println!("Planning searches: {}ms", planning_search_time.as_millis());
+    println!(
+        "Running searches: {}ms",
+        (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis()
+    );
+
+    let filtering_search_time =
+        context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
+    println!(
+        "Filtering context results: {}ms",
+        filtering_search_time.as_millis()
+    );
+
+    let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
+    println!("Making Prediction: {}ms", prediction_time.as_millis());
+
+    println!("-------------------");
+    let total_time =
+        (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis();
+    println!("Total: {}ms", total_time);
+
+    let inference_time =
+        (planning_search_time + filtering_search_time + prediction_time).as_millis();
+    println!(
+        "Inference: {}ms ({:.2}%)",
+        inference_time,
+        (inference_time as f64 / total_time as f64) * 100.
+    );
 
     anyhow::Ok(())
 }