assistant_panel.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use client::zed_urls;
  6use gpui::{
  7    prelude::*, px, svg, Action, AnyElement, AppContext, AsyncWindowContext, EventEmitter,
  8    FocusHandle, FocusableView, FontWeight, Model, Pixels, Task, View, ViewContext, WeakView,
  9    WindowContext,
 10};
 11use language::LanguageRegistry;
 12use language_model::LanguageModelRegistry;
 13use language_model_selector::LanguageModelSelector;
 14use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, KeyBinding, ListItem, Tab, Tooltip};
 15use workspace::dock::{DockPosition, Panel, PanelEvent};
 16use workspace::Workspace;
 17
 18use crate::active_thread::ActiveThread;
 19use crate::message_editor::MessageEditor;
 20use crate::thread::{Thread, ThreadError, ThreadId};
 21use crate::thread_store::ThreadStore;
 22use crate::{NewThread, OpenHistory, ToggleFocus, ToggleModelSelector};
 23
 24pub fn init(cx: &mut AppContext) {
 25    cx.observe_new_views(
 26        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 27            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 28                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 29            });
 30        },
 31    )
 32    .detach();
 33}
 34
 35pub struct AssistantPanel {
 36    workspace: WeakView<Workspace>,
 37    language_registry: Arc<LanguageRegistry>,
 38    thread_store: Model<ThreadStore>,
 39    thread: Option<View<ActiveThread>>,
 40    message_editor: View<MessageEditor>,
 41    tools: Arc<ToolWorkingSet>,
 42}
 43
 44impl AssistantPanel {
 45    pub fn load(
 46        workspace: WeakView<Workspace>,
 47        cx: AsyncWindowContext,
 48    ) -> Task<Result<View<Self>>> {
 49        cx.spawn(|mut cx| async move {
 50            let tools = Arc::new(ToolWorkingSet::default());
 51            let thread_store = workspace
 52                .update(&mut cx, |workspace, cx| {
 53                    let project = workspace.project().clone();
 54                    ThreadStore::new(project, tools.clone(), cx)
 55                })?
 56                .await?;
 57
 58            workspace.update(&mut cx, |workspace, cx| {
 59                cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
 60            })
 61        })
 62    }
 63
 64    fn new(
 65        workspace: &Workspace,
 66        thread_store: Model<ThreadStore>,
 67        tools: Arc<ToolWorkingSet>,
 68        cx: &mut ViewContext<Self>,
 69    ) -> Self {
 70        let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
 71
 72        Self {
 73            workspace: workspace.weak_handle(),
 74            language_registry: workspace.project().read(cx).languages().clone(),
 75            thread_store,
 76            thread: None,
 77            message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
 78            tools,
 79        }
 80    }
 81
 82    fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
 83        let thread = self
 84            .thread_store
 85            .update(cx, |this, cx| this.create_thread(cx));
 86
 87        self.thread = Some(cx.new_view(|cx| {
 88            ActiveThread::new(
 89                thread.clone(),
 90                self.workspace.clone(),
 91                self.language_registry.clone(),
 92                self.tools.clone(),
 93                cx,
 94            )
 95        }));
 96        self.message_editor = cx.new_view(|cx| MessageEditor::new(thread, cx));
 97        self.message_editor.focus_handle(cx).focus(cx);
 98    }
 99
