@@ -11,7 +11,7 @@ use edit_prediction_context::{
EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
};
use futures::AsyncReadExt as _;
-use futures::channel::mpsc;
+use futures::channel::{mpsc, oneshot};
use gpui::http_client::Method;
use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
@@ -76,7 +76,7 @@ pub struct Zeta {
projects: HashMap<EntityId, ZetaProject>,
options: ZetaOptions,
update_required: bool,
- debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
+ debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
}
#[derive(Debug, Clone, PartialEq)]
@@ -91,9 +91,9 @@ pub struct ZetaOptions {
pub struct PredictionDebugInfo {
pub context: EditPredictionContext,
pub retrieval_time: TimeDelta,
- pub request: RequestDebugInfo,
pub buffer: WeakEntity<Buffer>,
pub position: language::Anchor,
+ pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
@@ -204,7 +204,7 @@ impl Zeta {
}
}
- pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
+ pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
self.debug_tx = Some(debug_watch_tx);
debug_watch_rx
@@ -537,8 +537,22 @@ impl Zeta {
return Ok(None);
};
- let debug_context = if let Some(debug_tx) = debug_tx {
- Some((debug_tx, context.clone()))
+ let retrieval_time = chrono::Utc::now() - before_retrieval;
+
+ let debug_response_tx = if let Some(debug_tx) = debug_tx {
+ let (response_tx, response_rx) = oneshot::channel();
+ let context = context.clone();
+
+ debug_tx
+ .unbounded_send(PredictionDebugInfo {
+ context,
+ retrieval_time,
+ buffer: buffer.downgrade(),
+ position,
+ response_rx,
+ })
+ .ok();
+ Some(response_tx)
} else {
None
};
@@ -560,32 +574,21 @@ impl Zeta {
diagnostic_groups,
diagnostic_groups_truncated,
None,
- debug_context.is_some(),
+ debug_response_tx.is_some(),
&worktree_snapshots,
index_state.as_deref(),
Some(options.max_prompt_bytes),
options.prompt_format,
);
- let retrieval_time = chrono::Utc::now() - before_retrieval;
let response = Self::perform_request(client, llm_token, app_version, request).await;
- if let Some((debug_tx, context)) = debug_context {
- debug_tx
- .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
- |response| {
- let Some(request) =
- some_or_debug_panic(response.0.debug_info.clone())
- else {
- return Err("Missing debug info".to_string());
- };
- Ok(PredictionDebugInfo {
- context,
- request,
- retrieval_time,
- buffer: buffer.downgrade(),
- position,
- })
+ if let Some(debug_response_tx) = debug_response_tx {
+ debug_response_tx
+ .send(response.as_ref().map_err(|err| err.to_string()).and_then(
+ |response| match some_or_debug_panic(response.0.debug_info.clone()) {
+ Some(debug_info) => Ok(debug_info),
+ None => Err("Missing debug info".to_string()),
},
))
.ok();
@@ -5,7 +5,7 @@ use client::{Client, UserStore};
use cloud_llm_client::predict_edits_v3::PromptFormat;
use collections::HashMap;
use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
-use futures::StreamExt as _;
+use futures::{StreamExt as _, channel::oneshot};
use gpui::{
Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
prelude::*,
@@ -16,7 +16,7 @@ use ui::{ContextMenu, ContextMenuEntry, DropdownMenu, prelude::*};
use ui_input::SingleLineInput;
use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
use workspace::{Item, SplitDirection, Workspace};
-use zeta2::{DEFAULT_CONTEXT_OPTIONS, Zeta, ZetaOptions};
+use zeta2::{DEFAULT_CONTEXT_OPTIONS, PredictionDebugInfo, Zeta, ZetaOptions};
use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions};
@@ -56,7 +56,7 @@ pub fn init(cx: &mut App) {
pub struct Zeta2Inspector {
focus_handle: FocusHandle,
project: Entity<Project>,
- last_prediction: Option<LastPredictionState>,
+ last_prediction: Option<LastPrediction>,
max_excerpt_bytes_input: Entity<SingleLineInput>,
min_excerpt_bytes_input: Entity<SingleLineInput>,
cursor_context_ratio_input: Entity<SingleLineInput>,
@@ -74,25 +74,27 @@ enum ActiveView {
Inference,
}
-enum LastPredictionState {
- Failed(SharedString),
- Success(LastPrediction),
- Replaying {
- prediction: LastPrediction,
- _task: Task<()>,
- },
-}
-
struct LastPrediction {
context_editor: Entity<Editor>,
retrieval_time: TimeDelta,
- prompt_planning_time: TimeDelta,
- inference_time: TimeDelta,
- parsing_time: TimeDelta,
- prompt_editor: Entity<Editor>,
- model_response_editor: Entity<Editor>,
buffer: WeakEntity<Buffer>,
position: language::Anchor,
+ state: LastPredictionState,
+ _task: Option<Task<()>>,
+}
+
+enum LastPredictionState {
+ Requested,
+ Success {
+ inference_time: TimeDelta,
+ parsing_time: TimeDelta,
+ prompt_planning_time: TimeDelta,
+ prompt_editor: Entity<Editor>,
+ model_response_editor: Entity<Editor>,
+ },
+ Failed {
+ message: String,
+ },
}
impl Zeta2Inspector {
@@ -107,15 +109,9 @@ impl Zeta2Inspector {
let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info());
let receive_task = cx.spawn_in(window, async move |this, cx| {
- while let Some(prediction_result) = request_rx.next().await {
- this.update_in(cx, |this, window, cx| match prediction_result {
- Ok(prediction) => {
- this.update_last_prediction(prediction, window, cx);
- }
- Err(err) => {
- this.last_prediction = Some(LastPredictionState::Failed(err.into()));
- cx.notify();
- }
+ while let Some(prediction) = request_rx.next().await {
+ this.update_in(cx, |this, window, cx| {
+ this.update_last_prediction(prediction, window, cx)
})
.ok();
}
@@ -175,16 +171,12 @@ impl Zeta2Inspector {
const THROTTLE_TIME: Duration = Duration::from_millis(100);
- if let Some(
- LastPredictionState::Success(prediction)
- | LastPredictionState::Replaying { prediction, .. },
- ) = self.last_prediction.take()
- {
+ if let Some(prediction) = self.last_prediction.as_mut() {
if let Some(buffer) = prediction.buffer.upgrade() {
let position = prediction.position;
let zeta = self.zeta.clone();
let project = self.project.clone();
- let task = cx.spawn(async move |_this, cx| {
+ prediction._task = Some(cx.spawn(async move |_this, cx| {
cx.background_executor().timer(THROTTLE_TIME).await;
if let Some(task) = zeta
.update(cx, |zeta, cx| {
@@ -194,13 +186,10 @@ impl Zeta2Inspector {
{
task.await.log_err();
}
- });
- self.last_prediction = Some(LastPredictionState::Replaying {
- prediction,
- _task: task,
- });
+ }));
+ prediction.state = LastPredictionState::Requested;
} else {
- self.last_prediction = Some(LastPredictionState::Failed("Buffer dropped".into()));
+ self.last_prediction.take();
}
}
@@ -383,47 +372,86 @@ impl Zeta2Inspector {
Editor::new(EditorMode::full(), multibuffer, None, window, cx)
});
- let last_prediction = LastPrediction {
+ let PredictionDebugInfo {
+ response_rx,
+ position,
+ buffer,
+ retrieval_time,
+ ..
+ } = prediction;
+
+ let task = cx.spawn_in(window, async move |this, cx| {
+ let response = response_rx.await;
+
+ this.update_in(cx, |this, window, cx| {
+ if let Some(prediction) = this.last_prediction.as_mut() {
+ prediction.state = match response {
+ Ok(Ok(response)) => LastPredictionState::Success {
+ prompt_planning_time: response.prompt_planning_time,
+ inference_time: response.inference_time,
+ parsing_time: response.parsing_time,
+ prompt_editor: cx.new(|cx| {
+ let buffer = cx.new(|cx| {
+ let mut buffer = Buffer::local(response.prompt, cx);
+ buffer.set_language(markdown_language.clone(), cx);
+ buffer
+ });
+ let buffer =
+ cx.new(|cx| MultiBuffer::singleton(buffer, cx));
+ let mut editor = Editor::new(
+ EditorMode::full(),
+ buffer,
+ None,
+ window,
+ cx,
+ );
+ editor.set_read_only(true);
+ editor.set_show_line_numbers(false, cx);
+ editor.set_show_gutter(false, cx);
+ editor.set_show_scrollbars(false, cx);
+ editor
+ }),
+ model_response_editor: cx.new(|cx| {
+ let buffer = cx.new(|cx| {
+ let mut buffer =
+ Buffer::local(response.model_response, cx);
+ buffer.set_language(markdown_language, cx);
+ buffer
+ });
+ let buffer =
+ cx.new(|cx| MultiBuffer::singleton(buffer, cx));
+ let mut editor = Editor::new(
+ EditorMode::full(),
+ buffer,
+ None,
+ window,
+ cx,
+ );
+ editor.set_read_only(true);
+ editor.set_show_line_numbers(false, cx);
+ editor.set_show_gutter(false, cx);
+ editor.set_show_scrollbars(false, cx);
+ editor
+ }),
+ },
+ Ok(Err(err)) => LastPredictionState::Failed { message: err },
+ Err(oneshot::Canceled) => LastPredictionState::Failed {
+ message: "Canceled".to_string(),
+ },
+ };
+ }
+ })
+ .ok();
+ });
+
+ this.last_prediction = Some(LastPrediction {
context_editor,
- prompt_editor: cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(prediction.request.prompt, cx);
- buffer.set_language(markdown_language.clone(), cx);
- buffer
- });
- let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
- let mut editor =
- Editor::new(EditorMode::full(), buffer, None, window, cx);
- editor.set_read_only(true);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- }),
- model_response_editor: cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer =
- Buffer::local(prediction.request.model_response, cx);
- buffer.set_language(markdown_language, cx);
- buffer
- });
- let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
- let mut editor =
- Editor::new(EditorMode::full(), buffer, None, window, cx);
- editor.set_read_only(true);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- }),
- retrieval_time: prediction.retrieval_time,
- prompt_planning_time: prediction.request.prompt_planning_time,
- inference_time: prediction.request.inference_time,
- parsing_time: prediction.request.parsing_time,
- buffer: prediction.buffer,
- position: prediction.position,
- };
- this.last_prediction = Some(LastPredictionState::Success(last_prediction));
+ retrieval_time,
+ buffer,
+ position,
+ state: LastPredictionState::Requested,
+ _task: Some(task),
+ });
cx.notify();
})
.ok();
@@ -514,9 +542,7 @@ impl Zeta2Inspector {
}
fn render_tabs(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
- let Some(LastPredictionState::Success { .. } | LastPredictionState::Replaying { .. }) =
- self.last_prediction.as_ref()
- else {
+ if self.last_prediction.is_none() {
return None;
};
@@ -551,14 +577,26 @@ impl Zeta2Inspector {
}
fn render_stats(&self) -> Option<Div> {
- let Some(
- LastPredictionState::Success(prediction)
- | LastPredictionState::Replaying { prediction, .. },
- ) = self.last_prediction.as_ref()
- else {
+ let Some(prediction) = self.last_prediction.as_ref() else {
return None;
};
+ let (prompt_planning_time, inference_time, parsing_time) = match &prediction.state {
+ LastPredictionState::Success {
+ inference_time,
+ parsing_time,
+ prompt_planning_time,
+ ..
+ } => (
+ Some(*prompt_planning_time),
+ Some(*inference_time),
+ Some(*parsing_time),
+ ),
+ LastPredictionState::Requested | LastPredictionState::Failed { .. } => {
+ (None, None, None)
+ }
+ };
+
Some(
v_flex()
.p_4()
@@ -567,32 +605,30 @@ impl Zeta2Inspector {
.child(Headline::new("Stats").size(HeadlineSize::Small))
.child(Self::render_duration(
"Context retrieval",
- prediction.retrieval_time,
+ Some(prediction.retrieval_time),
))
.child(Self::render_duration(
"Prompt planning",
- prediction.prompt_planning_time,
- ))
- .child(Self::render_duration(
- "Inference",
- prediction.inference_time,
+ prompt_planning_time,
))
- .child(Self::render_duration("Parsing", prediction.parsing_time)),
+ .child(Self::render_duration("Inference", inference_time))
+ .child(Self::render_duration("Parsing", parsing_time)),
)
}
- fn render_duration(name: &'static str, time: chrono::TimeDelta) -> Div {
+ fn render_duration(name: &'static str, time: Option<chrono::TimeDelta>) -> Div {
h_flex()
.gap_1()
.child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
- .child(
- Label::new(if time.num_microseconds().unwrap_or(0) >= 1000 {
+ .child(match time {
+ Some(time) => Label::new(if time.num_microseconds().unwrap_or(0) >= 1000 {
format!("{} ms", time.num_milliseconds())
} else {
format!("{} ยตs", time.num_microseconds().unwrap_or(0))
})
.size(LabelSize::Small),
- )
+ None => Label::new("...").size(LabelSize::Small),
+ })
}
fn render_content(&self, cx: &mut Context<Self>) -> AnyElement {
@@ -603,50 +639,55 @@ impl Zeta2Inspector {
.items_center()
.child(Label::new("No prediction").size(LabelSize::Large))
.into_any(),
- Some(LastPredictionState::Success(prediction)) => {
- self.render_last_prediction(prediction, cx).into_any()
- }
- Some(LastPredictionState::Replaying { prediction, _task }) => self
- .render_last_prediction(prediction, cx)
- .opacity(0.6)
- .into_any(),
- Some(LastPredictionState::Failed(err)) => v_flex()
- .p_4()
- .gap_2()
- .child(Label::new(err.clone()).buffer_font(cx))
- .into_any(),
+ Some(prediction) => self.render_last_prediction(prediction, cx).into_any(),
}
}
fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
match &self.active_view {
ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
- ActiveView::Inference => h_flex()
- .items_start()
- .w_full()
- .flex_1()
- .border_t_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().colors().editor_background)
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .p_4()
- .h_full()
- .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
- .child(prediction.prompt_editor.clone()),
- )
- .child(ui::vertical_divider())
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .h_full()
- .p_4()
- .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall))
- .child(prediction.model_response_editor.clone()),
- ),
+ ActiveView::Inference => match &prediction.state {
+ LastPredictionState::Success {
+ prompt_editor,
+ model_response_editor,
+ ..
+ } => h_flex()
+ .items_start()
+ .w_full()
+ .flex_1()
+ .border_t_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ v_flex()
+ .flex_1()
+ .gap_2()
+ .p_4()
+ .h_full()
+ .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
+ .child(prompt_editor.clone()),
+ )
+ .child(ui::vertical_divider())
+ .child(
+ v_flex()
+ .flex_1()
+ .gap_2()
+ .h_full()
+ .p_4()
+ .child(
+ ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall),
+ )
+ .child(model_response_editor.clone()),
+ ),
+ LastPredictionState::Requested => v_flex()
+ .p_4()
+ .gap_2()
+ .child(Label::new("Loading...").buffer_font(cx)),
+ LastPredictionState::Failed { message } => v_flex()
+ .p_4()
+ .gap_2()
+ .child(Label::new(message.clone()).buffer_font(cx)),
+ },
}
}
}