assistant_panel.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use client::zed_urls;
  6use gpui::{
  7    list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter,
  8    FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, Subscription,
  9    Task, View, ViewContext, WeakView, WindowContext,
 10};
 11use language_model::{LanguageModelRegistry, Role};
 12use language_model_selector::LanguageModelSelector;
 13use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip};
 14use workspace::dock::{DockPosition, Panel, PanelEvent};
 15use workspace::Workspace;
 16
 17use crate::message_editor::MessageEditor;
 18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::{NewThread, ToggleFocus, ToggleModelSelector};
 21
 22pub fn init(cx: &mut AppContext) {
 23    cx.observe_new_views(
 24        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 25            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 26                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 27            });
 28        },
 29    )
 30    .detach();
 31}
 32
 33pub struct AssistantPanel {
 34    workspace: WeakView<Workspace>,
 35    #[allow(unused)]
 36    thread_store: Model<ThreadStore>,
 37    thread: Model<Thread>,
 38    thread_messages: Vec<MessageId>,
 39    thread_list_state: ListState,
 40    message_editor: View<MessageEditor>,
 41    tools: Arc<ToolWorkingSet>,
 42    last_error: Option<ThreadError>,
 43    _subscriptions: Vec<Subscription>,
 44}
 45
 46impl AssistantPanel {
 47    pub fn load(
 48        workspace: WeakView<Workspace>,
 49        cx: AsyncWindowContext,
 50    ) -> Task<Result<View<Self>>> {
 51        cx.spawn(|mut cx| async move {
 52            let tools = Arc::new(ToolWorkingSet::default());
 53            let thread_store = workspace
 54                .update(&mut cx, |workspace, cx| {
 55                    let project = workspace.project().clone();
 56                    ThreadStore::new(project, tools.clone(), cx)
 57                })?
 58                .await?;
 59
 60            workspace.update(&mut cx, |workspace, cx| {
 61                cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
 62            })
 63        })
 64    }
 65
 66    fn new(
 67        workspace: &Workspace,
 68        thread_store: Model<ThreadStore>,
 69        tools: Arc<ToolWorkingSet>,
 70        cx: &mut ViewContext<Self>,
 71    ) -> Self {
 72        let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
 73        let subscriptions = vec![
 74            cx.observe(&thread, |_, _, cx| cx.notify()),
 75            cx.subscribe(&thread, Self::handle_thread_event),
 76        ];
 77
 78        Self {
 79            workspace: workspace.weak_handle(),
 80            thread_store,
 81            thread: thread.clone(),
 82            thread_messages: Vec::new(),
 83            thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 84                let this = cx.view().downgrade();
 85                move |ix, cx: &mut WindowContext| {
 86                    this.update(cx, |this, cx| this.render_message(ix, cx))
 87                        .unwrap()
 88                }
 89            }),
 90            message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
 91            tools,
 92            last_error: None,
 93            _subscriptions: subscriptions,
 94        }
 95    }
 96
 97    fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
 98        let tools = self.thread.read(cx).tools().clone();
 99        let thread = cx.new_model(|cx| Thread::new(tools, cx));
