zeta2_tools.rs

  1use std::{
  2    collections::hash_map::Entry,
  3    ffi::OsStr,
  4    path::{Path, PathBuf},
  5    str::FromStr,
  6    sync::Arc,
  7    time::Duration,
  8};
  9
 10use chrono::TimeDelta;
 11use client::{Client, UserStore};
 12use collections::HashMap;
 13use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
 14use futures::StreamExt as _;
 15use gpui::{
 16    Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
 17    prelude::*,
 18};
 19use language::{Buffer, DiskState};
 20use project::{Project, WorktreeId};
 21use ui::prelude::*;
 22use ui_input::SingleLineInput;
 23use util::ResultExt;
 24use workspace::{Item, SplitDirection, Workspace};
 25use zeta2::{Zeta, ZetaOptions};
 26
 27use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
 28
 29actions!(
 30    dev,
 31    [
 32        /// Opens the language server protocol logs viewer.
 33        OpenZeta2Inspector
 34    ]
 35);
 36
 37pub fn init(cx: &mut App) {
 38    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
 39        workspace.register_action(move |workspace, _: &OpenZeta2Inspector, window, cx| {
 40            let project = workspace.project();
 41            workspace.split_item(
 42                SplitDirection::Right,
 43                Box::new(cx.new(|cx| {
 44                    Zeta2Inspector::new(
 45                        &project,
 46                        workspace.client(),
 47                        workspace.user_store(),
 48                        window,
 49                        cx,
 50                    )
 51                })),
 52                window,
 53                cx,
 54            );
 55        });
 56    })
 57    .detach();
 58}
 59
 60pub struct Zeta2Inspector {
 61    focus_handle: FocusHandle,
 62    project: Entity<Project>,
 63    last_prediction: Option<LastPredictionState>,
 64    max_bytes_input: Entity<SingleLineInput>,
 65    min_bytes_input: Entity<SingleLineInput>,
 66    cursor_context_ratio_input: Entity<SingleLineInput>,
 67    active_view: ActiveView,
 68    zeta: Entity<Zeta>,
 69    _active_editor_subscription: Option<Subscription>,
 70    _update_state_task: Task<()>,
 71    _receive_task: Task<()>,
 72}
 73
 74#[derive(PartialEq)]
 75enum ActiveView {
 76    Context,
 77    Inference,
 78}
 79
 80enum LastPredictionState {
 81    Failed(SharedString),
 82    Success(LastPrediction),
 83    Replaying {
 84        prediction: LastPrediction,
 85        _task: Task<()>,
 86    },
 87}
 88
 89struct LastPrediction {
 90    context_editor: Entity<Editor>,
 91    retrieval_time: TimeDelta,
 92    prompt_planning_time: TimeDelta,
 93    inference_time: TimeDelta,
 94    parsing_time: TimeDelta,
 95    prompt_editor: Entity<Editor>,
 96    model_response_editor: Entity<Editor>,
 97    buffer: WeakEntity<Buffer>,
 98    position: language::Anchor,
 99}
