zeta2_tools.rs

  1use std::{collections::hash_map::Entry, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
  2
  3use chrono::TimeDelta;
  4use client::{Client, UserStore};
  5use cloud_llm_client::predict_edits_v3::PromptFormat;
  6use collections::HashMap;
  7use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
  8use futures::StreamExt as _;
  9use gpui::{
 10    Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
 11    prelude::*,
 12};
 13use language::{Buffer, DiskState};
 14use project::{Project, WorktreeId};
 15use ui::{ContextMenu, ContextMenuEntry, DropdownMenu, prelude::*};
 16use ui_input::SingleLineInput;
 17use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
 18use workspace::{Item, SplitDirection, Workspace};
 19use zeta2::{Zeta, ZetaOptions};
 20
 21use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions};
 22
 23actions!(
 24    dev,
 25    [
 26        /// Opens the language server protocol logs viewer.
 27        OpenZeta2Inspector
 28    ]
 29);
 30
 31pub fn init(cx: &mut App) {
 32    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
 33        workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
 34            let project = workspace.project();
 35            workspace.split_item(
 36                SplitDirection::Right,
 37                Box::new(cx.new(|cx| {
 38                    Zeta2Inspector::new(
 39                        &project,
 40                        workspace.client(),
 41                        workspace.user_store(),
 42                        window,
 43                        cx,
 44                    )
 45                })),
 46                window,
 47                cx,
 48            );
 49        });
 50    })
 51    .detach();
 52}
 53
 54// TODO show included diagnostics, and events
 55
 56pub struct Zeta2Inspector {
 57    focus_handle: FocusHandle,
 58    project: Entity<Project>,
 59    last_prediction: Option<LastPredictionState>,
 60    max_excerpt_bytes_input: Entity<SingleLineInput>,
 61    min_excerpt_bytes_input: Entity<SingleLineInput>,
 62    cursor_context_ratio_input: Entity<SingleLineInput>,
 63    max_prompt_bytes_input: Entity<SingleLineInput>,
 64    active_view: ActiveView,
 65    zeta: Entity<Zeta>,
 66    _active_editor_subscription: Option<Subscription>,
 67    _update_state_task: Task<()>,
 68    _receive_task: Task<()>,
 69}
 70
 71#[derive(PartialEq)]
 72enum ActiveView {
 73    Context,
 74    Inference,
 75}
 76
 77enum LastPredictionState {
 78    Failed(SharedString),
 79    Success(LastPrediction),
 80    Replaying {
 81        prediction: LastPrediction,
 82        _task: Task<()>,
 83    },
 84}
 85
 86struct LastPrediction {
 87    context_editor: Entity<Editor>,
 88    retrieval_time: TimeDelta,
 89    prompt_planning_time: TimeDelta,
 90    inference_time: TimeDelta,
 91    parsing_time: TimeDelta,
 92    prompt_editor: Entity<Editor>,
 93    model_response_editor: Entity<Editor>,
 94    buffer: WeakEntity<Buffer>,
 95    position: language::Anchor,
 96}
 97
 98impl Zeta2Inspector {
 99    pub fn new(
100        project: &Entity<Project>,
101        client: &Arc<Client>,
102        user_store: &Entity<UserStore>,
103        window: &mut Window,
104        cx: &mut Context<Self>,
105    ) -> Self {
106        let zeta = Zeta::global(client, user_store, cx);
107        let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info());
108
109        let receive_task = cx.spawn_in(window, async move |this, cx| {
110            while let Some(prediction_result) = request_rx.next().await {
111                this.update_in(cx, |this, window, cx| match prediction_result {
112                    Ok(prediction) => {
113                        this.update_last_prediction(prediction, window, cx);
114                    }
115                    Err(err) => {
116                        this.last_prediction = Some(LastPredictionState::Failed(err.into()));
117                        cx.notify();
118                    }
119                })
120                .ok();
121            }
122        });
123
124        let mut this = Self {
125            focus_handle: cx.focus_handle(),
126            project: project.clone(),
127            last_prediction: None,
128            active_view: ActiveView::Context,
129            max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx),
130            min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
131            cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
132            max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx),
133            zeta: zeta.clone(),
134            _active_editor_subscription: None,
135            _update_state_task: Task::ready(()),
136            _receive_task: receive_task,
137        };
138        this.set_input_options(&zeta.read(cx).options().clone(), window, cx);
139        this
140    }
141
142    fn set_input_options(
143        &mut self,
144        options: &ZetaOptions,
145        window: &mut Window,
146        cx: &mut Context<Self>,
147    ) {
148        self.max_excerpt_bytes_input.update(cx, |input, cx| {
149            input.set_text(options.excerpt.max_bytes.to_string(), window, cx);
150        });
151        self.min_excerpt_bytes_input.update(cx, |input, cx| {
152            input.set_text(options.excerpt.min_bytes.to_string(), window, cx);
153        });
154        self.cursor_context_ratio_input.update(cx, |input, cx| {
155            input.set_text(
156                format!(
157                    "{:.2}",
158                    options.excerpt.target_before_cursor_over_total_bytes
159                ),
160                window,
161                cx,
162            );
163        });
164        self.max_prompt_bytes_input.update(cx, |input, cx| {
165            input.set_text(options.max_prompt_bytes.to_string(), window, cx);
166        });
167        cx.notify();
168    }
169
170    fn set_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
171        self.zeta.update(cx, |this, _cx| this.set_options(options));
172
173        const THROTTLE_TIME: Duration = Duration::from_millis(100);
174
175        if let Some(
176            LastPredictionState::Success(prediction)
177            | LastPredictionState::Replaying { prediction, .. },
178        ) = self.last_prediction.take()
179        {
180            if let Some(buffer) = prediction.buffer.upgrade() {
181                let position = prediction.position;
182                let zeta = self.zeta.clone();
183                let project = self.project.clone();
184                let task = cx.spawn(async move |_this, cx| {
185                    cx.background_executor().timer(THROTTLE_TIME).await;
186                    if let Some(task) = zeta
187                        .update(cx, |zeta, cx| {
188                            zeta.refresh_prediction(&project, &buffer, position, cx)
189                        })
190                        .ok()
191                    {
192                        task.await.log_err();
193                    }
194                });
195                self.last_prediction = Some(LastPredictionState::Replaying {
196                    prediction,
197                    _task: task,
198                });
199            } else {
200                self.last_prediction = Some(LastPredictionState::Failed("Buffer dropped".into()));
201            }
202        }
203
204        cx.notify();
205    }
206
207    fn number_input(
208        label: &'static str,
209        window: &mut Window,
210        cx: &mut Context<Self>,
211    ) -> Entity<SingleLineInput> {
212        let input = cx.new(|cx| {
213            SingleLineInput::new(window, cx, "")
214                .label(label)
215                .label_min_width(px(64.))
216        });
217
218        cx.subscribe_in(
219            &input.read(cx).editor().clone(),
220            window,
221            |this, _, event, _window, cx| {
222                let EditorEvent::BufferEdited = event else {
223                    return;
224                };
225
226                fn number_input_value<T: FromStr + Default>(
227                    input: &Entity<SingleLineInput>,
228                    cx: &App,
229                ) -> T {
230                    input
231                        .read(cx)
232                        .editor()
233                        .read(cx)
234                        .text(cx)
235                        .parse::<T>()
236                        .unwrap_or_default()
237                }
238
239                let excerpt_options = EditPredictionExcerptOptions {
240                    max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx),
241                    min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx),
242                    target_before_cursor_over_total_bytes: number_input_value(
243                        &this.cursor_context_ratio_input,
244                        cx,
245                    ),
246                };
247
248                let zeta_options = this.zeta.read(cx).options();
249                this.set_options(
250                    ZetaOptions {
251                        excerpt: excerpt_options,
252                        max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx),
253                        max_diagnostic_bytes: zeta_options.max_diagnostic_bytes,
254                        prompt_format: zeta_options.prompt_format,
255                    },
256                    cx,
257                );
258            },
259        )
260        .detach();
261        input
262    }
263
264    fn update_last_prediction(
265        &mut self,
266        prediction: zeta2::PredictionDebugInfo,
267        window: &mut Window,
268        cx: &mut Context<Self>,
269    ) {
270        let project = self.project.read(cx);
271        let path_style = project.path_style(cx);
272        let Some(worktree_id) = project
273            .worktrees(cx)
274            .next()
275            .map(|worktree| worktree.read(cx).id())
276        else {
277            log::error!("Open a worktree to use edit prediction debug view");
278            self.last_prediction.take();
279            return;
280        };
281
282        self._update_state_task = cx.spawn_in(window, {
283            let language_registry = self.project.read(cx).languages().clone();
284            async move |this, cx| {
285                let mut languages = HashMap::default();
286                for lang_id in prediction
287                    .context
288                    .declarations
289                    .iter()
290                    .map(|snippet| snippet.declaration.identifier().language_id)
291                    .chain(prediction.context.excerpt_text.language_id)
292                {
293                    if let Entry::Vacant(entry) = languages.entry(lang_id) {
294                        // Most snippets are gonna be the same language,
295                        // so we think it's fine to do this sequentially for now
296                        entry.insert(language_registry.language_for_id(lang_id).await.ok());
297                    }
298                }
299
300                let markdown_language = language_registry
301                    .language_for_name("Markdown")
302                    .await
303                    .log_err();
304
305                this.update_in(cx, |this, window, cx| {
306                    let context_editor = cx.new(|cx| {
307                        let multibuffer = cx.new(|cx| {
308                            let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
309                            let excerpt_file = Arc::new(ExcerptMetadataFile {
310                                title: RelPath::unix("Cursor Excerpt").unwrap().into(),
311                                path_style,
312                                worktree_id,
313                            });
314
315                            let excerpt_buffer = cx.new(|cx| {
316                                let mut buffer =
317                                    Buffer::local(prediction.context.excerpt_text.body, cx);
318                                if let Some(language) = prediction
319                                    .context
320                                    .excerpt_text
321                                    .language_id
322                                    .as_ref()
323                                    .and_then(|id| languages.get(id))
324                                {
325                                    buffer.set_language(language.clone(), cx);
326                                }
327                                buffer.file_updated(excerpt_file, cx);
328                                buffer
329                            });
330
331                            multibuffer.push_excerpts(
332                                excerpt_buffer,
333                                [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
334                                cx,
335                            );
336
337                            for snippet in &prediction.context.declarations {
338                                let path = this
339                                    .project
340                                    .read(cx)
341                                    .path_for_entry(snippet.declaration.project_entry_id(), cx);
342
343                                let snippet_file = Arc::new(ExcerptMetadataFile {
344                                    title: RelPath::unix(&format!(
345                                        "{} (Score density: {})",
346                                        path.map(|p| p.path.display(path_style).to_string())
347                                            .unwrap_or_else(|| "".to_string()),
348                                        snippet.score_density(DeclarationStyle::Declaration)
349                                    ))
350                                    .unwrap()
351                                    .into(),
352                                    path_style,
353                                    worktree_id,
354                                });
355
356                                let excerpt_buffer = cx.new(|cx| {
357                                    let mut buffer =
358                                        Buffer::local(snippet.declaration.item_text().0, cx);
359                                    buffer.file_updated(snippet_file, cx);
360                                    if let Some(language) =
361                                        languages.get(&snippet.declaration.identifier().language_id)
362                                    {
363                                        buffer.set_language(language.clone(), cx);
364                                    }
365                                    buffer
366                                });
367
368                                multibuffer.push_excerpts(
369                                    excerpt_buffer,
370                                    [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
371                                    cx,
372                                );
373                            }
374
375                            multibuffer
376                        });
377
378                        Editor::new(EditorMode::full(), multibuffer, None, window, cx)
379                    });
380
381                    let last_prediction = LastPrediction {
382                        context_editor,
383                        prompt_editor: cx.new(|cx| {
384                            let buffer = cx.new(|cx| {
385                                let mut buffer = Buffer::local(prediction.request.prompt, cx);
386                                buffer.set_language(markdown_language.clone(), cx);
387                                buffer
388                            });
389                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
390                            let mut editor =
391                                Editor::new(EditorMode::full(), buffer, None, window, cx);
392                            editor.set_read_only(true);
393                            editor.set_show_line_numbers(false, cx);
394                            editor.set_show_gutter(false, cx);
395                            editor.set_show_scrollbars(false, cx);
396                            editor
397                        }),
398                        model_response_editor: cx.new(|cx| {
399                            let buffer = cx.new(|cx| {
400                                let mut buffer =
401                                    Buffer::local(prediction.request.model_response, cx);
402                                buffer.set_language(markdown_language, cx);
403                                buffer
404                            });
405                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
406                            let mut editor =
407                                Editor::new(EditorMode::full(), buffer, None, window, cx);
408                            editor.set_read_only(true);
409                            editor.set_show_line_numbers(false, cx);
410                            editor.set_show_gutter(false, cx);
411                            editor.set_show_scrollbars(false, cx);
412                            editor
413                        }),
414                        retrieval_time: prediction.retrieval_time,
415                        prompt_planning_time: prediction.request.prompt_planning_time,
416                        inference_time: prediction.request.inference_time,
417                        parsing_time: prediction.request.parsing_time,
418                        buffer: prediction.buffer,
419                        position: prediction.position,
420                    };
421                    this.last_prediction = Some(LastPredictionState::Success(last_prediction));
422                    cx.notify();
423                })
424                .ok();
425            }
426        });
427    }
428
429    fn render_options(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
430        v_flex()
431            .gap_2()
432            .child(
433                h_flex()
434                    .child(Headline::new("Options").size(HeadlineSize::Small))
435                    .justify_between()
436                    .child(
437                        ui::Button::new("reset-options", "Reset")
438                            .disabled(self.zeta.read(cx).options() == &zeta2::DEFAULT_OPTIONS)
439                            .style(ButtonStyle::Outlined)
440                            .size(ButtonSize::Large)
441                            .on_click(cx.listener(|this, _, window, cx| {
442                                this.set_input_options(&zeta2::DEFAULT_OPTIONS, window, cx);
443                            })),
444                    ),
445            )
446            .child(
447                v_flex()
448                    .gap_2()
449                    .child(
450                        h_flex()
451                            .gap_2()
452                            .items_end()
453                            .child(self.max_excerpt_bytes_input.clone())
454                            .child(self.min_excerpt_bytes_input.clone())
455                            .child(self.cursor_context_ratio_input.clone()),
456                    )
457                    .child(
458                        h_flex()
459                            .gap_2()
460                            .items_end()
461                            .child(self.max_prompt_bytes_input.clone())
462                            .child(self.render_prompt_format_dropdown(window, cx)),
463                    ),
464            )
465    }
466
467    fn render_prompt_format_dropdown(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
468        let active_format = self.zeta.read(cx).options().prompt_format;
469        let this = cx.weak_entity();
470
471        v_flex()
472            .gap_1p5()
473            .child(
474                Label::new("Prompt Format")
475                    .size(LabelSize::Small)
476                    .color(Color::Muted),
477            )
478            .child(
479                DropdownMenu::new(
480                    "ep-prompt-format",
481                    active_format.to_string(),
482                    ContextMenu::build(window, cx, move |mut menu, _window, _cx| {
483                        for prompt_format in PromptFormat::iter() {
484                            menu = menu.item(
485                                ContextMenuEntry::new(prompt_format.to_string())
486                                    .toggleable(IconPosition::End, active_format == prompt_format)
487                                    .handler({
488                                        let this = this.clone();
489                                        move |_window, cx| {
490                                            this.update(cx, |this, cx| {
491                                                let current_options =
492                                                    this.zeta.read(cx).options().clone();
493                                                let options = ZetaOptions {
494                                                    prompt_format,
495                                                    ..current_options
496                                                };
497                                                this.set_options(options, cx);
498                                            })
499                                            .ok();
500                                        }
501                                    }),
502                            )
503                        }
504                        menu
505                    }),
506                )
507                .style(ui::DropdownStyle::Outlined),
508            )
509    }
510
511    fn render_tabs(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
512        let Some(LastPredictionState::Success { .. } | LastPredictionState::Replaying { .. }) =
513            self.last_prediction.as_ref()
514        else {
515            return None;
516        };
517
518        Some(
519            ui::ToggleButtonGroup::single_row(
520                "prediction",
521                [
522                    ui::ToggleButtonSimple::new(
523                        "Context",
524                        cx.listener(|this, _, _, cx| {
525                            this.active_view = ActiveView::Context;
526                            cx.notify();
527                        }),
528                    ),
529                    ui::ToggleButtonSimple::new(
530                        "Inference",
531                        cx.listener(|this, _, _, cx| {
532                            this.active_view = ActiveView::Inference;
533                            cx.notify();
534                        }),
535                    ),
536                ],
537            )
538            .style(ui::ToggleButtonGroupStyle::Outlined)
539            .selected_index(if self.active_view == ActiveView::Context {
540                0
541            } else {
542                1
543            })
544            .into_any_element(),
545        )
546    }
547
548    fn render_stats(&self) -> Option<Div> {
549        let Some(
550            LastPredictionState::Success(prediction)
551            | LastPredictionState::Replaying { prediction, .. },
552        ) = self.last_prediction.as_ref()
553        else {
554            return None;
555        };
556
557        Some(
558            v_flex()
559                .p_4()
560                .gap_2()
561                .min_w(px(160.))
562                .child(Headline::new("Stats").size(HeadlineSize::Small))
563                .child(Self::render_duration(
564                    "Context retrieval",
565                    prediction.retrieval_time,
566                ))
567                .child(Self::render_duration(
568                    "Prompt planning",
569                    prediction.prompt_planning_time,
570                ))
571                .child(Self::render_duration(
572                    "Inference",
573                    prediction.inference_time,
574                ))
575                .child(Self::render_duration("Parsing", prediction.parsing_time)),
576        )
577    }
578
579    fn render_duration(name: &'static str, time: chrono::TimeDelta) -> Div {
580        h_flex()
581            .gap_1()
582            .child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
583            .child(
584                Label::new(if time.num_microseconds().unwrap_or(0) >= 1000 {
585                    format!("{} ms", time.num_milliseconds())
586                } else {
587                    format!("{} ยตs", time.num_microseconds().unwrap_or(0))
588                })
589                .size(LabelSize::Small),
590            )
591    }
592
593    fn render_content(&self, cx: &mut Context<Self>) -> AnyElement {
594        match self.last_prediction.as_ref() {
595            None => v_flex()
596                .size_full()
597                .justify_center()
598                .items_center()
599                .child(Label::new("No prediction").size(LabelSize::Large))
600                .into_any(),
601            Some(LastPredictionState::Success(prediction)) => {
602                self.render_last_prediction(prediction, cx).into_any()
603            }
604            Some(LastPredictionState::Replaying { prediction, _task }) => self
605                .render_last_prediction(prediction, cx)
606                .opacity(0.6)
607                .into_any(),
608            Some(LastPredictionState::Failed(err)) => v_flex()
609                .p_4()
610                .gap_2()
611                .child(Label::new(err.clone()).buffer_font(cx))
612                .into_any(),
613        }
614    }
615
616    fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
617        match &self.active_view {
618            ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
619            ActiveView::Inference => h_flex()
620                .items_start()
621                .w_full()
622                .flex_1()
623                .border_t_1()
624                .border_color(cx.theme().colors().border)
625                .bg(cx.theme().colors().editor_background)
626                .child(
627                    v_flex()
628                        .flex_1()
629                        .gap_2()
630                        .p_4()
631                        .h_full()
632                        .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
633                        .child(prediction.prompt_editor.clone()),
634                )
635                .child(ui::vertical_divider())
636                .child(
637                    v_flex()
638                        .flex_1()
639                        .gap_2()
640                        .h_full()
641                        .p_4()
642                        .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall))
643                        .child(prediction.model_response_editor.clone()),
644                ),
645        }
646    }
647}
648
649impl Focusable for Zeta2Inspector {
650    fn focus_handle(&self, _cx: &App) -> FocusHandle {
651        self.focus_handle.clone()
652    }
653}
654
655impl Item for Zeta2Inspector {
656    type Event = ();
657
658    fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
659        "Zeta2 Inspector".into()
660    }
661}
662
663impl EventEmitter<()> for Zeta2Inspector {}
664
665impl Render for Zeta2Inspector {
666    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
667        v_flex()
668            .size_full()
669            .bg(cx.theme().colors().editor_background)
670            .child(
671                h_flex()
672                    .w_full()
673                    .child(
674                        v_flex()
675                            .flex_1()
676                            .p_4()
677                            .h_full()
678                            .justify_between()
679                            .child(self.render_options(window, cx))
680                            .gap_4()
681                            .children(self.render_tabs(cx)),
682                    )
683                    .child(ui::vertical_divider())
684                    .children(self.render_stats()),
685            )
686            .child(self.render_content(cx))
687    }
688}
689
690// Using same approach as commit view
691
692struct ExcerptMetadataFile {
693    title: Arc<RelPath>,
694    worktree_id: WorktreeId,
695    path_style: PathStyle,
696}
697
698impl language::File for ExcerptMetadataFile {
699    fn as_local(&self) -> Option<&dyn language::LocalFile> {
700        None
701    }
702
703    fn disk_state(&self) -> DiskState {
704        DiskState::New
705    }
706
707    fn path(&self) -> &Arc<RelPath> {
708        &self.title
709    }
710
711    fn full_path(&self, _: &App) -> PathBuf {
712        self.title.as_std_path().to_path_buf()
713    }
714
715    fn file_name<'a>(&'a self, _: &'a App) -> &'a str {
716        self.title.file_name().unwrap()
717    }
718
719    fn path_style(&self, _: &App) -> PathStyle {
720        self.path_style
721    }
722
723    fn worktree_id(&self, _: &App) -> WorktreeId {
724        self.worktree_id
725    }
726
727    fn to_proto(&self, _: &App) -> language::proto::File {
728        unimplemented!()
729    }
730
731    fn is_private(&self) -> bool {
732        false
733    }
734}