agent_configuration.rs

  1mod add_context_server_modal;
  2mod configure_context_server_modal;
  3mod manage_profiles_modal;
  4mod tool_picker;
  5
  6use std::{sync::Arc, time::Duration};
  7
  8use assistant_settings::AssistantSettings;
  9use assistant_tool::{ToolSource, ToolWorkingSet};
 10use collections::HashMap;
 11use context_server::ContextServerId;
 12use fs::Fs;
 13use gpui::{
 14    Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
 15    Focusable, ScrollHandle, Subscription, pulsating_between,
 16};
 17use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 18use project::context_server_store::{ContextServerStatus, ContextServerStore};
 19use settings::{Settings, update_settings_file};
 20use ui::{
 21    Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
 22    Switch, SwitchColor, Tooltip, prelude::*,
 23};
 24use util::ResultExt as _;
 25use zed_actions::ExtensionCategoryFilter;
 26
 27pub(crate) use add_context_server_modal::AddContextServerModal;
 28pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
 29pub(crate) use manage_profiles_modal::ManageProfilesModal;
 30
 31use crate::AddContextServer;
 32
 33pub struct AgentConfiguration {
 34    fs: Arc<dyn Fs>,
 35    focus_handle: FocusHandle,
 36    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 37    context_server_store: Entity<ContextServerStore>,
 38    expanded_context_server_tools: HashMap<ContextServerId, bool>,
 39    tools: Entity<ToolWorkingSet>,
 40    _registry_subscription: Subscription,
 41    scroll_handle: ScrollHandle,
 42    scrollbar_state: ScrollbarState,
 43}
 44
 45impl AgentConfiguration {
 46    pub fn new(
 47        fs: Arc<dyn Fs>,
 48        context_server_store: Entity<ContextServerStore>,
 49        tools: Entity<ToolWorkingSet>,
 50        window: &mut Window,
 51        cx: &mut Context<Self>,
 52    ) -> Self {
 53        let focus_handle = cx.focus_handle();
 54
 55        let registry_subscription = cx.subscribe_in(
 56            &LanguageModelRegistry::global(cx),
 57            window,
 58            |this, _, event: &language_model::Event, window, cx| match event {
 59                language_model::Event::AddedProvider(provider_id) => {
 60                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 61                    if let Some(provider) = provider {
 62                        this.add_provider_configuration_view(&provider, window, cx);
 63                    }
 64                }
 65                language_model::Event::RemovedProvider(provider_id) => {
 66                    this.remove_provider_configuration_view(provider_id);
 67                }
 68                _ => {}
 69            },
 70        );
 71
 72        let scroll_handle = ScrollHandle::new();
 73        let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
 74
 75        let mut this = Self {
 76            fs,
 77            focus_handle,
 78            configuration_views_by_provider: HashMap::default(),
 79            context_server_store,
 80            expanded_context_server_tools: HashMap::default(),
 81            tools,
 82            _registry_subscription: registry_subscription,
 83            scroll_handle,
 84            scrollbar_state,
 85        };
 86        this.build_provider_configuration_views(window, cx);
 87        this
 88    }
 89
 90    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 91        let providers = LanguageModelRegistry::read_global(cx).providers();
 92        for provider in providers {
 93            self.add_provider_configuration_view(&provider, window, cx);
 94        }
 95    }
 96
 97    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 98        self.configuration_views_by_provider.remove(provider_id);
 99    }
