@@ -5,7 +5,7 @@ use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
use cloud_llm_client::{
EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
-use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
+use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
use edit_prediction_context::{
DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
@@ -93,6 +93,7 @@ pub struct PredictionDebugInfo {
pub retrieval_time: TimeDelta,
pub buffer: WeakEntity<Buffer>,
pub position: language::Anchor,
+ pub local_prompt: Result<String, String>,
pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
}
@@ -539,24 +540,6 @@ impl Zeta {
let retrieval_time = chrono::Utc::now() - before_retrieval;
- let debug_response_tx = if let Some(debug_tx) = debug_tx {
- let (response_tx, response_rx) = oneshot::channel();
- let context = context.clone();
-
- debug_tx
- .unbounded_send(PredictionDebugInfo {
- context,
- retrieval_time,
- buffer: buffer.downgrade(),
- position,
- response_rx,
- })
- .ok();
- Some(response_tx)
- } else {
- None
- };
-
let (diagnostic_groups, diagnostic_groups_truncated) =
Self::gather_nearby_diagnostics(
cursor_offset,
@@ -565,6 +548,8 @@ impl Zeta {
options.max_diagnostic_bytes,
);
+ let debug_context = debug_tx.map(|tx| (tx, context.clone()));
+
let request = make_cloud_request(
excerpt_path,
context,
@@ -574,13 +559,44 @@ impl Zeta {
diagnostic_groups,
diagnostic_groups_truncated,
None,
- debug_response_tx.is_some(),
+ debug_context.is_some(),
&worktree_snapshots,
index_state.as_deref(),
Some(options.max_prompt_bytes),
options.prompt_format,
);
+ let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
+ let (response_tx, response_rx) = oneshot::channel();
+
+ let local_prompt = PlannedPrompt::populate(&request)
+ .and_then(|p| p.to_prompt_string().map(|p| p.0))
+ .map_err(|err| err.to_string());
+
+ debug_tx
+ .unbounded_send(PredictionDebugInfo {
+ context,
+ retrieval_time,
+ buffer: buffer.downgrade(),
+ local_prompt,
+ position,
+ response_rx,
+ })
+ .ok();
+ Some(response_tx)
+ } else {
+ None
+ };
+
+ if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
+ if let Some(debug_response_tx) = debug_response_tx {
+ debug_response_tx
+ .send(Err("Request skipped".to_string()))
+ .ok();
+ }
+ anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
+ }
+
let response = Self::perform_request(client, llm_token, app_version, request).await;
if let Some(debug_response_tx) = debug_response_tx {
@@ -76,6 +76,7 @@ enum ActiveView {
struct LastPrediction {
context_editor: Entity<Editor>,
+ prompt_editor: Entity<Editor>,
retrieval_time: TimeDelta,
buffer: WeakEntity<Buffer>,
position: language::Anchor,
@@ -89,7 +90,6 @@ enum LastPredictionState {
inference_time: TimeDelta,
parsing_time: TimeDelta,
prompt_planning_time: TimeDelta,
- prompt_editor: Entity<Editor>,
model_response_editor: Entity<Editor>,
},
Failed {
@@ -377,75 +377,92 @@ impl Zeta2Inspector {
position,
buffer,
retrieval_time,
+ local_prompt,
..
} = prediction;
- let task = cx.spawn_in(window, async move |this, cx| {
- let response = response_rx.await;
-
- this.update_in(cx, |this, window, cx| {
- if let Some(prediction) = this.last_prediction.as_mut() {
- prediction.state = match response {
- Ok(Ok(response)) => LastPredictionState::Success {
- prompt_planning_time: response.prompt_planning_time,
- inference_time: response.inference_time,
- parsing_time: response.parsing_time,
- prompt_editor: cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(response.prompt, cx);
- buffer.set_language(markdown_language.clone(), cx);
- buffer
- });
- let buffer =
- cx.new(|cx| MultiBuffer::singleton(buffer, cx));
- let mut editor = Editor::new(
- EditorMode::full(),
- buffer,
- None,
- window,
- cx,
- );
- editor.set_read_only(true);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- }),
- model_response_editor: cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer =
- Buffer::local(response.model_response, cx);
- buffer.set_language(markdown_language, cx);
- buffer
- });
- let buffer =
- cx.new(|cx| MultiBuffer::singleton(buffer, cx));
- let mut editor = Editor::new(
- EditorMode::full(),
- buffer,
- None,
- window,
+ let task = cx.spawn_in(window, {
+ let markdown_language = markdown_language.clone();
+ async move |this, cx| {
+ let response = response_rx.await;
+
+ this.update_in(cx, |this, window, cx| {
+ if let Some(prediction) = this.last_prediction.as_mut() {
+ prediction.state = match response {
+ Ok(Ok(response)) => {
+ prediction.prompt_editor.update(
cx,
+ |prompt_editor, cx| {
+ prompt_editor.set_text(
+ response.prompt,
+ window,
+ cx,
+ );
+ },
);
- editor.set_read_only(true);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- }),
- },
- Ok(Err(err)) => LastPredictionState::Failed { message: err },
- Err(oneshot::Canceled) => LastPredictionState::Failed {
- message: "Canceled".to_string(),
- },
- };
- }
- })
- .ok();
+
+ LastPredictionState::Success {
+ prompt_planning_time: response.prompt_planning_time,
+ inference_time: response.inference_time,
+ parsing_time: response.parsing_time,
+ model_response_editor: cx.new(|cx| {
+ let buffer = cx.new(|cx| {
+ let mut buffer = Buffer::local(
+ response.model_response,
+ cx,
+ );
+ buffer.set_language(markdown_language, cx);
+ buffer
+ });
+ let buffer = cx.new(|cx| {
+ MultiBuffer::singleton(buffer, cx)
+ });
+ let mut editor = Editor::new(
+ EditorMode::full(),
+ buffer,
+ None,
+ window,
+ cx,
+ );
+ editor.set_read_only(true);
+ editor.set_show_line_numbers(false, cx);
+ editor.set_show_gutter(false, cx);
+ editor.set_show_scrollbars(false, cx);
+ editor
+ }),
+ }
+ }
+ Ok(Err(err)) => {
+ LastPredictionState::Failed { message: err }
+ }
+ Err(oneshot::Canceled) => LastPredictionState::Failed {
+ message: "Canceled".to_string(),
+ },
+ };
+ }
+ })
+ .ok();
+ }
});
this.last_prediction = Some(LastPrediction {
context_editor,
+ prompt_editor: cx.new(|cx| {
+ let buffer = cx.new(|cx| {
+ let mut buffer =
+ Buffer::local(local_prompt.unwrap_or_else(|err| err), cx);
+ buffer.set_language(markdown_language.clone(), cx);
+ buffer
+ });
+ let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
+ let mut editor =
+ Editor::new(EditorMode::full(), buffer, None, window, cx);
+ editor.set_read_only(true);
+ editor.set_show_line_numbers(false, cx);
+ editor.set_show_gutter(false, cx);
+ editor.set_show_scrollbars(false, cx);
+ editor
+ }),
retrieval_time,
buffer,
position,
@@ -646,48 +663,60 @@ impl Zeta2Inspector {
fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
match &self.active_view {
ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
- ActiveView::Inference => match &prediction.state {
- LastPredictionState::Success {
- prompt_editor,
- model_response_editor,
- ..
- } => h_flex()
- .items_start()
- .w_full()
- .flex_1()
- .border_t_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().colors().editor_background)
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .p_4()
- .h_full()
- .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
- .child(prompt_editor.clone()),
- )
- .child(ui::vertical_divider())
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .h_full()
- .p_4()
- .child(
- ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall),
- )
- .child(model_response_editor.clone()),
- ),
- LastPredictionState::Requested => v_flex()
- .p_4()
- .gap_2()
- .child(Label::new("Loading...").buffer_font(cx)),
- LastPredictionState::Failed { message } => v_flex()
- .p_4()
- .gap_2()
- .child(Label::new(message.clone()).buffer_font(cx)),
- },
+ ActiveView::Inference => h_flex()
+ .items_start()
+ .w_full()
+ .flex_1()
+ .border_t_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ v_flex()
+ .flex_1()
+ .gap_2()
+ .p_4()
+ .h_full()
+ .child(
+ h_flex()
+ .justify_between()
+ .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
+ .child(match prediction.state {
+ LastPredictionState::Requested
+ | LastPredictionState::Failed { .. } => ui::Chip::new("Local")
+ .bg_color(cx.theme().status().warning_background)
+ .label_color(Color::Success),
+ LastPredictionState::Success { .. } => ui::Chip::new("Cloud")
+ .bg_color(cx.theme().status().success_background)
+ .label_color(Color::Success),
+ }),
+ )
+ .child(prediction.prompt_editor.clone()),
+ )
+ .child(ui::vertical_divider())
+ .child(
+ v_flex()
+ .flex_1()
+ .gap_2()
+ .h_full()
+ .p_4()
+ .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall))
+ .child(match &prediction.state {
+ LastPredictionState::Success {
+ model_response_editor,
+ ..
+ } => model_response_editor.clone().into_any_element(),
+ LastPredictionState::Requested => v_flex()
+ .p_4()
+ .gap_2()
+ .child(Label::new("Loading...").buffer_font(cx))
+ .into_any(),
+ LastPredictionState::Failed { message } => v_flex()
+ .p_4()
+ .gap_2()
+ .child(Label::new(message.clone()).buffer_font(cx))
+ .into_any(),
+ }),
+ ),
}
}
}