zeta2 inspector: Plan prompt locally (#39811)

Agus Zubiaga and Michael Sloan created

Plans and displays the prompt locally before the response arrives.
Helpful while debugging prompt planning.

Release Notes:

- N/A

---------

Co-authored-by: Michael Sloan <mgsloan@gmail.com>

Change summary

crates/zeta2/src/zeta2.rs             |  56 ++++--
crates/zeta2_tools/src/zeta2_tools.rs | 235 ++++++++++++++++------------
2 files changed, 168 insertions(+), 123 deletions(-)

Detailed changes

crates/zeta2/src/zeta2.rs 🔗

@@ -5,7 +5,7 @@ use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
 use cloud_llm_client::{
     EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
 };
-use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
+use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
 use edit_prediction_context::{
     DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
     EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
@@ -93,6 +93,7 @@ pub struct PredictionDebugInfo {
     pub retrieval_time: TimeDelta,
     pub buffer: WeakEntity<Buffer>,
     pub position: language::Anchor,
+    pub local_prompt: Result<String, String>,
     pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
 }
 
@@ -539,24 +540,6 @@ impl Zeta {
 
                 let retrieval_time = chrono::Utc::now() - before_retrieval;
 
-                let debug_response_tx = if let Some(debug_tx) = debug_tx {
-                    let (response_tx, response_rx) = oneshot::channel();
-                    let context = context.clone();
-
-                    debug_tx
-                        .unbounded_send(PredictionDebugInfo {
-                            context,
-                            retrieval_time,
-                            buffer: buffer.downgrade(),
-                            position,
-                            response_rx,
-                        })
-                        .ok();
-                    Some(response_tx)
-                } else {
-                    None
-                };
-
                 let (diagnostic_groups, diagnostic_groups_truncated) =
                     Self::gather_nearby_diagnostics(
                         cursor_offset,
@@ -565,6 +548,8 @@ impl Zeta {
                         options.max_diagnostic_bytes,
                     );
 
+                let debug_context = debug_tx.map(|tx| (tx, context.clone()));
+
                 let request = make_cloud_request(
                     excerpt_path,
                     context,
@@ -574,13 +559,44 @@ impl Zeta {
                     diagnostic_groups,
                     diagnostic_groups_truncated,
                     None,
-                    debug_response_tx.is_some(),
+                    debug_context.is_some(),
                     &worktree_snapshots,
                     index_state.as_deref(),
                     Some(options.max_prompt_bytes),
                     options.prompt_format,
                 );
 
+                let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
+                    let (response_tx, response_rx) = oneshot::channel();
+
+                    let local_prompt = PlannedPrompt::populate(&request)
+                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
+                        .map_err(|err| err.to_string());
+
+                    debug_tx
+                        .unbounded_send(PredictionDebugInfo {
+                            context,
+                            retrieval_time,
+                            buffer: buffer.downgrade(),
+                            local_prompt,
+                            position,
+                            response_rx,
+                        })
+                        .ok();
+                    Some(response_tx)
+                } else {
+                    None
+                };
+
+                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
+                    if let Some(debug_response_tx) = debug_response_tx {
+                        debug_response_tx
+                            .send(Err("Request skipped".to_string()))
+                            .ok();
+                    }
+                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
+                }
+
                 let response = Self::perform_request(client, llm_token, app_version, request).await;
 
                 if let Some(debug_response_tx) = debug_response_tx {

crates/zeta2_tools/src/zeta2_tools.rs 🔗

@@ -76,6 +76,7 @@ enum ActiveView {
 
 struct LastPrediction {
     context_editor: Entity<Editor>,
+    prompt_editor: Entity<Editor>,
     retrieval_time: TimeDelta,
     buffer: WeakEntity<Buffer>,
     position: language::Anchor,
@@ -89,7 +90,6 @@ enum LastPredictionState {
         inference_time: TimeDelta,
         parsing_time: TimeDelta,
         prompt_planning_time: TimeDelta,
-        prompt_editor: Entity<Editor>,
         model_response_editor: Entity<Editor>,
     },
     Failed {
@@ -377,75 +377,92 @@ impl Zeta2Inspector {
                         position,
                         buffer,
                         retrieval_time,
+                        local_prompt,
                         ..
                     } = prediction;
 
-                    let task = cx.spawn_in(window, async move |this, cx| {
-                        let response = response_rx.await;
-
-                        this.update_in(cx, |this, window, cx| {
-                            if let Some(prediction) = this.last_prediction.as_mut() {
-                                prediction.state = match response {
-                                    Ok(Ok(response)) => LastPredictionState::Success {
-                                        prompt_planning_time: response.prompt_planning_time,
-                                        inference_time: response.inference_time,
-                                        parsing_time: response.parsing_time,
-                                        prompt_editor: cx.new(|cx| {
-                                            let buffer = cx.new(|cx| {
-                                                let mut buffer = Buffer::local(response.prompt, cx);
-                                                buffer.set_language(markdown_language.clone(), cx);
-                                                buffer
-                                            });
-                                            let buffer =
-                                                cx.new(|cx| MultiBuffer::singleton(buffer, cx));
-                                            let mut editor = Editor::new(
-                                                EditorMode::full(),
-                                                buffer,
-                                                None,
-                                                window,
-                                                cx,
-                                            );
-                                            editor.set_read_only(true);
-                                            editor.set_show_line_numbers(false, cx);
-                                            editor.set_show_gutter(false, cx);
-                                            editor.set_show_scrollbars(false, cx);
-                                            editor
-                                        }),
-                                        model_response_editor: cx.new(|cx| {
-                                            let buffer = cx.new(|cx| {
-                                                let mut buffer =
-                                                    Buffer::local(response.model_response, cx);
-                                                buffer.set_language(markdown_language, cx);
-                                                buffer
-                                            });
-                                            let buffer =
-                                                cx.new(|cx| MultiBuffer::singleton(buffer, cx));
-                                            let mut editor = Editor::new(
-                                                EditorMode::full(),
-                                                buffer,
-                                                None,
-                                                window,
+                    let task = cx.spawn_in(window, {
+                        let markdown_language = markdown_language.clone();
+                        async move |this, cx| {
+                            let response = response_rx.await;
+
+                            this.update_in(cx, |this, window, cx| {
+                                if let Some(prediction) = this.last_prediction.as_mut() {
+                                    prediction.state = match response {
+                                        Ok(Ok(response)) => {
+                                            prediction.prompt_editor.update(
                                                 cx,
+                                                |prompt_editor, cx| {
+                                                    prompt_editor.set_text(
+                                                        response.prompt,
+                                                        window,
+                                                        cx,
+                                                    );
+                                                },
                                             );
-                                            editor.set_read_only(true);
-                                            editor.set_show_line_numbers(false, cx);
-                                            editor.set_show_gutter(false, cx);
-                                            editor.set_show_scrollbars(false, cx);
-                                            editor
-                                        }),
-                                    },
-                                    Ok(Err(err)) => LastPredictionState::Failed { message: err },
-                                    Err(oneshot::Canceled) => LastPredictionState::Failed {
-                                        message: "Canceled".to_string(),
-                                    },
-                                };
-                            }
-                        })
-                        .ok();
+
+                                            LastPredictionState::Success {
+                                                prompt_planning_time: response.prompt_planning_time,
+                                                inference_time: response.inference_time,
+                                                parsing_time: response.parsing_time,
+                                                model_response_editor: cx.new(|cx| {
+                                                    let buffer = cx.new(|cx| {
+                                                        let mut buffer = Buffer::local(
+                                                            response.model_response,
+                                                            cx,
+                                                        );
+                                                        buffer.set_language(markdown_language, cx);
+                                                        buffer
+                                                    });
+                                                    let buffer = cx.new(|cx| {
+                                                        MultiBuffer::singleton(buffer, cx)
+                                                    });
+                                                    let mut editor = Editor::new(
+                                                        EditorMode::full(),
+                                                        buffer,
+                                                        None,
+                                                        window,
+                                                        cx,
+                                                    );
+                                                    editor.set_read_only(true);
+                                                    editor.set_show_line_numbers(false, cx);
+                                                    editor.set_show_gutter(false, cx);
+                                                    editor.set_show_scrollbars(false, cx);
+                                                    editor
+                                                }),
+                                            }
+                                        }
+                                        Ok(Err(err)) => {
+                                            LastPredictionState::Failed { message: err }
+                                        }
+                                        Err(oneshot::Canceled) => LastPredictionState::Failed {
+                                            message: "Canceled".to_string(),
+                                        },
+                                    };
+                                }
+                            })
+                            .ok();
+                        }
                     });
 
                     this.last_prediction = Some(LastPrediction {
                         context_editor,
+                        prompt_editor: cx.new(|cx| {
+                            let buffer = cx.new(|cx| {
+                                let mut buffer =
+                                    Buffer::local(local_prompt.unwrap_or_else(|err| err), cx);
+                                buffer.set_language(markdown_language.clone(), cx);
+                                buffer
+                            });
+                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
+                            let mut editor =
+                                Editor::new(EditorMode::full(), buffer, None, window, cx);
+                            editor.set_read_only(true);
+                            editor.set_show_line_numbers(false, cx);
+                            editor.set_show_gutter(false, cx);
+                            editor.set_show_scrollbars(false, cx);
+                            editor
+                        }),
                         retrieval_time,
                         buffer,
                         position,
@@ -646,48 +663,60 @@ impl Zeta2Inspector {
     fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
         match &self.active_view {
             ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
-            ActiveView::Inference => match &prediction.state {
-                LastPredictionState::Success {
-                    prompt_editor,
-                    model_response_editor,
-                    ..
-                } => h_flex()
-                    .items_start()
-                    .w_full()
-                    .flex_1()
-                    .border_t_1()
-                    .border_color(cx.theme().colors().border)
-                    .bg(cx.theme().colors().editor_background)
-                    .child(
-                        v_flex()
-                            .flex_1()
-                            .gap_2()
-                            .p_4()
-                            .h_full()
-                            .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
-                            .child(prompt_editor.clone()),
-                    )
-                    .child(ui::vertical_divider())
-                    .child(
-                        v_flex()
-                            .flex_1()
-                            .gap_2()
-                            .h_full()
-                            .p_4()
-                            .child(
-                                ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall),
-                            )
-                            .child(model_response_editor.clone()),
-                    ),
-                LastPredictionState::Requested => v_flex()
-                    .p_4()
-                    .gap_2()
-                    .child(Label::new("Loading...").buffer_font(cx)),
-                LastPredictionState::Failed { message } => v_flex()
-                    .p_4()
-                    .gap_2()
-                    .child(Label::new(message.clone()).buffer_font(cx)),
-            },
+            ActiveView::Inference => h_flex()
+                .items_start()
+                .w_full()
+                .flex_1()
+                .border_t_1()
+                .border_color(cx.theme().colors().border)
+                .bg(cx.theme().colors().editor_background)
+                .child(
+                    v_flex()
+                        .flex_1()
+                        .gap_2()
+                        .p_4()
+                        .h_full()
+                        .child(
+                            h_flex()
+                                .justify_between()
+                                .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
+                                .child(match prediction.state {
+                                    LastPredictionState::Requested
+                                    | LastPredictionState::Failed { .. } => ui::Chip::new("Local")
+                                        .bg_color(cx.theme().status().warning_background)
+                                        .label_color(Color::Success),
+                                    LastPredictionState::Success { .. } => ui::Chip::new("Cloud")
+                                        .bg_color(cx.theme().status().success_background)
+                                        .label_color(Color::Success),
+                                }),
+                        )
+                        .child(prediction.prompt_editor.clone()),
+                )
+                .child(ui::vertical_divider())
+                .child(
+                    v_flex()
+                        .flex_1()
+                        .gap_2()
+                        .h_full()
+                        .p_4()
+                        .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall))
+                        .child(match &prediction.state {
+                            LastPredictionState::Success {
+                                model_response_editor,
+                                ..
+                            } => model_response_editor.clone().into_any_element(),
+                            LastPredictionState::Requested => v_flex()
+                                .p_4()
+                                .gap_2()
+                                .child(Label::new("Loading...").buffer_font(cx))
+                                .into_any(),
+                            LastPredictionState::Failed { message } => v_flex()
+                                .p_4()
+                                .gap_2()
+                                .child(Label::new(message.clone()).buffer_font(cx))
+                                .into_any(),
+                        }),
+                ),
         }
     }
 }