assistant_configuration.rs

  1mod add_context_server_modal;
  2mod manage_profiles_modal;
  3mod tool_picker;
  4
  5use std::sync::Arc;
  6
  7use assistant_settings::AssistantSettings;
  8use assistant_tool::{ToolSource, ToolWorkingSet};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use fs::Fs;
 12use gpui::{
 13    Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
 14};
 15use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 16use settings::{Settings, update_settings_file};
 17use ui::{
 18    Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
 19    Switch, SwitchColor, Tooltip, prelude::*,
 20};
 21use util::ResultExt as _;
 22use zed_actions::ExtensionCategoryFilter;
 23
 24pub(crate) use add_context_server_modal::AddContextServerModal;
 25pub(crate) use manage_profiles_modal::ManageProfilesModal;
 26
 27use crate::AddContextServer;
 28
 29pub struct AssistantConfiguration {
 30    fs: Arc<dyn Fs>,
 31    focus_handle: FocusHandle,
 32    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 33    context_server_manager: Entity<ContextServerManager>,
 34    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 35    tools: Entity<ToolWorkingSet>,
 36    _registry_subscription: Subscription,
 37    scroll_handle: ScrollHandle,
 38    scrollbar_state: ScrollbarState,
 39}
 40
 41impl AssistantConfiguration {
 42    pub fn new(
 43        fs: Arc<dyn Fs>,
 44        context_server_manager: Entity<ContextServerManager>,
 45        tools: Entity<ToolWorkingSet>,
 46        window: &mut Window,
 47        cx: &mut Context<Self>,
 48    ) -> Self {
 49        let focus_handle = cx.focus_handle();
 50
 51        let registry_subscription = cx.subscribe_in(
 52            &LanguageModelRegistry::global(cx),
 53            window,
 54            |this, _, event: &language_model::Event, window, cx| match event {
 55                language_model::Event::AddedProvider(provider_id) => {
 56                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 57                    if let Some(provider) = provider {
 58                        this.add_provider_configuration_view(&provider, window, cx);
 59                    }
 60                }
 61                language_model::Event::RemovedProvider(provider_id) => {
 62                    this.remove_provider_configuration_view(provider_id);
 63                }
 64                _ => {}
 65            },
 66        );
 67
 68        let scroll_handle = ScrollHandle::new();
 69        let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
 70
 71        let mut this = Self {
 72            fs,
 73            focus_handle,
 74            configuration_views_by_provider: HashMap::default(),
 75            context_server_manager,
 76            expanded_context_server_tools: HashMap::default(),
 77            tools,
 78            _registry_subscription: registry_subscription,
 79            scroll_handle,
 80            scrollbar_state,
 81        };
 82        this.build_provider_configuration_views(window, cx);
 83        this
 84    }
 85
 86    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 87        let providers = LanguageModelRegistry::read_global(cx).providers();
 88        for provider in providers {
 89            self.add_provider_configuration_view(&provider, window, cx);
 90        }
 91    }
 92
 93    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 94        self.configuration_views_by_provider.remove(provider_id);
 95    }
 96
 97    fn add_provider_configuration_view(
 98        &mut self,
 99        provider: &Arc<dyn LanguageModelProvider>,
100        window: &mut Window,
101        cx: &mut Context<Self>,
102    ) {
103        let configuration_view = provider.configuration_view(window, cx);
104        self.configuration_views_by_provider
105            .insert(provider.id(), configuration_view);
106    }
107}
108
109impl Focusable for AssistantConfiguration {
110    fn focus_handle(&self, _: &App) -> FocusHandle {
111        self.focus_handle.clone()
112    }
113}
114
115pub enum AssistantConfigurationEvent {
116    NewThread(Arc<dyn LanguageModelProvider>),
117}
118
119impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
120
121impl AssistantConfiguration {
122    fn render_provider_configuration_block(
123        &mut self,
124        provider: &Arc<dyn LanguageModelProvider>,
125        cx: &mut Context<Self>,
126    ) -> impl IntoElement + use<> {
127        let provider_id = provider.id().0.clone();
128        let provider_name = provider.name().0.clone();
129        let configuration_view = self
130            .configuration_views_by_provider
131            .get(&provider.id())
132            .cloned();
133
134        v_flex()
135            .pt_3()
136            .pb_1()
137            .gap_1p5()
138            .border_t_1()
139            .border_color(cx.theme().colors().border.opacity(0.6))
140            .child(
141                h_flex()
142                    .justify_between()
143                    .child(
144                        h_flex()
145                            .gap_2()
146                            .child(
147                                Icon::new(provider.icon())
148                                    .size(IconSize::Small)
149                                    .color(Color::Muted),
150                            )
151                            .child(Label::new(provider_name.clone()).size(LabelSize::Large)),
152                    )
153                    .when(provider.is_authenticated(cx), |parent| {
154                        parent.child(
155                            Button::new(
156                                SharedString::from(format!("new-thread-{provider_id}")),
157                                "Start New Thread",
158                            )
159                            .icon_position(IconPosition::Start)
160                            .icon(IconName::Plus)
161                            .icon_size(IconSize::Small)
162                            .style(ButtonStyle::Filled)
163                            .layer(ElevationIndex::ModalSurface)
164                            .label_size(LabelSize::Small)
165                            .on_click(cx.listener({
166                                let provider = provider.clone();
167                                move |_this, _event, _window, cx| {
168                                    cx.emit(AssistantConfigurationEvent::NewThread(
169                                        provider.clone(),
170                                    ))
171                                }
172                            })),
173                        )
174                    }),
175            )
176            .map(|parent| match configuration_view {
177                Some(configuration_view) => parent.child(configuration_view),
178                None => parent.child(div().child(Label::new(format!(
179                    "No configuration view for {provider_name}",
180                )))),
181            })
182    }
183
184    fn render_provider_configuration_section(
185        &mut self,
186        cx: &mut Context<Self>,
187    ) -> impl IntoElement {
188        let providers = LanguageModelRegistry::read_global(cx).providers();
189
190        v_flex()
191            .p(DynamicSpacing::Base16.rems(cx))
192            .pr(DynamicSpacing::Base20.rems(cx))
193            .gap_4()
194            .flex_1()
195            .child(
196                v_flex()
197                    .gap_0p5()
198                    .child(Headline::new("LLM Providers"))
199                    .child(
200                        Label::new("Add at least one provider to use AI-powered features.")
201                            .color(Color::Muted),
202                    ),
203            )
204            .children(
205                providers
206                    .into_iter()
207                    .map(|provider| self.render_provider_configuration_block(&provider, cx)),
208            )
209    }
210
211    fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
212        let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
213
214        const HEADING: &str = "Allow running editing tools without asking for confirmation";
215
216        v_flex()
217            .p(DynamicSpacing::Base16.rems(cx))
218            .pr(DynamicSpacing::Base20.rems(cx))
219            .gap_2()
220            .flex_1()
221            .child(Headline::new("General Settings"))
222            .child(
223                h_flex()
224                    .gap_4()
225                    .justify_between()
226                    .flex_wrap()
227                    .child(
228                        v_flex()
229                            .gap_0p5()
230                            .max_w_5_6()
231                            .child(Label::new(HEADING))
232                            .child(Label::new("When enabled, the agent can perform potentially destructive actions without asking for your confirmation.").color(Color::Muted)),
233                    )
234                    .child(
235                        Switch::new(
236                            "always-allow-tool-actions-switch",
237                            always_allow_tool_actions.into(),
238                        )
239                        .color(SwitchColor::Accent)
240                        .on_click({
241                            let fs = self.fs.clone();
242                            move |state, _window, cx| {
243                                let allow = state == &ToggleState::Selected;
244                                update_settings_file::<AssistantSettings>(
245                                    fs.clone(),
246                                    cx,
247                                    move |settings, _| {
248                                        settings.set_always_allow_tool_actions(allow);
249                                    },
250                                );
251                            }
252                        }),
253                    ),
254            )
255    }
256
257    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
258        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
259        let tools_by_source = self.tools.read(cx).tools_by_source(cx);
260        let empty = Vec::new();
261
262        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
263
264        v_flex()
265            .p(DynamicSpacing::Base16.rems(cx))
266            .pr(DynamicSpacing::Base20.rems(cx))
267            .gap_2()
268            .flex_1()
269            .child(
270                v_flex()
271                    .gap_0p5()
272                    .child(Headline::new("Model Context Protocol (MCP) Servers"))
273                    .child(Label::new(SUBHEADING).color(Color::Muted)),
274            )
275            .children(context_servers.into_iter().map(|context_server| {
276                let is_running = context_server.client().is_some();
277                let are_tools_expanded = self
278                    .expanded_context_server_tools
279                    .get(&context_server.id())
280                    .copied()
281                    .unwrap_or_default();
282
283                let tools = tools_by_source
284                    .get(&ToolSource::ContextServer {
285                        id: context_server.id().into(),
286                    })
287                    .unwrap_or_else(|| &empty);
288                let tool_count = tools.len();
289
290                v_flex()
291                    .id(SharedString::from(context_server.id()))
292                    .border_1()
293                    .rounded_md()
294                    .border_color(cx.theme().colors().border)
295                    .bg(cx.theme().colors().background.opacity(0.25))
296                    .child(
297                        h_flex()
298                            .p_1()
299                            .justify_between()
300                            .when(are_tools_expanded && tool_count > 1, |element| {
301                                element
302                                    .border_b_1()
303                                    .border_color(cx.theme().colors().border)
304                            })
305                            .child(
306                                h_flex()
307                                    .gap_2()
308                                    .child(
309                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
310                                            .disabled(tool_count == 0)
311                                            .on_click(cx.listener({
312                                                let context_server_id = context_server.id();
313                                                move |this, _event, _window, _cx| {
314                                                    let is_open = this
315                                                        .expanded_context_server_tools
316                                                        .entry(context_server_id.clone())
317                                                        .or_insert(false);
318
319                                                    *is_open = !*is_open;
320                                                }
321                                            })),
322                                    )
323                                    .child(Indicator::dot().color(if is_running {
324                                        Color::Success
325                                    } else {
326                                        Color::Error
327                                    }))
328                                    .child(Label::new(context_server.id()))
329                                    .child(
330                                        Label::new(format!("{tool_count} tools"))
331                                            .color(Color::Muted)
332                                            .size(LabelSize::Small),
333                                    ),
334                            )
335                            .child(
336                                Switch::new("context-server-switch", is_running.into())
337                                    .color(SwitchColor::Accent)
338                                    .on_click({
339                                        let context_server_manager =
340                                            self.context_server_manager.clone();
341                                        let context_server = context_server.clone();
342                                        move |state, _window, cx| match state {
343                                            ToggleState::Unselected
344                                            | ToggleState::Indeterminate => {
345                                                context_server_manager.update(cx, |this, cx| {
346                                                    this.stop_server(context_server.clone(), cx)
347                                                        .log_err();
348                                                });
349                                            }
350                                            ToggleState::Selected => {
351                                                cx.spawn({
352                                                    let context_server_manager =
353                                                        context_server_manager.clone();
354                                                    let context_server = context_server.clone();
355                                                    async move |cx| {
356                                                        if let Some(start_server_task) =
357                                                            context_server_manager
358                                                                .update(cx, |this, cx| {
359                                                                    this.start_server(
360                                                                        context_server,
361                                                                        cx,
362                                                                    )
363                                                                })
364                                                                .log_err()
365                                                        {
366                                                            start_server_task.await.log_err();
367                                                        }
368                                                    }
369                                                })
370                                                .detach();
371                                            }
372                                        }
373                                    }),
374                            ),
375                    )
376                    .map(|parent| {
377                        if !are_tools_expanded {
378                            return parent;
379                        }
380
381                        parent.child(v_flex().py_1p5().px_1().gap_1().children(
382                            tools.into_iter().enumerate().map(|(ix, tool)| {
383                                h_flex()
384                                    .id(("tool-item", ix))
385                                    .px_1()
386                                    .gap_2()
387                                    .justify_between()
388                                    .hover(|style| style.bg(cx.theme().colors().element_hover))
389                                    .rounded_sm()
390                                    .child(
391                                        Label::new(tool.name())
392                                            .buffer_font(cx)
393                                            .size(LabelSize::Small),
394                                    )
395                                    .child(
396                                        Icon::new(IconName::Info)
397                                            .size(IconSize::Small)
398                                            .color(Color::Ignored),
399                                    )
400                                    .tooltip(Tooltip::text(tool.description()))
401                            }),
402                        ))
403                    })
404            }))
405            .child(
406                h_flex()
407                    .justify_between()
408                    .gap_2()
409                    .child(
410                        h_flex().w_full().child(
411                            Button::new("add-context-server", "Add Custom Server")
412                                .style(ButtonStyle::Filled)
413                                .layer(ElevationIndex::ModalSurface)
414                                .full_width()
415                                .icon(IconName::Plus)
416                                .icon_size(IconSize::Small)
417                                .icon_position(IconPosition::Start)
418                                .on_click(|_event, window, cx| {
419                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
420                                }),
421                        ),
422                    )
423                    .child(
424                        h_flex().w_full().child(
425                            Button::new(
426                                "install-context-server-extensions",
427                                "Install MCP Extensions",
428                            )
429                            .style(ButtonStyle::Filled)
430                            .layer(ElevationIndex::ModalSurface)
431                            .full_width()
432                            .icon(IconName::DatabaseZap)
433                            .icon_size(IconSize::Small)
434                            .icon_position(IconPosition::Start)
435                            .on_click(|_event, window, cx| {
436                                window.dispatch_action(
437                                    zed_actions::Extensions {
438                                        category_filter: Some(
439                                            ExtensionCategoryFilter::ContextServers,
440                                        ),
441                                    }
442                                    .boxed_clone(),
443                                    cx,
444                                )
445                            }),
446                        ),
447                    ),
448            )
449    }
450}
451
452impl Render for AssistantConfiguration {
453    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
454        v_flex()
455            .id("assistant-configuration")
456            .key_context("AgentConfiguration")
457            .track_focus(&self.focus_handle(cx))
458            .relative()
459            .size_full()
460            .pb_8()
461            .bg(cx.theme().colors().panel_background)
462            .child(
463                v_flex()
464                    .id("assistant-configuration-content")
465                    .track_scroll(&self.scroll_handle)
466                    .size_full()
467                    .overflow_y_scroll()
468                    .child(self.render_command_permission(cx))
469                    .child(Divider::horizontal().color(DividerColor::Border))
470                    .child(self.render_context_servers_section(cx))
471                    .child(Divider::horizontal().color(DividerColor::Border))
472                    .child(self.render_provider_configuration_section(cx)),
473            )
474            .child(
475                div()
476                    .id("assistant-configuration-scrollbar")
477                    .occlude()
478                    .absolute()
479                    .right(px(3.))
480                    .top_0()
481                    .bottom_0()
482                    .pb_6()
483                    .w(px(12.))
484                    .cursor_default()
485                    .on_mouse_move(cx.listener(|_, _, _window, cx| {
486                        cx.notify();
487                        cx.stop_propagation()
488                    }))
489                    .on_hover(|_, _window, cx| {
490                        cx.stop_propagation();
491                    })
492                    .on_any_mouse_down(|_, _window, cx| {
493                        cx.stop_propagation();
494                    })
495                    .on_scroll_wheel(cx.listener(|_, _, _window, cx| {
496                        cx.notify();
497                    }))
498                    .children(Scrollbar::vertical(self.scrollbar_state.clone())),
499            )
500    }
501}