diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 97846a8edf63cae577eb17d49ee835b43295be35..8e940894704be753846330c60885023f8808740d 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -1397,6 +1397,13 @@ "end": "settings_editor::FocusLastNavEntry" } }, + { + "context": "Zeta2Inspector > Editor", + "bindings": { + "ctrl-h": "dev::Zeta2InspectorPrevious", + "ctrl-l": "dev::Zeta2InspectorNext" + } + }, { "context": "Zeta2Feedback > Editor", "bindings": { diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 42eb565502e6568491e820dfb5c0921e4d56039b..1deb0064a6588c13df3f2b146fd1c6af57300027 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -12,8 +12,9 @@ use edit_prediction_context::{ EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState, }; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; -use futures::AsyncReadExt as _; use futures::channel::{mpsc, oneshot}; +use futures::future::{BoxFuture, Shared}; +use futures::{AsyncReadExt as _, FutureExt}; use gpui::http_client::{AsyncBody, Method}; use gpui::{ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, @@ -89,7 +90,7 @@ pub struct Zeta { projects: HashMap, options: ZetaOptions, update_required: bool, - debug_tx: Option>, + observe_tx: Option>, } #[derive(Debug, Clone, PartialEq)] @@ -101,13 +102,15 @@ pub struct ZetaOptions { pub file_indexing_parallelism: usize, } -pub struct PredictionDebugInfo { - pub request: predict_edits_v3::PredictEditsRequest, +#[derive(Clone)] +pub struct ObservedPredictionRequest { + pub full_request: predict_edits_v3::PredictEditsRequest, pub retrieval_time: TimeDelta, pub buffer: WeakEntity, pub position: language::Anchor, pub local_prompt: Result, - pub response_rx: oneshot::Receiver>, + pub response: + Shared>>, } pub type RequestDebugInfo = predict_edits_v3::DebugInfo; @@ -224,14 +227,14 @@ impl Zeta { }, ), update_required: false, - debug_tx: None, + observe_tx: None, } } - pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { - let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); - self.debug_tx = Some(debug_watch_tx); - debug_watch_rx + pub fn observe_requests(&mut self) -> mpsc::UnboundedReceiver { + let (tx, rx) = mpsc::unbounded(); + self.observe_tx = Some(tx); + rx } pub fn options(&self) -> &ZetaOptions { @@ -518,7 +521,7 @@ impl Zeta { .worktrees(cx) .map(|worktree| worktree.read(cx).snapshot()) .collect::>(); - let debug_tx = self.debug_tx.clone(); + let observe_tx = self.observe_tx.clone(); let events = project_state .map(|state| { @@ -616,28 +619,35 @@ impl Zeta { diagnostic_groups, diagnostic_groups_truncated, None, - debug_tx.is_some(), + observe_tx.is_some(), &worktree_snapshots, index_state.as_deref(), Some(options.max_prompt_bytes), options.prompt_format, ); - let debug_response_tx = if let Some(debug_tx) = &debug_tx { + let observe_response_tx = if let Some(observe_tx) = &observe_tx { let (response_tx, response_rx) = oneshot::channel(); let local_prompt = PlannedPrompt::populate(&request) .and_then(|p| p.to_prompt_string().map(|p| p.0)) .map_err(|err| err.to_string()); - debug_tx - .unbounded_send(PredictionDebugInfo { - request: request.clone(), + observe_tx + .unbounded_send(ObservedPredictionRequest { + full_request: request.clone(), retrieval_time, buffer: buffer.downgrade(), local_prompt, position, - response_rx, + response: async move { + let resp = response_rx + .await + .map_err(|_: oneshot::Canceled| "Canceled".to_string()); + Ok(resp??) + } + .boxed() + .shared(), }) .ok(); Some(response_tx) @@ -646,8 +656,8 @@ impl Zeta { }; if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { - if let Some(debug_response_tx) = debug_response_tx { - debug_response_tx + if let Some(observe_response_tx) = observe_response_tx { + observe_response_tx .send(Err("Request skipped".to_string())) .ok(); } @@ -657,7 +667,7 @@ impl Zeta { let response = Self::send_prediction_request(client, llm_token, app_version, request).await; - if let Some(debug_response_tx) = debug_response_tx { + if let Some(debug_response_tx) = observe_response_tx { debug_response_tx .send( response diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 7b806e2b9a4ba7c7dbda41bb0f5750e5d2b9ff97..844d1abe4a8c98b81184358ac069e29a2682c3da 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -1,14 +1,16 @@ -use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration}; +use std::{ + cmp::Reverse, collections::hash_map::Entry, ops::Add as _, path::PathBuf, str::FromStr, + sync::Arc, time::Duration, +}; -use chrono::TimeDelta; use client::{Client, UserStore}; use cloud_llm_client::predict_edits_v3::{ - self, DeclarationScoreComponents, PredictEditsRequest, PredictEditsResponse, PromptFormat, + self, DeclarationScoreComponents, PredictEditsResponse, PromptFormat, }; -use collections::HashMap; +use collections::{HashMap, HashSet}; use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer}; use feature_flags::FeatureFlagAppExt as _; -use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared}; +use futures::{FutureExt, StreamExt as _, future::Shared}; use gpui::{ CursorStyle, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions, prelude::*, @@ -20,7 +22,7 @@ use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, pr use ui_input::SingleLineInput; use util::{ResultExt, paths::PathStyle, rel_path::RelPath}; use workspace::{Item, SplitDirection, Workspace}; -use zeta2::{PredictionDebugInfo, Zeta, Zeta2FeatureFlag, ZetaOptions}; +use zeta2::{ObservedPredictionRequest, Zeta, Zeta2FeatureFlag, ZetaOptions}; use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions}; @@ -33,6 +35,10 @@ actions!( Zeta2RatePredictionPositive, /// Rate prediction as negative. Zeta2RatePredictionNegative, + /// Go to the previous request in the zeta2 inspector + Zeta2InspectorPrevious, + /// Go to the next request in the zeta2 inspector + Zeta2InspectorNext, ] ); @@ -64,7 +70,8 @@ pub fn init(cx: &mut App) { pub struct Zeta2Inspector { focus_handle: FocusHandle, project: Entity, - last_prediction: Option, + requests: Vec, + current: Option, max_excerpt_bytes_input: Entity, min_excerpt_bytes_input: Entity, cursor_context_ratio_input: Entity, @@ -77,21 +84,24 @@ pub struct Zeta2Inspector { _receive_task: Task<()>, } +struct InspectorObservedPredictionRequest { + observed_request: ObservedPredictionRequest, + project_snapshot: Shared>>, +} + #[derive(PartialEq)] enum ActiveView { Context, Inference, } -struct LastPrediction { +struct CurrentRequest { + index: usize, context_editor: Entity, prompt_editor: Entity, - retrieval_time: TimeDelta, buffer: WeakEntity, position: language::Anchor, - state: LastPredictionState, - request: PredictEditsRequest, - project_snapshot: Shared>>, + state: CurrentPredictionState, _task: Option>, } @@ -101,7 +111,7 @@ enum Feedback { Negative, } -enum LastPredictionState { +enum CurrentPredictionState { Requested, Success { model_response_editor: Entity, @@ -123,12 +133,21 @@ impl Zeta2Inspector { cx: &mut Context, ) -> Self { let zeta = Zeta::global(client, user_store, cx); - let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info()); + let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.observe_requests()); let receive_task = cx.spawn_in(window, async move |this, cx| { - while let Some(prediction) = request_rx.next().await { + while let Some(request) = request_rx.next().await { this.update_in(cx, |this, window, cx| { - this.update_last_prediction(prediction, window, cx) + let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx); + this.requests.push(InspectorObservedPredictionRequest { + observed_request: request, + project_snapshot: cx + .foreground_executor() + .spawn(async move { Arc::new(project_snapshot_task.await) }) + .shared(), + }); + + this.set_current_predition(this.requests.len() - 1, window, cx) }) .ok(); } @@ -137,7 +156,8 @@ impl Zeta2Inspector { let mut this = Self { focus_handle: cx.focus_handle(), project: project.clone(), - last_prediction: None, + requests: Vec::new(), + current: None, active_view: ActiveView::Inference, max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx), min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx), @@ -196,7 +216,7 @@ impl Zeta2Inspector { const THROTTLE_TIME: Duration = Duration::from_millis(100); - if let Some(prediction) = self.last_prediction.as_mut() { + if let Some(prediction) = self.current.as_mut() { if let Some(buffer) = prediction.buffer.upgrade() { let position = prediction.position; let zeta = self.zeta.clone(); @@ -212,9 +232,9 @@ impl Zeta2Inspector { task.await.log_err(); } })); - prediction.state = LastPredictionState::Requested; + prediction.state = CurrentPredictionState::Requested; } else { - self.last_prediction.take(); + self.current.take(); } } @@ -287,12 +307,7 @@ impl Zeta2Inspector { input } - fn update_last_prediction( - &mut self, - prediction: zeta2::PredictionDebugInfo, - window: &mut Window, - cx: &mut Context, - ) { + fn set_current_predition(&mut self, index: usize, window: &mut Window, cx: &mut Context) { let project = self.project.read(cx); let path_style = project.path_style(cx); let Some(worktree_id) = project @@ -301,7 +316,7 @@ impl Zeta2Inspector { .map(|worktree| worktree.read(cx).id()) else { log::error!("Open a worktree to use edit prediction debug view"); - self.last_prediction.take(); + self.current.take(); return; }; @@ -309,20 +324,42 @@ impl Zeta2Inspector { let language_registry = self.project.read(cx).languages().clone(); async move |this, cx| { let mut languages = HashMap::default(); - for ext in prediction - .request - .referenced_declarations - .iter() - .filter_map(|snippet| snippet.path.extension()) - .chain(prediction.request.excerpt_path.extension()) - { - if !languages.contains_key(ext) { + + let file_extensions = this + .read_with(cx, |this, _| { + let mut file_extensions = HashSet::default(); + let request = &this.requests[index]; + + for ext in request + .observed_request + .full_request + .referenced_declarations + .iter() + .filter_map(|snippet| snippet.path.extension()) + .chain( + request + .observed_request + .full_request + .excerpt_path + .extension(), + ) + { + if !file_extensions.contains(ext) { + file_extensions.insert(ext.to_owned()); + } + } + file_extensions + }) + .unwrap(); + + for ext in file_extensions { + if let Entry::Vacant(entry) = languages.entry(ext) { // Most snippets are gonna be the same language, // so we think it's fine to do this sequentially for now - languages.insert( - ext.to_owned(), + let ext_str = entry.key().to_string_lossy().to_string(); + entry.insert( language_registry - .language_for_name_or_extension(&ext.to_string_lossy()) + .language_for_name_or_extension(&ext_str) .await .ok(), ); @@ -335,6 +372,11 @@ impl Zeta2Inspector { .log_err(); this.update_in(cx, |this, window, cx| { + let InspectorObservedPredictionRequest { + observed_request: request, + .. + } = &this.requests[index]; + let context_editor = cx.new(|cx| { let mut excerpt_score_components = HashMap::default(); @@ -348,9 +390,9 @@ impl Zeta2Inspector { let excerpt_buffer = cx.new(|cx| { let mut buffer = - Buffer::local(prediction.request.excerpt.clone(), cx); - if let Some(language) = prediction - .request + Buffer::local(request.full_request.excerpt.clone(), cx); + if let Some(language) = request + .full_request .excerpt_path .extension() .and_then(|ext| languages.get(ext)) @@ -368,7 +410,7 @@ impl Zeta2Inspector { ); let mut declarations = - prediction.request.referenced_declarations.clone(); + request.full_request.referenced_declarations.clone(); declarations.sort_unstable_by_key(|declaration| { Reverse(OrderedFloat(declaration.declaration_score)) }); @@ -419,24 +461,23 @@ impl Zeta2Inspector { editor }); - let PredictionDebugInfo { - response_rx, + let ObservedPredictionRequest { + response, position, buffer, - retrieval_time, local_prompt, .. - } = prediction; + } = request; + let response_task = response.clone(); let task = cx.spawn_in(window, { let markdown_language = markdown_language.clone(); async move |this, cx| { - let response = response_rx.await; - + let response = response_task.await; this.update_in(cx, |this, window, cx| { - if let Some(prediction) = this.last_prediction.as_mut() { + if let Some(prediction) = this.current.as_mut() { prediction.state = match response { - Ok(Ok(response)) => { + Ok(response) => { if let Some(debug_info) = &response.debug_info { prediction.prompt_editor.update( cx, @@ -488,8 +529,8 @@ impl Zeta2Inspector { |this, editor, ev, window, cx| match ev { EditorEvent::BufferEdited => { if let Some(last_prediction) = - this.last_prediction.as_mut() - && let LastPredictionState::Success { + this.current.as_mut() + && let CurrentPredictionState::Success { feedback: feedback_state, .. } = &mut last_prediction.state @@ -511,7 +552,7 @@ impl Zeta2Inspector { ) .detach(); - LastPredictionState::Success { + CurrentPredictionState::Success { model_response_editor: cx.new(|cx| { let buffer = cx.new(|cx| { let mut buffer = Buffer::local( @@ -548,12 +589,7 @@ impl Zeta2Inspector { response, } } - Ok(Err(err)) => { - LastPredictionState::Failed { message: err } - } - Err(oneshot::Canceled) => LastPredictionState::Failed { - message: "Canceled".to_string(), - }, + Err(err) => CurrentPredictionState::Failed { message: err }, }; } }) @@ -561,14 +597,15 @@ impl Zeta2Inspector { } }); - let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx); - - this.last_prediction = Some(LastPrediction { + this.current = Some(CurrentRequest { + index, context_editor, prompt_editor: cx.new(|cx| { let buffer = cx.new(|cx| { - let mut buffer = - Buffer::local(local_prompt.unwrap_or_else(|err| err), cx); + let mut buffer = Buffer::local( + local_prompt.as_ref().unwrap_or_else(|err| err), + cx, + ); buffer.set_language(markdown_language.clone(), cx); buffer }); @@ -581,15 +618,9 @@ impl Zeta2Inspector { editor.set_show_scrollbars(false, cx); editor }), - retrieval_time, - buffer, - position, - state: LastPredictionState::Requested, - project_snapshot: cx - .foreground_executor() - .spawn(async move { Arc::new(project_snapshot_task.await) }) - .shared(), - request: prediction.request, + buffer: buffer.clone(), + position: *position, + state: CurrentPredictionState::Requested, _task: Some(task), }); cx.notify(); @@ -599,6 +630,41 @@ impl Zeta2Inspector { }); } + fn handle_prev( + &mut self, + _action: &Zeta2InspectorPrevious, + window: &mut Window, + cx: &mut Context, + ) { + self.set_current_predition( + self.current + .as_ref() + .map(|c| c.index) + .unwrap_or_default() + .saturating_sub(1), + window, + cx, + ); + } + + fn handle_next( + &mut self, + _action: &Zeta2InspectorPrevious, + window: &mut Window, + cx: &mut Context, + ) { + self.set_current_predition( + self.current + .as_ref() + .map(|c| c.index) + .unwrap_or_default() + .add(1) + .min(self.requests.len() - 1), + window, + cx, + ); + } + fn handle_rate_positive( &mut self, _action: &Zeta2RatePredictionPositive, @@ -618,23 +684,25 @@ impl Zeta2Inspector { } fn handle_rate(&mut self, kind: Feedback, window: &mut Window, cx: &mut Context) { - let Some(last_prediction) = self.last_prediction.as_mut() else { + let Some(last_prediction) = self.current.as_mut() else { return; }; - if !last_prediction.request.can_collect_data { + let request = &self.requests[last_prediction.index]; + if !request.observed_request.full_request.can_collect_data { return; } - let project_snapshot_task = last_prediction.project_snapshot.clone(); + let project_snapshot_task = request.project_snapshot.clone(); cx.spawn_in(window, async move |this, cx| { let project_snapshot = project_snapshot_task.await; this.update_in(cx, |this, window, cx| { - let Some(last_prediction) = this.last_prediction.as_mut() else { + let Some(last_prediction) = this.current.as_mut() else { return; }; - let LastPredictionState::Success { + // todo! move to Self::requests? + let CurrentPredictionState::Success { feedback: feedback_state, feedback_editor, model_response_editor, @@ -670,12 +738,14 @@ impl Zeta2Inspector { Feedback::Negative => "negative", }; + let request = &this.requests[last_prediction.index]; + telemetry::event!( "Zeta2 Prediction Rated", id = response.request_id, kind = kind, text = text, - request = last_prediction.request, + request = request.observed_request.full_request, response = response, project_snapshot = project_snapshot, ); @@ -686,8 +756,8 @@ impl Zeta2Inspector { } fn focus_feedback(&mut self, window: &mut Window, cx: &mut Context) { - if let Some(last_prediction) = self.last_prediction.as_mut() { - if let LastPredictionState::Success { + if let Some(last_prediction) = self.current.as_mut() { + if let CurrentPredictionState::Success { feedback_editor, .. } = &mut last_prediction.state { @@ -780,7 +850,7 @@ impl Zeta2Inspector { } fn render_tabs(&self, cx: &mut Context) -> Option { - if self.last_prediction.is_none() { + if self.current.is_none() { return None; }; @@ -816,19 +886,19 @@ impl Zeta2Inspector { } fn render_stats(&self) -> Option
{ - let Some(prediction) = self.last_prediction.as_ref() else { + let Some(current) = self.current.as_ref() else { return None; }; let (prompt_planning_time, inference_time, parsing_time) = - if let LastPredictionState::Success { + if let CurrentPredictionState::Success { response: PredictEditsResponse { debug_info: Some(debug_info), .. }, .. - } = &prediction.state + } = ¤t.state { ( Some(debug_info.prompt_planning_time), @@ -847,7 +917,7 @@ impl Zeta2Inspector { .child(Headline::new("Stats").size(HeadlineSize::Small)) .child(Self::render_duration( "Context retrieval", - Some(prediction.retrieval_time), + Some(self.requests[current.index].observed_request.retrieval_time), )) .child(Self::render_duration( "Prompt planning", @@ -878,10 +948,10 @@ impl Zeta2Inspector { return Self::render_message("`zeta2` feature flag is not enabled"); } - match self.last_prediction.as_ref() { + match self.current.as_ref() { None => Self::render_message("No prediction"), Some(prediction) => self - .render_last_prediction(prediction, window, cx) + .render_current_prediction(prediction, window, cx) .into_any(), } } @@ -895,14 +965,14 @@ impl Zeta2Inspector { .into_any() } - fn render_last_prediction( + fn render_current_prediction( &self, - prediction: &LastPrediction, + current: &CurrentRequest, window: &mut Window, cx: &mut Context, ) -> Div { match &self.active_view { - ActiveView::Context => div().size_full().child(prediction.context_editor.clone()), + ActiveView::Context => div().size_full().child(current.context_editor.clone()), ActiveView::Inference => h_flex() .items_start() .w_full() @@ -920,17 +990,21 @@ impl Zeta2Inspector { h_flex() .justify_between() .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall)) - .child(match prediction.state { - LastPredictionState::Requested - | LastPredictionState::Failed { .. } => ui::Chip::new("Local") - .bg_color(cx.theme().status().warning_background) - .label_color(Color::Success), - LastPredictionState::Success { .. } => ui::Chip::new("Cloud") - .bg_color(cx.theme().status().success_background) - .label_color(Color::Success), + .child(match current.state { + CurrentPredictionState::Requested + | CurrentPredictionState::Failed { .. } => { + ui::Chip::new("Local") + .bg_color(cx.theme().status().warning_background) + .label_color(Color::Success) + } + CurrentPredictionState::Success { .. } => { + ui::Chip::new("Cloud") + .bg_color(cx.theme().status().success_background) + .label_color(Color::Success) + } }), ) - .child(prediction.prompt_editor.clone()), + .child(current.prompt_editor.clone()), ) .child(ui::vertical_divider()) .child( @@ -947,16 +1021,16 @@ impl Zeta2Inspector { ui::Headline::new("Model Response") .size(ui::HeadlineSize::XSmall), ) - .child(match &prediction.state { - LastPredictionState::Success { + .child(match ¤t.state { + CurrentPredictionState::Success { model_response_editor, .. } => model_response_editor.clone().into_any_element(), - LastPredictionState::Requested => v_flex() + CurrentPredictionState::Requested => v_flex() .gap_2() .child(Label::new("Loading...").buffer_font(cx)) .into_any_element(), - LastPredictionState::Failed { message } => v_flex() + CurrentPredictionState::Failed { message } => v_flex() .gap_2() .max_w_96() .child(Label::new(message.clone()).buffer_font(cx)) @@ -965,12 +1039,15 @@ impl Zeta2Inspector { ) .child(ui::divider()) .child( - if prediction.request.can_collect_data - && let LastPredictionState::Success { + if self.requests[current.index] + .observed_request + .full_request + .can_collect_data + && let CurrentPredictionState::Success { feedback_editor, feedback: feedback_state, .. - } = &prediction.state + } = ¤t.state { v_flex() .key_context("Zeta2Feedback") @@ -1063,7 +1140,11 @@ impl EventEmitter<()> for Zeta2Inspector {} impl Render for Zeta2Inspector { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() + .track_focus(&self.focus_handle) + .key_context("Zeta2Inspector") .size_full() + .on_action(cx.listener(Self::handle_prev)) + .on_action(cx.listener(Self::handle_next)) .bg(cx.theme().colors().editor_background) .child( h_flex()