agent_configuration.rs

  1mod add_context_server_modal;
  2mod configure_context_server_modal;
  3mod manage_profiles_modal;
  4mod tool_picker;
  5
  6use std::{sync::Arc, time::Duration};
  7
  8use assistant_settings::AssistantSettings;
  9use assistant_tool::{ToolSource, ToolWorkingSet};
 10use collections::HashMap;
 11use context_server::ContextServerId;
 12use fs::Fs;
 13use gpui::{
 14    Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
 15    Focusable, ScrollHandle, Subscription, pulsating_between,
 16};
 17use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 18use project::context_server_store::{ContextServerStatus, ContextServerStore};
 19use settings::{Settings, update_settings_file};
 20use ui::{
 21    Disclosure, ElevationIndex, Indicator, Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip,
 22    prelude::*,
 23};
 24use util::ResultExt as _;
 25use zed_actions::ExtensionCategoryFilter;
 26
 27pub(crate) use add_context_server_modal::AddContextServerModal;
 28pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
 29pub(crate) use manage_profiles_modal::ManageProfilesModal;
 30
 31use crate::AddContextServer;
 32
 33pub struct AgentConfiguration {
 34    fs: Arc<dyn Fs>,
 35    focus_handle: FocusHandle,
 36    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 37    context_server_store: Entity<ContextServerStore>,
 38    expanded_context_server_tools: HashMap<ContextServerId, bool>,
 39    expanded_provider_configurations: HashMap<LanguageModelProviderId, bool>,
 40    tools: Entity<ToolWorkingSet>,
 41    _registry_subscription: Subscription,
 42    scroll_handle: ScrollHandle,
 43    scrollbar_state: ScrollbarState,
 44}
 45
 46impl AgentConfiguration {
 47    pub fn new(
 48        fs: Arc<dyn Fs>,
 49        context_server_store: Entity<ContextServerStore>,
 50        tools: Entity<ToolWorkingSet>,
 51        window: &mut Window,
 52        cx: &mut Context<Self>,
 53    ) -> Self {
 54        let focus_handle = cx.focus_handle();
 55
 56        let registry_subscription = cx.subscribe_in(
 57            &LanguageModelRegistry::global(cx),
 58            window,
 59            |this, _, event: &language_model::Event, window, cx| match event {
 60                language_model::Event::AddedProvider(provider_id) => {
 61                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 62                    if let Some(provider) = provider {
 63                        this.add_provider_configuration_view(&provider, window, cx);
 64                    }
 65                }
 66                language_model::Event::RemovedProvider(provider_id) => {
 67                    this.remove_provider_configuration_view(provider_id);
 68                }
 69                _ => {}
 70            },
 71        );
 72
 73        let scroll_handle = ScrollHandle::new();
 74        let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
 75
 76        let mut this = Self {
 77            fs,
 78            focus_handle,
 79            configuration_views_by_provider: HashMap::default(),
 80            context_server_store,
 81            expanded_context_server_tools: HashMap::default(),
 82            expanded_provider_configurations: HashMap::default(),
 83            tools,
 84            _registry_subscription: registry_subscription,
 85            scroll_handle,
 86            scrollbar_state,
 87        };
 88        this.build_provider_configuration_views(window, cx);
 89        this
 90    }
 91
 92    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 93        let providers = LanguageModelRegistry::read_global(cx).providers();
 94        for provider in providers {
 95            self.add_provider_configuration_view(&provider, window, cx);
 96        }
 97    }
 98
 99    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
