assistant_panel.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use client::zed_urls;
  6use gpui::{
  7    prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
  8    FocusableView, FontWeight, Model, Pixels, Subscription, Task, View, ViewContext, WeakView,
  9    WindowContext,
 10};
 11use language_model::{LanguageModelRegistry, Role};
 12use language_model_selector::LanguageModelSelector;
 13use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip};
 14use workspace::dock::{DockPosition, Panel, PanelEvent};
 15use workspace::Workspace;
 16
 17use crate::message_editor::MessageEditor;
 18use crate::thread::{Message, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::{NewThread, ToggleFocus, ToggleModelSelector};
 21
 22pub fn init(cx: &mut AppContext) {
 23    cx.observe_new_views(
 24        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 25            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 26                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 27            });
 28        },
 29    )
 30    .detach();
 31}
 32
 33pub struct AssistantPanel {
 34    workspace: WeakView<Workspace>,
 35    #[allow(unused)]
 36    thread_store: Model<ThreadStore>,
 37    thread: Model<Thread>,
 38    message_editor: View<MessageEditor>,
 39    tools: Arc<ToolWorkingSet>,
 40    last_error: Option<ThreadError>,
 41    _subscriptions: Vec<Subscription>,
 42}
 43
 44impl AssistantPanel {
 45    pub fn load(
 46        workspace: WeakView<Workspace>,
 47        cx: AsyncWindowContext,
 48    ) -> Task<Result<View<Self>>> {
 49        cx.spawn(|mut cx| async move {
 50            let tools = Arc::new(ToolWorkingSet::default());
 51            let thread_store = workspace
 52                .update(&mut cx, |workspace, cx| {
 53                    let project = workspace.project().clone();
 54                    ThreadStore::new(project, tools.clone(), cx)
 55                })?
 56                .await?;
 57
 58            workspace.update(&mut cx, |workspace, cx| {
 59                cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
 60            })
 61        })
 62    }
 63
 64    fn new(
 65        workspace: &Workspace,
 66        thread_store: Model<ThreadStore>,
 67        tools: Arc<ToolWorkingSet>,
 68        cx: &mut ViewContext<Self>,
 69    ) -> Self {
 70        let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
 71        let subscriptions = vec![
 72            cx.observe(&thread, |_, _, cx| cx.notify()),
 73            cx.subscribe(&thread, Self::handle_thread_event),
 74        ];
 75
 76        Self {
 77            workspace: workspace.weak_handle(),
 78            thread_store,
 79            thread: thread.clone(),
 80            message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
 81            tools,
 82            last_error: None,
 83            _subscriptions: subscriptions,
 84        }
 85    }
 86
 87    fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
 88        let tools = self.thread.read(cx).tools().clone();
 89        let thread = cx.new_model(|cx| Thread::new(tools, cx));
 90        let subscriptions = vec![
 91            cx.observe(&thread, |_, _, cx| cx.notify()),
 92            cx.subscribe(&thread, Self::handle_thread_event),
 93        ];
 94
 95        self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
 96        self.thread = thread;
 97        self._subscriptions = subscriptions;
 98
 99        self.message_editor.focus_handle(cx).focus(cx);
100    }
101
102    fn handle_thread_event(
103        &mut self,
104        _: Model<Thread>,
105        event: &ThreadEvent,
106        cx: &mut ViewContext<Self>,
107    ) {
108        match event {
109            ThreadEvent::ShowError(error) => {
110                self.last_error = Some(error.clone());
111            }
112            ThreadEvent::StreamedCompletion => {}
113            ThreadEvent::UsePendingTools => {
114                let pending_tool_uses = self
115                    .thread
116                    .read(cx)
117                    .pending_tool_uses()
118                    .into_iter()
119                    .filter(|tool_use| tool_use.status.is_idle())
120                    .cloned()
121                    .collect::<Vec<_>>();
122
123                for tool_use in pending_tool_uses {
124                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
125                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
126
127                        self.thread.update(cx, |thread, cx| {
128                            thread.insert_tool_output(
129                                tool_use.assistant_message_id,
130                                tool_use.id.clone(),
131                                task,
132                                cx,
133                            );
134                        });
135                    }
136                }
137            }
138            ThreadEvent::ToolFinished { .. } => {}
139        }
140    }
141}
142
143impl FocusableView for AssistantPanel {
144    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
145        self.message_editor.focus_handle(cx)
146    }
147}
148
149impl EventEmitter<PanelEvent> for AssistantPanel {}
150
151impl Panel for AssistantPanel {
152    fn persistent_name() -> &'static str {
153        "AssistantPanel2"
154    }
155
156    fn position(&self, _cx: &WindowContext) -> DockPosition {
157        DockPosition::Right
158    }
159
160    fn position_is_valid(&self, _: DockPosition) -> bool {
161        true
162    }
163
164    fn set_position(&mut self, _position: DockPosition, _cx: &mut ViewContext<Self>) {}
165
166    fn size(&self, _cx: &WindowContext) -> Pixels {
167        px(640.)
168    }
169
170    fn set_size(&mut self, _size: Option<Pixels>, _cx: &mut ViewContext<Self>) {}
171
172    fn set_active(&mut self, _active: bool, _cx: &mut ViewContext<Self>) {}
173
174    fn remote_id() -> Option<proto::PanelId> {
175        Some(proto::PanelId::AssistantPanel)
176    }
177
178    fn icon(&self, _cx: &WindowContext) -> Option<IconName> {
179        Some(IconName::ZedAssistant)
180    }
181
182    fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
183        Some("Assistant Panel")
184    }
185
186    fn toggle_action(&self) -> Box<dyn Action> {
187        Box::new(ToggleFocus)
188    }
189}
190
191impl AssistantPanel {
192    fn render_toolbar(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
193        let focus_handle = self.focus_handle(cx);
194
195        h_flex()
196            .id("assistant-toolbar")
197            .justify_between()
198            .gap(DynamicSpacing::Base08.rems(cx))
199            .h(Tab::container_height(cx))
200            .px(DynamicSpacing::Base08.rems(cx))
201            .bg(cx.theme().colors().tab_bar_background)
202            .border_b_1()
203            .border_color(cx.theme().colors().border_variant)
204            .child(h_flex().child(Label::new("Thread Title Goes Here")))
205            .child(
206                h_flex()
207                    .gap(DynamicSpacing::Base08.rems(cx))
208                    .child(self.render_language_model_selector(cx))
209                    .child(Divider::vertical())
210                    .child(
211                        IconButton::new("new-thread", IconName::Plus)
212                            .shape(IconButtonShape::Square)
213                            .icon_size(IconSize::Small)
214                            .style(ButtonStyle::Subtle)
215                            .tooltip({
216                                let focus_handle = focus_handle.clone();
217                                move |cx| {
218                                    Tooltip::for_action_in(
219                                        "New Thread",
220                                        &NewThread,
221                                        &focus_handle,
222                                        cx,
223                                    )
224                                }
225                            })
226                            .on_click(move |_event, _cx| {
227                                println!("New Thread");
228                            }),
229                    )
230                    .child(
231                        IconButton::new("open-history", IconName::HistoryRerun)
232                            .shape(IconButtonShape::Square)
233                            .icon_size(IconSize::Small)
234                            .style(ButtonStyle::Subtle)
235                            .tooltip(move |cx| Tooltip::text("Open History", cx))
236                            .on_click(move |_event, _cx| {
237                                println!("Open History");
238                            }),
239                    )
240                    .child(
241                        IconButton::new("configure-assistant", IconName::Settings)
242                            .shape(IconButtonShape::Square)
243                            .icon_size(IconSize::Small)
244                            .style(ButtonStyle::Subtle)
245                            .tooltip(move |cx| Tooltip::text("Configure Assistant", cx))
246                            .on_click(move |_event, _cx| {
247                                println!("Configure Assistant");
248                            }),
249                    ),
250            )
251    }
252
253    fn render_language_model_selector(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
254        let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
255        let active_model = LanguageModelRegistry::read_global(cx).active_model();
256
257        LanguageModelSelector::new(
258            |model, _cx| {
259                println!("Selected {:?}", model.name());
260            },
261            ButtonLike::new("active-model")
262                .style(ButtonStyle::Subtle)
263                .child(
264                    h_flex()
265                        .w_full()
266                        .gap_0p5()
267                        .child(
268                            div()
269                                .overflow_x_hidden()
270                                .flex_grow()
271                                .whitespace_nowrap()
272                                .child(match (active_provider, active_model) {
273                                    (Some(provider), Some(model)) => h_flex()
274                                        .gap_1()
275                                        .child(
276                                            Icon::new(
277                                                model.icon().unwrap_or_else(|| provider.icon()),
278                                            )
279                                            .color(Color::Muted)
280                                            .size(IconSize::XSmall),
281                                        )
282                                        .child(
283                                            Label::new(model.name().0)
284                                                .size(LabelSize::Small)
285                                                .color(Color::Muted),
286                                        )
287                                        .into_any_element(),
288                                    _ => Label::new("No model selected")
289                                        .size(LabelSize::Small)
290                                        .color(Color::Muted)
291                                        .into_any_element(),
292                                }),
293                        )
294                        .child(
295                            Icon::new(IconName::ChevronDown)
296                                .color(Color::Muted)
297                                .size(IconSize::XSmall),
298                        ),
299                )
300                .tooltip(move |cx| Tooltip::for_action("Change Model", &ToggleModelSelector, cx)),
301        )
302    }
303
304    fn render_message(&self, message: Message, cx: &mut ViewContext<Self>) -> impl IntoElement {
305        let (role_icon, role_name) = match message.role {
306            Role::User => (IconName::Person, "You"),
307            Role::Assistant => (IconName::ZedAssistant, "Assistant"),
308            Role::System => (IconName::Settings, "System"),
309        };
310
311        v_flex()
312            .border_1()
313            .border_color(cx.theme().colors().border_variant)
314            .rounded_md()
315            .child(
316                h_flex()
317                    .justify_between()
318                    .p_1p5()
319                    .border_b_1()
320                    .border_color(cx.theme().colors().border_variant)
321                    .child(
322                        h_flex()
323                            .gap_2()
324                            .child(Icon::new(role_icon).size(IconSize::Small))
325                            .child(Label::new(role_name).size(LabelSize::Small)),
326                    ),
327            )
328            .child(v_flex().p_1p5().child(Label::new(message.text.clone())))
329    }
330
331    fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
332        let last_error = self.last_error.as_ref()?;
333
334        Some(
335            div()
336                .absolute()
337                .right_3()
338                .bottom_12()
339                .max_w_96()
340                .py_2()
341                .px_3()
342                .elevation_2(cx)
343                .occlude()
344                .child(match last_error {
345                    ThreadError::PaymentRequired => self.render_payment_required_error(cx),
346                    ThreadError::MaxMonthlySpendReached => {
347                        self.render_max_monthly_spend_reached_error(cx)
348                    }
349                    ThreadError::Message(error_message) => {
350                        self.render_error_message(error_message, cx)
351                    }
352                })
353                .into_any(),
354        )
355    }
356
357    fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
358        const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
359
360        v_flex()
361            .gap_0p5()
362            .child(
363                h_flex()
364                    .gap_1p5()
365                    .items_center()
366                    .child(Icon::new(IconName::XCircle).color(Color::Error))
367                    .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
368            )
369            .child(
370                div()
371                    .id("error-message")
372                    .max_h_24()
373                    .overflow_y_scroll()
374                    .child(Label::new(ERROR_MESSAGE)),
375            )
376            .child(
377                h_flex()
378                    .justify_end()
379                    .mt_1()
380                    .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
381                        |this, _, cx| {
382                            this.last_error = None;
383                            cx.open_url(&zed_urls::account_url(cx));
384                            cx.notify();
385                        },
386                    )))
387                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
388                        |this, _, cx| {
389                            this.last_error = None;
390                            cx.notify();
391                        },
392                    ))),
393            )
394            .into_any()
395    }
396
397    fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
398        const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
399
400        v_flex()
401            .gap_0p5()
402            .child(
403                h_flex()
404                    .gap_1p5()
405                    .items_center()
406                    .child(Icon::new(IconName::XCircle).color(Color::Error))
407                    .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
408            )
409            .child(
410                div()
411                    .id("error-message")
412                    .max_h_24()
413                    .overflow_y_scroll()
414                    .child(Label::new(ERROR_MESSAGE)),
415            )
416            .child(
417                h_flex()
418                    .justify_end()
419                    .mt_1()
420                    .child(
421                        Button::new("subscribe", "Update Monthly Spend Limit").on_click(
422                            cx.listener(|this, _, cx| {
423                                this.last_error = None;
424                                cx.open_url(&zed_urls::account_url(cx));
425                                cx.notify();
426                            }),
427                        ),
428                    )
429                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
430                        |this, _, cx| {
431                            this.last_error = None;
432                            cx.notify();
433                        },
434                    ))),
435            )
436            .into_any()
437    }
438
439    fn render_error_message(
440        &self,
441        error_message: &SharedString,
442        cx: &mut ViewContext<Self>,
443    ) -> AnyElement {
444        v_flex()
445            .gap_0p5()
446            .child(
447                h_flex()
448                    .gap_1p5()
449                    .items_center()
450                    .child(Icon::new(IconName::XCircle).color(Color::Error))
451                    .child(
452                        Label::new("Error interacting with language model")
453                            .weight(FontWeight::MEDIUM),
454                    ),
455            )
456            .child(
457                div()
458                    .id("error-message")
459                    .max_h_32()
460                    .overflow_y_scroll()
461                    .child(Label::new(error_message.clone())),
462            )
463            .child(
464                h_flex()
465                    .justify_end()
466                    .mt_1()
467                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
468                        |this, _, cx| {
469                            this.last_error = None;
470                            cx.notify();
471                        },
472                    ))),
473            )
474            .into_any()
475    }
476}
477
478impl Render for AssistantPanel {
479    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
480        let messages = self.thread.read(cx).messages().cloned().collect::<Vec<_>>();
481
482        v_flex()
483            .key_context("AssistantPanel2")
484            .justify_between()
485            .size_full()
486            .on_action(cx.listener(|this, _: &NewThread, cx| {
487                this.new_thread(cx);
488            }))
489            .child(self.render_toolbar(cx))
490            .child(
491                v_flex()
492                    .id("message-list")
493                    .gap_2()
494                    .size_full()
495                    .p_2()
496                    .overflow_y_scroll()
497                    .bg(cx.theme().colors().panel_background)
498                    .children(
499                        messages
500                            .into_iter()
501                            .map(|message| self.render_message(message, cx)),
502                    ),
503            )
504            .child(
505                h_flex()
506                    .border_t_1()
507                    .border_color(cx.theme().colors().border_variant)
508                    .child(self.message_editor.clone()),
509            )
510            .children(self.render_last_error(cx))
511    }
512}