diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 30ef2d79da05b87c730ccf0c87c4061225d1c723..aa81c09237305e6f7edd77f1d033169857217e2e 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -11,7 +11,7 @@ use edit_prediction_context::{ EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState, }; use futures::AsyncReadExt as _; -use futures::channel::mpsc; +use futures::channel::{mpsc, oneshot}; use gpui::http_client::Method; use gpui::{ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, @@ -76,7 +76,7 @@ pub struct Zeta { projects: HashMap, options: ZetaOptions, update_required: bool, - debug_tx: Option>>, + debug_tx: Option>, } #[derive(Debug, Clone, PartialEq)] @@ -91,9 +91,9 @@ pub struct ZetaOptions { pub struct PredictionDebugInfo { pub context: EditPredictionContext, pub retrieval_time: TimeDelta, - pub request: RequestDebugInfo, pub buffer: WeakEntity, pub position: language::Anchor, + pub response_rx: oneshot::Receiver>, } pub type RequestDebugInfo = predict_edits_v3::DebugInfo; @@ -204,7 +204,7 @@ impl Zeta { } } - pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver> { + pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); self.debug_tx = Some(debug_watch_tx); debug_watch_rx @@ -537,8 +537,22 @@ impl Zeta { return Ok(None); }; - let debug_context = if let Some(debug_tx) = debug_tx { - Some((debug_tx, context.clone())) + let retrieval_time = chrono::Utc::now() - before_retrieval; + + let debug_response_tx = if let Some(debug_tx) = debug_tx { + let (response_tx, response_rx) = oneshot::channel(); + let context = context.clone(); + + debug_tx + .unbounded_send(PredictionDebugInfo { + context, + retrieval_time, + buffer: buffer.downgrade(), + position, + response_rx, + }) + .ok(); + Some(response_tx) } else { None }; @@ -560,32 +574,21 @@ impl Zeta { diagnostic_groups, diagnostic_groups_truncated, None, - debug_context.is_some(), + debug_response_tx.is_some(), &worktree_snapshots, index_state.as_deref(), Some(options.max_prompt_bytes), options.prompt_format, ); - let retrieval_time = chrono::Utc::now() - before_retrieval; let response = Self::perform_request(client, llm_token, app_version, request).await; - if let Some((debug_tx, context)) = debug_context { - debug_tx - .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then( - |response| { - let Some(request) = - some_or_debug_panic(response.0.debug_info.clone()) - else { - return Err("Missing debug info".to_string()); - }; - Ok(PredictionDebugInfo { - context, - request, - retrieval_time, - buffer: buffer.downgrade(), - position, - }) + if let Some(debug_response_tx) = debug_response_tx { + debug_response_tx + .send(response.as_ref().map_err(|err| err.to_string()).and_then( + |response| match some_or_debug_panic(response.0.debug_info.clone()) { + Some(debug_info) => Ok(debug_info), + None => Err("Missing debug info".to_string()), }, )) .ok(); diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 40315265df4c9a4aec3dfee37185d94249841eda..e957cce380266aa8586e7fa283da35b259227f20 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -5,7 +5,7 @@ use client::{Client, UserStore}; use cloud_llm_client::predict_edits_v3::PromptFormat; use collections::HashMap; use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer}; -use futures::StreamExt as _; +use futures::{StreamExt as _, channel::oneshot}; use gpui::{ Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions, prelude::*, @@ -16,7 +16,7 @@ use ui::{ContextMenu, ContextMenuEntry, DropdownMenu, prelude::*}; use ui_input::SingleLineInput; use util::{ResultExt, paths::PathStyle, rel_path::RelPath}; use workspace::{Item, SplitDirection, Workspace}; -use zeta2::{DEFAULT_CONTEXT_OPTIONS, Zeta, ZetaOptions}; +use zeta2::{DEFAULT_CONTEXT_OPTIONS, PredictionDebugInfo, Zeta, ZetaOptions}; use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions}; @@ -56,7 +56,7 @@ pub fn init(cx: &mut App) { pub struct Zeta2Inspector { focus_handle: FocusHandle, project: Entity, - last_prediction: Option, + last_prediction: Option, max_excerpt_bytes_input: Entity, min_excerpt_bytes_input: Entity, cursor_context_ratio_input: Entity, @@ -74,25 +74,27 @@ enum ActiveView { Inference, } -enum LastPredictionState { - Failed(SharedString), - Success(LastPrediction), - Replaying { - prediction: LastPrediction, - _task: Task<()>, - }, -} - struct LastPrediction { context_editor: Entity, retrieval_time: TimeDelta, - prompt_planning_time: TimeDelta, - inference_time: TimeDelta, - parsing_time: TimeDelta, - prompt_editor: Entity, - model_response_editor: Entity, buffer: WeakEntity, position: language::Anchor, + state: LastPredictionState, + _task: Option>, +} + +enum LastPredictionState { + Requested, + Success { + inference_time: TimeDelta, + parsing_time: TimeDelta, + prompt_planning_time: TimeDelta, + prompt_editor: Entity, + model_response_editor: Entity, + }, + Failed { + message: String, + }, } impl Zeta2Inspector { @@ -107,15 +109,9 @@ impl Zeta2Inspector { let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info()); let receive_task = cx.spawn_in(window, async move |this, cx| { - while let Some(prediction_result) = request_rx.next().await { - this.update_in(cx, |this, window, cx| match prediction_result { - Ok(prediction) => { - this.update_last_prediction(prediction, window, cx); - } - Err(err) => { - this.last_prediction = Some(LastPredictionState::Failed(err.into())); - cx.notify(); - } + while let Some(prediction) = request_rx.next().await { + this.update_in(cx, |this, window, cx| { + this.update_last_prediction(prediction, window, cx) }) .ok(); } @@ -175,16 +171,12 @@ impl Zeta2Inspector { const THROTTLE_TIME: Duration = Duration::from_millis(100); - if let Some( - LastPredictionState::Success(prediction) - | LastPredictionState::Replaying { prediction, .. }, - ) = self.last_prediction.take() - { + if let Some(prediction) = self.last_prediction.as_mut() { if let Some(buffer) = prediction.buffer.upgrade() { let position = prediction.position; let zeta = self.zeta.clone(); let project = self.project.clone(); - let task = cx.spawn(async move |_this, cx| { + prediction._task = Some(cx.spawn(async move |_this, cx| { cx.background_executor().timer(THROTTLE_TIME).await; if let Some(task) = zeta .update(cx, |zeta, cx| { @@ -194,13 +186,10 @@ impl Zeta2Inspector { { task.await.log_err(); } - }); - self.last_prediction = Some(LastPredictionState::Replaying { - prediction, - _task: task, - }); + })); + prediction.state = LastPredictionState::Requested; } else { - self.last_prediction = Some(LastPredictionState::Failed("Buffer dropped".into())); + self.last_prediction.take(); } } @@ -383,47 +372,86 @@ impl Zeta2Inspector { Editor::new(EditorMode::full(), multibuffer, None, window, cx) }); - let last_prediction = LastPrediction { + let PredictionDebugInfo { + response_rx, + position, + buffer, + retrieval_time, + .. + } = prediction; + + let task = cx.spawn_in(window, async move |this, cx| { + let response = response_rx.await; + + this.update_in(cx, |this, window, cx| { + if let Some(prediction) = this.last_prediction.as_mut() { + prediction.state = match response { + Ok(Ok(response)) => LastPredictionState::Success { + prompt_planning_time: response.prompt_planning_time, + inference_time: response.inference_time, + parsing_time: response.parsing_time, + prompt_editor: cx.new(|cx| { + let buffer = cx.new(|cx| { + let mut buffer = Buffer::local(response.prompt, cx); + buffer.set_language(markdown_language.clone(), cx); + buffer + }); + let buffer = + cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let mut editor = Editor::new( + EditorMode::full(), + buffer, + None, + window, + cx, + ); + editor.set_read_only(true); + editor.set_show_line_numbers(false, cx); + editor.set_show_gutter(false, cx); + editor.set_show_scrollbars(false, cx); + editor + }), + model_response_editor: cx.new(|cx| { + let buffer = cx.new(|cx| { + let mut buffer = + Buffer::local(response.model_response, cx); + buffer.set_language(markdown_language, cx); + buffer + }); + let buffer = + cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let mut editor = Editor::new( + EditorMode::full(), + buffer, + None, + window, + cx, + ); + editor.set_read_only(true); + editor.set_show_line_numbers(false, cx); + editor.set_show_gutter(false, cx); + editor.set_show_scrollbars(false, cx); + editor + }), + }, + Ok(Err(err)) => LastPredictionState::Failed { message: err }, + Err(oneshot::Canceled) => LastPredictionState::Failed { + message: "Canceled".to_string(), + }, + }; + } + }) + .ok(); + }); + + this.last_prediction = Some(LastPrediction { context_editor, - prompt_editor: cx.new(|cx| { - let buffer = cx.new(|cx| { - let mut buffer = Buffer::local(prediction.request.prompt, cx); - buffer.set_language(markdown_language.clone(), cx); - buffer - }); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - let mut editor = - Editor::new(EditorMode::full(), buffer, None, window, cx); - editor.set_read_only(true); - editor.set_show_line_numbers(false, cx); - editor.set_show_gutter(false, cx); - editor.set_show_scrollbars(false, cx); - editor - }), - model_response_editor: cx.new(|cx| { - let buffer = cx.new(|cx| { - let mut buffer = - Buffer::local(prediction.request.model_response, cx); - buffer.set_language(markdown_language, cx); - buffer - }); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - let mut editor = - Editor::new(EditorMode::full(), buffer, None, window, cx); - editor.set_read_only(true); - editor.set_show_line_numbers(false, cx); - editor.set_show_gutter(false, cx); - editor.set_show_scrollbars(false, cx); - editor - }), - retrieval_time: prediction.retrieval_time, - prompt_planning_time: prediction.request.prompt_planning_time, - inference_time: prediction.request.inference_time, - parsing_time: prediction.request.parsing_time, - buffer: prediction.buffer, - position: prediction.position, - }; - this.last_prediction = Some(LastPredictionState::Success(last_prediction)); + retrieval_time, + buffer, + position, + state: LastPredictionState::Requested, + _task: Some(task), + }); cx.notify(); }) .ok(); @@ -514,9 +542,7 @@ impl Zeta2Inspector { } fn render_tabs(&self, cx: &mut Context) -> Option { - let Some(LastPredictionState::Success { .. } | LastPredictionState::Replaying { .. }) = - self.last_prediction.as_ref() - else { + if self.last_prediction.is_none() { return None; }; @@ -551,14 +577,26 @@ impl Zeta2Inspector { } fn render_stats(&self) -> Option
{ - let Some( - LastPredictionState::Success(prediction) - | LastPredictionState::Replaying { prediction, .. }, - ) = self.last_prediction.as_ref() - else { + let Some(prediction) = self.last_prediction.as_ref() else { return None; }; + let (prompt_planning_time, inference_time, parsing_time) = match &prediction.state { + LastPredictionState::Success { + inference_time, + parsing_time, + prompt_planning_time, + .. + } => ( + Some(*prompt_planning_time), + Some(*inference_time), + Some(*parsing_time), + ), + LastPredictionState::Requested | LastPredictionState::Failed { .. } => { + (None, None, None) + } + }; + Some( v_flex() .p_4() @@ -567,32 +605,30 @@ impl Zeta2Inspector { .child(Headline::new("Stats").size(HeadlineSize::Small)) .child(Self::render_duration( "Context retrieval", - prediction.retrieval_time, + Some(prediction.retrieval_time), )) .child(Self::render_duration( "Prompt planning", - prediction.prompt_planning_time, - )) - .child(Self::render_duration( - "Inference", - prediction.inference_time, + prompt_planning_time, )) - .child(Self::render_duration("Parsing", prediction.parsing_time)), + .child(Self::render_duration("Inference", inference_time)) + .child(Self::render_duration("Parsing", parsing_time)), ) } - fn render_duration(name: &'static str, time: chrono::TimeDelta) -> Div { + fn render_duration(name: &'static str, time: Option) -> Div { h_flex() .gap_1() .child(Label::new(name).color(Color::Muted).size(LabelSize::Small)) - .child( - Label::new(if time.num_microseconds().unwrap_or(0) >= 1000 { + .child(match time { + Some(time) => Label::new(if time.num_microseconds().unwrap_or(0) >= 1000 { format!("{} ms", time.num_milliseconds()) } else { format!("{} µs", time.num_microseconds().unwrap_or(0)) }) .size(LabelSize::Small), - ) + None => Label::new("...").size(LabelSize::Small), + }) } fn render_content(&self, cx: &mut Context) -> AnyElement { @@ -603,50 +639,55 @@ impl Zeta2Inspector { .items_center() .child(Label::new("No prediction").size(LabelSize::Large)) .into_any(), - Some(LastPredictionState::Success(prediction)) => { - self.render_last_prediction(prediction, cx).into_any() - } - Some(LastPredictionState::Replaying { prediction, _task }) => self - .render_last_prediction(prediction, cx) - .opacity(0.6) - .into_any(), - Some(LastPredictionState::Failed(err)) => v_flex() - .p_4() - .gap_2() - .child(Label::new(err.clone()).buffer_font(cx)) - .into_any(), + Some(prediction) => self.render_last_prediction(prediction, cx).into_any(), } } fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context) -> Div { match &self.active_view { ActiveView::Context => div().size_full().child(prediction.context_editor.clone()), - ActiveView::Inference => h_flex() - .items_start() - .w_full() - .flex_1() - .border_t_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().editor_background) - .child( - v_flex() - .flex_1() - .gap_2() - .p_4() - .h_full() - .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall)) - .child(prediction.prompt_editor.clone()), - ) - .child(ui::vertical_divider()) - .child( - v_flex() - .flex_1() - .gap_2() - .h_full() - .p_4() - .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall)) - .child(prediction.model_response_editor.clone()), - ), + ActiveView::Inference => match &prediction.state { + LastPredictionState::Success { + prompt_editor, + model_response_editor, + .. + } => h_flex() + .items_start() + .w_full() + .flex_1() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) + .child( + v_flex() + .flex_1() + .gap_2() + .p_4() + .h_full() + .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall)) + .child(prompt_editor.clone()), + ) + .child(ui::vertical_divider()) + .child( + v_flex() + .flex_1() + .gap_2() + .h_full() + .p_4() + .child( + ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall), + ) + .child(model_response_editor.clone()), + ), + LastPredictionState::Requested => v_flex() + .p_4() + .gap_2() + .child(Label::new("Loading...").buffer_font(cx)), + LastPredictionState::Failed { message } => v_flex() + .p_4() + .gap_2() + .child(Label::new(message.clone()).buffer_font(cx)), + }, } } }