assistant_panel.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use client::zed_urls;
  6use collections::HashMap;
  7use gpui::{
  8    list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter,
  9    FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels,
 10    StyleRefinement, Subscription, Task, TextStyleRefinement, View, ViewContext, WeakView,
 11    WindowContext,
 12};
 13use language::LanguageRegistry;
 14use language_model::{LanguageModelRegistry, Role};
 15use language_model_selector::LanguageModelSelector;
 16use markdown::{Markdown, MarkdownStyle};
 17use settings::Settings;
 18use theme::ThemeSettings;
 19use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip};
 20use workspace::dock::{DockPosition, Panel, PanelEvent};
 21use workspace::Workspace;
 22
 23use crate::message_editor::MessageEditor;
 24use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 25use crate::thread_store::ThreadStore;
 26use crate::{NewThread, ToggleFocus, ToggleModelSelector};
 27
 28pub fn init(cx: &mut AppContext) {
 29    cx.observe_new_views(
 30        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 31            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 32                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 33            });
 34        },
 35    )
 36    .detach();
 37}
 38
 39pub struct AssistantPanel {
 40    workspace: WeakView<Workspace>,
 41    language_registry: Arc<LanguageRegistry>,
 42    #[allow(unused)]
 43    thread_store: Model<ThreadStore>,
 44    thread: Model<Thread>,
 45    thread_messages: Vec<MessageId>,
 46    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 47    thread_list_state: ListState,
 48    message_editor: View<MessageEditor>,
 49    tools: Arc<ToolWorkingSet>,
 50    last_error: Option<ThreadError>,
 51    _subscriptions: Vec<Subscription>,
 52}
 53
 54impl AssistantPanel {
 55    pub fn load(
 56        workspace: WeakView<Workspace>,
 57        cx: AsyncWindowContext,
 58    ) -> Task<Result<View<Self>>> {
 59        cx.spawn(|mut cx| async move {
 60            let tools = Arc::new(ToolWorkingSet::default());
 61            let thread_store = workspace
 62                .update(&mut cx, |workspace, cx| {
 63                    let project = workspace.project().clone();
 64                    ThreadStore::new(project, tools.clone(), cx)
 65                })?
 66                .await?;
 67
 68            workspace.update(&mut cx, |workspace, cx| {
 69                cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
 70            })
 71        })
 72    }
 73
 74    fn new(
 75        workspace: &Workspace,
 76        thread_store: Model<ThreadStore>,
 77        tools: Arc<ToolWorkingSet>,
 78        cx: &mut ViewContext<Self>,
 79    ) -> Self {
 80        let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
 81        let subscriptions = vec![
 82            cx.observe(&thread, |_, _, cx| cx.notify()),
 83            cx.subscribe(&thread, Self::handle_thread_event),
 84        ];
 85
 86        Self {
 87            workspace: workspace.weak_handle(),
 88            language_registry: workspace.project().read(cx).languages().clone(),
 89            thread_store,
 90            thread: thread.clone(),
 91            thread_messages: Vec::new(),
 92            rendered_messages_by_id: HashMap::default(),
 93            thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 94                let this = cx.view().downgrade();
 95                move |ix, cx: &mut WindowContext| {
 96                    this.update(cx, |this, cx| this.render_message(ix, cx))
 97                        .unwrap()
 98                }
 99            }),
100            message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
101            tools,
102            last_error: None,
103            _subscriptions: subscriptions,
104        }
105    }
106
107    fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
108        let tools = self.thread.read(cx).tools().clone();
109        let thread = cx.new_model(|cx| Thread::new(tools, cx));
110        let subscriptions = vec![
111            cx.observe(&thread, |_, _, cx| cx.notify()),
112            cx.subscribe(&thread, Self::handle_thread_event),
113        ];
114
115        self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
116        self.thread = thread;
117        self.thread_messages.clear();
118        self.thread_list_state.reset(0);
119        self.rendered_messages_by_id.clear();
120        self._subscriptions = subscriptions;
121
122        self.message_editor.focus_handle(cx).focus(cx);
123    }
124
125    fn handle_thread_event(
126        &mut self,
127        _: Model<Thread>,
128        event: &ThreadEvent,
129        cx: &mut ViewContext<Self>,
130    ) {
131        match event {
132            ThreadEvent::ShowError(error) => {
133                self.last_error = Some(error.clone());
134            }
135            ThreadEvent::StreamedCompletion => {}
136            ThreadEvent::StreamedAssistantText(message_id, text) => {
137                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
138                    markdown.update(cx, |markdown, cx| {
139                        markdown.append(text, cx);
140                    });
141                }
142            }
143            ThreadEvent::MessageAdded(message_id) => {
144                let old_len = self.thread_messages.len();
145                self.thread_messages.push(*message_id);
146                self.thread_list_state.splice(old_len..old_len, 1);
147
148                if let Some(message_text) = self
149                    .thread
150                    .read(cx)
151                    .message(*message_id)
152                    .map(|message| message.text.clone())
153                {
154                    let theme_settings = ThemeSettings::get_global(cx);
155                    let ui_font_size = TextSize::Default.rems(cx);
156                    let buffer_font_size = theme_settings.buffer_font_size;
157
158                    let mut text_style = cx.text_style();
159                    text_style.refine(&TextStyleRefinement {
160                        font_family: Some(theme_settings.ui_font.family.clone()),
161                        font_size: Some(ui_font_size.into()),
162                        color: Some(cx.theme().colors().text),
163                        ..Default::default()
164                    });
165
166                    let markdown_style = MarkdownStyle {
167                        base_text_style: text_style,
168                        syntax: cx.theme().syntax().clone(),
169                        selection_background_color: cx.theme().players().local().selection,
170                        code_block: StyleRefinement {
171                            text: Some(TextStyleRefinement {
172                                font_family: Some(theme_settings.buffer_font.family.clone()),
173                                font_size: Some(buffer_font_size.into()),
174                                ..Default::default()
175                            }),
176                            ..Default::default()
177                        },
178                        inline_code: TextStyleRefinement {
179                            font_family: Some(theme_settings.buffer_font.family.clone()),
180                            font_size: Some(ui_font_size.into()),
181                            background_color: Some(cx.theme().colors().editor_background),
182                            ..Default::default()
183                        },
184                        ..Default::default()
185                    };
186
187                    let markdown = cx.new_view(|cx| {
188                        Markdown::new(
189                            message_text,
190                            markdown_style,
191                            Some(self.language_registry.clone()),
192                            None,
193                            cx,
194                        )
195                    });
196                    self.rendered_messages_by_id.insert(*message_id, markdown);
197                }
198
199                cx.notify();
200            }
201            ThreadEvent::UsePendingTools => {
202                let pending_tool_uses = self
203                    .thread
204                    .read(cx)
205                    .pending_tool_uses()
206                    .into_iter()
207                    .filter(|tool_use| tool_use.status.is_idle())
208                    .cloned()
209                    .collect::<Vec<_>>();
210
211                for tool_use in pending_tool_uses {
212                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
213                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
214
215                        self.thread.update(cx, |thread, cx| {
216                            thread.insert_tool_output(
217                                tool_use.assistant_message_id,
218                                tool_use.id.clone(),
219                                task,
220                                cx,
221                            );
222                        });
223                    }
224                }
225            }
226            ThreadEvent::ToolFinished { .. } => {}
227        }
228    }
229}
230
231impl FocusableView for AssistantPanel {
232    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
233        self.message_editor.focus_handle(cx)
234    }
235}
236
237impl EventEmitter<PanelEvent> for AssistantPanel {}
238
239impl Panel for AssistantPanel {
240    fn persistent_name() -> &'static str {
241        "AssistantPanel2"
242    }
243
244    fn position(&self, _cx: &WindowContext) -> DockPosition {
245        DockPosition::Right
246    }
247
248    fn position_is_valid(&self, _: DockPosition) -> bool {
249        true
250    }
251
252    fn set_position(&mut self, _position: DockPosition, _cx: &mut ViewContext<Self>) {}
253
254    fn size(&self, _cx: &WindowContext) -> Pixels {
255        px(640.)
256    }
257
258    fn set_size(&mut self, _size: Option<Pixels>, _cx: &mut ViewContext<Self>) {}
259
260    fn set_active(&mut self, _active: bool, _cx: &mut ViewContext<Self>) {}
261
262    fn remote_id() -> Option<proto::PanelId> {
263        Some(proto::PanelId::AssistantPanel)
264    }
265
266    fn icon(&self, _cx: &WindowContext) -> Option<IconName> {
267        Some(IconName::ZedAssistant)
268    }
269
270    fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
271        Some("Assistant Panel")
272    }
273
274    fn toggle_action(&self) -> Box<dyn Action> {
275        Box::new(ToggleFocus)
276    }
277}
278
279impl AssistantPanel {
280    fn render_toolbar(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
281        let focus_handle = self.focus_handle(cx);
282
283        h_flex()
284            .id("assistant-toolbar")
285            .justify_between()
286            .gap(DynamicSpacing::Base08.rems(cx))
287            .h(Tab::container_height(cx))
288            .px(DynamicSpacing::Base08.rems(cx))
289            .bg(cx.theme().colors().tab_bar_background)
290            .border_b_1()
291            .border_color(cx.theme().colors().border_variant)
292            .child(h_flex().child(Label::new("Thread Title Goes Here")))
293            .child(
294                h_flex()
295                    .gap(DynamicSpacing::Base08.rems(cx))
296                    .child(self.render_language_model_selector(cx))
297                    .child(Divider::vertical())
298                    .child(
299                        IconButton::new("new-thread", IconName::Plus)
300                            .shape(IconButtonShape::Square)
301                            .icon_size(IconSize::Small)
302                            .style(ButtonStyle::Subtle)
303                            .tooltip({
304                                let focus_handle = focus_handle.clone();
305                                move |cx| {
306                                    Tooltip::for_action_in(
307                                        "New Thread",
308                                        &NewThread,
309                                        &focus_handle,
310                                        cx,
311                                    )
312                                }
313                            })
314                            .on_click(move |_event, _cx| {
315                                println!("New Thread");
316                            }),
317                    )
318                    .child(
319                        IconButton::new("open-history", IconName::HistoryRerun)
320                            .shape(IconButtonShape::Square)
321                            .icon_size(IconSize::Small)
322                            .style(ButtonStyle::Subtle)
323                            .tooltip(move |cx| Tooltip::text("Open History", cx))
324                            .on_click(move |_event, _cx| {
325                                println!("Open History");
326                            }),
327                    )
328                    .child(
329                        IconButton::new("configure-assistant", IconName::Settings)
330                            .shape(IconButtonShape::Square)
331                            .icon_size(IconSize::Small)
332                            .style(ButtonStyle::Subtle)
333                            .tooltip(move |cx| Tooltip::text("Configure Assistant", cx))
334                            .on_click(move |_event, _cx| {
335                                println!("Configure Assistant");
336                            }),
337                    ),
338            )
339    }
340
341    fn render_language_model_selector(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
342        let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
343        let active_model = LanguageModelRegistry::read_global(cx).active_model();
344
345        LanguageModelSelector::new(
346            |model, _cx| {
347                println!("Selected {:?}", model.name());
348            },
349            ButtonLike::new("active-model")
350                .style(ButtonStyle::Subtle)
351                .child(
352                    h_flex()
353                        .w_full()
354                        .gap_0p5()
355                        .child(
356                            div()
357                                .overflow_x_hidden()
358                                .flex_grow()
359                                .whitespace_nowrap()
360                                .child(match (active_provider, active_model) {
361                                    (Some(provider), Some(model)) => h_flex()
362                                        .gap_1()
363                                        .child(
364                                            Icon::new(
365                                                model.icon().unwrap_or_else(|| provider.icon()),
366                                            )
367                                            .color(Color::Muted)
368                                            .size(IconSize::XSmall),
369                                        )
370                                        .child(
371                                            Label::new(model.name().0)
372                                                .size(LabelSize::Small)
373                                                .color(Color::Muted),
374                                        )
375                                        .into_any_element(),
376                                    _ => Label::new("No model selected")
377                                        .size(LabelSize::Small)
378                                        .color(Color::Muted)
379                                        .into_any_element(),
380                                }),
381                        )
382                        .child(
383                            Icon::new(IconName::ChevronDown)
384                                .color(Color::Muted)
385                                .size(IconSize::XSmall),
386                        ),
387                )
388                .tooltip(move |cx| Tooltip::for_action("Change Model", &ToggleModelSelector, cx)),
389        )
390    }
391
392    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
393        let message_id = self.thread_messages[ix];
394        let Some(message) = self.thread.read(cx).message(message_id) else {
395            return Empty.into_any();
396        };
397
398        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
399            return Empty.into_any();
400        };
401
402        let (role_icon, role_name) = match message.role {
403            Role::User => (IconName::Person, "You"),
404            Role::Assistant => (IconName::ZedAssistant, "Assistant"),
405            Role::System => (IconName::Settings, "System"),
406        };
407
408        div()
409            .id(("message-container", ix))
410            .p_2()
411            .child(
412                v_flex()
413                    .border_1()
414                    .border_color(cx.theme().colors().border_variant)
415                    .rounded_md()
416                    .child(
417                        h_flex()
418                            .justify_between()
419                            .p_1p5()
420                            .border_b_1()
421                            .border_color(cx.theme().colors().border_variant)
422                            .child(
423                                h_flex()
424                                    .gap_2()
425                                    .child(Icon::new(role_icon).size(IconSize::Small))
426                                    .child(Label::new(role_name).size(LabelSize::Small)),
427                            ),
428                    )
429                    .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
430            )
431            .into_any()
432    }
433
434    fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
435        let last_error = self.last_error.as_ref()?;
436
437        Some(
438            div()
439                .absolute()
440                .right_3()
441                .bottom_12()
442                .max_w_96()
443                .py_2()
444                .px_3()
445                .elevation_2(cx)
446                .occlude()
447                .child(match last_error {
448                    ThreadError::PaymentRequired => self.render_payment_required_error(cx),
449                    ThreadError::MaxMonthlySpendReached => {
450                        self.render_max_monthly_spend_reached_error(cx)
451                    }
452                    ThreadError::Message(error_message) => {
453                        self.render_error_message(error_message, cx)
454                    }
455                })
456                .into_any(),
457        )
458    }
459
460    fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
461        const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
462
463        v_flex()
464            .gap_0p5()
465            .child(
466                h_flex()
467                    .gap_1p5()
468                    .items_center()
469                    .child(Icon::new(IconName::XCircle).color(Color::Error))
470                    .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
471            )
472            .child(
473                div()
474                    .id("error-message")
475                    .max_h_24()
476                    .overflow_y_scroll()
477                    .child(Label::new(ERROR_MESSAGE)),
478            )
479            .child(
480                h_flex()
481                    .justify_end()
482                    .mt_1()
483                    .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
484                        |this, _, cx| {
485                            this.last_error = None;
486                            cx.open_url(&zed_urls::account_url(cx));
487                            cx.notify();
488                        },
489                    )))
490                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
491                        |this, _, cx| {
492                            this.last_error = None;
493                            cx.notify();
494                        },
495                    ))),
496            )
497            .into_any()
498    }
499
500    fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
501        const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
502
503        v_flex()
504            .gap_0p5()
505            .child(
506                h_flex()
507                    .gap_1p5()
508                    .items_center()
509                    .child(Icon::new(IconName::XCircle).color(Color::Error))
510                    .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
511            )
512            .child(
513                div()
514                    .id("error-message")
515                    .max_h_24()
516                    .overflow_y_scroll()
517                    .child(Label::new(ERROR_MESSAGE)),
518            )
519            .child(
520                h_flex()
521                    .justify_end()
522                    .mt_1()
523                    .child(
524                        Button::new("subscribe", "Update Monthly Spend Limit").on_click(
525                            cx.listener(|this, _, cx| {
526                                this.last_error = None;
527                                cx.open_url(&zed_urls::account_url(cx));
528                                cx.notify();
529                            }),
530                        ),
531                    )
532                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
533                        |this, _, cx| {
534                            this.last_error = None;
535                            cx.notify();
536                        },
537                    ))),
538            )
539            .into_any()
540    }
541
542    fn render_error_message(
543        &self,
544        error_message: &SharedString,
545        cx: &mut ViewContext<Self>,
546    ) -> AnyElement {
547        v_flex()
548            .gap_0p5()
549            .child(
550                h_flex()
551                    .gap_1p5()
552                    .items_center()
553                    .child(Icon::new(IconName::XCircle).color(Color::Error))
554                    .child(
555                        Label::new("Error interacting with language model")
556                            .weight(FontWeight::MEDIUM),
557                    ),
558            )
559            .child(
560                div()
561                    .id("error-message")
562                    .max_h_32()
563                    .overflow_y_scroll()
564                    .child(Label::new(error_message.clone())),
565            )
566            .child(
567                h_flex()
568                    .justify_end()
569                    .mt_1()
570                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
571                        |this, _, cx| {
572                            this.last_error = None;
573                            cx.notify();
574                        },
575                    ))),
576            )
577            .into_any()
578    }
579}
580
581impl Render for AssistantPanel {
582    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
583        v_flex()
584            .key_context("AssistantPanel2")
585            .justify_between()
586            .size_full()
587            .on_action(cx.listener(|this, _: &NewThread, cx| {
588                this.new_thread(cx);
589            }))
590            .child(self.render_toolbar(cx))
591            .child(list(self.thread_list_state.clone()).flex_1())
592            .child(
593                h_flex()
594                    .border_t_1()
595                    .border_color(cx.theme().colors().border_variant)
596                    .child(self.message_editor.clone()),
597            )
598            .children(self.render_last_error(cx))
599    }
600}