100        let subscriptions = vec![
101            cx.observe(&thread, |_, _, cx| cx.notify()),
102            cx.subscribe(&thread, Self::handle_thread_event),
103        ];
104
105        self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
106        self.thread = thread;
107        self._subscriptions = subscriptions;
108
109        self.message_editor.focus_handle(cx).focus(cx);
110    }
111
112    fn handle_thread_event(
113        &mut self,
114        _: Model<Thread>,
115        event: &ThreadEvent,
116        cx: &mut ViewContext<Self>,
117    ) {
118        match event {
119            ThreadEvent::ShowError(error) => {
120                self.last_error = Some(error.clone());
121            }
122            ThreadEvent::StreamedCompletion => {}
123            ThreadEvent::MessageAdded(message_id) => {
124                let old_len = self.thread_messages.len();
125                self.thread_messages.push(*message_id);
126                self.thread_list_state.splice(old_len..old_len, 1);
127                cx.notify();
128            }
129            ThreadEvent::UsePendingTools => {
130                let pending_tool_uses = self
131                    .thread
132                    .read(cx)
133                    .pending_tool_uses()
134                    .into_iter()
135                    .filter(|tool_use| tool_use.status.is_idle())
136                    .cloned()
137                    .collect::<Vec<_>>();
138
139                for tool_use in pending_tool_uses {
140                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
141                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
142
143                        self.thread.update(cx, |thread, cx| {
144                            thread.insert_tool_output(
145                                tool_use.assistant_message_id,
146                                tool_use.id.clone(),
147                                task,
148                                cx,
149                            );
150                        });
151                    }
152                }
153            }
154            ThreadEvent::ToolFinished { .. } => {}
155        }
156    }
157}
158
159impl FocusableView for AssistantPanel {
160    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
161        self.message_editor.focus_handle(cx)
162    }
163}
164
165impl EventEmitter<PanelEvent> for AssistantPanel {}
166
167impl Panel for AssistantPanel {
168    fn persistent_name() -> &'static str {
169        "AssistantPanel2"
170    }
171
172    fn position(&self, _cx: &WindowContext) -> DockPosition {
173        DockPosition::Right
174    }
175
176    fn position_is_valid(&self, _: DockPosition) -> bool {
177        true
178    }
179
180    fn set_position(&mut self, _position: DockPosition, _cx: &mut ViewContext<Self>) {}
181
182    fn size(&self, _cx: &WindowContext) -> Pixels {
183        px(640.)
184    }
185
186    fn set_size(&mut self, _size: Option<Pixels>, _cx: &mut ViewContext<Self>) {}
187
188    fn set_active(&mut self, _active: bool, _cx: &mut ViewContext<Self>) {}
189
190    fn remote_id() -> Option<proto::PanelId> {
191        Some(proto::PanelId::AssistantPanel)
192    }
193
194    fn icon(&self, _cx: &WindowContext) -> Option<IconName> {
195        Some(IconName::ZedAssistant)
196    }
197
198    fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
199        Some("Assistant Panel")
200    }
201
202    fn toggle_action(&self) -> Box<dyn Action> {
203        Box::new(ToggleFocus)
204    }
205}
206
207impl AssistantPanel {
208    fn render_toolbar(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
209        let focus_handle = self.focus_handle(cx);
210
211        h_flex()
212            .id("assistant-toolbar")
213            .justify_between()
214            .gap(DynamicSpacing::Base08.rems(cx))
215            .h(Tab::container_height(cx))
216            .px(DynamicSpacing::Base08.rems(cx))
217            .bg(cx.theme().colors().tab_bar_background)
218            .border_b_1()
219            .border_color(cx.theme().colors().border_variant)
220            .child(h_flex().child(Label::new("Thread Title Goes Here")))
221            .child(
222                h_flex()
223                    .gap(DynamicSpacing::Base08.rems(cx))
224                    .child(self.render_language_model_selector(cx))
225                    .child(Divider::vertical())
226                    .child(
227                        IconButton::new("new-thread", IconName::Plus)
228                            .shape(IconButtonShape::Square)
229                            .icon_size(IconSize::Small)
230                            .style(ButtonStyle::Subtle)
231                            .tooltip({
232                                let focus_handle = focus_handle.clone();
233                                move |cx| {
234                                    Tooltip::for_action_in(
235                                        "New Thread",
236                                        &NewThread,
237                                        &focus_handle,
238                                        cx,
239                                    )
240                                }
241                            })
242                            .on_click(move |_event, _cx| {
243                                println!("New Thread");
244                            }),
245                    )
246                    .child(
247                        IconButton::new("open-history", IconName::HistoryRerun)
248                            .shape(IconButtonShape::Square)
249                            .icon_size(IconSize::Small)
250                            .style(ButtonStyle::Subtle)
251                            .tooltip(move |cx| Tooltip::text("Open History", cx))
252                            .on_click(move |_event, _cx| {
253                                println!("Open History");
254                            }),
255                    )
256                    .child(
257                        IconButton::new("configure-assistant", IconName::Settings)
258                            .shape(IconButtonShape::Square)
259                            .icon_size(IconSize::Small)
260                            .style(ButtonStyle::Subtle)
261                            .tooltip(move |cx| Tooltip::text("Configure Assistant", cx))
262                            .on_click(move |_event, _cx| {
263                                println!("Configure Assistant");
264                            }),
265                    ),
266            )
267    }
268
269    fn render_language_model_selector(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
270        let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
271        let active_model = LanguageModelRegistry::read_global(cx).active_model();
272
273        LanguageModelSelector::new(
274            |model, _cx| {
275                println!("Selected {:?}", model.name());
276            },
277            ButtonLike::new("active-model")
278                .style(ButtonStyle::Subtle)
279                .child(
280                    h_flex()
281                        .w_full()
282                        .gap_0p5()
283                        .child(
284                            div()
285                                .overflow_x_hidden()
286                                .flex_grow()
287                                .whitespace_nowrap()
288                                .child(match (active_provider, active_model) {
289                                    (Some(provider), Some(model)) => h_flex()
290                                        .gap_1()
291                                        .child(
292                                            Icon::new(
293                                                model.icon().unwrap_or_else(|| provider.icon()),
294                                            )
295                                            .color(Color::Muted)
296                                            .size(IconSize::XSmall),
297                                        )
298                                        .child(
299                                            Label::new(model.name().0)
300                                                .size(LabelSize::Small)
301                                                .color(Color::Muted),
302                                        )
303                                        .into_any_element(),
304                                    _ => Label::new("No model selected")
305                                        .size(LabelSize::Small)
306                                        .color(Color::Muted)
307                                        .into_any_element(),
308                                }),
309                        )
310                        .child(
311                            Icon::new(IconName::ChevronDown)
312                                .color(Color::Muted)
313                                .size(IconSize::XSmall),
314                        ),
315                )
316                .tooltip(move |cx| Tooltip::for_action("Change Model", &ToggleModelSelector, cx)),
317        )
318    }
319
320    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
321        let message_id = self.thread_messages[ix];
322        let Some(message) = self.thread.read(cx).message(message_id) else {
323            return Empty.into_any();
324        };
325
326        let (role_icon, role_name) = match message.role {
327            Role::User => (IconName::Person, "You"),
328            Role::Assistant => (IconName::ZedAssistant, "Assistant"),
329            Role::System => (IconName::Settings, "System"),
330        };
331
332        div()
333            .id(("message-container", ix))
334            .p_2()
335            .child(
336                v_flex()
337                    .border_1()
338                    .border_color(cx.theme().colors().border_variant)
339                    .rounded_md()
340                    .child(
341                        h_flex()
342                            .justify_between()
343                            .p_1p5()
344                            .border_b_1()
345                            .border_color(cx.theme().colors().border_variant)
346                            .child(
347                                h_flex()
348                                    .gap_2()
349                                    .child(Icon::new(role_icon).size(IconSize::Small))
350                                    .child(Label::new(role_name).size(LabelSize::Small)),
351                            ),
352                    )
353                    .child(v_flex().p_1p5().child(Label::new(message.text.clone()))),
354            )
355            .into_any()
356    }
357
358    fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
359        let last_error = self.last_error.as_ref()?;
360
361        Some(
362            div()
363                .absolute()
364                .right_3()
365                .bottom_12()
366                .max_w_96()
367                .py_2()
368                .px_3()
369                .elevation_2(cx)
370                .occlude()
371                .child(match last_error {
372                    ThreadError::PaymentRequired => self.render_payment_required_error(cx),
373                    ThreadError::MaxMonthlySpendReached => {
374                        self.render_max_monthly_spend_reached_error(cx)
375                    }
376                    ThreadError::Message(error_message) => {
377                        self.render_error_message(error_message, cx)
378                    }
379                })
380                .into_any(),
381        )
382    }
383
384    fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
385        const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
386
387        v_flex()
388            .gap_0p5()
389            .child(
390                h_flex()
391                    .gap_1p5()
392                    .items_center()
393                    .child(Icon::new(IconName::XCircle).color(Color::Error))
394                    .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
395            )
396            .child(
397                div()
398                    .id("error-message")
399                    .max_h_24()
400                    .overflow_y_scroll()
401                    .child(Label::new(ERROR_MESSAGE)),
402            )
403            .child(
404                h_flex()
405                    .justify_end()
406                    .mt_1()
407                    .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
408                        |this, _, cx| {
409                            this.last_error = None;
410                            cx.open_url(&zed_urls::account_url(cx));
411                            cx.notify();
412                        },
413                    )))
414                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
415                        |this, _, cx| {
416                            this.last_error = None;
417                            cx.notify();
418                        },
419                    ))),
420            )
421            .into_any()
422    }
423
424    fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
425        const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
426
427        v_flex()
428            .gap_0p5()
429            .child(
430                h_flex()
431                    .gap_1p5()
432                    .items_center()
433                    .child(Icon::new(IconName::XCircle).color(Color::Error))
434                    .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
435            )
436            .child(
437                div()
438                    .id("error-message")
439                    .max_h_24()
440                    .overflow_y_scroll()
441                    .child(Label::new(ERROR_MESSAGE)),
442            )
443            .child(
444                h_flex()
445                    .justify_end()
446                    .mt_1()
447                    .child(
448                        Button::new("subscribe", "Update Monthly Spend Limit").on_click(
449                            cx.listener(|this, _, cx| {
450                                this.last_error = None;
451                                cx.open_url(&zed_urls::account_url(cx));
452                                cx.notify();
453                            }),
454                        ),
455                    )
456                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
457                        |this, _, cx| {
458                            this.last_error = None;
459                            cx.notify();
460                        },
461                    ))),
462            )
463            .into_any()
464    }
465
466    fn render_error_message(
467        &self,
468        error_message: &SharedString,
469        cx: &mut ViewContext<Self>,
470    ) -> AnyElement {
471        v_flex()
472            .gap_0p5()
473            .child(
474                h_flex()
475                    .gap_1p5()
476                    .items_center()
477                    .child(Icon::new(IconName::XCircle).color(Color::Error))
478                    .child(
479                        Label::new("Error interacting with language model")
480                            .weight(FontWeight::MEDIUM),
481                    ),
482            )
483            .child(
484                div()
485                    .id("error-message")
486                    .max_h_32()
487                    .overflow_y_scroll()
488                    .child(Label::new(error_message.clone())),
489            )
490            .child(
491                h_flex()
492                    .justify_end()
493                    .mt_1()
494                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
495                        |this, _, cx| {
496                            this.last_error = None;
497                            cx.notify();
498                        },
499                    ))),
500            )
501            .into_any()
502    }
503}
504
505impl Render for AssistantPanel {
506    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
507        v_flex()
508            .key_context("AssistantPanel2")
509            .justify_between()
510            .size_full()
511            .on_action(cx.listener(|this, _: &NewThread, cx| {
512                this.new_thread(cx);
513            }))
514            .child(self.render_toolbar(cx))
515            .child(list(self.thread_list_state.clone()).flex_1())
516            .child(
517                h_flex()
518                    .border_t_1()
519                    .border_color(cx.theme().colors().border_variant)
520                    .child(self.message_editor.clone()),
521            )
522            .children(self.render_last_error(cx))
523    }
524}