zeta2_tools.rs

  1use std::{collections::hash_map::Entry, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
  2
  3use chrono::TimeDelta;
  4use client::{Client, UserStore};
  5use collections::HashMap;
  6use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
  7use futures::StreamExt as _;
  8use gpui::{
  9    Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
 10    prelude::*,
 11};
 12use language::{Buffer, DiskState};
 13use project::{Project, WorktreeId};
 14use ui::prelude::*;
 15use ui_input::SingleLineInput;
 16use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
 17use workspace::{Item, SplitDirection, Workspace};
 18use zeta2::{Zeta, ZetaOptions};
 19
 20use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
 21
 22actions!(
 23    dev,
 24    [
 25        /// Opens the language server protocol logs viewer.
 26        OpenZeta2Inspector
 27    ]
 28);
 29
 30pub fn init(cx: &mut App) {
 31    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
 32        workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
 33            let project = workspace.project();
 34            workspace.split_item(
 35                SplitDirection::Right,
 36                Box::new(cx.new(|cx| {
 37                    Zeta2Inspector::new(
 38                        &project,
 39                        workspace.client(),
 40                        workspace.user_store(),
 41                        window,
 42                        cx,
 43                    )
 44                })),
 45                window,
 46                cx,
 47            );
 48        });
 49    })
 50    .detach();
 51}
 52
 53// TODO show included diagnostics, and events
 54
 55pub struct Zeta2Inspector {
 56    focus_handle: FocusHandle,
 57    project: Entity<Project>,
 58    last_prediction: Option<LastPredictionState>,
 59    max_excerpt_bytes_input: Entity<SingleLineInput>,
 60    min_excerpt_bytes_input: Entity<SingleLineInput>,
 61    cursor_context_ratio_input: Entity<SingleLineInput>,
 62    max_prompt_bytes_input: Entity<SingleLineInput>,
 63    active_view: ActiveView,
 64    zeta: Entity<Zeta>,
 65    _active_editor_subscription: Option<Subscription>,
 66    _update_state_task: Task<()>,
 67    _receive_task: Task<()>,
 68}
 69
 70#[derive(PartialEq)]
 71enum ActiveView {
 72    Context,
 73    Inference,
 74}
 75
 76enum LastPredictionState {
 77    Failed(SharedString),
 78    Success(LastPrediction),
 79    Replaying {
 80        prediction: LastPrediction,
 81        _task: Task<()>,
 82    },
 83}
 84
 85struct LastPrediction {
 86    context_editor: Entity<Editor>,
 87    retrieval_time: TimeDelta,
 88    prompt_planning_time: TimeDelta,
 89    inference_time: TimeDelta,
 90    parsing_time: TimeDelta,
 91    prompt_editor: Entity<Editor>,
 92    model_response_editor: Entity<Editor>,
 93    buffer: WeakEntity<Buffer>,
 94    position: language::Anchor,
 95}
 96
 97impl Zeta2Inspector {
 98    pub fn new(
 99        project: &Entity<Project>,
100        client: &Arc<Client>,
101        user_store: &Entity<UserStore>,
102        window: &mut Window,
103        cx: &mut Context<Self>,
104    ) -> Self {
105        let zeta = Zeta::global(client, user_store, cx);
106        let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info());
107
108        let receive_task = cx.spawn_in(window, async move |this, cx| {
109            while let Some(prediction_result) = request_rx.next().await {
110                this.update_in(cx, |this, window, cx| match prediction_result {
111                    Ok(prediction) => {
112                        this.update_last_prediction(prediction, window, cx);
113                    }
114                    Err(err) => {
115                        this.last_prediction = Some(LastPredictionState::Failed(err.into()));
116                        cx.notify();
117                    }
118                })
119                .ok();
120            }
121        });
122
123        let mut this = Self {
124            focus_handle: cx.focus_handle(),
125            project: project.clone(),
126            last_prediction: None,
127            active_view: ActiveView::Context,
128            max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx),
129            min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
130            cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
131            max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx),
132            zeta: zeta.clone(),
133            _active_editor_subscription: None,
134            _update_state_task: Task::ready(()),
135            _receive_task: receive_task,
136        };
137        this.set_input_options(&zeta.read(cx).options().clone(), window, cx);
138        this
139    }
140
141    fn set_input_options(
142        &mut self,
143        options: &ZetaOptions,
144        window: &mut Window,
145        cx: &mut Context<Self>,
146    ) {
147        self.max_excerpt_bytes_input.update(cx, |input, cx| {
148            input.set_text(options.excerpt.max_bytes.to_string(), window, cx);
149        });
150        self.min_excerpt_bytes_input.update(cx, |input, cx| {
151            input.set_text(options.excerpt.min_bytes.to_string(), window, cx);
152        });
153        self.cursor_context_ratio_input.update(cx, |input, cx| {
154            input.set_text(
155                format!(
156                    "{:.2}",
157                    options.excerpt.target_before_cursor_over_total_bytes
158                ),
159                window,
160                cx,
161            );
162        });
163        self.max_prompt_bytes_input.update(cx, |input, cx| {
164            input.set_text(options.max_prompt_bytes.to_string(), window, cx);
165        });
166        cx.notify();
167    }
168
169    fn set_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
170        self.zeta.update(cx, |this, _cx| this.set_options(options));
171
172        const THROTTLE_TIME: Duration = Duration::from_millis(100);
173
174        if let Some(
175            LastPredictionState::Success(prediction)
176            | LastPredictionState::Replaying { prediction, .. },
177        ) = self.last_prediction.take()
178        {
179            if let Some(buffer) = prediction.buffer.upgrade() {
180                let position = prediction.position;
181                let zeta = self.zeta.clone();
182                let project = self.project.clone();
183                let task = cx.spawn(async move |_this, cx| {
184                    cx.background_executor().timer(THROTTLE_TIME).await;
185                    if let Some(task) = zeta
186                        .update(cx, |zeta, cx| {
187                            zeta.request_prediction(&project, &buffer, position, cx)
188                        })
189                        .ok()
190                    {
191                        task.await.log_err();
192                    }
193                });
194                self.last_prediction = Some(LastPredictionState::Replaying {
195                    prediction,
196                    _task: task,
197                });
198            } else {
199                self.last_prediction = Some(LastPredictionState::Failed("Buffer dropped".into()));
200            }
201        }
202
203        cx.notify();
204    }
205
206    fn number_input(
207        label: &'static str,
208        window: &mut Window,
209        cx: &mut Context<Self>,
210    ) -> Entity<SingleLineInput> {
211        let input = cx.new(|cx| {
212            SingleLineInput::new(window, cx, "")
213                .label(label)
214                .label_min_width(px(64.))
215        });
216
217        cx.subscribe_in(
218            &input.read(cx).editor().clone(),
219            window,
220            |this, _, event, _window, cx| {
221                let EditorEvent::BufferEdited = event else {
222                    return;
223                };
224
225                fn number_input_value<T: FromStr + Default>(
226                    input: &Entity<SingleLineInput>,
227                    cx: &App,
228                ) -> T {
229                    input
230                        .read(cx)
231                        .editor()
232                        .read(cx)
233                        .text(cx)
234                        .parse::<T>()
235                        .unwrap_or_default()
236                }
237
238                let excerpt_options = EditPredictionExcerptOptions {
239                    max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx),
240                    min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx),
241                    target_before_cursor_over_total_bytes: number_input_value(
242                        &this.cursor_context_ratio_input,
243                        cx,
244                    ),
245                };
246
247                let zeta_options = this.zeta.read(cx).options();
248                this.set_options(
249                    ZetaOptions {
250                        excerpt: excerpt_options,
251                        max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx),
252                        max_diagnostic_bytes: zeta_options.max_diagnostic_bytes,
253                        prompt_format: zeta_options.prompt_format,
254                    },
255                    cx,
256                );
257            },
258        )
259        .detach();
260        input
261    }
262
263    fn update_last_prediction(
264        &mut self,
265        prediction: zeta2::PredictionDebugInfo,
266        window: &mut Window,
267        cx: &mut Context<Self>,
268    ) {
269        let project = self.project.read(cx);
270        let path_style = project.path_style(cx);
271        let Some(worktree_id) = project
272            .worktrees(cx)
273            .next()
274            .map(|worktree| worktree.read(cx).id())
275        else {
276            log::error!("Open a worktree to use edit prediction debug view");
277            self.last_prediction.take();
278            return;
279        };
280
281        self._update_state_task = cx.spawn_in(window, {
282            let language_registry = self.project.read(cx).languages().clone();
283            async move |this, cx| {
284                let mut languages = HashMap::default();
285                for lang_id in prediction
286                    .context
287                    .snippets
288                    .iter()
289                    .map(|snippet| snippet.declaration.identifier().language_id)
290                    .chain(prediction.context.excerpt_text.language_id)
291                {
292                    if let Entry::Vacant(entry) = languages.entry(lang_id) {
293                        // Most snippets are gonna be the same language,
294                        // so we think it's fine to do this sequentially for now
295                        entry.insert(language_registry.language_for_id(lang_id).await.ok());
296                    }
297                }
298
299                let markdown_language = language_registry
300                    .language_for_name("Markdown")
301                    .await
302                    .log_err();
303
304                this.update_in(cx, |this, window, cx| {
305                    let context_editor = cx.new(|cx| {
306                        let multibuffer = cx.new(|cx| {
307                            let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
308                            let excerpt_file = Arc::new(ExcerptMetadataFile {
309                                title: RelPath::new("Cursor Excerpt").unwrap().into(),
310                                path_style,
311                                worktree_id,
312                            });
313
314                            let excerpt_buffer = cx.new(|cx| {
315                                let mut buffer =
316                                    Buffer::local(prediction.context.excerpt_text.body, cx);
317                                if let Some(language) = prediction
318                                    .context
319                                    .excerpt_text
320                                    .language_id
321                                    .as_ref()
322                                    .and_then(|id| languages.get(id))
323                                {
324                                    buffer.set_language(language.clone(), cx);
325                                }
326                                buffer.file_updated(excerpt_file, cx);
327                                buffer
328                            });
329
330                            multibuffer.push_excerpts(
331                                excerpt_buffer,
332                                [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
333                                cx,
334                            );
335
336                            for snippet in &prediction.context.snippets {
337                                let path = this
338                                    .project
339                                    .read(cx)
340                                    .path_for_entry(snippet.declaration.project_entry_id(), cx);
341
342                                let snippet_file = Arc::new(ExcerptMetadataFile {
343                                    title: RelPath::new(&format!(
344                                        "{} (Score density: {})",
345                                        path.map(|p| p.path.display(path_style).to_string())
346                                            .unwrap_or_else(|| "".to_string()),
347                                        snippet.score_density(SnippetStyle::Declaration)
348                                    ))
349                                    .unwrap()
350                                    .into(),
351                                    path_style,
352                                    worktree_id,
353                                });
354
355                                let excerpt_buffer = cx.new(|cx| {
356                                    let mut buffer =
357                                        Buffer::local(snippet.declaration.item_text().0, cx);
358                                    buffer.file_updated(snippet_file, cx);
359                                    if let Some(language) =
360                                        languages.get(&snippet.declaration.identifier().language_id)
361                                    {
362                                        buffer.set_language(language.clone(), cx);
363                                    }
364                                    buffer
365                                });
366
367                                multibuffer.push_excerpts(
368                                    excerpt_buffer,
369                                    [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
370                                    cx,
371                                );
372                            }
373
374                            multibuffer
375                        });
376
377                        Editor::new(EditorMode::full(), multibuffer, None, window, cx)
378                    });
379
380                    let last_prediction = LastPrediction {
381                        context_editor,
382                        prompt_editor: cx.new(|cx| {
383                            let buffer = cx.new(|cx| {
384                                let mut buffer = Buffer::local(prediction.request.prompt, cx);
385                                buffer.set_language(markdown_language.clone(), cx);
386                                buffer
387                            });
388                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
389                            let mut editor =
390                                Editor::new(EditorMode::full(), buffer, None, window, cx);
391                            editor.set_read_only(true);
392                            editor.set_show_line_numbers(false, cx);
393                            editor.set_show_gutter(false, cx);
394                            editor.set_show_scrollbars(false, cx);
395                            editor
396                        }),
397                        model_response_editor: cx.new(|cx| {
398                            let buffer = cx.new(|cx| {
399                                let mut buffer =
400                                    Buffer::local(prediction.request.model_response, cx);
401                                buffer.set_language(markdown_language, cx);
402                                buffer
403                            });
404                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
405                            let mut editor =
406                                Editor::new(EditorMode::full(), buffer, None, window, cx);
407                            editor.set_read_only(true);
408                            editor.set_show_line_numbers(false, cx);
409                            editor.set_show_gutter(false, cx);
410                            editor.set_show_scrollbars(false, cx);
411                            editor
412                        }),
413                        retrieval_time: prediction.retrieval_time,
414                        prompt_planning_time: prediction.request.prompt_planning_time,
415                        inference_time: prediction.request.inference_time,
416                        parsing_time: prediction.request.parsing_time,
417                        buffer: prediction.buffer,
418                        position: prediction.position,
419                    };
420                    this.last_prediction = Some(LastPredictionState::Success(last_prediction));
421                    cx.notify();
422                })
423                .ok();
424            }
425        });
426    }
427
428    fn render_duration(name: &'static str, time: chrono::TimeDelta) -> Div {
429        h_flex()
430            .gap_1()
431            .child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
432            .child(
433                Label::new(if time.num_microseconds().unwrap_or(0) > 1000 {
434                    format!("{} ms", time.num_milliseconds())
435                } else {
436                    format!("{} µs", time.num_microseconds().unwrap_or(0))
437                })
438                .size(LabelSize::Small),
439            )
440    }
441
442    fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
443        match &self.active_view {
444            ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
445            ActiveView::Inference => h_flex()
446                .items_start()
447                .w_full()
448                .flex_1()
449                .border_t_1()
450                .border_color(cx.theme().colors().border)
451                .bg(cx.theme().colors().editor_background)
452                .child(
453                    v_flex()
454                        .flex_1()
455                        .gap_2()
456                        .p_4()
457                        .h_full()
458                        .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
459                        .child(prediction.prompt_editor.clone()),
460                )
461                .child(ui::vertical_divider())
462                .child(
463                    v_flex()
464                        .flex_1()
465                        .gap_2()
466                        .h_full()
467                        .p_4()
468                        .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall))
469                        .child(prediction.model_response_editor.clone()),
470                ),
471        }
472    }
473}
474
475impl Focusable for Zeta2Inspector {
476    fn focus_handle(&self, _cx: &App) -> FocusHandle {
477        self.focus_handle.clone()
478    }
479}
480
481impl Item for Zeta2Inspector {
482    type Event = ();
483
484    fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
485        "Zeta2 Inspector".into()
486    }
487}
488
489impl EventEmitter<()> for Zeta2Inspector {}
490
491impl Render for Zeta2Inspector {
492    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
493        let content = match self.last_prediction.as_ref() {
494            None => v_flex()
495                .size_full()
496                .justify_center()
497                .items_center()
498                .child(Label::new("No prediction").size(LabelSize::Large))
499                .into_any(),
500            Some(LastPredictionState::Success(prediction)) => {
501                self.render_last_prediction(prediction, cx).into_any()
502            }
503            Some(LastPredictionState::Replaying { prediction, _task }) => self
504                .render_last_prediction(prediction, cx)
505                .opacity(0.6)
506                .into_any(),
507            Some(LastPredictionState::Failed(err)) => v_flex()
508                .p_4()
509                .gap_2()
510                .child(Label::new(err.clone()).buffer_font(cx))
511                .into_any(),
512        };
513
514        v_flex()
515            .size_full()
516            .bg(cx.theme().colors().editor_background)
517            .child(
518                h_flex()
519                    .w_full()
520                    .child(
521                        v_flex()
522                            .flex_1()
523                            .p_4()
524                            .h_full()
525                            .justify_between()
526                            .child(
527                                v_flex()
528                                    .gap_2()
529                                    .child(Headline::new("Options").size(HeadlineSize::Small))
530                                    .child(
531                                        h_flex()
532                                            .gap_2()
533                                            .items_end()
534                                            .child(self.max_excerpt_bytes_input.clone())
535                                            .child(self.min_excerpt_bytes_input.clone())
536                                            .child(self.cursor_context_ratio_input.clone())
537                                            .child(self.max_prompt_bytes_input.clone())
538                                            .child(
539                                                ui::Button::new("reset-options", "Reset")
540                                                    .disabled(
541                                                        self.zeta.read(cx).options()
542                                                            == &zeta2::DEFAULT_OPTIONS,
543                                                    )
544                                                    .style(ButtonStyle::Outlined)
545                                                    .size(ButtonSize::Large)
546                                                    .on_click(cx.listener(
547                                                        |this, _, window, cx| {
548                                                            this.set_input_options(
549                                                                &zeta2::DEFAULT_OPTIONS,
550                                                                window,
551                                                                cx,
552                                                            );
553                                                        },
554                                                    )),
555                                            ),
556                                    ),
557                            )
558                            .map(|this| {
559                                if let Some(
560                                    LastPredictionState::Success { .. }
561                                    | LastPredictionState::Replaying { .. },
562                                ) = self.last_prediction.as_ref()
563                                {
564                                    this.child(
565                                        ui::ToggleButtonGroup::single_row(
566                                            "prediction",
567                                            [
568                                                ui::ToggleButtonSimple::new(
569                                                    "Context",
570                                                    cx.listener(|this, _, _, cx| {
571                                                        this.active_view = ActiveView::Context;
572                                                        cx.notify();
573                                                    }),
574                                                ),
575                                                ui::ToggleButtonSimple::new(
576                                                    "Inference",
577                                                    cx.listener(|this, _, _, cx| {
578                                                        this.active_view = ActiveView::Inference;
579                                                        cx.notify();
580                                                    }),
581                                                ),
582                                            ],
583                                        )
584                                        .style(ui::ToggleButtonGroupStyle::Outlined)
585                                        .selected_index(
586                                            if self.active_view == ActiveView::Context {
587                                                0
588                                            } else {
589                                                1
590                                            },
591                                        ),
592                                    )
593                                } else {
594                                    this
595                                }
596                            }),
597                    )
598                    .child(ui::vertical_divider())
599                    .map(|this| {
600                        if let Some(
601                            LastPredictionState::Success(prediction)
602                            | LastPredictionState::Replaying { prediction, .. },
603                        ) = self.last_prediction.as_ref()
604                        {
605                            this.child(
606                                v_flex()
607                                    .p_4()
608                                    .gap_2()
609                                    .min_w(px(160.))
610                                    .child(Headline::new("Stats").size(HeadlineSize::Small))
611                                    .child(Self::render_duration(
612                                        "Context retrieval",
613                                        prediction.retrieval_time,
614                                    ))
615                                    .child(Self::render_duration(
616                                        "Prompt planning",
617                                        prediction.prompt_planning_time,
618                                    ))
619                                    .child(Self::render_duration(
620                                        "Inference",
621                                        prediction.inference_time,
622                                    ))
623                                    .child(Self::render_duration(
624                                        "Parsing",
625                                        prediction.parsing_time,
626                                    )),
627                            )
628                        } else {
629                            this
630                        }
631                    }),
632            )
633            .child(content)
634    }
635}
636
637// Using same approach as commit view
638
639struct ExcerptMetadataFile {
640    title: Arc<RelPath>,
641    worktree_id: WorktreeId,
642    path_style: PathStyle,
643}
644
645impl language::File for ExcerptMetadataFile {
646    fn as_local(&self) -> Option<&dyn language::LocalFile> {
647        None
648    }
649
650    fn disk_state(&self) -> DiskState {
651        DiskState::New
652    }
653
654    fn path(&self) -> &Arc<RelPath> {
655        &self.title
656    }
657
658    fn full_path(&self, _: &App) -> PathBuf {
659        self.title.as_std_path().to_path_buf()
660    }
661
662    fn file_name<'a>(&'a self, _: &'a App) -> &'a str {
663        self.title.file_name().unwrap()
664    }
665
666    fn path_style(&self, _: &App) -> PathStyle {
667        self.path_style
668    }
669
670    fn worktree_id(&self, _: &App) -> WorktreeId {
671        self.worktree_id
672    }
673
674    fn to_proto(&self, _: &App) -> language::proto::File {
675        unimplemented!()
676    }
677
678    fn is_private(&self) -> bool {
679        false
680    }
681}