100
101impl Zeta2Inspector {
102    pub fn new(
103        project: &Entity<Project>,
104        client: &Arc<Client>,
105        user_store: &Entity<UserStore>,
106        window: &mut Window,
107        cx: &mut Context<Self>,
108    ) -> Self {
109        let zeta = Zeta::global(client, user_store, cx);
110        let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info());
111
112        let receive_task = cx.spawn_in(window, async move |this, cx| {
113            while let Some(prediction_result) = request_rx.next().await {
114                this.update_in(cx, |this, window, cx| match prediction_result {
115                    Ok(prediction) => {
116                        this.update_last_prediction(prediction, window, cx);
117                    }
118                    Err(err) => {
119                        this.last_prediction = Some(LastPredictionState::Failed(err.into()));
120                        cx.notify();
121                    }
122                })
123                .ok();
124            }
125        });
126
127        let mut this = Self {
128            focus_handle: cx.focus_handle(),
129            project: project.clone(),
130            last_prediction: None,
131            active_view: ActiveView::Context,
132            max_bytes_input: Self::number_input("Max Bytes", window, cx),
133            min_bytes_input: Self::number_input("Min Bytes", window, cx),
134            cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
135            zeta: zeta.clone(),
136            _active_editor_subscription: None,
137            _update_state_task: Task::ready(()),
138            _receive_task: receive_task,
139        };
140        this.set_input_options(&zeta.read(cx).options().clone(), window, cx);
141        this
142    }
143
144    fn set_input_options(
145        &mut self,
146        options: &ZetaOptions,
147        window: &mut Window,
148        cx: &mut Context<Self>,
149    ) {
150        self.max_bytes_input.update(cx, |input, cx| {
151            input.set_text(options.excerpt.max_bytes.to_string(), window, cx);
152        });
153        self.min_bytes_input.update(cx, |input, cx| {
154            input.set_text(options.excerpt.min_bytes.to_string(), window, cx);
155        });
156        self.cursor_context_ratio_input.update(cx, |input, cx| {
157            input.set_text(
158                format!(
159                    "{:.2}",
160                    options.excerpt.target_before_cursor_over_total_bytes
161                ),
162                window,
163                cx,
164            );
165        });
166        cx.notify();
167    }
168
169    fn set_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
170        self.zeta.update(cx, |this, _cx| this.set_options(options));
171
172        const THROTTLE_TIME: Duration = Duration::from_millis(100);
173
174        if let Some(
175            LastPredictionState::Success(prediction)
176            | LastPredictionState::Replaying { prediction, .. },
177        ) = self.last_prediction.take()
178        {
179            if let Some(buffer) = prediction.buffer.upgrade() {
180                let position = prediction.position;
181                let zeta = self.zeta.clone();
182                let project = self.project.clone();
183                let task = cx.spawn(async move |_this, cx| {
184                    cx.background_executor().timer(THROTTLE_TIME).await;
185                    if let Some(task) = zeta
186                        .update(cx, |zeta, cx| {
187                            zeta.request_prediction(&project, &buffer, position, cx)
188                        })
189                        .ok()
190                    {
191                        task.await.log_err();
192                    }
193                });
194                self.last_prediction = Some(LastPredictionState::Replaying {
195                    prediction,
196                    _task: task,
197                });
198            } else {
199                self.last_prediction = Some(LastPredictionState::Failed("Buffer dropped".into()));
200            }
201        }
202
203        cx.notify();
204    }
205
206    fn number_input(
207        label: &'static str,
208        window: &mut Window,
209        cx: &mut Context<Self>,
210    ) -> Entity<SingleLineInput> {
211        let input = cx.new(|cx| {
212            SingleLineInput::new(window, cx, "")
213                .label(label)
214                .label_min_width(px(64.))
215        });
216
217        cx.subscribe_in(
218            &input.read(cx).editor().clone(),
219            window,
220            |this, _, event, _window, cx| {
221                let EditorEvent::BufferEdited = event else {
222                    return;
223                };
224
225                fn number_input_value<T: FromStr + Default>(
226                    input: &Entity<SingleLineInput>,
227                    cx: &App,
228                ) -> T {
229                    input
230                        .read(cx)
231                        .editor()
232                        .read(cx)
233                        .text(cx)
234                        .parse::<T>()
235                        .unwrap_or_default()
236                }
237
238                let excerpt_options = EditPredictionExcerptOptions {
239                    max_bytes: number_input_value(&this.max_bytes_input, cx),
240                    min_bytes: number_input_value(&this.min_bytes_input, cx),
241                    target_before_cursor_over_total_bytes: number_input_value(
242                        &this.cursor_context_ratio_input,
243                        cx,
244                    ),
245                };
246
247                this.set_options(
248                    ZetaOptions {
249                        excerpt: excerpt_options,
250                        ..this.zeta.read(cx).options().clone()
251                    },
252                    cx,
253                );
254            },
255        )
256        .detach();
257        input
258    }
259
260    fn update_last_prediction(
261        &mut self,
262        prediction: zeta2::PredictionDebugInfo,
263        window: &mut Window,
264        cx: &mut Context<Self>,
265    ) {
266        let Some(worktree_id) = self
267            .project
268            .read(cx)
269            .worktrees(cx)
270            .next()
271            .map(|worktree| worktree.read(cx).id())
272        else {
273            log::error!("Open a worktree to use edit prediction debug view");
274            self.last_prediction.take();
275            return;
276        };
277
278        self._update_state_task = cx.spawn_in(window, {
279            let language_registry = self.project.read(cx).languages().clone();
280            async move |this, cx| {
281                let mut languages = HashMap::default();
282                for lang_id in prediction
283                    .context
284                    .snippets
285                    .iter()
286                    .map(|snippet| snippet.declaration.identifier().language_id)
287                    .chain(prediction.context.excerpt_text.language_id)
288                {
289                    if let Entry::Vacant(entry) = languages.entry(lang_id) {
290                        // Most snippets are gonna be the same language,
291                        // so we think it's fine to do this sequentially for now
292                        entry.insert(language_registry.language_for_id(lang_id).await.ok());
293                    }
294                }
295
296                let markdown_language = language_registry
297                    .language_for_name("Markdown")
298                    .await
299                    .log_err();
300
301                this.update_in(cx, |this, window, cx| {
302                    let context_editor = cx.new(|cx| {
303                        let multibuffer = cx.new(|cx| {
304                            let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
305                            let excerpt_file = Arc::new(ExcerptMetadataFile {
306                                title: PathBuf::from("Cursor Excerpt").into(),
307                                worktree_id,
308                            });
309
310                            let excerpt_buffer = cx.new(|cx| {
311                                let mut buffer =
312                                    Buffer::local(prediction.context.excerpt_text.body, cx);
313                                if let Some(language) = prediction
314                                    .context
315                                    .excerpt_text
316                                    .language_id
317                                    .as_ref()
318                                    .and_then(|id| languages.get(id))
319                                {
320                                    buffer.set_language(language.clone(), cx);
321                                }
322                                buffer.file_updated(excerpt_file, cx);
323                                buffer
324                            });
325
326                            multibuffer.push_excerpts(
327                                excerpt_buffer,
328                                [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
329                                cx,
330                            );
331
332                            for snippet in &prediction.context.snippets {
333                                let path = this
334                                    .project
335                                    .read(cx)
336                                    .path_for_entry(snippet.declaration.project_entry_id(), cx);
337
338                                let snippet_file = Arc::new(ExcerptMetadataFile {
339                                    title: PathBuf::from(format!(
340                                        "{} (Score density: {})",
341                                        path.map(|p| p.path.to_string_lossy().to_string())
342                                            .unwrap_or_else(|| "".to_string()),
343                                        snippet.score_density(SnippetStyle::Declaration)
344                                    ))
345                                    .into(),
346                                    worktree_id,
347                                });
348
349                                let excerpt_buffer = cx.new(|cx| {
350                                    let mut buffer =
351                                        Buffer::local(snippet.declaration.item_text().0, cx);
352                                    buffer.file_updated(snippet_file, cx);
353                                    if let Some(language) =
354                                        languages.get(&snippet.declaration.identifier().language_id)
355                                    {
356                                        buffer.set_language(language.clone(), cx);
357                                    }
358                                    buffer
359                                });
360
361                                multibuffer.push_excerpts(
362                                    excerpt_buffer,
363                                    [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
364                                    cx,
365                                );
366                            }
367
368                            multibuffer
369                        });
370
371                        Editor::new(EditorMode::full(), multibuffer, None, window, cx)
372                    });
373
374                    let last_prediction = LastPrediction {
375                        context_editor,
376                        prompt_editor: cx.new(|cx| {
377                            let buffer = cx.new(|cx| {
378                                let mut buffer = Buffer::local(prediction.request.prompt, cx);
379                                buffer.set_language(markdown_language.clone(), cx);
380                                buffer
381                            });
382                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
383                            let mut editor =
384                                Editor::new(EditorMode::full(), buffer, None, window, cx);
385                            editor.set_read_only(true);
386                            editor.set_show_line_numbers(false, cx);
387                            editor.set_show_gutter(false, cx);
388                            editor.set_show_scrollbars(false, cx);
389                            editor
390                        }),
391                        model_response_editor: cx.new(|cx| {
392                            let buffer = cx.new(|cx| {
393                                let mut buffer =
394                                    Buffer::local(prediction.request.model_response, cx);
395                                buffer.set_language(markdown_language, cx);
396                                buffer
397                            });
398                            let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
399                            let mut editor =
400                                Editor::new(EditorMode::full(), buffer, None, window, cx);
401                            editor.set_read_only(true);
402                            editor.set_show_line_numbers(false, cx);
403                            editor.set_show_gutter(false, cx);
404                            editor.set_show_scrollbars(false, cx);
405                            editor
406                        }),
407                        retrieval_time: prediction.retrieval_time,
408                        prompt_planning_time: prediction.request.prompt_planning_time,
409                        inference_time: prediction.request.inference_time,
410                        parsing_time: prediction.request.parsing_time,
411                        buffer: prediction.buffer,
412                        position: prediction.position,
413                    };
414                    this.last_prediction = Some(LastPredictionState::Success(last_prediction));
415                    cx.notify();
416                })
417                .ok();
418            }
419        });
420    }
421
422    fn render_duration(name: &'static str, time: chrono::TimeDelta) -> Div {
423        h_flex()
424            .gap_1()
425            .child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
426            .child(
427                Label::new(if time.num_microseconds().unwrap_or(0) > 1000 {
428                    format!("{} ms", time.num_milliseconds())
429                } else {
430                    format!("{} µs", time.num_microseconds().unwrap_or(0))
431                })
432                .size(LabelSize::Small),
433            )
434    }
435
436    fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
437        match &self.active_view {
438            ActiveView::Context => div().size_full().child(prediction.context_editor.clone()),
439            ActiveView::Inference => h_flex()
440                .items_start()
441                .w_full()
442                .flex_1()
443                .border_t_1()
444                .border_color(cx.theme().colors().border)
445                .bg(cx.theme().colors().editor_background)
446                .child(
447                    v_flex()
448                        .flex_1()
449                        .gap_2()
450                        .p_4()
451                        .h_full()
452                        .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
453                        .child(prediction.prompt_editor.clone()),
454                )
455                .child(ui::vertical_divider())
456                .child(
457                    v_flex()
458                        .flex_1()
459                        .gap_2()
460                        .h_full()
461                        .p_4()
462                        .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall))
463                        .child(prediction.model_response_editor.clone()),
464                ),
465        }
466    }
467}
468
469impl Focusable for Zeta2Inspector {
470    fn focus_handle(&self, _cx: &App) -> FocusHandle {
471        self.focus_handle.clone()
472    }
473}
474
475impl Item for Zeta2Inspector {
476    type Event = ();
477
478    fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
479        "Zeta2 Inspector".into()
480    }
481}
482
483impl EventEmitter<()> for Zeta2Inspector {}
484
485impl Render for Zeta2Inspector {
486    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
487        let content = match self.last_prediction.as_ref() {
488            None => v_flex()
489                .size_full()
490                .justify_center()
491                .items_center()
492                .child(Label::new("No prediction").size(LabelSize::Large))
493                .into_any(),
494            Some(LastPredictionState::Success(prediction)) => {
495                self.render_last_prediction(prediction, cx).into_any()
496            }
497            Some(LastPredictionState::Replaying { prediction, _task }) => self
498                .render_last_prediction(prediction, cx)
499                .opacity(0.6)
500                .into_any(),
501            Some(LastPredictionState::Failed(err)) => v_flex()
502                .p_4()
503                .gap_2()
504                .child(Label::new(err.clone()).buffer_font(cx))
505                .into_any(),
506        };
507
508        v_flex()
509            .size_full()
510            .bg(cx.theme().colors().editor_background)
511            .child(
512                h_flex()
513                    .w_full()
514                    .child(
515                        v_flex()
516                            .flex_1()
517                            .p_4()
518                            .h_full()
519                            .justify_between()
520                            .child(
521                                v_flex()
522                                    .gap_2()
523                                    .child(
524                                        Headline::new("Excerpt Options").size(HeadlineSize::Small),
525                                    )
526                                    .child(
527                                        h_flex()
528                                            .gap_2()
529                                            .items_end()
530                                            .child(self.max_bytes_input.clone())
531                                            .child(self.min_bytes_input.clone())
532                                            .child(self.cursor_context_ratio_input.clone())
533                                            .child(
534                                                ui::Button::new("reset-options", "Reset")
535                                                    .disabled(
536                                                        self.zeta.read(cx).options()
537                                                            == &zeta2::DEFAULT_OPTIONS,
538                                                    )
539                                                    .style(ButtonStyle::Outlined)
540                                                    .size(ButtonSize::Large)
541                                                    .on_click(cx.listener(
542                                                        |this, _, window, cx| {
543                                                            this.set_input_options(
544                                                                &zeta2::DEFAULT_OPTIONS,
545                                                                window,
546                                                                cx,
547                                                            );
548                                                        },
549                                                    )),
550                                            ),
551                                    ),
552                            )
553                            .map(|this| {
554                                if let Some(
555                                    LastPredictionState::Success { .. }
556                                    | LastPredictionState::Replaying { .. },
557                                ) = self.last_prediction.as_ref()
558                                {
559                                    this.child(
560                                        ui::ToggleButtonGroup::single_row(
561                                            "prediction",
562                                            [
563                                                ui::ToggleButtonSimple::new(
564                                                    "Context",
565                                                    cx.listener(|this, _, _, cx| {
566                                                        this.active_view = ActiveView::Context;
567                                                        cx.notify();
568                                                    }),
569                                                ),
570                                                ui::ToggleButtonSimple::new(
571                                                    "Inference",
572                                                    cx.listener(|this, _, _, cx| {
573                                                        this.active_view = ActiveView::Inference;
574                                                        cx.notify();
575                                                    }),
576                                                ),
577                                            ],
578                                        )
579                                        .style(ui::ToggleButtonGroupStyle::Outlined)
580                                        .selected_index(
581                                            if self.active_view == ActiveView::Context {
582                                                0
583                                            } else {
584                                                1
585                                            },
586                                        ),
587                                    )
588                                } else {
589                                    this
590                                }
591                            }),
592                    )
593                    .child(ui::vertical_divider())
594                    .map(|this| {
595                        if let Some(
596                            LastPredictionState::Success(prediction)
597                            | LastPredictionState::Replaying { prediction, .. },
598                        ) = self.last_prediction.as_ref()
599                        {
600                            this.child(
601                                v_flex()
602                                    .p_4()
603                                    .gap_2()
604                                    .min_w(px(160.))
605                                    .child(Headline::new("Stats").size(HeadlineSize::Small))
606                                    .child(Self::render_duration(
607                                        "Context retrieval",
608                                        prediction.retrieval_time,
609                                    ))
610                                    .child(Self::render_duration(
611                                        "Prompt planning",
612                                        prediction.prompt_planning_time,
613                                    ))
614                                    .child(Self::render_duration(
615                                        "Inference",
616                                        prediction.inference_time,
617                                    ))
618                                    .child(Self::render_duration(
619                                        "Parsing",
620                                        prediction.parsing_time,
621                                    )),
622                            )
623                        } else {
624                            this
625                        }
626                    }),
627            )
628            .child(content)
629    }
630}
631
632// Using same approach as commit view
633
634struct ExcerptMetadataFile {
635    title: Arc<Path>,
636    worktree_id: WorktreeId,
637}
638
639impl language::File for ExcerptMetadataFile {
640    fn as_local(&self) -> Option<&dyn language::LocalFile> {
641        None
642    }
643
644    fn disk_state(&self) -> DiskState {
645        DiskState::New
646    }
647
648    fn path(&self) -> &Arc<Path> {
649        &self.title
650    }
651
652    fn full_path(&self, _: &App) -> PathBuf {
653        self.title.as_ref().into()
654    }
655
656    fn file_name<'a>(&'a self, _: &'a App) -> &'a OsStr {
657        self.title.file_name().unwrap()
658    }
659
660    fn worktree_id(&self, _: &App) -> WorktreeId {
661        self.worktree_id
662    }
663
664    fn to_proto(&self, _: &App) -> language::proto::File {
665        unimplemented!()
666    }
667
668    fn is_private(&self) -> bool {
669        false
670    }
671}