zeta2_tools.rs

  1use std::{
  2    collections::hash_map::Entry,
  3    ffi::OsStr,
  4    path::{Path, PathBuf},
  5    str::FromStr,
  6    sync::Arc,
  7    time::Duration,
  8};
  9
 10use chrono::TimeDelta;
 11use client::{Client, UserStore};
 12use collections::HashMap;
 13use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
 14use futures::StreamExt as _;
 15use gpui::{
 16    Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
 17    prelude::*,
 18};
 19use language::{Buffer, DiskState};
 20use project::{Project, WorktreeId};
 21use ui::prelude::*;
 22use ui_input::SingleLineInput;
 23use util::ResultExt;
 24use workspace::{Item, SplitDirection, Workspace};
 25use zeta2::{Zeta, ZetaOptions};
 26
 27use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
 28
 29actions!(
 30    dev,
 31    [
 32        /// Opens the language server protocol logs viewer.
 33        OpenZeta2Inspector
 34    ]
 35);
 36
 37pub fn init(cx: &mut App) {
 38    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
 39        workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
 40            let project = workspace.project();
 41            workspace.split_item(
 42                SplitDirection::Right,
 43                Box::new(cx.new(|cx| {
 44                    Zeta2Inspector::new(
 45                        &project,
 46                        workspace.client(),
 47                        workspace.user_store(),
 48                        window,
 49                        cx,
 50                    )
 51                })),
 52                window,
 53                cx,
 54            );
 55        });
 56    })
 57    .detach();
 58}
 59
 60// TODO show included diagnostics, and events
 61
 62pub struct Zeta2Inspector {
 63    focus_handle: FocusHandle,
 64    project: Entity<Project>,
 65    last_prediction: Option<LastPredictionState>,
 66    max_excerpt_bytes_input: Entity<SingleLineInput>,
 67    min_excerpt_bytes_input: Entity<SingleLineInput>,
 68    cursor_context_ratio_input: Entity<SingleLineInput>,
 69    max_prompt_bytes_input: Entity<SingleLineInput>,
 70    active_view: ActiveView,
 71    zeta: Entity<Zeta>,
 72    _active_editor_subscription: Option<Subscription>,
 73    _update_state_task: Task<()>,
 74    _receive_task: Task<()>,
 75}
 76
 77#[derive(PartialEq)]
 78enum ActiveView {
 79    Context,
 80    Inference,
 81}
 82
 83enum LastPredictionState {
 84    Failed(SharedString),
 85    Success(LastPrediction),
 86    Replaying {
 87        prediction: LastPrediction,
 88        _task: Task<()>,
 89    },
 90}
 91
 92struct LastPrediction {
 93    context_editor: Entity<Editor>,
 94    retrieval_time: TimeDelta,
 95    prompt_planning_time: TimeDelta,
 96    inference_time: TimeDelta,
 97    parsing_time: TimeDelta,
 98    prompt_editor: Entity<Editor>,
 99    model_response_editor: Entity<Editor>,
100    buffer: WeakEntity<Buffer>,
101    position: language::Anchor,
102}
103
104impl Zeta2Inspector {
105    pub fn new(
106        project: &Entity<Project>,
107        client: &Arc<Client>,
108        user_store: &Entity<UserStore>,
109        window: &mut Window,
110        cx: &mut Context<Self>,
111    ) -> Self {
112        let zeta = Zeta::global(client, user_store, cx);
113        let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info());
114
115        let receive_task = cx.spawn_in(window, async move |this, cx| {
116            while let Some(prediction_result) = request_rx.next().await {
117                this.update_in(cx, |this, window, cx| match prediction_result {
118                    Ok(prediction) => {
119                        this.update_last_prediction(prediction, window, cx);
120                    }
121                    Err(err) => {
122                        this.last_prediction = Some(LastPredictionState::Failed(err.into()));
123                        cx.notify();
124                    }
125                })
126                .ok();
127            }
128        });
129
130        let mut this = Self {
131            focus_handle: cx.focus_handle(),
132            project: project.clone(),
133            last_prediction: None,
134            active_view: ActiveView::Context,
135            max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx),
136            min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
137            cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
138            max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx),
139            zeta: zeta.clone(),
140            _active_editor_subscription: None,
141            _update_state_task: Task::ready(()),
142            _receive_task: receive_task,
143        };
144        this.set_input_options(&zeta.read(cx).options().clone(), window, cx);
145        this
146    }
147
148    fn set_input_options(
149        &mut self,
150        options: &ZetaOptions,
151        window: &mut Window,
152        cx: &mut Context<Self>,
153    ) {
154        self.max_excerpt_bytes_input.update(cx, |input, cx| {
155            input.set_text(options.excerpt.max_bytes.to_string(), window, cx);
156        });
157        self.min_excerpt_bytes_input.update(cx, |input, cx| {
158            input.set_text(options.excerpt.min_bytes.to_string(), window, cx);
159        });
160        self.cursor_context_ratio_input.update(cx, |input, cx| {
161            input.set_text(
162                format!(
163                    "{:.2}",
164                    options.excerpt.target_before_cursor_over_total_bytes
165                ),
166                window,
167                cx,
168            );
169        });
170        self.max_prompt_bytes_input.update(cx, |input, cx| {
171            input.set_text(options.max_prompt_bytes.to_string(), window, cx);
172        });
173        cx.notify();
174    }
175
176    fn set_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
177        self.zeta.update(cx, |this, _cx| this.set_options(options));
178
179        const THROTTLE_TIME: Duration = Duration::from_millis(100);
180
181        if let Some(
182            LastPredictionState::Success(prediction)
183            | LastPredictionState::Replaying { prediction, .. },
184        ) = self.last_prediction.take()
185        {
186            if let Some(buffer) = prediction.buffer.upgrade() {
187                let position = prediction.position;
188                let zeta = self.zeta.clone();
189                let project = self.project.clone();
190                let task = cx.spawn(async move |_this, cx| {
191                    cx.background_executor().timer(THROTTLE_TIME).await;
192                    if let Some(task) = zeta
193                        .update(cx, |zeta, cx| {
194                            zeta.request_prediction(&project, &buffer, position, cx)
195                        })
196                        .ok()
197                    {
198                        task.await.log_err();
199                    }
200                });
201                self.last_prediction = Some(LastPredictionState::Replaying {
202                    prediction,
203                    _task: task,
204                });
205            } else {
206                self.last_prediction = Some(LastPredictionState::Failed("Buffer dropped".into()));
207            }
208        }
209
210        cx.notify();
211    }
212
213    fn number_input(
214        label: &'static str,
215        window: &mut Window,
216        cx: &mut Context<Self>,
217    ) -> Entity<SingleLineInput> {
218        let input = cx.new(|cx| {
219            SingleLineInput::new(window, cx, "")
220                .label(label)
221                .label_min_width(px(64.))
222        });
223
224        cx.subscribe_in(
225            &input.read(cx).editor().clone(),
226            window,
227            |this, _, event, _window, cx| {
228                let EditorEvent::BufferEdited = event else {
229                    return;
230                };
231
232                fn number_input_value<T: FromStr + Default>(
233                    input: &Entity<SingleLineInput>,
234                    cx: &App,
235                ) -> T {
236                    input
237                        .read(cx)
238                        .editor()
239                        .read(cx)
240                        .text(cx)
241                        .parse::<T>()
242                        .unwrap_or_default()
243                }
244
245                let excerpt_options = EditPredictionExcerptOptions {
246                    max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx),
247                    min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx),
248                    target_before_cursor_over_total_bytes: number_input_value(
249                        &this.cursor_context_ratio_input,
250                        cx,
251                    ),
252                };
253
254                this.set_options(
255                    ZetaOptions {
256                        excerpt: excerpt_options,
257                        max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx),
258                        max_diagnostic_bytes: this.zeta.read(cx).options().max_diagnostic_bytes,
259                    },
260                    cx,
261                );
262            },
263        )
264        .detach();
265        input
266    }
267
268    fn update_last_prediction(
269        &mut self,
270        prediction: zeta2::PredictionDebugInfo,
271        window: &mut Window,
272        cx: &mut Context<Self>,
273    ) {
274        let Some(worktree_id) = self
275            .project
276            .read(cx)
277            .worktrees(cx)
278            .next()
279            .map(|worktree| worktree.read(cx).id())
280        else {
281            log::error!("Open a worktree to use edit prediction debug view");
282            self.last_prediction.take();
283            return;
284        };
285
286        self._update_state_task = cx.spawn_in(window, {
287            let language_registry = self.project.read(cx).languages().clone();
288            async move |this, cx| {
289                let mut languages = HashMap::default();
290                for lang_id in prediction
291                    .context
292                    .snippets
293                    .iter()
294                    .map(|snippet| snippet.declaration.identifier().language_id)
295                    .chain(prediction.context.excerpt_text.language_id)
296                {
297                    if let Entry::Vacant(entry) = languages.entry(lang_id) {
298                        // Most snippets are gonna be the same language,
299                        // so we think it's fine to do this sequentially for now
300                        entry.insert(language_registry.language_for_id(lang_id).await.ok());
301                    }
302                }
303
304                let markdown_language = language_registry
305                    .language_for_name("Markdown")
306                    .await
307                    .log_err();
308
309                this.update_in(cx, |this, window, cx| {
310                    let context_editor = cx.new(|cx| {
311                        let multibuffer = cx.new(|cx| {
312                            let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
313                            let excerpt_file = Arc::new(ExcerptMetadataFile {
314                                title: PathBuf::from("Cursor Excerpt").into(),
315                                worktree_id,
316                            });
317
318                            let excerpt_buffer = cx.new(|cx| {
319                                let mut buffer =
320                                    Buffer::local(prediction.context.excerpt_text.body, cx);
321                                if let Some(language) = prediction
322                                    .context
323                                    .excerpt_text
324                                    .language_id
325                                    .as_ref()
326                                    .and_then(|id| languages.get(id))
327                                {
328                                    buffer.set_language(language.clone(), cx);
329                                }
330                                buffer.file_updated(excerpt_file, cx);
331                                buffer
332                            });
333
334                            multibuffer.push_excerpts(
335                                excerpt_buffer,
336                                [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
337                                cx,
338                            );
339
340                            for snippet in &prediction.context.snippets {
341                                let path = this
342                                    .project
343                                    .read(cx)
344                                    .path_for_entry(snippet.declaration.project_entry_id(), cx);
345
346                                let snippet_file = Arc::new(ExcerptMetadataFile {
347                                    title: PathBuf::from(format!(
348                                        "{} (Score density: {})",
349                                        path.map(|p| p.path.to_string_lossy().to_string())
350                                            .unwrap_or_else(|| "".to_string()),
351                                        snippet.score_density(SnippetStyle::Declaration)
352                                    ))
353                                    .into(),
354                                    worktree_id,
355                                });
356
357                                let excerpt_buffer = cx.new(|cx| {
358                                    let mut buffer =
359                                        Buffer::local(snippet.declaration.item_text().0, cx);
360                                    buffer.file_updated(snippet_file, cx);
361                                    if let Some(language) =
362                                        languages.get(&snippet.declaration.identifier().language_id)
363                                    {
364                                        buffer.set_language(language.clone(), cx);
365                                    }
366                                    buffer
367                                });
368
369                                multibuffer.push_excerpts(
370                                    excerpt_buffer,
371                                    [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
372                                    cx,
373                                );
374                            }
375
376                            multibuffer
377                        });
378
379                        Editor::new(EditorMode::full(), multibuffer, None, window, cx)
380                    });
381
382                    let last_prediction = LastPrediction {
383                        context_editor,
384                        prompt_editor: cx.new(|cx| {
385                            let buffer = cx.new(|cx| {
386                                let mut buffer = Buffer::local(prediction.request.prompt, cx);
387                                buffer.set_language(markdown_language.clone(), cx);
388                                buffer
389                            });
390                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
391                            let mut editor =
392                                Editor::new(EditorMode::full(), buffer, None, window, cx);
393                            editor.set_read_only(true);
394                            editor.set_show_line_numbers(false, cx);
395                            editor.set_show_gutter(false, cx);
396                            editor.set_show_scrollbars(false, cx);
397                            editor
398                        }),
399                        model_response_editor: cx.new(|cx| {
400                            let buffer = cx.new(|cx| {
401                                let mut buffer =
402                                    Buffer::local(prediction.request.model_response, cx);
403                                buffer.set_language(markdown_language, cx);
404                                buffer
405                            });
406                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
407                            let mut editor =
408                                Editor::new(EditorMode::full(), buffer, None, window, cx);
409                            editor.set_read_only(true);
410                            editor.set_show_line_numbers(false, cx);
411                            editor.set_show_gutter(false, cx);
412                            editor.set_show_scrollbars(false, cx);
413                            editor
414                        }),
415                        retrieval_time: prediction.retrieval_time,
416                        prompt_planning_time: prediction.request.prompt_planning_time,
417                        inference_time: prediction.request.inference_time,
418                        parsing_time: prediction.request.parsing_time,
419                        buffer: prediction.buffer,
420                        position: prediction.position,
421                    };
422                    this.last_prediction = Some(LastPredictionState::Success(last_prediction));
423                    cx.notify();
424                })
425                .ok();
426            }
427        });
428    }
429
430    fn render_duration(name: &'static str, time: chrono::TimeDelta) -> Div {
431        h_flex()
432            .gap_1()
433            .child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
434            .child(
435                Label::new(if time.num_microseconds().unwrap_or(0) > 1000 {
436                    format!("{} ms", time.num_milliseconds())
437                } else {
438                    format!("{} µs", time.num_microseconds().unwrap_or(0))
439                })
440                .size(LabelSize::Small),
441            )
442    }
443
444    fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
445        match &self.active_view {
446            ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
447            ActiveView::Inference => h_flex()
448                .items_start()
449                .w_full()
450                .flex_1()
451                .border_t_1()
452                .border_color(cx.theme().colors().border)
453                .bg(cx.theme().colors().editor_background)
454                .child(
455                    v_flex()
456                        .flex_1()
457                        .gap_2()
458                        .p_4()
459                        .h_full()
460                        .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
461                        .child(prediction.prompt_editor.clone()),
462                )
463                .child(ui::vertical_divider())
464                .child(
465                    v_flex()
466                        .flex_1()
467                        .gap_2()
468                        .h_full()
469                        .p_4()
470                        .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall))
471                        .child(prediction.model_response_editor.clone()),
472                ),
473        }
474    }
475}
476
477impl Focusable for Zeta2Inspector {
478    fn focus_handle(&self, _cx: &App) -> FocusHandle {
479        self.focus_handle.clone()
480    }
481}
482
483impl Item for Zeta2Inspector {
484    type Event = ();
485
486    fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
487        "Zeta2 Inspector".into()
488    }
489}
490
491impl EventEmitter<()> for Zeta2Inspector {}
492
493impl Render for Zeta2Inspector {
494    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
495        let content = match self.last_prediction.as_ref() {
496            None => v_flex()
497                .size_full()
498                .justify_center()
499                .items_center()
500                .child(Label::new("No prediction").size(LabelSize::Large))
501                .into_any(),
502            Some(LastPredictionState::Success(prediction)) => {
503                self.render_last_prediction(prediction, cx).into_any()
504            }
505            Some(LastPredictionState::Replaying { prediction, _task }) => self
506                .render_last_prediction(prediction, cx)
507                .opacity(0.6)
508                .into_any(),
509            Some(LastPredictionState::Failed(err)) => v_flex()
510                .p_4()
511                .gap_2()
512                .child(Label::new(err.clone()).buffer_font(cx))
513                .into_any(),
514        };
515
516        v_flex()
517            .size_full()
518            .bg(cx.theme().colors().editor_background)
519            .child(
520                h_flex()
521                    .w_full()
522                    .child(
523                        v_flex()
524                            .flex_1()
525                            .p_4()
526                            .h_full()
527                            .justify_between()
528                            .child(
529                                v_flex()
530                                    .gap_2()
531                                    .child(Headline::new("Options").size(HeadlineSize::Small))
532                                    .child(
533                                        h_flex()
534                                            .gap_2()
535                                            .items_end()
536                                            .child(self.max_excerpt_bytes_input.clone())
537                                            .child(self.min_excerpt_bytes_input.clone())
538                                            .child(self.cursor_context_ratio_input.clone())
539                                            .child(self.max_prompt_bytes_input.clone())
540                                            .child(
541                                                ui::Button::new("reset-options", "Reset")
542                                                    .disabled(
543                                                        self.zeta.read(cx).options()
544                                                            == &zeta2::DEFAULT_OPTIONS,
545                                                    )
546                                                    .style(ButtonStyle::Outlined)
547                                                    .size(ButtonSize::Large)
548                                                    .on_click(cx.listener(
549                                                        |this, _, window, cx| {
550                                                            this.set_input_options(
551                                                                &zeta2::DEFAULT_OPTIONS,
552                                                                window,
553                                                                cx,
554                                                            );
555                                                        },
556                                                    )),
557                                            ),
558                                    ),
559                            )
560                            .map(|this| {
561                                if let Some(
562                                    LastPredictionState::Success { .. }
563                                    | LastPredictionState::Replaying { .. },
564                                ) = self.last_prediction.as_ref()
565                                {
566                                    this.child(
567                                        ui::ToggleButtonGroup::single_row(
568                                            "prediction",
569                                            [
570                                                ui::ToggleButtonSimple::new(
571                                                    "Context",
572                                                    cx.listener(|this, _, _, cx| {
573                                                        this.active_view = ActiveView::Context;
574                                                        cx.notify();
575                                                    }),
576                                                ),
577                                                ui::ToggleButtonSimple::new(
578                                                    "Inference",
579                                                    cx.listener(|this, _, _, cx| {
580                                                        this.active_view = ActiveView::Inference;
581                                                        cx.notify();
582                                                    }),
583                                                ),
584                                            ],
585                                        )
586                                        .style(ui::ToggleButtonGroupStyle::Outlined)
587                                        .selected_index(
588                                            if self.active_view == ActiveView::Context {
589                                                0
590                                            } else {
591                                                1
592                                            },
593                                        ),
594                                    )
595                                } else {
596                                    this
597                                }
598                            }),
599                    )
600                    .child(ui::vertical_divider())
601                    .map(|this| {
602                        if let Some(
603                            LastPredictionState::Success(prediction)
604                            | LastPredictionState::Replaying { prediction, .. },
605                        ) = self.last_prediction.as_ref()
606                        {
607                            this.child(
608                                v_flex()
609                                    .p_4()
610                                    .gap_2()
611                                    .min_w(px(160.))
612                                    .child(Headline::new("Stats").size(HeadlineSize::Small))
613                                    .child(Self::render_duration(
614                                        "Context retrieval",
615                                        prediction.retrieval_time,
616                                    ))
617                                    .child(Self::render_duration(
618                                        "Prompt planning",
619                                        prediction.prompt_planning_time,
620                                    ))
621                                    .child(Self::render_duration(
622                                        "Inference",
623                                        prediction.inference_time,
624                                    ))
625                                    .child(Self::render_duration(
626                                        "Parsing",
627                                        prediction.parsing_time,
628                                    )),
629                            )
630                        } else {
631                            this
632                        }
633                    }),
634            )
635            .child(content)
636    }
637}
638
639// Using same approach as commit view
640
641struct ExcerptMetadataFile {
642    title: Arc<Path>,
643    worktree_id: WorktreeId,
644}
645
646impl language::File for ExcerptMetadataFile {
647    fn as_local(&self) -> Option<&dyn language::LocalFile> {
648        None
649    }
650
651    fn disk_state(&self) -> DiskState {
652        DiskState::New
653    }
654
655    fn path(&self) -> &Arc<Path> {
656        &self.title
657    }
658
659    fn full_path(&self, _: &App) -> PathBuf {
660        self.title.as_ref().into()
661    }
662
663    fn file_name<'a>(&'a self, _: &'a App) -> &'a OsStr {
664        self.title.file_name().unwrap()
665    }
666
667    fn worktree_id(&self, _: &App) -> WorktreeId {
668        self.worktree_id
669    }
670
671    fn to_proto(&self, _: &App) -> language::proto::File {
672        unimplemented!()
673    }
674
675    fn is_private(&self) -> bool {
676        false
677    }
678}