100
101    fn add_provider_configuration_view(
102        &mut self,
103        provider: &Arc<dyn LanguageModelProvider>,
104        window: &mut Window,
105        cx: &mut Context<Self>,
106    ) {
107        let configuration_view = provider.configuration_view(window, cx);
108        self.configuration_views_by_provider
109            .insert(provider.id(), configuration_view);
110    }
111}
112
113impl Focusable for AgentConfiguration {
114    fn focus_handle(&self, _: &App) -> FocusHandle {
115        self.focus_handle.clone()
116    }
117}
118
119pub enum AssistantConfigurationEvent {
120    NewThread(Arc<dyn LanguageModelProvider>),
121}
122
123impl EventEmitter<AssistantConfigurationEvent> for AgentConfiguration {}
124
125impl AgentConfiguration {
126    fn render_provider_configuration_block(
127        &mut self,
128        provider: &Arc<dyn LanguageModelProvider>,
129        cx: &mut Context<Self>,
130    ) -> impl IntoElement + use<> {
131        let provider_id = provider.id().0.clone();
132        let provider_name = provider.name().0.clone();
133        let configuration_view = self
134            .configuration_views_by_provider
135            .get(&provider.id())
136            .cloned();
137
138        v_flex()
139            .pt_3()
140            .pb_1()
141            .gap_1p5()
142            .border_t_1()
143            .border_color(cx.theme().colors().border.opacity(0.6))
144            .child(
145                h_flex()
146                    .justify_between()
147                    .child(
148                        h_flex()
149                            .gap_2()
150                            .child(
151                                Icon::new(provider.icon())
152                                    .size(IconSize::Small)
153                                    .color(Color::Muted),
154                            )
155                            .child(Label::new(provider_name.clone()).size(LabelSize::Large)),
156                    )
157                    .when(provider.is_authenticated(cx), |parent| {
158                        parent.child(
159                            Button::new(
160                                SharedString::from(format!("new-thread-{provider_id}")),
161                                "Start New Thread",
162                            )
163                            .icon_position(IconPosition::Start)
164                            .icon(IconName::Plus)
165                            .icon_size(IconSize::Small)
166                            .style(ButtonStyle::Filled)
167                            .layer(ElevationIndex::ModalSurface)
168                            .label_size(LabelSize::Small)
169                            .on_click(cx.listener({
170                                let provider = provider.clone();
171                                move |_this, _event, _window, cx| {
172                                    cx.emit(AssistantConfigurationEvent::NewThread(
173                                        provider.clone(),
174                                    ))
175                                }
176                            })),
177                        )
178                    }),
179            )
180            .map(|parent| match configuration_view {
181                Some(configuration_view) => parent.child(configuration_view),
182                None => parent.child(div().child(Label::new(format!(
183                    "No configuration view for {provider_name}",
184                )))),
185            })
186    }
187
188    fn render_provider_configuration_section(
189        &mut self,
190        cx: &mut Context<Self>,
191    ) -> impl IntoElement {
192        let providers = LanguageModelRegistry::read_global(cx).providers();
193
194        v_flex()
195            .p(DynamicSpacing::Base16.rems(cx))
196            .pr(DynamicSpacing::Base20.rems(cx))
197            .gap_4()
198            .flex_1()
199            .child(
200                v_flex()
201                    .gap_0p5()
202                    .child(Headline::new("LLM Providers"))
203                    .child(
204                        Label::new("Add at least one provider to use AI-powered features.")
205                            .color(Color::Muted),
206                    ),
207            )
208            .children(
209                providers
210                    .into_iter()
211                    .map(|provider| self.render_provider_configuration_block(&provider, cx)),
212            )
213    }
214
215    fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
216        let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
217
218        h_flex()
219            .gap_4()
220            .justify_between()
221            .flex_wrap()
222            .child(
223                v_flex()
224                    .gap_0p5()
225                    .max_w_5_6()
226                    .child(Label::new("Allow running editing tools without asking for confirmation"))
227                    .child(
228                        Label::new(
229                            "The agent can perform potentially destructive actions without asking for your confirmation.",
230                        )
231                        .color(Color::Muted),
232                    ),
233            )
234            .child(
235                Switch::new(
236                    "always-allow-tool-actions-switch",
237                    always_allow_tool_actions.into(),
238                )
239                .color(SwitchColor::Accent)
240                .on_click({
241                    let fs = self.fs.clone();
242                    move |state, _window, cx| {
243                        let allow = state == &ToggleState::Selected;
244                        update_settings_file::<AssistantSettings>(
245                            fs.clone(),
246                            cx,
247                            move |settings, _| {
248                                settings.set_always_allow_tool_actions(allow);
249                            },
250                        );
251                    }
252                }),
253            )
254    }
255
256    fn render_single_file_review(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
257        let single_file_review = AssistantSettings::get_global(cx).single_file_review;
258
259        h_flex()
260            .gap_4()
261            .justify_between()
262            .flex_wrap()
263            .child(
264                v_flex()
265                    .gap_0p5()
266                    .max_w_5_6()
267                    .child(Label::new("Enable single-file agent reviews"))
268                    .child(
269                        Label::new(
270                            "Agent edits are also displayed in single-file editors for review.",
271                        )
272                        .color(Color::Muted),
273                    ),
274            )
275            .child(
276                Switch::new("single-file-review-switch", single_file_review.into())
277                    .color(SwitchColor::Accent)
278                    .on_click({
279                        let fs = self.fs.clone();
280                        move |state, _window, cx| {
281                            let allow = state == &ToggleState::Selected;
282                            update_settings_file::<AssistantSettings>(
283                                fs.clone(),
284                                cx,
285                                move |settings, _| {
286                                    settings.set_single_file_review(allow);
287                                },
288                            );
289                        }
290                    }),
291            )
292    }
293
294    fn render_general_settings_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
295        v_flex()
296            .p(DynamicSpacing::Base16.rems(cx))
297            .pr(DynamicSpacing::Base20.rems(cx))
298            .gap_2p5()
299            .flex_1()
300            .child(Headline::new("General Settings"))
301            .child(self.render_command_permission(cx))
302            .child(self.render_single_file_review(cx))
303    }
304
305    fn render_context_servers_section(
306        &mut self,
307        window: &mut Window,
308        cx: &mut Context<Self>,
309    ) -> impl IntoElement {
310        let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone();
311
312        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
313
314        v_flex()
315            .p(DynamicSpacing::Base16.rems(cx))
316            .pr(DynamicSpacing::Base20.rems(cx))
317            .gap_2()
318            .flex_1()
319            .child(
320                v_flex()
321                    .gap_0p5()
322                    .child(Headline::new("Model Context Protocol (MCP) Servers"))
323                    .child(Label::new(SUBHEADING).color(Color::Muted)),
324            )
325            .children(
326                context_server_ids.into_iter().map(|context_server_id| {
327                    self.render_context_server(context_server_id, window, cx)
328                }),
329            )
330            .child(
331                h_flex()
332                    .justify_between()
333                    .gap_2()
334                    .child(
335                        h_flex().w_full().child(
336                            Button::new("add-context-server", "Add Custom Server")
337                                .style(ButtonStyle::Filled)
338                                .layer(ElevationIndex::ModalSurface)
339                                .full_width()
340                                .icon(IconName::Plus)
341                                .icon_size(IconSize::Small)
342                                .icon_position(IconPosition::Start)
343                                .on_click(|_event, window, cx| {
344                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
345                                }),
346                        ),
347                    )
348                    .child(
349                        h_flex().w_full().child(
350                            Button::new(
351                                "install-context-server-extensions",
352                                "Install MCP Extensions",
353                            )
354                            .style(ButtonStyle::Filled)
355                            .layer(ElevationIndex::ModalSurface)
356                            .full_width()
357                            .icon(IconName::Hammer)
358                            .icon_size(IconSize::Small)
359                            .icon_position(IconPosition::Start)
360                            .on_click(|_event, window, cx| {
361                                window.dispatch_action(
362                                    zed_actions::Extensions {
363                                        category_filter: Some(
364                                            ExtensionCategoryFilter::ContextServers,
365                                        ),
366                                    }
367                                    .boxed_clone(),
368                                    cx,
369                                )
370                            }),
371                        ),
372                    ),
373            )
374    }
375
376    fn render_context_server(
377        &self,
378        context_server_id: ContextServerId,
379        window: &mut Window,
380        cx: &mut Context<Self>,
381    ) -> impl use<> + IntoElement {
382        let tools_by_source = self.tools.read(cx).tools_by_source(cx);
383        let server_status = self
384            .context_server_store
385            .read(cx)
386            .status_for_server(&context_server_id)
387            .unwrap_or(ContextServerStatus::Stopped);
388
389        let is_running = matches!(server_status, ContextServerStatus::Running);
390
391        let error = if let ContextServerStatus::Error(error) = server_status.clone() {
392            Some(error)
393        } else {
394            None
395        };
396
397        let are_tools_expanded = self
398            .expanded_context_server_tools
399            .get(&context_server_id)
400            .copied()
401            .unwrap_or_default();
402
403        let tools = tools_by_source
404            .get(&ToolSource::ContextServer {
405                id: context_server_id.0.clone().into(),
406            })
407            .map_or([].as_slice(), |tools| tools.as_slice());
408        let tool_count = tools.len();
409
410        let border_color = cx.theme().colors().border.opacity(0.6);
411
412        v_flex()
413            .id(SharedString::from(context_server_id.0.clone()))
414            .border_1()
415            .rounded_md()
416            .border_color(border_color)
417            .bg(cx.theme().colors().background.opacity(0.2))
418            .overflow_hidden()
419            .child(
420                h_flex()
421                    .p_1()
422                    .justify_between()
423                    .when(
424                        error.is_some() || are_tools_expanded && tool_count > 1,
425                        |element| element.border_b_1().border_color(border_color),
426                    )
427                    .child(
428                        h_flex()
429                            .gap_1p5()
430                            .child(
431                                Disclosure::new(
432                                    "tool-list-disclosure",
433                                    are_tools_expanded || error.is_some(),
434                                )
435                                .disabled(tool_count == 0)
436                                .on_click(cx.listener({
437                                    let context_server_id = context_server_id.clone();
438                                    move |this, _event, _window, _cx| {
439                                        let is_open = this
440                                            .expanded_context_server_tools
441                                            .entry(context_server_id.clone())
442                                            .or_insert(false);
443
444                                        *is_open = !*is_open;
445                                    }
446                                })),
447                            )
448                            .child(match server_status {
449                                ContextServerStatus::Starting => {
450                                    let color = Color::Success.color(cx);
451                                    Indicator::dot()
452                                        .color(Color::Success)
453                                        .with_animation(
454                                            SharedString::from(format!(
455                                                "{}-starting",
456                                                context_server_id.0.clone(),
457                                            )),
458                                            Animation::new(Duration::from_secs(2))
459                                                .repeat()
460                                                .with_easing(pulsating_between(0.4, 1.)),
461                                            move |this, delta| {
462                                                this.color(color.alpha(delta).into())
463                                            },
464                                        )
465                                        .into_any_element()
466                                }
467                                ContextServerStatus::Running => {
468                                    Indicator::dot().color(Color::Success).into_any_element()
469                                }
470                                ContextServerStatus::Error(_) => {
471                                    Indicator::dot().color(Color::Error).into_any_element()
472                                }
473                                ContextServerStatus::Stopped => {
474                                    Indicator::dot().color(Color::Muted).into_any_element()
475                                }
476                            })
477                            .child(Label::new(context_server_id.0.clone()).ml_0p5())
478                            .when(is_running, |this| {
479                                this.child(
480                                    Label::new(if tool_count == 1 {
481                                        SharedString::from("1 tool")
482                                    } else {
483                                        SharedString::from(format!("{} tools", tool_count))
484                                    })
485                                    .color(Color::Muted)
486                                    .size(LabelSize::Small),
487                                )
488                            }),
489                    )
490                    .child(
491                        Switch::new("context-server-switch", is_running.into())
492                            .color(SwitchColor::Accent)
493                            .on_click({
494                                let context_server_manager = self.context_server_store.clone();
495                                let context_server_id = context_server_id.clone();
496                                move |state, _window, cx| match state {
497                                    ToggleState::Unselected | ToggleState::Indeterminate => {
498                                        context_server_manager.update(cx, |this, cx| {
499                                            this.stop_server(&context_server_id, cx).log_err();
500                                        });
501                                    }
502                                    ToggleState::Selected => {
503                                        context_server_manager.update(cx, |this, cx| {
504                                            if let Some(server) =
505                                                this.get_server(&context_server_id)
506                                            {
507                                                this.start_server(server, cx).log_err();
508                                            }
509                                        })
510                                    }
511                                }
512                            }),
513                    ),
514            )
515            .map(|parent| {
516                if let Some(error) = error {
517                    return parent.child(
518                        h_flex()
519                            .p_2()
520                            .gap_2()
521                            .items_start()
522                            .child(
523                                h_flex()
524                                    .flex_none()
525                                    .h(window.line_height() / 1.6_f32)
526                                    .justify_center()
527                                    .child(
528                                        Icon::new(IconName::XCircle)
529                                            .size(IconSize::XSmall)
530                                            .color(Color::Error),
531                                    ),
532                            )
533                            .child(
534                                div().w_full().child(
535                                    Label::new(error)
536                                        .buffer_font(cx)
537                                        .color(Color::Muted)
538                                        .size(LabelSize::Small),
539                                ),
540                            ),
541                    );
542                }
543
544                if !are_tools_expanded || tools.is_empty() {
545                    return parent;
546                }
547
548                parent.child(v_flex().py_1p5().px_1().gap_1().children(
549                    tools.into_iter().enumerate().map(|(ix, tool)| {
550                        h_flex()
551                            .id(("tool-item", ix))
552                            .px_1()
553                            .gap_2()
554                            .justify_between()
555                            .hover(|style| style.bg(cx.theme().colors().element_hover))
556                            .rounded_sm()
557                            .child(
558                                Label::new(tool.name())
559                                    .buffer_font(cx)
560                                    .size(LabelSize::Small),
561                            )
562                            .child(
563                                Icon::new(IconName::Info)
564                                    .size(IconSize::Small)
565                                    .color(Color::Ignored),
566                            )
567                            .tooltip(Tooltip::text(tool.description()))
568                    }),
569                ))
570            })
571    }
572}
573
574impl Render for AgentConfiguration {
575    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
576        v_flex()
577            .id("assistant-configuration")
578            .key_context("AgentConfiguration")
579            .track_focus(&self.focus_handle(cx))
580            .relative()
581            .size_full()
582            .pb_8()
583            .bg(cx.theme().colors().panel_background)
584            .child(
585                v_flex()
586                    .id("assistant-configuration-content")
587                    .track_scroll(&self.scroll_handle)
588                    .size_full()
589                    .overflow_y_scroll()
590                    .child(self.render_general_settings_section(cx))
591                    .child(Divider::horizontal().color(DividerColor::Border))
592                    .child(self.render_context_servers_section(window, cx))
593                    .child(Divider::horizontal().color(DividerColor::Border))
594                    .child(self.render_provider_configuration_section(cx)),
595            )
596            .child(
597                div()
598                    .id("assistant-configuration-scrollbar")
599                    .occlude()
600                    .absolute()
601                    .right(px(3.))
602                    .top_0()
603                    .bottom_0()
604                    .pb_6()
605                    .w(px(12.))
606                    .cursor_default()
607                    .on_mouse_move(cx.listener(|_, _, _window, cx| {
608                        cx.notify();
609                        cx.stop_propagation()
610                    }))
611                    .on_hover(|_, _window, cx| {
612                        cx.stop_propagation();
613                    })
614                    .on_any_mouse_down(|_, _window, cx| {
615                        cx.stop_propagation();
616                    })
617                    .on_scroll_wheel(cx.listener(|_, _, _window, cx| {
618                        cx.notify();
619                    }))
620                    .children(Scrollbar::vertical(self.scrollbar_state.clone())),
621            )
622    }
623}