100    fn open_thread(&mut self, thread_id: &ThreadId, cx: &mut ViewContext<Self>) {
101        let Some(thread) = self
102            .thread_store
103            .update(cx, |this, cx| this.open_thread(thread_id, cx))
104        else {
105            return;
106        };
107
108        self.thread = Some(cx.new_view(|cx| {
109            ActiveThread::new(
110                thread.clone(),
111                self.workspace.clone(),
112                self.language_registry.clone(),
113                self.tools.clone(),
114                cx,
115            )
116        }));
117        self.message_editor = cx.new_view(|cx| MessageEditor::new(thread, cx));
118        self.message_editor.focus_handle(cx).focus(cx);
119    }
120}
121
122impl FocusableView for AssistantPanel {
123    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
124        self.message_editor.focus_handle(cx)
125    }
126}
127
128impl EventEmitter<PanelEvent> for AssistantPanel {}
129
130impl Panel for AssistantPanel {
131    fn persistent_name() -> &'static str {
132        "AssistantPanel2"
133    }
134
135    fn position(&self, _cx: &WindowContext) -> DockPosition {
136        DockPosition::Right
137    }
138
139    fn position_is_valid(&self, _: DockPosition) -> bool {
140        true
141    }
142
143    fn set_position(&mut self, _position: DockPosition, _cx: &mut ViewContext<Self>) {}
144
145    fn size(&self, _cx: &WindowContext) -> Pixels {
146        px(640.)
147    }
148
149    fn set_size(&mut self, _size: Option<Pixels>, _cx: &mut ViewContext<Self>) {}
150
151    fn set_active(&mut self, _active: bool, _cx: &mut ViewContext<Self>) {}
152
153    fn remote_id() -> Option<proto::PanelId> {
154        Some(proto::PanelId::AssistantPanel)
155    }
156
157    fn icon(&self, _cx: &WindowContext) -> Option<IconName> {
158        Some(IconName::ZedAssistant)
159    }
160
161    fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
162        Some("Assistant Panel")
163    }
164
165    fn toggle_action(&self) -> Box<dyn Action> {
166        Box::new(ToggleFocus)
167    }
168}
169
170impl AssistantPanel {
171    fn render_toolbar(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
172        let focus_handle = self.focus_handle(cx);
173
174        h_flex()
175            .id("assistant-toolbar")
176            .justify_between()
177            .gap(DynamicSpacing::Base08.rems(cx))
178            .h(Tab::container_height(cx))
179            .px(DynamicSpacing::Base08.rems(cx))
180            .bg(cx.theme().colors().tab_bar_background)
181            .border_b_1()
182            .border_color(cx.theme().colors().border_variant)
183            .child(h_flex().child(Label::new("Thread Title Goes Here")))
184            .child(
185                h_flex()
186                    .gap(DynamicSpacing::Base08.rems(cx))
187                    .child(self.render_language_model_selector(cx))
188                    .child(Divider::vertical())
189                    .child(
190                        IconButton::new("new-thread", IconName::Plus)
191                            .shape(IconButtonShape::Square)
192                            .icon_size(IconSize::Small)
193                            .style(ButtonStyle::Subtle)
194                            .tooltip({
195                                let focus_handle = focus_handle.clone();
196                                move |cx| {
197                                    Tooltip::for_action_in(
198                                        "New Thread",
199                                        &NewThread,
200                                        &focus_handle,
201                                        cx,
202                                    )
203                                }
204                            })
205                            .on_click(move |_event, cx| {
206                                cx.dispatch_action(NewThread.boxed_clone());
207                            }),
208                    )
209                    .child(
210                        IconButton::new("open-history", IconName::HistoryRerun)
211                            .shape(IconButtonShape::Square)
212                            .icon_size(IconSize::Small)
213                            .style(ButtonStyle::Subtle)
214                            .tooltip({
215                                let focus_handle = focus_handle.clone();
216                                move |cx| {
217                                    Tooltip::for_action_in(
218                                        "Open History",
219                                        &OpenHistory,
220                                        &focus_handle,
221                                        cx,
222                                    )
223                                }
224                            })
225                            .on_click(move |_event, cx| {
226                                cx.dispatch_action(OpenHistory.boxed_clone());
227                            }),
228                    )
229                    .child(
230                        IconButton::new("configure-assistant", IconName::Settings)
231                            .shape(IconButtonShape::Square)
232                            .icon_size(IconSize::Small)
233                            .style(ButtonStyle::Subtle)
234                            .tooltip(move |cx| Tooltip::text("Configure Assistant", cx))
235                            .on_click(move |_event, _cx| {
236                                println!("Configure Assistant");
237                            }),
238                    ),
239            )
240    }
241
242    fn render_language_model_selector(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
243        let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
244        let active_model = LanguageModelRegistry::read_global(cx).active_model();
245
246        LanguageModelSelector::new(
247            |model, _cx| {
248                println!("Selected {:?}", model.name());
249            },
250            ButtonLike::new("active-model")
251                .style(ButtonStyle::Subtle)
252                .child(
253                    h_flex()
254                        .w_full()
255                        .gap_0p5()
256                        .child(
257                            div()
258                                .overflow_x_hidden()
259                                .flex_grow()
260                                .whitespace_nowrap()
261                                .child(match (active_provider, active_model) {
262                                    (Some(provider), Some(model)) => h_flex()
263                                        .gap_1()
264                                        .child(
265                                            Icon::new(
266                                                model.icon().unwrap_or_else(|| provider.icon()),
267                                            )
268                                            .color(Color::Muted)
269                                            .size(IconSize::XSmall),
270                                        )
271                                        .child(
272                                            Label::new(model.name().0)
273                                                .size(LabelSize::Small)
274                                                .color(Color::Muted),
275                                        )
276                                        .into_any_element(),
277                                    _ => Label::new("No model selected")
278                                        .size(LabelSize::Small)
279                                        .color(Color::Muted)
280                                        .into_any_element(),
281                                }),
282                        )
283                        .child(
284                            Icon::new(IconName::ChevronDown)
285                                .color(Color::Muted)
286                                .size(IconSize::XSmall),
287                        ),
288                )
289                .tooltip(move |cx| Tooltip::for_action("Change Model", &ToggleModelSelector, cx)),
290        )
291    }
292
293    fn render_active_thread_or_empty_state(&self, cx: &mut ViewContext<Self>) -> AnyElement {
294        let Some(thread) = self.thread.as_ref() else {
295            return self.render_thread_empty_state(cx).into_any_element();
296        };
297
298        if thread.read(cx).is_empty() {
299            return self.render_thread_empty_state(cx).into_any_element();
300        }
301
302        thread.clone().into_any()
303    }
304
305    fn render_thread_empty_state(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
306        let recent_threads = self
307            .thread_store
308            .update(cx, |this, cx| this.recent_threads(3, cx));
309
310        v_flex()
311            .gap_2()
312            .mx_auto()
313            .child(
314                v_flex().w_full().child(
315                    svg()
316                        .path("icons/logo_96.svg")
317                        .text_color(cx.theme().colors().text)
318                        .w(px(40.))
319                        .h(px(40.))
320                        .mx_auto()
321                        .mb_4(),
322                ),
323            )
324            .child(v_flex())
325            .child(
326                h_flex()
327                    .w_full()
328                    .justify_center()
329                    .child(Label::new("Context Examples:").size(LabelSize::Small)),
330            )
331            .child(
332                h_flex()
333                    .gap_2()
334                    .justify_center()
335                    .child(
336                        h_flex()
337                            .gap_1()
338                            .p_0p5()
339                            .rounded_md()
340                            .border_1()
341                            .border_color(cx.theme().colors().border_variant)
342                            .child(
343                                Icon::new(IconName::Terminal)
344                                    .size(IconSize::Small)
345                                    .color(Color::Disabled),
346                            )
347                            .child(Label::new("Terminal").size(LabelSize::Small)),
348                    )
349                    .child(
350                        h_flex()
351                            .gap_1()
352                            .p_0p5()
353                            .rounded_md()
354                            .border_1()
355                            .border_color(cx.theme().colors().border_variant)
356                            .child(
357                                Icon::new(IconName::Folder)
358                                    .size(IconSize::Small)
359                                    .color(Color::Disabled),
360                            )
361                            .child(Label::new("/src/components").size(LabelSize::Small)),
362                    ),
363            )
364            .child(
365                h_flex()
366                    .w_full()
367                    .justify_center()
368                    .child(Label::new("Recent Threads:").size(LabelSize::Small)),
369            )
370            .child(
371                v_flex().gap_2().children(
372                    recent_threads
373                        .into_iter()
374                        .map(|thread| self.render_past_thread(thread, cx)),
375                ),
376            )
377            .child(
378                h_flex().w_full().justify_center().child(
379                    Button::new("view-all-past-threads", "View All Past Threads")
380                        .style(ButtonStyle::Subtle)
381                        .label_size(LabelSize::Small)
382                        .key_binding(KeyBinding::for_action_in(
383                            &OpenHistory,
384                            &self.focus_handle(cx),
385                            cx,
386                        ))
387                        .on_click(move |_event, cx| {
388                            cx.dispatch_action(OpenHistory.boxed_clone());
389                        }),
390                ),
391            )
392    }
393
394    fn render_past_thread(
395        &self,
396        thread: Model<Thread>,
397        cx: &mut ViewContext<Self>,
398    ) -> impl IntoElement {
399        let id = thread.read(cx).id().clone();
400
401        ListItem::new(("past-thread", thread.entity_id()))
402            .start_slot(Icon::new(IconName::MessageBubbles))
403            .child(Label::new(format!("Thread {id}")))
404            .end_slot(
405                h_flex()
406                    .gap_2()
407                    .child(Label::new("1 hour ago").color(Color::Disabled))
408                    .child(
409                        IconButton::new("delete", IconName::TrashAlt)
410                            .shape(IconButtonShape::Square)
411                            .icon_size(IconSize::Small),
412                    ),
413            )
414            .on_click(cx.listener(move |this, _event, cx| {
415                this.open_thread(&id, cx);
416            }))
417    }
418
419    fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
420        let last_error = self.thread.as_ref()?.read(cx).last_error()?;
421
422        Some(
423            div()
424                .absolute()
425                .right_3()
426                .bottom_12()
427                .max_w_96()
428                .py_2()
429                .px_3()
430                .elevation_2(cx)
431                .occlude()
432                .child(match last_error {
433                    ThreadError::PaymentRequired => self.render_payment_required_error(cx),
434                    ThreadError::MaxMonthlySpendReached => {
435                        self.render_max_monthly_spend_reached_error(cx)
436                    }
437                    ThreadError::Message(error_message) => {
438                        self.render_error_message(&error_message, cx)
439                    }
440                })
441                .into_any(),
442        )
443    }
444
445    fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
446        const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
447
448        v_flex()
449            .gap_0p5()
450            .child(
451                h_flex()
452                    .gap_1p5()
453                    .items_center()
454                    .child(Icon::new(IconName::XCircle).color(Color::Error))
455                    .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
456            )
457            .child(
458                div()
459                    .id("error-message")
460                    .max_h_24()
461                    .overflow_y_scroll()
462                    .child(Label::new(ERROR_MESSAGE)),
463            )
464            .child(
465                h_flex()
466                    .justify_end()
467                    .mt_1()
468                    .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
469                        |this, _, cx| {
470                            if let Some(thread) = this.thread.as_ref() {
471                                thread.update(cx, |this, _cx| {
472                                    this.clear_last_error();
473                                });
474                            }
475
476                            cx.open_url(&zed_urls::account_url(cx));
477                            cx.notify();
478                        },
479                    )))
480                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
481                        |this, _, cx| {
482                            if let Some(thread) = this.thread.as_ref() {
483                                thread.update(cx, |this, _cx| {
484                                    this.clear_last_error();
485                                });
486                            }
487
488                            cx.notify();
489                        },
490                    ))),
491            )
492            .into_any()
493    }
494
495    fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
496        const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
497
498        v_flex()
499            .gap_0p5()
500            .child(
501                h_flex()
502                    .gap_1p5()
503                    .items_center()
504                    .child(Icon::new(IconName::XCircle).color(Color::Error))
505                    .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
506            )
507            .child(
508                div()
509                    .id("error-message")
510                    .max_h_24()
511                    .overflow_y_scroll()
512                    .child(Label::new(ERROR_MESSAGE)),
513            )
514            .child(
515                h_flex()
516                    .justify_end()
517                    .mt_1()
518                    .child(
519                        Button::new("subscribe", "Update Monthly Spend Limit").on_click(
520                            cx.listener(|this, _, cx| {
521                                if let Some(thread) = this.thread.as_ref() {
522                                    thread.update(cx, |this, _cx| {
523                                        this.clear_last_error();
524                                    });
525                                }
526
527                                cx.open_url(&zed_urls::account_url(cx));
528                                cx.notify();
529                            }),
530                        ),
531                    )
532                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
533                        |this, _, cx| {
534                            if let Some(thread) = this.thread.as_ref() {
535                                thread.update(cx, |this, _cx| {
536                                    this.clear_last_error();
537                                });
538                            }
539
540                            cx.notify();
541                        },
542                    ))),
543            )
544            .into_any()
545    }
546
547    fn render_error_message(
548        &self,
549        error_message: &SharedString,
550        cx: &mut ViewContext<Self>,
551    ) -> AnyElement {
552        v_flex()
553            .gap_0p5()
554            .child(
555                h_flex()
556                    .gap_1p5()
557                    .items_center()
558                    .child(Icon::new(IconName::XCircle).color(Color::Error))
559                    .child(
560                        Label::new("Error interacting with language model")
561                            .weight(FontWeight::MEDIUM),
562                    ),
563            )
564            .child(
565                div()
566                    .id("error-message")
567                    .max_h_32()
568                    .overflow_y_scroll()
569                    .child(Label::new(error_message.clone())),
570            )
571            .child(
572                h_flex()
573                    .justify_end()
574                    .mt_1()
575                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
576                        |this, _, cx| {
577                            if let Some(thread) = this.thread.as_ref() {
578                                thread.update(cx, |this, _cx| {
579                                    this.clear_last_error();
580                                });
581                            }
582
583                            cx.notify();
584                        },
585                    ))),
586            )
587            .into_any()
588    }
589}
590
591impl Render for AssistantPanel {
592    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
593        v_flex()
594            .key_context("AssistantPanel2")
595            .justify_between()
596            .size_full()
597            .on_action(cx.listener(|this, _: &NewThread, cx| {
598                this.new_thread(cx);
599            }))
600            .on_action(cx.listener(|_this, _: &OpenHistory, _cx| {
601                println!("Open History");
602            }))
603            .child(self.render_toolbar(cx))
604            .child(self.render_active_thread_or_empty_state(cx))
605            .child(
606                h_flex()
607                    .border_t_1()
608                    .border_color(cx.theme().colors().border_variant)
609                    .child(self.message_editor.clone()),
610            )
611            .children(self.render_last_error(cx))
612    }
613}