assistant_configuration.rs

  1mod add_context_server_modal;
  2mod manage_profiles_modal;
  3mod tool_picker;
  4
  5use std::sync::Arc;
  6
  7use assistant_settings::AssistantSettings;
  8use assistant_tool::{ToolSource, ToolWorkingSet};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use fs::Fs;
 12use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
 13use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 14use settings::{Settings, update_settings_file};
 15use ui::{
 16    Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip, prelude::*,
 17};
 18use util::ResultExt as _;
 19use zed_actions::ExtensionCategoryFilter;
 20
 21pub(crate) use add_context_server_modal::AddContextServerModal;
 22pub(crate) use manage_profiles_modal::ManageProfilesModal;
 23
 24use crate::AddContextServer;
 25
 26pub struct AssistantConfiguration {
 27    fs: Arc<dyn Fs>,
 28    focus_handle: FocusHandle,
 29    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 30    context_server_manager: Entity<ContextServerManager>,
 31    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 32    tools: Entity<ToolWorkingSet>,
 33    _registry_subscription: Subscription,
 34}
 35
 36impl AssistantConfiguration {
 37    pub fn new(
 38        fs: Arc<dyn Fs>,
 39        context_server_manager: Entity<ContextServerManager>,
 40        tools: Entity<ToolWorkingSet>,
 41        window: &mut Window,
 42        cx: &mut Context<Self>,
 43    ) -> Self {
 44        let focus_handle = cx.focus_handle();
 45
 46        let registry_subscription = cx.subscribe_in(
 47            &LanguageModelRegistry::global(cx),
 48            window,
 49            |this, _, event: &language_model::Event, window, cx| match event {
 50                language_model::Event::AddedProvider(provider_id) => {
 51                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 52                    if let Some(provider) = provider {
 53                        this.add_provider_configuration_view(&provider, window, cx);
 54                    }
 55                }
 56                language_model::Event::RemovedProvider(provider_id) => {
 57                    this.remove_provider_configuration_view(provider_id);
 58                }
 59                _ => {}
 60            },
 61        );
 62
 63        let mut this = Self {
 64            fs,
 65            focus_handle,
 66            configuration_views_by_provider: HashMap::default(),
 67            context_server_manager,
 68            expanded_context_server_tools: HashMap::default(),
 69            tools,
 70            _registry_subscription: registry_subscription,
 71        };
 72        this.build_provider_configuration_views(window, cx);
 73        this
 74    }
 75
 76    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 77        let providers = LanguageModelRegistry::read_global(cx).providers();
 78        for provider in providers {
 79            self.add_provider_configuration_view(&provider, window, cx);
 80        }
 81    }
 82
 83    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 84        self.configuration_views_by_provider.remove(provider_id);
 85    }
 86
 87    fn add_provider_configuration_view(
 88        &mut self,
 89        provider: &Arc<dyn LanguageModelProvider>,
 90        window: &mut Window,
 91        cx: &mut Context<Self>,
 92    ) {
 93        let configuration_view = provider.configuration_view(window, cx);
 94        self.configuration_views_by_provider
 95            .insert(provider.id(), configuration_view);
 96    }
 97}
 98
 99impl Focusable for AssistantConfiguration {
100    fn focus_handle(&self, _: &App) -> FocusHandle {
101        self.focus_handle.clone()
102    }
103}
104
105pub enum AssistantConfigurationEvent {
106    NewThread(Arc<dyn LanguageModelProvider>),
107}
108
109impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
110
111impl AssistantConfiguration {
112    fn render_provider_configuration(
113        &mut self,
114        provider: &Arc<dyn LanguageModelProvider>,
115        cx: &mut Context<Self>,
116    ) -> impl IntoElement + use<> {
117        let provider_id = provider.id().0.clone();
118        let provider_name = provider.name().0.clone();
119        let configuration_view = self
120            .configuration_views_by_provider
121            .get(&provider.id())
122            .cloned();
123
124        v_flex()
125            .gap_1p5()
126            .child(
127                h_flex()
128                    .justify_between()
129                    .child(
130                        h_flex()
131                            .gap_2()
132                            .child(
133                                Icon::new(provider.icon())
134                                    .size(IconSize::Small)
135                                    .color(Color::Muted),
136                            )
137                            .child(Label::new(provider_name.clone())),
138                    )
139                    .when(provider.is_authenticated(cx), |parent| {
140                        parent.child(
141                            Button::new(
142                                SharedString::from(format!("new-thread-{provider_id}")),
143                                "Start New Thread",
144                            )
145                            .icon_position(IconPosition::Start)
146                            .icon(IconName::Plus)
147                            .icon_size(IconSize::Small)
148                            .style(ButtonStyle::Filled)
149                            .layer(ElevationIndex::ModalSurface)
150                            .label_size(LabelSize::Small)
151                            .on_click(cx.listener({
152                                let provider = provider.clone();
153                                move |_this, _event, _window, cx| {
154                                    cx.emit(AssistantConfigurationEvent::NewThread(
155                                        provider.clone(),
156                                    ))
157                                }
158                            })),
159                        )
160                    }),
161            )
162            .child(
163                div()
164                    .p(DynamicSpacing::Base08.rems(cx))
165                    .bg(cx.theme().colors().editor_background)
166                    .border_1()
167                    .border_color(cx.theme().colors().border_variant)
168                    .rounded_sm()
169                    .map(|parent| match configuration_view {
170                        Some(configuration_view) => parent.child(configuration_view),
171                        None => parent.child(div().child(Label::new(format!(
172                            "No configuration view for {provider_name}",
173                        )))),
174                    }),
175            )
176    }
177
178    fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
179        let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
180
181        const HEADING: &str = "Allow running tools without asking for confirmation";
182
183        v_flex()
184            .p(DynamicSpacing::Base16.rems(cx))
185            .gap_2()
186            .flex_1()
187            .child(Headline::new("General Settings").size(HeadlineSize::Small))
188            .child(
189                h_flex()
190                    .p_2p5()
191                    .rounded_sm()
192                    .bg(cx.theme().colors().editor_background)
193                    .border_1()
194                    .border_color(cx.theme().colors().border)
195                    .gap_4()
196                    .justify_between()
197                    .flex_wrap()
198                    .child(
199                        v_flex()
200                            .gap_0p5()
201                            .max_w_5_6()
202                            .child(Label::new(HEADING))
203                            .child(Label::new("When enabled, the agent can perform potentially destructive actions without asking for your confirmation.").color(Color::Muted)),
204                    )
205                    .child(
206                        Switch::new(
207                            "always-allow-tool-actions-switch",
208                            always_allow_tool_actions.into(),
209                        )
210                        .on_click({
211                            let fs = self.fs.clone();
212                            move |state, _window, cx| {
213                                let allow = state == &ToggleState::Selected;
214                                update_settings_file::<AssistantSettings>(
215                                    fs.clone(),
216                                    cx,
217                                    move |settings, _| {
218                                        settings.set_always_allow_tool_actions(allow);
219                                    },
220                                );
221                            }
222                        }),
223                    ),
224            )
225    }
226
227    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
228        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
229        let tools_by_source = self.tools.read(cx).tools_by_source(cx);
230        let empty = Vec::new();
231
232        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
233
234        v_flex()
235            .p(DynamicSpacing::Base16.rems(cx))
236            .gap_2()
237            .flex_1()
238            .child(
239                v_flex()
240                    .gap_0p5()
241                    .child(
242                        Headline::new("Model Context Protocol (MCP) Servers")
243                            .size(HeadlineSize::Small),
244                    )
245                    .child(Label::new(SUBHEADING).color(Color::Muted)),
246            )
247            .children(context_servers.into_iter().map(|context_server| {
248                let is_running = context_server.client().is_some();
249                let are_tools_expanded = self
250                    .expanded_context_server_tools
251                    .get(&context_server.id())
252                    .copied()
253                    .unwrap_or_default();
254
255                let tools = tools_by_source
256                    .get(&ToolSource::ContextServer {
257                        id: context_server.id().into(),
258                    })
259                    .unwrap_or_else(|| &empty);
260                let tool_count = tools.len();
261
262                v_flex()
263                    .id(SharedString::from(context_server.id()))
264                    .border_1()
265                    .rounded_sm()
266                    .border_color(cx.theme().colors().border)
267                    .bg(cx.theme().colors().editor_background)
268                    .child(
269                        h_flex()
270                            .p_1()
271                            .justify_between()
272                            .when(are_tools_expanded && tool_count > 1, |element| {
273                                element
274                                    .border_b_1()
275                                    .border_color(cx.theme().colors().border)
276                            })
277                            .child(
278                                h_flex()
279                                    .gap_2()
280                                    .child(
281                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
282                                            .disabled(tool_count == 0)
283                                            .on_click(cx.listener({
284                                                let context_server_id = context_server.id();
285                                                move |this, _event, _window, _cx| {
286                                                    let is_open = this
287                                                        .expanded_context_server_tools
288                                                        .entry(context_server_id.clone())
289                                                        .or_insert(false);
290
291                                                    *is_open = !*is_open;
292                                                }
293                                            })),
294                                    )
295                                    .child(Indicator::dot().color(if is_running {
296                                        Color::Success
297                                    } else {
298                                        Color::Error
299                                    }))
300                                    .child(Label::new(context_server.id()))
301                                    .child(
302                                        Label::new(format!("{tool_count} tools"))
303                                            .color(Color::Muted)
304                                            .size(LabelSize::Small),
305                                    ),
306                            )
307                            .child(
308                                Switch::new("context-server-switch", is_running.into()).on_click({
309                                    let context_server_manager =
310                                        self.context_server_manager.clone();
311                                    let context_server = context_server.clone();
312                                    move |state, _window, cx| match state {
313                                        ToggleState::Unselected | ToggleState::Indeterminate => {
314                                            context_server_manager.update(cx, |this, cx| {
315                                                this.stop_server(context_server.clone(), cx)
316                                                    .log_err();
317                                            });
318                                        }
319                                        ToggleState::Selected => {
320                                            cx.spawn({
321                                                let context_server_manager =
322                                                    context_server_manager.clone();
323                                                let context_server = context_server.clone();
324                                                async move |cx| {
325                                                    if let Some(start_server_task) =
326                                                        context_server_manager
327                                                            .update(cx, |this, cx| {
328                                                                this.start_server(
329                                                                    context_server,
330                                                                    cx,
331                                                                )
332                                                            })
333                                                            .log_err()
334                                                    {
335                                                        start_server_task.await.log_err();
336                                                    }
337                                                }
338                                            })
339                                            .detach();
340                                        }
341                                    }
342                                }),
343                            ),
344                    )
345                    .map(|parent| {
346                        if !are_tools_expanded {
347                            return parent;
348                        }
349
350                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
351                            |(ix, tool)| {
352                                h_flex()
353                                    .id("tool-item")
354                                    .pl_2()
355                                    .pr_1()
356                                    .py_1()
357                                    .gap_2()
358                                    .justify_between()
359                                    .when(ix < tool_count - 1, |element| {
360                                        element
361                                            .border_b_1()
362                                            .border_color(cx.theme().colors().border_variant)
363                                    })
364                                    .child(
365                                        Label::new(tool.name())
366                                            .buffer_font(cx)
367                                            .size(LabelSize::Small),
368                                    )
369                                    .child(
370                                        IconButton::new(("tool-description", ix), IconName::Info)
371                                            .shape(ui::IconButtonShape::Square)
372                                            .icon_size(IconSize::Small)
373                                            .icon_color(Color::Ignored)
374                                            .tooltip(Tooltip::text(tool.description())),
375                                    )
376                            },
377                        )))
378                    })
379            }))
380            .child(
381                h_flex()
382                    .justify_between()
383                    .gap_2()
384                    .child(
385                        h_flex().w_full().child(
386                            Button::new("add-context-server", "Add MCPs Directly")
387                                .style(ButtonStyle::Filled)
388                                .layer(ElevationIndex::ModalSurface)
389                                .full_width()
390                                .icon(IconName::Plus)
391                                .icon_size(IconSize::Small)
392                                .icon_position(IconPosition::Start)
393                                .on_click(|_event, window, cx| {
394                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
395                                }),
396                        ),
397                    )
398                    .child(
399                        h_flex().w_full().child(
400                            Button::new(
401                                "install-context-server-extensions",
402                                "Install MCP Extensions",
403                            )
404                            .style(ButtonStyle::Filled)
405                            .layer(ElevationIndex::ModalSurface)
406                            .full_width()
407                            .icon(IconName::DatabaseZap)
408                            .icon_size(IconSize::Small)
409                            .icon_position(IconPosition::Start)
410                            .on_click(|_event, window, cx| {
411                                window.dispatch_action(
412                                    zed_actions::Extensions {
413                                        category_filter: Some(
414                                            ExtensionCategoryFilter::ContextServers,
415                                        ),
416                                    }
417                                    .boxed_clone(),
418                                    cx,
419                                )
420                            }),
421                        ),
422                    ),
423            )
424    }
425}
426
427impl Render for AssistantConfiguration {
428    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
429        let providers = LanguageModelRegistry::read_global(cx).providers();
430
431        v_flex()
432            .id("assistant-configuration")
433            .key_context("AgentConfiguration")
434            .track_focus(&self.focus_handle(cx))
435            .bg(cx.theme().colors().panel_background)
436            .size_full()
437            .overflow_y_scroll()
438            .child(self.render_command_permission(cx))
439            .child(Divider::horizontal().color(DividerColor::Border))
440            .child(self.render_context_servers_section(cx))
441            .child(Divider::horizontal().color(DividerColor::Border))
442            .child(
443                v_flex()
444                    .p(DynamicSpacing::Base16.rems(cx))
445                    .mt_1()
446                    .gap_6()
447                    .flex_1()
448                    .child(
449                        v_flex()
450                            .gap_0p5()
451                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
452                            .child(
453                                Label::new("Add at least one provider to use AI-powered features.")
454                                    .color(Color::Muted),
455                            ),
456                    )
457                    .children(
458                        providers
459                            .into_iter()
460                            .map(|provider| self.render_provider_configuration(&provider, cx)),
461                    ),
462            )
463    }
464}