100        self.configuration_views_by_provider.remove(provider_id);
101        self.expanded_provider_configurations.remove(provider_id);
102    }
103
104    fn add_provider_configuration_view(
105        &mut self,
106        provider: &Arc<dyn LanguageModelProvider>,
107        window: &mut Window,
108        cx: &mut Context<Self>,
109    ) {
110        let configuration_view = provider.configuration_view(window, cx);
111        self.configuration_views_by_provider
112            .insert(provider.id(), configuration_view);
113    }
114}
115
116impl Focusable for AgentConfiguration {
117    fn focus_handle(&self, _: &App) -> FocusHandle {
118        self.focus_handle.clone()
119    }
120}
121
122pub enum AssistantConfigurationEvent {
123    NewThread(Arc<dyn LanguageModelProvider>),
124}
125
126impl EventEmitter<AssistantConfigurationEvent> for AgentConfiguration {}
127
128impl AgentConfiguration {
129    fn render_provider_configuration_block(
130        &mut self,
131        provider: &Arc<dyn LanguageModelProvider>,
132        cx: &mut Context<Self>,
133    ) -> impl IntoElement + use<> {
134        let provider_id = provider.id().0.clone();
135        let provider_name = provider.name().0.clone();
136        let configuration_view = self
137            .configuration_views_by_provider
138            .get(&provider.id())
139            .cloned();
140
141        let is_expanded = self
142            .expanded_provider_configurations
143            .get(&provider.id())
144            .copied()
145            .unwrap_or(false);
146
147        v_flex()
148            .pt_3()
149            .gap_1p5()
150            .border_t_1()
151            .border_color(cx.theme().colors().border.opacity(0.6))
152            .child(
153                h_flex()
154                    .justify_between()
155                    .child(
156                        h_flex()
157                            .gap_2()
158                            .child(
159                                Icon::new(provider.icon())
160                                    .size(IconSize::Small)
161                                    .color(Color::Muted),
162                            )
163                            .child(Label::new(provider_name.clone()).size(LabelSize::Large))
164                            .when(provider.is_authenticated(cx) && !is_expanded, |parent| {
165                                parent.child(Icon::new(IconName::Check).color(Color::Success))
166                            }),
167                    )
168                    .child(
169                        h_flex()
170                            .gap_1()
171                            .when(provider.is_authenticated(cx), |parent| {
172                                parent.child(
173                                    Button::new(
174                                        SharedString::from(format!("new-thread-{provider_id}")),
175                                        "Start New Thread",
176                                    )
177                                    .icon_position(IconPosition::Start)
178                                    .icon(IconName::Plus)
179                                    .icon_size(IconSize::Small)
180                                    .layer(ElevationIndex::ModalSurface)
181                                    .label_size(LabelSize::Small)
182                                    .on_click(cx.listener({
183                                        let provider = provider.clone();
184                                        move |_this, _event, _window, cx| {
185                                            cx.emit(AssistantConfigurationEvent::NewThread(
186                                                provider.clone(),
187                                            ))
188                                        }
189                                    })),
190                                )
191                            })
192                            .child(
193                                Disclosure::new(
194                                    SharedString::from(format!(
195                                        "provider-disclosure-{provider_id}"
196                                    )),
197                                    is_expanded,
198                                )
199                                .opened_icon(IconName::ChevronUp)
200                                .closed_icon(IconName::ChevronDown)
201                                .on_click(cx.listener({
202                                    let provider_id = provider.id().clone();
203                                    move |this, _event, _window, _cx| {
204                                        let is_expanded = this
205                                            .expanded_provider_configurations
206                                            .entry(provider_id.clone())
207                                            .or_insert(false);
208
209                                        *is_expanded = !*is_expanded;
210                                    }
211                                })),
212                            ),
213                    ),
214            )
215            .when(is_expanded, |parent| match configuration_view {
216                Some(configuration_view) => parent.child(configuration_view),
217                None => parent.child(Label::new(format!(
218                    "No configuration view for {provider_name}",
219                ))),
220            })
221    }
222
223    fn render_provider_configuration_section(
224        &mut self,
225        cx: &mut Context<Self>,
226    ) -> impl IntoElement {
227        let providers = LanguageModelRegistry::read_global(cx).providers();
228
229        v_flex()
230            .p(DynamicSpacing::Base16.rems(cx))
231            .pr(DynamicSpacing::Base20.rems(cx))
232            .gap_4()
233            .border_b_1()
234            .border_color(cx.theme().colors().border)
235            .child(
236                v_flex()
237                    .gap_0p5()
238                    .child(Headline::new("LLM Providers"))
239                    .child(
240                        Label::new("Add at least one provider to use AI-powered features.")
241                            .color(Color::Muted),
242                    ),
243            )
244            .children(
245                providers
246                    .into_iter()
247                    .map(|provider| self.render_provider_configuration_block(&provider, cx)),
248            )
249    }
250
251    fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
252        let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
253
254        h_flex()
255            .gap_4()
256            .justify_between()
257            .flex_wrap()
258            .child(
259                v_flex()
260                    .gap_0p5()
261                    .max_w_5_6()
262                    .child(Label::new("Allow running editing tools without asking for confirmation"))
263                    .child(
264                        Label::new(
265                            "The agent can perform potentially destructive actions without asking for your confirmation.",
266                        )
267                        .color(Color::Muted),
268                    ),
269            )
270            .child(
271                Switch::new(
272                    "always-allow-tool-actions-switch",
273                    always_allow_tool_actions.into(),
274                )
275                .color(SwitchColor::Accent)
276                .on_click({
277                    let fs = self.fs.clone();
278                    move |state, _window, cx| {
279                        let allow = state == &ToggleState::Selected;
280                        update_settings_file::<AssistantSettings>(
281                            fs.clone(),
282                            cx,
283                            move |settings, _| {
284                                settings.set_always_allow_tool_actions(allow);
285                            },
286                        );
287                    }
288                }),
289            )
290    }
291
292    fn render_single_file_review(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
293        let single_file_review = AssistantSettings::get_global(cx).single_file_review;
294
295        h_flex()
296            .gap_4()
297            .justify_between()
298            .flex_wrap()
299            .child(
300                v_flex()
301                    .gap_0p5()
302                    .max_w_5_6()
303                    .child(Label::new("Enable single-file agent reviews"))
304                    .child(
305                        Label::new(
306                            "Agent edits are also displayed in single-file editors for review.",
307                        )
308                        .color(Color::Muted),
309                    ),
310            )
311            .child(
312                Switch::new("single-file-review-switch", single_file_review.into())
313                    .color(SwitchColor::Accent)
314                    .on_click({
315                        let fs = self.fs.clone();
316                        move |state, _window, cx| {
317                            let allow = state == &ToggleState::Selected;
318                            update_settings_file::<AssistantSettings>(
319                                fs.clone(),
320                                cx,
321                                move |settings, _| {
322                                    settings.set_single_file_review(allow);
323                                },
324                            );
325                        }
326                    }),
327            )
328    }
329
330    fn render_sound_notification(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
331        let play_sound_when_agent_done =
332            AssistantSettings::get_global(cx).play_sound_when_agent_done;
333
334        h_flex()
335            .gap_4()
336            .justify_between()
337            .flex_wrap()
338            .child(
339                v_flex()
340                    .gap_0p5()
341                    .max_w_5_6()
342                    .child(Label::new("Play sound when finished generating"))
343                    .child(
344                        Label::new(
345                            "Hear a notification sound when the agent is done generating changes or needs your input.",
346                        )
347                        .color(Color::Muted),
348                    ),
349            )
350            .child(
351                Switch::new("play-sound-notification-switch", play_sound_when_agent_done.into())
352                    .color(SwitchColor::Accent)
353                    .on_click({
354                        let fs = self.fs.clone();
355                        move |state, _window, cx| {
356                            let allow = state == &ToggleState::Selected;
357                            update_settings_file::<AssistantSettings>(
358                                fs.clone(),
359                                cx,
360                                move |settings, _| {
361                                    settings.set_play_sound_when_agent_done(allow);
362                                },
363                            );
364                        }
365                    }),
366            )
367    }
368
369    fn render_general_settings_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
370        v_flex()
371            .p(DynamicSpacing::Base16.rems(cx))
372            .pr(DynamicSpacing::Base20.rems(cx))
373            .gap_2p5()
374            .border_b_1()
375            .border_color(cx.theme().colors().border)
376            .child(Headline::new("General Settings"))
377            .child(self.render_command_permission(cx))
378            .child(self.render_single_file_review(cx))
379            .child(self.render_sound_notification(cx))
380    }
381
382    fn render_context_servers_section(
383        &mut self,
384        window: &mut Window,
385        cx: &mut Context<Self>,
386    ) -> impl IntoElement {
387        let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone();
388
389        v_flex()
390            .p(DynamicSpacing::Base16.rems(cx))
391            .pr(DynamicSpacing::Base20.rems(cx))
392            .gap_2()
393            .border_b_1()
394            .border_color(cx.theme().colors().border)
395            .child(
396                v_flex()
397                    .gap_0p5()
398                    .child(Headline::new("Model Context Protocol (MCP) Servers"))
399                    .child(Label::new("Connect to context servers via the Model Context Protocol either via Zed extensions or directly.").color(Color::Muted)),
400            )
401            .children(
402                context_server_ids.into_iter().map(|context_server_id| {
403                    self.render_context_server(context_server_id, window, cx)
404                }),
405            )
406            .child(
407                h_flex()
408                    .justify_between()
409                    .gap_2()
410                    .child(
411                        h_flex().w_full().child(
412                            Button::new("add-context-server", "Add Custom Server")
413                                .style(ButtonStyle::Filled)
414                                .layer(ElevationIndex::ModalSurface)
415                                .full_width()
416                                .icon(IconName::Plus)
417                                .icon_size(IconSize::Small)
418                                .icon_position(IconPosition::Start)
419                                .on_click(|_event, window, cx| {
420                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
421                                }),
422                        ),
423                    )
424                    .child(
425                        h_flex().w_full().child(
426                            Button::new(
427                                "install-context-server-extensions",
428                                "Install MCP Extensions",
429                            )
430                            .style(ButtonStyle::Filled)
431                            .layer(ElevationIndex::ModalSurface)
432                            .full_width()
433                            .icon(IconName::Hammer)
434                            .icon_size(IconSize::Small)
435                            .icon_position(IconPosition::Start)
436                            .on_click(|_event, window, cx| {
437                                window.dispatch_action(
438                                    zed_actions::Extensions {
439                                        category_filter: Some(
440                                            ExtensionCategoryFilter::ContextServers,
441                                        ),
442                                    }
443                                    .boxed_clone(),
444                                    cx,
445                                )
446                            }),
447                        ),
448                    ),
449            )
450    }
451
452    fn render_context_server(
453        &self,
454        context_server_id: ContextServerId,
455        window: &mut Window,
456        cx: &mut Context<Self>,
457    ) -> impl use<> + IntoElement {
458        let tools_by_source = self.tools.read(cx).tools_by_source(cx);
459        let server_status = self
460            .context_server_store
461            .read(cx)
462            .status_for_server(&context_server_id)
463            .unwrap_or(ContextServerStatus::Stopped);
464
465        let is_running = matches!(server_status, ContextServerStatus::Running);
466        let item_id = SharedString::from(context_server_id.0.clone());
467
468        let error = if let ContextServerStatus::Error(error) = server_status.clone() {
469            Some(error)
470        } else {
471            None
472        };
473
474        let are_tools_expanded = self
475            .expanded_context_server_tools
476            .get(&context_server_id)
477            .copied()
478            .unwrap_or_default();
479
480        let tools = tools_by_source
481            .get(&ToolSource::ContextServer {
482                id: context_server_id.0.clone().into(),
483            })
484            .map_or([].as_slice(), |tools| tools.as_slice());
485        let tool_count = tools.len();
486
487        let border_color = cx.theme().colors().border.opacity(0.6);
488        let success_color = Color::Success.color(cx);
489
490        let (status_indicator, tooltip_text) = match server_status {
491            ContextServerStatus::Starting => (
492                Indicator::dot()
493                    .color(Color::Success)
494                    .with_animation(
495                        SharedString::from(format!("{}-starting", context_server_id.0.clone(),)),
496                        Animation::new(Duration::from_secs(2))
497                            .repeat()
498                            .with_easing(pulsating_between(0.4, 1.)),
499                        move |this, delta| this.color(success_color.alpha(delta).into()),
500                    )
501                    .into_any_element(),
502                "Server is starting.",
503            ),
504            ContextServerStatus::Running => (
505                Indicator::dot().color(Color::Success).into_any_element(),
506                "Server is running.",
507            ),
508            ContextServerStatus::Error(_) => (
509                Indicator::dot().color(Color::Error).into_any_element(),
510                "Server has an error.",
511            ),
512            ContextServerStatus::Stopped => (
513                Indicator::dot().color(Color::Muted).into_any_element(),
514                "Server is stopped.",
515            ),
516        };
517
518        v_flex()
519            .id(item_id.clone())
520            .border_1()
521            .rounded_md()
522            .border_color(border_color)
523            .bg(cx.theme().colors().background.opacity(0.2))
524            .overflow_hidden()
525            .child(
526                h_flex()
527                    .p_1()
528                    .justify_between()
529                    .when(
530                        error.is_some() || are_tools_expanded && tool_count > 1,
531                        |element| element.border_b_1().border_color(border_color),
532                    )
533                    .child(
534                        h_flex()
535                            .gap_1p5()
536                            .child(
537                                Disclosure::new(
538                                    "tool-list-disclosure",
539                                    are_tools_expanded || error.is_some(),
540                                )
541                                .disabled(tool_count == 0)
542                                .on_click(cx.listener({
543                                    let context_server_id = context_server_id.clone();
544                                    move |this, _event, _window, _cx| {
545                                        let is_open = this
546                                            .expanded_context_server_tools
547                                            .entry(context_server_id.clone())
548                                            .or_insert(false);
549
550                                        *is_open = !*is_open;
551                                    }
552                                })),
553                            )
554                            .child(
555                                div()
556                                    .id(item_id.clone())
557                                    .tooltip(Tooltip::text(tooltip_text))
558                                    .child(status_indicator),
559                            )
560                            .child(Label::new(context_server_id.0.clone()).ml_0p5())
561                            .when(is_running, |this| {
562                                this.child(
563                                    Label::new(if tool_count == 1 {
564                                        SharedString::from("1 tool")
565                                    } else {
566                                        SharedString::from(format!("{} tools", tool_count))
567                                    })
568                                    .color(Color::Muted)
569                                    .size(LabelSize::Small),
570                                )
571                            }),
572                    )
573                    .child(
574                        Switch::new("context-server-switch", is_running.into())
575                            .color(SwitchColor::Accent)
576                            .on_click({
577                                let context_server_manager = self.context_server_store.clone();
578                                let context_server_id = context_server_id.clone();
579                                move |state, _window, cx| match state {
580                                    ToggleState::Unselected | ToggleState::Indeterminate => {
581                                        context_server_manager.update(cx, |this, cx| {
582                                            this.stop_server(&context_server_id, cx).log_err();
583                                        });
584                                    }
585                                    ToggleState::Selected => {
586                                        context_server_manager.update(cx, |this, cx| {
587                                            if let Some(server) =
588                                                this.get_server(&context_server_id)
589                                            {
590                                                this.start_server(server, cx).log_err();
591                                            }
592                                        })
593                                    }
594                                }
595                            }),
596                    ),
597            )
598            .map(|parent| {
599                if let Some(error) = error {
600                    return parent.child(
601                        h_flex()
602                            .p_2()
603                            .gap_2()
604                            .items_start()
605                            .child(
606                                h_flex()
607                                    .flex_none()
608                                    .h(window.line_height() / 1.6_f32)
609                                    .justify_center()
610                                    .child(
611                                        Icon::new(IconName::XCircle)
612                                            .size(IconSize::XSmall)
613                                            .color(Color::Error),
614                                    ),
615                            )
616                            .child(
617                                div().w_full().child(
618                                    Label::new(error)
619                                        .buffer_font(cx)
620                                        .color(Color::Muted)
621                                        .size(LabelSize::Small),
622                                ),
623                            ),
624                    );
625                }
626
627                if !are_tools_expanded || tools.is_empty() {
628                    return parent;
629                }
630
631                parent.child(v_flex().py_1p5().px_1().gap_1().children(
632                    tools.into_iter().enumerate().map(|(ix, tool)| {
633                        h_flex()
634                            .id(("tool-item", ix))
635                            .px_1()
636                            .gap_2()
637                            .justify_between()
638                            .hover(|style| style.bg(cx.theme().colors().element_hover))
639                            .rounded_sm()
640                            .child(
641                                Label::new(tool.name())
642                                    .buffer_font(cx)
643                                    .size(LabelSize::Small),
644                            )
645                            .child(
646                                Icon::new(IconName::Info)
647                                    .size(IconSize::Small)
648                                    .color(Color::Ignored),
649                            )
650                            .tooltip(Tooltip::text(tool.description()))
651                    }),
652                ))
653            })
654    }
655}
656
657impl Render for AgentConfiguration {
658    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
659        v_flex()
660            .id("assistant-configuration")
661            .key_context("AgentConfiguration")
662            .track_focus(&self.focus_handle(cx))
663            .relative()
664            .size_full()
665            .pb_8()
666            .bg(cx.theme().colors().panel_background)
667            .child(
668                v_flex()
669                    .id("assistant-configuration-content")
670                    .track_scroll(&self.scroll_handle)
671                    .size_full()
672                    .overflow_y_scroll()
673                    .child(self.render_general_settings_section(cx))
674                    .child(self.render_context_servers_section(window, cx))
675                    .child(self.render_provider_configuration_section(cx)),
676            )
677            .child(
678                div()
679                    .id("assistant-configuration-scrollbar")
680                    .occlude()
681                    .absolute()
682                    .right(px(3.))
683                    .top_0()
684                    .bottom_0()
685                    .pb_6()
686                    .w(px(12.))
687                    .cursor_default()
688                    .on_mouse_move(cx.listener(|_, _, _window, cx| {
689                        cx.notify();
690                        cx.stop_propagation()
691                    }))
692                    .on_hover(|_, _window, cx| {
693                        cx.stop_propagation();
694                    })
695                    .on_any_mouse_down(|_, _window, cx| {
696                        cx.stop_propagation();
697                    })
698                    .on_scroll_wheel(cx.listener(|_, _, _window, cx| {
699                        cx.notify();
700                    }))
701                    .children(Scrollbar::vertical(self.scrollbar_state.clone())),
702            )
703    }
704}