assistant_configuration.rs

  1mod add_context_server_modal;
  2mod manage_profiles_modal;
  3mod tool_picker;
  4
  5use std::sync::Arc;
  6
  7use assistant_settings::AssistantSettings;
  8use assistant_tool::{ToolSource, ToolWorkingSet};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use fs::Fs;
 12use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
 13use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 14use settings::{Settings, update_settings_file};
 15use ui::{Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, prelude::*};
 16use util::ResultExt as _;
 17use zed_actions::ExtensionCategoryFilter;
 18
 19pub(crate) use add_context_server_modal::AddContextServerModal;
 20pub(crate) use manage_profiles_modal::ManageProfilesModal;
 21
 22use crate::AddContextServer;
 23
 24pub struct AssistantConfiguration {
 25    fs: Arc<dyn Fs>,
 26    focus_handle: FocusHandle,
 27    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 28    context_server_manager: Entity<ContextServerManager>,
 29    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 30    tools: Arc<ToolWorkingSet>,
 31    _registry_subscription: Subscription,
 32}
 33
 34impl AssistantConfiguration {
 35    pub fn new(
 36        fs: Arc<dyn Fs>,
 37        context_server_manager: Entity<ContextServerManager>,
 38        tools: Arc<ToolWorkingSet>,
 39        window: &mut Window,
 40        cx: &mut Context<Self>,
 41    ) -> Self {
 42        let focus_handle = cx.focus_handle();
 43
 44        let registry_subscription = cx.subscribe_in(
 45            &LanguageModelRegistry::global(cx),
 46            window,
 47            |this, _, event: &language_model::Event, window, cx| match event {
 48                language_model::Event::AddedProvider(provider_id) => {
 49                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 50                    if let Some(provider) = provider {
 51                        this.add_provider_configuration_view(&provider, window, cx);
 52                    }
 53                }
 54                language_model::Event::RemovedProvider(provider_id) => {
 55                    this.remove_provider_configuration_view(provider_id);
 56                }
 57                _ => {}
 58            },
 59        );
 60
 61        let mut this = Self {
 62            fs,
 63            focus_handle,
 64            configuration_views_by_provider: HashMap::default(),
 65            context_server_manager,
 66            expanded_context_server_tools: HashMap::default(),
 67            tools,
 68            _registry_subscription: registry_subscription,
 69        };
 70        this.build_provider_configuration_views(window, cx);
 71        this
 72    }
 73
 74    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 75        let providers = LanguageModelRegistry::read_global(cx).providers();
 76        for provider in providers {
 77            self.add_provider_configuration_view(&provider, window, cx);
 78        }
 79    }
 80
 81    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 82        self.configuration_views_by_provider.remove(provider_id);
 83    }
 84
 85    fn add_provider_configuration_view(
 86        &mut self,
 87        provider: &Arc<dyn LanguageModelProvider>,
 88        window: &mut Window,
 89        cx: &mut Context<Self>,
 90    ) {
 91        let configuration_view = provider.configuration_view(window, cx);
 92        self.configuration_views_by_provider
 93            .insert(provider.id(), configuration_view);
 94    }
 95}
 96
 97impl Focusable for AssistantConfiguration {
 98    fn focus_handle(&self, _: &App) -> FocusHandle {
 99        self.focus_handle.clone()
100    }
101}
102
103pub enum AssistantConfigurationEvent {
104    NewThread(Arc<dyn LanguageModelProvider>),
105}
106
107impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
108
109impl AssistantConfiguration {
110    fn render_provider_configuration(
111        &mut self,
112        provider: &Arc<dyn LanguageModelProvider>,
113        cx: &mut Context<Self>,
114    ) -> impl IntoElement + use<> {
115        let provider_id = provider.id().0.clone();
116        let provider_name = provider.name().0.clone();
117        let configuration_view = self
118            .configuration_views_by_provider
119            .get(&provider.id())
120            .cloned();
121
122        v_flex()
123            .gap_1p5()
124            .child(
125                h_flex()
126                    .justify_between()
127                    .child(
128                        h_flex()
129                            .gap_2()
130                            .child(
131                                Icon::new(provider.icon())
132                                    .size(IconSize::Small)
133                                    .color(Color::Muted),
134                            )
135                            .child(Label::new(provider_name.clone())),
136                    )
137                    .when(provider.is_authenticated(cx), |parent| {
138                        parent.child(
139                            Button::new(
140                                SharedString::from(format!("new-thread-{provider_id}")),
141                                "Start New Thread",
142                            )
143                            .icon_position(IconPosition::Start)
144                            .icon(IconName::Plus)
145                            .icon_size(IconSize::Small)
146                            .style(ButtonStyle::Filled)
147                            .layer(ElevationIndex::ModalSurface)
148                            .label_size(LabelSize::Small)
149                            .on_click(cx.listener({
150                                let provider = provider.clone();
151                                move |_this, _event, _window, cx| {
152                                    cx.emit(AssistantConfigurationEvent::NewThread(
153                                        provider.clone(),
154                                    ))
155                                }
156                            })),
157                        )
158                    }),
159            )
160            .child(
161                div()
162                    .p(DynamicSpacing::Base08.rems(cx))
163                    .bg(cx.theme().colors().editor_background)
164                    .border_1()
165                    .border_color(cx.theme().colors().border_variant)
166                    .rounded_sm()
167                    .map(|parent| match configuration_view {
168                        Some(configuration_view) => parent.child(configuration_view),
169                        None => parent.child(div().child(Label::new(format!(
170                            "No configuration view for {provider_name}",
171                        )))),
172                    }),
173            )
174    }
175
176    fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
177        let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
178
179        const HEADING: &str = "Allow running tools without asking for confirmation";
180
181        v_flex()
182            .p(DynamicSpacing::Base16.rems(cx))
183            .gap_2()
184            .flex_1()
185            .child(Headline::new("General Settings").size(HeadlineSize::Small))
186            .child(
187                h_flex()
188                    .p_2p5()
189                    .rounded_sm()
190                    .bg(cx.theme().colors().editor_background)
191                    .border_1()
192                    .border_color(cx.theme().colors().border)
193                    .gap_4()
194                    .justify_between()
195                    .flex_wrap()
196                    .child(
197                        v_flex()
198                            .gap_0p5()
199                            .max_w_5_6()
200                            .child(Label::new(HEADING))
201                            .child(Label::new("When enabled, the agent can perform potentially destructive actions without asking for your confirmation.").color(Color::Muted)),
202                    )
203                    .child(
204                        Switch::new(
205                            "always-allow-tool-actions-switch",
206                            always_allow_tool_actions.into(),
207                        )
208                        .on_click({
209                            let fs = self.fs.clone();
210                            move |state, _window, cx| {
211                                let allow = state == &ToggleState::Selected;
212                                update_settings_file::<AssistantSettings>(
213                                    fs.clone(),
214                                    cx,
215                                    move |settings, _| {
216                                        settings.set_always_allow_tool_actions(allow);
217                                    },
218                                );
219                            }
220                        }),
221                    ),
222            )
223    }
224
225    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
226        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
227        let tools_by_source = self.tools.tools_by_source(cx);
228        let empty = Vec::new();
229
230        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
231
232        v_flex()
233            .p(DynamicSpacing::Base16.rems(cx))
234            .gap_2()
235            .flex_1()
236            .child(
237                v_flex()
238                    .gap_0p5()
239                    .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
240                    .child(Label::new(SUBHEADING).color(Color::Muted)),
241            )
242            .children(context_servers.into_iter().map(|context_server| {
243                let is_running = context_server.client().is_some();
244                let are_tools_expanded = self
245                    .expanded_context_server_tools
246                    .get(&context_server.id())
247                    .copied()
248                    .unwrap_or_default();
249
250                let tools = tools_by_source
251                    .get(&ToolSource::ContextServer {
252                        id: context_server.id().into(),
253                    })
254                    .unwrap_or_else(|| &empty);
255                let tool_count = tools.len();
256
257                v_flex()
258                    .id(SharedString::from(context_server.id()))
259                    .border_1()
260                    .rounded_sm()
261                    .border_color(cx.theme().colors().border)
262                    .bg(cx.theme().colors().editor_background)
263                    .child(
264                        h_flex()
265                            .justify_between()
266                            .px_2()
267                            .py_1()
268                            .when(are_tools_expanded, |element| {
269                                element
270                                    .border_b_1()
271                                    .border_color(cx.theme().colors().border)
272                            })
273                            .child(
274                                h_flex()
275                                    .gap_2()
276                                    .child(
277                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
278                                            .on_click(cx.listener({
279                                                let context_server_id = context_server.id();
280                                                move |this, _event, _window, _cx| {
281                                                    let is_open = this
282                                                        .expanded_context_server_tools
283                                                        .entry(context_server_id.clone())
284                                                        .or_insert(false);
285
286                                                    *is_open = !*is_open;
287                                                }
288                                            })),
289                                    )
290                                    .child(Indicator::dot().color(if is_running {
291                                        Color::Success
292                                    } else {
293                                        Color::Error
294                                    }))
295                                    .child(Label::new(context_server.id()))
296                                    .child(
297                                        Label::new(format!("{tool_count} tools"))
298                                            .color(Color::Muted),
299                                    ),
300                            )
301                            .child(h_flex().child(
302                                Switch::new("context-server-switch", is_running.into()).on_click({
303                                    let context_server_manager =
304                                        self.context_server_manager.clone();
305                                    let context_server = context_server.clone();
306                                    move |state, _window, cx| match state {
307                                        ToggleState::Unselected | ToggleState::Indeterminate => {
308                                            context_server_manager.update(cx, |this, cx| {
309                                                this.stop_server(context_server.clone(), cx)
310                                                    .log_err();
311                                            });
312                                        }
313                                        ToggleState::Selected => {
314                                            cx.spawn({
315                                                let context_server_manager =
316                                                    context_server_manager.clone();
317                                                let context_server = context_server.clone();
318                                                async move |cx| {
319                                                    if let Some(start_server_task) =
320                                                        context_server_manager
321                                                            .update(cx, |this, cx| {
322                                                                this.start_server(
323                                                                    context_server,
324                                                                    cx,
325                                                                )
326                                                            })
327                                                            .log_err()
328                                                    {
329                                                        start_server_task.await.log_err();
330                                                    }
331                                                }
332                                            })
333                                            .detach();
334                                        }
335                                    }
336                                }),
337                            )),
338                    )
339                    .map(|parent| {
340                        if !are_tools_expanded {
341                            return parent;
342                        }
343
344                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
345                            |(ix, tool)| {
346                                h_flex()
347                                    .px_2()
348                                    .py_1()
349                                    .when(ix < tool_count - 1, |element| {
350                                        element
351                                            .border_b_1()
352                                            .border_color(cx.theme().colors().border)
353                                    })
354                                    .child(Label::new(tool.name()))
355                            },
356                        )))
357                    })
358            }))
359            .child(
360                h_flex()
361                    .justify_between()
362                    .gap_2()
363                    .child(
364                        h_flex().w_full().child(
365                            Button::new("add-context-server", "Add Context Server")
366                                .style(ButtonStyle::Filled)
367                                .layer(ElevationIndex::ModalSurface)
368                                .full_width()
369                                .icon(IconName::Plus)
370                                .icon_size(IconSize::Small)
371                                .icon_position(IconPosition::Start)
372                                .on_click(|_event, window, cx| {
373                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
374                                }),
375                        ),
376                    )
377                    .child(
378                        h_flex().w_full().child(
379                            Button::new(
380                                "install-context-server-extensions",
381                                "Install Context Server Extensions",
382                            )
383                            .style(ButtonStyle::Filled)
384                            .layer(ElevationIndex::ModalSurface)
385                            .full_width()
386                            .icon(IconName::DatabaseZap)
387                            .icon_size(IconSize::Small)
388                            .icon_position(IconPosition::Start)
389                            .on_click(|_event, window, cx| {
390                                window.dispatch_action(
391                                    zed_actions::Extensions {
392                                        category_filter: Some(
393                                            ExtensionCategoryFilter::ContextServers,
394                                        ),
395                                    }
396                                    .boxed_clone(),
397                                    cx,
398                                )
399                            }),
400                        ),
401                    ),
402            )
403    }
404}
405
406impl Render for AssistantConfiguration {
407    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
408        let providers = LanguageModelRegistry::read_global(cx).providers();
409
410        v_flex()
411            .id("assistant-configuration")
412            .key_context("AgentConfiguration")
413            .track_focus(&self.focus_handle(cx))
414            .bg(cx.theme().colors().panel_background)
415            .size_full()
416            .overflow_y_scroll()
417            .child(self.render_command_permission(cx))
418            .child(Divider::horizontal().color(DividerColor::Border))
419            .child(self.render_context_servers_section(cx))
420            .child(Divider::horizontal().color(DividerColor::Border))
421            .child(
422                v_flex()
423                    .p(DynamicSpacing::Base16.rems(cx))
424                    .mt_1()
425                    .gap_6()
426                    .flex_1()
427                    .child(
428                        v_flex()
429                            .gap_0p5()
430                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
431                            .child(
432                                Label::new("Add at least one provider to use AI-powered features.")
433                                    .color(Color::Muted),
434                            ),
435                    )
436                    .children(
437                        providers
438                            .into_iter()
439                            .map(|provider| self.render_provider_configuration(&provider, cx)),
440                    ),
441            )
442    }
443}