assistant_configuration.rs

  1mod add_context_server_modal;
  2mod manage_profiles_modal;
  3mod tool_picker;
  4
  5use std::sync::Arc;
  6
  7use assistant_settings::AssistantSettings;
  8use assistant_tool::{ToolSource, ToolWorkingSet};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use fs::Fs;
 12use gpui::{
 13    Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
 14};
 15use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 16use settings::{Settings, update_settings_file};
 17use ui::{
 18    Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
 19    Switch, Tooltip, prelude::*,
 20};
 21use util::ResultExt as _;
 22use zed_actions::ExtensionCategoryFilter;
 23
 24pub(crate) use add_context_server_modal::AddContextServerModal;
 25pub(crate) use manage_profiles_modal::ManageProfilesModal;
 26
 27use crate::AddContextServer;
 28
 29pub struct AssistantConfiguration {
 30    fs: Arc<dyn Fs>,
 31    focus_handle: FocusHandle,
 32    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 33    context_server_manager: Entity<ContextServerManager>,
 34    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 35    tools: Entity<ToolWorkingSet>,
 36    _registry_subscription: Subscription,
 37    scroll_handle: ScrollHandle,
 38    scrollbar_state: ScrollbarState,
 39}
 40
 41impl AssistantConfiguration {
 42    pub fn new(
 43        fs: Arc<dyn Fs>,
 44        context_server_manager: Entity<ContextServerManager>,
 45        tools: Entity<ToolWorkingSet>,
 46        window: &mut Window,
 47        cx: &mut Context<Self>,
 48    ) -> Self {
 49        let focus_handle = cx.focus_handle();
 50
 51        let registry_subscription = cx.subscribe_in(
 52            &LanguageModelRegistry::global(cx),
 53            window,
 54            |this, _, event: &language_model::Event, window, cx| match event {
 55                language_model::Event::AddedProvider(provider_id) => {
 56                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 57                    if let Some(provider) = provider {
 58                        this.add_provider_configuration_view(&provider, window, cx);
 59                    }
 60                }
 61                language_model::Event::RemovedProvider(provider_id) => {
 62                    this.remove_provider_configuration_view(provider_id);
 63                }
 64                _ => {}
 65            },
 66        );
 67
 68        let scroll_handle = ScrollHandle::new();
 69        let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
 70
 71        let mut this = Self {
 72            fs,
 73            focus_handle,
 74            configuration_views_by_provider: HashMap::default(),
 75            context_server_manager,
 76            expanded_context_server_tools: HashMap::default(),
 77            tools,
 78            _registry_subscription: registry_subscription,
 79            scroll_handle,
 80            scrollbar_state,
 81        };
 82        this.build_provider_configuration_views(window, cx);
 83        this
 84    }
 85
 86    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 87        let providers = LanguageModelRegistry::read_global(cx).providers();
 88        for provider in providers {
 89            self.add_provider_configuration_view(&provider, window, cx);
 90        }
 91    }
 92
 93    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 94        self.configuration_views_by_provider.remove(provider_id);
 95    }
 96
 97    fn add_provider_configuration_view(
 98        &mut self,
 99        provider: &Arc<dyn LanguageModelProvider>,
100        window: &mut Window,
101        cx: &mut Context<Self>,
102    ) {
103        let configuration_view = provider.configuration_view(window, cx);
104        self.configuration_views_by_provider
105            .insert(provider.id(), configuration_view);
106    }
107}
108
109impl Focusable for AssistantConfiguration {
110    fn focus_handle(&self, _: &App) -> FocusHandle {
111        self.focus_handle.clone()
112    }
113}
114
115pub enum AssistantConfigurationEvent {
116    NewThread(Arc<dyn LanguageModelProvider>),
117}
118
119impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
120
121impl AssistantConfiguration {
122    fn render_provider_configuration_block(
123        &mut self,
124        provider: &Arc<dyn LanguageModelProvider>,
125        cx: &mut Context<Self>,
126    ) -> impl IntoElement + use<> {
127        let provider_id = provider.id().0.clone();
128        let provider_name = provider.name().0.clone();
129        let configuration_view = self
130            .configuration_views_by_provider
131            .get(&provider.id())
132            .cloned();
133
134        v_flex()
135            .gap_1p5()
136            .child(
137                h_flex()
138                    .justify_between()
139                    .child(
140                        h_flex()
141                            .gap_2()
142                            .child(
143                                Icon::new(provider.icon())
144                                    .size(IconSize::Small)
145                                    .color(Color::Muted),
146                            )
147                            .child(Label::new(provider_name.clone())),
148                    )
149                    .when(provider.is_authenticated(cx), |parent| {
150                        parent.child(
151                            Button::new(
152                                SharedString::from(format!("new-thread-{provider_id}")),
153                                "Start New Thread",
154                            )
155                            .icon_position(IconPosition::Start)
156                            .icon(IconName::Plus)
157                            .icon_size(IconSize::Small)
158                            .style(ButtonStyle::Filled)
159                            .layer(ElevationIndex::ModalSurface)
160                            .label_size(LabelSize::Small)
161                            .on_click(cx.listener({
162                                let provider = provider.clone();
163                                move |_this, _event, _window, cx| {
164                                    cx.emit(AssistantConfigurationEvent::NewThread(
165                                        provider.clone(),
166                                    ))
167                                }
168                            })),
169                        )
170                    }),
171            )
172            .child(
173                div()
174                    .p(DynamicSpacing::Base08.rems(cx))
175                    .bg(cx.theme().colors().editor_background)
176                    .border_1()
177                    .border_color(cx.theme().colors().border)
178                    .rounded_sm()
179                    .map(|parent| match configuration_view {
180                        Some(configuration_view) => parent.child(configuration_view),
181                        None => parent.child(div().child(Label::new(format!(
182                            "No configuration view for {provider_name}",
183                        )))),
184                    }),
185            )
186    }
187
188    fn render_provider_configuration_section(
189        &mut self,
190        cx: &mut Context<Self>,
191    ) -> impl IntoElement {
192        let providers = LanguageModelRegistry::read_global(cx).providers();
193
194        v_flex()
195            .p(DynamicSpacing::Base16.rems(cx))
196            .pr(DynamicSpacing::Base20.rems(cx))
197            .gap_4()
198            .flex_1()
199            .child(
200                v_flex()
201                    .gap_0p5()
202                    .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
203                    .child(
204                        Label::new("Add at least one provider to use AI-powered features.")
205                            .color(Color::Muted),
206                    ),
207            )
208            .children(
209                providers
210                    .into_iter()
211                    .map(|provider| self.render_provider_configuration_block(&provider, cx)),
212            )
213    }
214
215    fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
216        let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
217
218        const HEADING: &str = "Allow running tools without asking for confirmation";
219
220        v_flex()
221            .p(DynamicSpacing::Base16.rems(cx))
222            .pr(DynamicSpacing::Base20.rems(cx))
223            .gap_2()
224            .flex_1()
225            .child(Headline::new("General Settings").size(HeadlineSize::Small))
226            .child(
227                h_flex()
228                    .p_2p5()
229                    .rounded_sm()
230                    .bg(cx.theme().colors().editor_background)
231                    .border_1()
232                    .border_color(cx.theme().colors().border)
233                    .gap_4()
234                    .justify_between()
235                    .flex_wrap()
236                    .child(
237                        v_flex()
238                            .gap_0p5()
239                            .max_w_5_6()
240                            .child(Label::new(HEADING))
241                            .child(Label::new("When enabled, the agent can perform potentially destructive actions without asking for your confirmation.").color(Color::Muted)),
242                    )
243                    .child(
244                        Switch::new(
245                            "always-allow-tool-actions-switch",
246                            always_allow_tool_actions.into(),
247                        )
248                        .on_click({
249                            let fs = self.fs.clone();
250                            move |state, _window, cx| {
251                                let allow = state == &ToggleState::Selected;
252                                update_settings_file::<AssistantSettings>(
253                                    fs.clone(),
254                                    cx,
255                                    move |settings, _| {
256                                        settings.set_always_allow_tool_actions(allow);
257                                    },
258                                );
259                            }
260                        }),
261                    ),
262            )
263    }
264
265    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
266        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
267        let tools_by_source = self.tools.read(cx).tools_by_source(cx);
268        let empty = Vec::new();
269
270        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
271
272        v_flex()
273            .p(DynamicSpacing::Base16.rems(cx))
274            .pr(DynamicSpacing::Base20.rems(cx))
275            .gap_2()
276            .flex_1()
277            .child(
278                v_flex()
279                    .gap_0p5()
280                    .child(
281                        Headline::new("Model Context Protocol (MCP) Servers")
282                            .size(HeadlineSize::Small),
283                    )
284                    .child(Label::new(SUBHEADING).color(Color::Muted)),
285            )
286            .children(context_servers.into_iter().map(|context_server| {
287                let is_running = context_server.client().is_some();
288                let are_tools_expanded = self
289                    .expanded_context_server_tools
290                    .get(&context_server.id())
291                    .copied()
292                    .unwrap_or_default();
293
294                let tools = tools_by_source
295                    .get(&ToolSource::ContextServer {
296                        id: context_server.id().into(),
297                    })
298                    .unwrap_or_else(|| &empty);
299                let tool_count = tools.len();
300
301                v_flex()
302                    .id(SharedString::from(context_server.id()))
303                    .border_1()
304                    .rounded_sm()
305                    .border_color(cx.theme().colors().border)
306                    .bg(cx.theme().colors().editor_background)
307                    .child(
308                        h_flex()
309                            .p_1()
310                            .justify_between()
311                            .when(are_tools_expanded && tool_count > 1, |element| {
312                                element
313                                    .border_b_1()
314                                    .border_color(cx.theme().colors().border)
315                            })
316                            .child(
317                                h_flex()
318                                    .gap_2()
319                                    .child(
320                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
321                                            .disabled(tool_count == 0)
322                                            .on_click(cx.listener({
323                                                let context_server_id = context_server.id();
324                                                move |this, _event, _window, _cx| {
325                                                    let is_open = this
326                                                        .expanded_context_server_tools
327                                                        .entry(context_server_id.clone())
328                                                        .or_insert(false);
329
330                                                    *is_open = !*is_open;
331                                                }
332                                            })),
333                                    )
334                                    .child(Indicator::dot().color(if is_running {
335                                        Color::Success
336                                    } else {
337                                        Color::Error
338                                    }))
339                                    .child(Label::new(context_server.id()))
340                                    .child(
341                                        Label::new(format!("{tool_count} tools"))
342                                            .color(Color::Muted)
343                                            .size(LabelSize::Small),
344                                    ),
345                            )
346                            .child(
347                                Switch::new("context-server-switch", is_running.into()).on_click({
348                                    let context_server_manager =
349                                        self.context_server_manager.clone();
350                                    let context_server = context_server.clone();
351                                    move |state, _window, cx| match state {
352                                        ToggleState::Unselected | ToggleState::Indeterminate => {
353                                            context_server_manager.update(cx, |this, cx| {
354                                                this.stop_server(context_server.clone(), cx)
355                                                    .log_err();
356                                            });
357                                        }
358                                        ToggleState::Selected => {
359                                            cx.spawn({
360                                                let context_server_manager =
361                                                    context_server_manager.clone();
362                                                let context_server = context_server.clone();
363                                                async move |cx| {
364                                                    if let Some(start_server_task) =
365                                                        context_server_manager
366                                                            .update(cx, |this, cx| {
367                                                                this.start_server(
368                                                                    context_server,
369                                                                    cx,
370                                                                )
371                                                            })
372                                                            .log_err()
373                                                    {
374                                                        start_server_task.await.log_err();
375                                                    }
376                                                }
377                                            })
378                                            .detach();
379                                        }
380                                    }
381                                }),
382                            ),
383                    )
384                    .map(|parent| {
385                        if !are_tools_expanded {
386                            return parent;
387                        }
388
389                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
390                            |(ix, tool)| {
391                                h_flex()
392                                    .id("tool-item")
393                                    .pl_2()
394                                    .pr_1()
395                                    .py_1()
396                                    .gap_2()
397                                    .justify_between()
398                                    .when(ix < tool_count - 1, |element| {
399                                        element
400                                            .border_b_1()
401                                            .border_color(cx.theme().colors().border_variant)
402                                    })
403                                    .child(
404                                        Label::new(tool.name())
405                                            .buffer_font(cx)
406                                            .size(LabelSize::Small),
407                                    )
408                                    .child(
409                                        IconButton::new(("tool-description", ix), IconName::Info)
410                                            .shape(ui::IconButtonShape::Square)
411                                            .icon_size(IconSize::Small)
412                                            .icon_color(Color::Ignored)
413                                            .tooltip(Tooltip::text(tool.description())),
414                                    )
415                            },
416                        )))
417                    })
418            }))
419            .child(
420                h_flex()
421                    .justify_between()
422                    .gap_2()
423                    .child(
424                        h_flex().w_full().child(
425                            Button::new("add-context-server", "Add MCPs Directly")
426                                .style(ButtonStyle::Filled)
427                                .layer(ElevationIndex::ModalSurface)
428                                .full_width()
429                                .icon(IconName::Plus)
430                                .icon_size(IconSize::Small)
431                                .icon_position(IconPosition::Start)
432                                .on_click(|_event, window, cx| {
433                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
434                                }),
435                        ),
436                    )
437                    .child(
438                        h_flex().w_full().child(
439                            Button::new(
440                                "install-context-server-extensions",
441                                "Install MCP Extensions",
442                            )
443                            .style(ButtonStyle::Filled)
444                            .layer(ElevationIndex::ModalSurface)
445                            .full_width()
446                            .icon(IconName::DatabaseZap)
447                            .icon_size(IconSize::Small)
448                            .icon_position(IconPosition::Start)
449                            .on_click(|_event, window, cx| {
450                                window.dispatch_action(
451                                    zed_actions::Extensions {
452                                        category_filter: Some(
453                                            ExtensionCategoryFilter::ContextServers,
454                                        ),
455                                    }
456                                    .boxed_clone(),
457                                    cx,
458                                )
459                            }),
460                        ),
461                    ),
462            )
463    }
464}
465
466impl Render for AssistantConfiguration {
467    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
468        v_flex()
469            .id("assistant-configuration")
470            .key_context("AgentConfiguration")
471            .track_focus(&self.focus_handle(cx))
472            .relative()
473            .size_full()
474            .pb_8()
475            .bg(cx.theme().colors().panel_background)
476            .child(
477                v_flex()
478                    .id("assistant-configuration-content")
479                    .track_scroll(&self.scroll_handle)
480                    .size_full()
481                    .overflow_y_scroll()
482                    .child(self.render_command_permission(cx))
483                    .child(Divider::horizontal().color(DividerColor::Border))
484                    .child(self.render_context_servers_section(cx))
485                    .child(Divider::horizontal().color(DividerColor::Border))
486                    .child(self.render_provider_configuration_section(cx)),
487            )
488            .child(
489                div()
490                    .id("assistant-configuration-scrollbar")
491                    .occlude()
492                    .absolute()
493                    .right(px(3.))
494                    .top_0()
495                    .bottom_0()
496                    .pb_6()
497                    .w(px(12.))
498                    .cursor_default()
499                    .on_mouse_move(cx.listener(|_, _, _window, cx| {
500                        cx.notify();
501                        cx.stop_propagation()
502                    }))
503                    .on_hover(|_, _window, cx| {
504                        cx.stop_propagation();
505                    })
506                    .on_any_mouse_down(|_, _window, cx| {
507                        cx.stop_propagation();
508                    })
509                    .on_scroll_wheel(cx.listener(|_, _, _window, cx| {
510                        cx.notify();
511                    }))
512                    .children(Scrollbar::vertical(self.scrollbar_state.clone())),
513            )
514    }
515}