assistant_configuration.rs

  1mod add_context_server_modal;
  2mod configure_context_server_modal;
  3mod manage_profiles_modal;
  4mod tool_picker;
  5
  6use std::{sync::Arc, time::Duration};
  7
  8use assistant_settings::AssistantSettings;
  9use assistant_tool::{ToolSource, ToolWorkingSet};
 10use collections::HashMap;
 11use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus};
 12use fs::Fs;
 13use gpui::{
 14    Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
 15    Focusable, ScrollHandle, Subscription, pulsating_between,
 16};
 17use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 18use settings::{Settings, update_settings_file};
 19use ui::{
 20    Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
 21    Switch, SwitchColor, Tooltip, prelude::*,
 22};
 23use util::ResultExt as _;
 24use zed_actions::ExtensionCategoryFilter;
 25
 26pub(crate) use add_context_server_modal::AddContextServerModal;
 27pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
 28pub(crate) use manage_profiles_modal::ManageProfilesModal;
 29
 30use crate::AddContextServer;
 31
 32pub struct AssistantConfiguration {
 33    fs: Arc<dyn Fs>,
 34    focus_handle: FocusHandle,
 35    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 36    context_server_manager: Entity<ContextServerManager>,
 37    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 38    tools: Entity<ToolWorkingSet>,
 39    _registry_subscription: Subscription,
 40    scroll_handle: ScrollHandle,
 41    scrollbar_state: ScrollbarState,
 42}
 43
 44impl AssistantConfiguration {
 45    pub fn new(
 46        fs: Arc<dyn Fs>,
 47        context_server_manager: Entity<ContextServerManager>,
 48        tools: Entity<ToolWorkingSet>,
 49        window: &mut Window,
 50        cx: &mut Context<Self>,
 51    ) -> Self {
 52        let focus_handle = cx.focus_handle();
 53
 54        let registry_subscription = cx.subscribe_in(
 55            &LanguageModelRegistry::global(cx),
 56            window,
 57            |this, _, event: &language_model::Event, window, cx| match event {
 58                language_model::Event::AddedProvider(provider_id) => {
 59                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 60                    if let Some(provider) = provider {
 61                        this.add_provider_configuration_view(&provider, window, cx);
 62                    }
 63                }
 64                language_model::Event::RemovedProvider(provider_id) => {
 65                    this.remove_provider_configuration_view(provider_id);
 66                }
 67                _ => {}
 68            },
 69        );
 70
 71        let scroll_handle = ScrollHandle::new();
 72        let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
 73
 74        let mut this = Self {
 75            fs,
 76            focus_handle,
 77            configuration_views_by_provider: HashMap::default(),
 78            context_server_manager,
 79            expanded_context_server_tools: HashMap::default(),
 80            tools,
 81            _registry_subscription: registry_subscription,
 82            scroll_handle,
 83            scrollbar_state,
 84        };
 85        this.build_provider_configuration_views(window, cx);
 86        this
 87    }
 88
 89    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 90        let providers = LanguageModelRegistry::read_global(cx).providers();
 91        for provider in providers {
 92            self.add_provider_configuration_view(&provider, window, cx);
 93        }
 94    }
 95
 96    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 97        self.configuration_views_by_provider.remove(provider_id);
 98    }
 99
100    fn add_provider_configuration_view(
101        &mut self,
102        provider: &Arc<dyn LanguageModelProvider>,
103        window: &mut Window,
104        cx: &mut Context<Self>,
105    ) {
106        let configuration_view = provider.configuration_view(window, cx);
107        self.configuration_views_by_provider
108            .insert(provider.id(), configuration_view);
109    }
110}
111
112impl Focusable for AssistantConfiguration {
113    fn focus_handle(&self, _: &App) -> FocusHandle {
114        self.focus_handle.clone()
115    }
116}
117
118pub enum AssistantConfigurationEvent {
119    NewThread(Arc<dyn LanguageModelProvider>),
120}
121
122impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
123
124impl AssistantConfiguration {
125    fn render_provider_configuration_block(
126        &mut self,
127        provider: &Arc<dyn LanguageModelProvider>,
128        cx: &mut Context<Self>,
129    ) -> impl IntoElement + use<> {
130        let provider_id = provider.id().0.clone();
131        let provider_name = provider.name().0.clone();
132        let configuration_view = self
133            .configuration_views_by_provider
134            .get(&provider.id())
135            .cloned();
136
137        v_flex()
138            .pt_3()
139            .pb_1()
140            .gap_1p5()
141            .border_t_1()
142            .border_color(cx.theme().colors().border.opacity(0.6))
143            .child(
144                h_flex()
145                    .justify_between()
146                    .child(
147                        h_flex()
148                            .gap_2()
149                            .child(
150                                Icon::new(provider.icon())
151                                    .size(IconSize::Small)
152                                    .color(Color::Muted),
153                            )
154                            .child(Label::new(provider_name.clone()).size(LabelSize::Large)),
155                    )
156                    .when(provider.is_authenticated(cx), |parent| {
157                        parent.child(
158                            Button::new(
159                                SharedString::from(format!("new-thread-{provider_id}")),
160                                "Start New Thread",
161                            )
162                            .icon_position(IconPosition::Start)
163                            .icon(IconName::Plus)
164                            .icon_size(IconSize::Small)
165                            .style(ButtonStyle::Filled)
166                            .layer(ElevationIndex::ModalSurface)
167                            .label_size(LabelSize::Small)
168                            .on_click(cx.listener({
169                                let provider = provider.clone();
170                                move |_this, _event, _window, cx| {
171                                    cx.emit(AssistantConfigurationEvent::NewThread(
172                                        provider.clone(),
173                                    ))
174                                }
175                            })),
176                        )
177                    }),
178            )
179            .map(|parent| match configuration_view {
180                Some(configuration_view) => parent.child(configuration_view),
181                None => parent.child(div().child(Label::new(format!(
182                    "No configuration view for {provider_name}",
183                )))),
184            })
185    }
186
187    fn render_provider_configuration_section(
188        &mut self,
189        cx: &mut Context<Self>,
190    ) -> impl IntoElement {
191        let providers = LanguageModelRegistry::read_global(cx).providers();
192
193        v_flex()
194            .p(DynamicSpacing::Base16.rems(cx))
195            .pr(DynamicSpacing::Base20.rems(cx))
196            .gap_4()
197            .flex_1()
198            .child(
199                v_flex()
200                    .gap_0p5()
201                    .child(Headline::new("LLM Providers"))
202                    .child(
203                        Label::new("Add at least one provider to use AI-powered features.")
204                            .color(Color::Muted),
205                    ),
206            )
207            .children(
208                providers
209                    .into_iter()
210                    .map(|provider| self.render_provider_configuration_block(&provider, cx)),
211            )
212    }
213
214    fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
215        let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
216
217        h_flex()
218            .gap_4()
219            .justify_between()
220            .flex_wrap()
221            .child(
222                v_flex()
223                    .gap_0p5()
224                    .max_w_5_6()
225                    .child(Label::new("Allow running editing tools without asking for confirmation"))
226                    .child(
227                        Label::new(
228                            "The agent can perform potentially destructive actions without asking for your confirmation.",
229                        )
230                        .color(Color::Muted),
231                    ),
232            )
233            .child(
234                Switch::new(
235                    "always-allow-tool-actions-switch",
236                    always_allow_tool_actions.into(),
237                )
238                .color(SwitchColor::Accent)
239                .on_click({
240                    let fs = self.fs.clone();
241                    move |state, _window, cx| {
242                        let allow = state == &ToggleState::Selected;
243                        update_settings_file::<AssistantSettings>(
244                            fs.clone(),
245                            cx,
246                            move |settings, _| {
247                                settings.set_always_allow_tool_actions(allow);
248                            },
249                        );
250                    }
251                }),
252            )
253    }
254
255    fn render_single_file_review(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
256        let single_file_review = AssistantSettings::get_global(cx).single_file_review;
257
258        h_flex()
259            .gap_4()
260            .justify_between()
261            .flex_wrap()
262            .child(
263                v_flex()
264                    .gap_0p5()
265                    .max_w_5_6()
266                    .child(Label::new("Enable single-file agent reviews"))
267                    .child(
268                        Label::new(
269                            "Agent edits are also displayed in single-file editors for review.",
270                        )
271                        .color(Color::Muted),
272                    ),
273            )
274            .child(
275                Switch::new("single-file-review-switch", single_file_review.into())
276                    .color(SwitchColor::Accent)
277                    .on_click({
278                        let fs = self.fs.clone();
279                        move |state, _window, cx| {
280                            let allow = state == &ToggleState::Selected;
281                            update_settings_file::<AssistantSettings>(
282                                fs.clone(),
283                                cx,
284                                move |settings, _| {
285                                    settings.set_single_file_review(allow);
286                                },
287                            );
288                        }
289                    }),
290            )
291    }
292
293    fn render_general_settings_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
294        v_flex()
295            .p(DynamicSpacing::Base16.rems(cx))
296            .pr(DynamicSpacing::Base20.rems(cx))
297            .gap_2p5()
298            .flex_1()
299            .child(Headline::new("General Settings"))
300            .child(self.render_command_permission(cx))
301            .child(self.render_single_file_review(cx))
302    }
303
304    fn render_context_servers_section(
305        &mut self,
306        window: &mut Window,
307        cx: &mut Context<Self>,
308    ) -> impl IntoElement {
309        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
310
311        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
312
313        v_flex()
314            .p(DynamicSpacing::Base16.rems(cx))
315            .pr(DynamicSpacing::Base20.rems(cx))
316            .gap_2()
317            .flex_1()
318            .child(
319                v_flex()
320                    .gap_0p5()
321                    .child(Headline::new("Model Context Protocol (MCP) Servers"))
322                    .child(Label::new(SUBHEADING).color(Color::Muted)),
323            )
324            .children(
325                context_servers
326                    .into_iter()
327                    .map(|context_server| self.render_context_server(context_server, window, cx)),
328            )
329            .child(
330                h_flex()
331                    .justify_between()
332                    .gap_2()
333                    .child(
334                        h_flex().w_full().child(
335                            Button::new("add-context-server", "Add Custom Server")
336                                .style(ButtonStyle::Filled)
337                                .layer(ElevationIndex::ModalSurface)
338                                .full_width()
339                                .icon(IconName::Plus)
340                                .icon_size(IconSize::Small)
341                                .icon_position(IconPosition::Start)
342                                .on_click(|_event, window, cx| {
343                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
344                                }),
345                        ),
346                    )
347                    .child(
348                        h_flex().w_full().child(
349                            Button::new(
350                                "install-context-server-extensions",
351                                "Install MCP Extensions",
352                            )
353                            .style(ButtonStyle::Filled)
354                            .layer(ElevationIndex::ModalSurface)
355                            .full_width()
356                            .icon(IconName::Hammer)
357                            .icon_size(IconSize::Small)
358                            .icon_position(IconPosition::Start)
359                            .on_click(|_event, window, cx| {
360                                window.dispatch_action(
361                                    zed_actions::Extensions {
362                                        category_filter: Some(
363                                            ExtensionCategoryFilter::ContextServers,
364                                        ),
365                                    }
366                                    .boxed_clone(),
367                                    cx,
368                                )
369                            }),
370                        ),
371                    ),
372            )
373    }
374
375    fn render_context_server(
376        &self,
377        context_server: Arc<ContextServer>,
378        window: &mut Window,
379        cx: &mut Context<Self>,
380    ) -> impl use<> + IntoElement {
381        let tools_by_source = self.tools.read(cx).tools_by_source(cx);
382        let server_status = self
383            .context_server_manager
384            .read(cx)
385            .status_for_server(&context_server.id());
386
387        let is_running = matches!(server_status, Some(ContextServerStatus::Running));
388
389        let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() {
390            Some(error)
391        } else {
392            None
393        };
394
395        let are_tools_expanded = self
396            .expanded_context_server_tools
397            .get(&context_server.id())
398            .copied()
399            .unwrap_or_default();
400
401        let tools = tools_by_source
402            .get(&ToolSource::ContextServer {
403                id: context_server.id().into(),
404            })
405            .map_or([].as_slice(), |tools| tools.as_slice());
406        let tool_count = tools.len();
407
408        let border_color = cx.theme().colors().border.opacity(0.6);
409
410        v_flex()
411            .id(SharedString::from(context_server.id()))
412            .border_1()
413            .rounded_md()
414            .border_color(border_color)
415            .bg(cx.theme().colors().background.opacity(0.2))
416            .overflow_hidden()
417            .child(
418                h_flex()
419                    .p_1()
420                    .justify_between()
421                    .when(
422                        error.is_some() || are_tools_expanded && tool_count > 1,
423                        |element| element.border_b_1().border_color(border_color),
424                    )
425                    .child(
426                        h_flex()
427                            .gap_1p5()
428                            .child(
429                                Disclosure::new(
430                                    "tool-list-disclosure",
431                                    are_tools_expanded || error.is_some(),
432                                )
433                                .disabled(tool_count == 0)
434                                .on_click(cx.listener({
435                                    let context_server_id = context_server.id();
436                                    move |this, _event, _window, _cx| {
437                                        let is_open = this
438                                            .expanded_context_server_tools
439                                            .entry(context_server_id.clone())
440                                            .or_insert(false);
441
442                                        *is_open = !*is_open;
443                                    }
444                                })),
445                            )
446                            .child(match server_status {
447                                Some(ContextServerStatus::Starting) => {
448                                    let color = Color::Success.color(cx);
449                                    Indicator::dot()
450                                        .color(Color::Success)
451                                        .with_animation(
452                                            SharedString::from(format!(
453                                                "{}-starting",
454                                                context_server.id(),
455                                            )),
456                                            Animation::new(Duration::from_secs(2))
457                                                .repeat()
458                                                .with_easing(pulsating_between(0.4, 1.)),
459                                            move |this, delta| {
460                                                this.color(color.alpha(delta).into())
461                                            },
462                                        )
463                                        .into_any_element()
464                                }
465                                Some(ContextServerStatus::Running) => {
466                                    Indicator::dot().color(Color::Success).into_any_element()
467                                }
468                                Some(ContextServerStatus::Error(_)) => {
469                                    Indicator::dot().color(Color::Error).into_any_element()
470                                }
471                                None => Indicator::dot().color(Color::Muted).into_any_element(),
472                            })
473                            .child(Label::new(context_server.id()).ml_0p5())
474                            .when(is_running, |this| {
475                                this.child(
476                                    Label::new(if tool_count == 1 {
477                                        SharedString::from("1 tool")
478                                    } else {
479                                        SharedString::from(format!("{} tools", tool_count))
480                                    })
481                                    .color(Color::Muted)
482                                    .size(LabelSize::Small),
483                                )
484                            }),
485                    )
486                    .child(
487                        Switch::new("context-server-switch", is_running.into())
488                            .color(SwitchColor::Accent)
489                            .on_click({
490                                let context_server_manager = self.context_server_manager.clone();
491                                let context_server = context_server.clone();
492                                move |state, _window, cx| match state {
493                                    ToggleState::Unselected | ToggleState::Indeterminate => {
494                                        context_server_manager.update(cx, |this, cx| {
495                                            this.stop_server(context_server.clone(), cx).log_err();
496                                        });
497                                    }
498                                    ToggleState::Selected => {
499                                        cx.spawn({
500                                            let context_server_manager =
501                                                context_server_manager.clone();
502                                            let context_server = context_server.clone();
503                                            async move |cx| {
504                                                if let Some(start_server_task) =
505                                                    context_server_manager
506                                                        .update(cx, |this, cx| {
507                                                            this.start_server(context_server, cx)
508                                                        })
509                                                        .log_err()
510                                                {
511                                                    start_server_task.await.log_err();
512                                                }
513                                            }
514                                        })
515                                        .detach();
516                                    }
517                                }
518                            }),
519                    ),
520            )
521            .map(|parent| {
522                if let Some(error) = error {
523                    return parent.child(
524                        h_flex()
525                            .p_2()
526                            .gap_2()
527                            .items_start()
528                            .child(
529                                h_flex()
530                                    .flex_none()
531                                    .h(window.line_height() / 1.6_f32)
532                                    .justify_center()
533                                    .child(
534                                        Icon::new(IconName::XCircle)
535                                            .size(IconSize::XSmall)
536                                            .color(Color::Error),
537                                    ),
538                            )
539                            .child(
540                                div().w_full().child(
541                                    Label::new(error)
542                                        .buffer_font(cx)
543                                        .color(Color::Muted)
544                                        .size(LabelSize::Small),
545                                ),
546                            ),
547                    );
548                }
549
550                if !are_tools_expanded || tools.is_empty() {
551                    return parent;
552                }
553
554                parent.child(v_flex().py_1p5().px_1().gap_1().children(
555                    tools.into_iter().enumerate().map(|(ix, tool)| {
556                        h_flex()
557                            .id(("tool-item", ix))
558                            .px_1()
559                            .gap_2()
560                            .justify_between()
561                            .hover(|style| style.bg(cx.theme().colors().element_hover))
562                            .rounded_sm()
563                            .child(
564                                Label::new(tool.name())
565                                    .buffer_font(cx)
566                                    .size(LabelSize::Small),
567                            )
568                            .child(
569                                Icon::new(IconName::Info)
570                                    .size(IconSize::Small)
571                                    .color(Color::Ignored),
572                            )
573                            .tooltip(Tooltip::text(tool.description()))
574                    }),
575                ))
576            })
577    }
578}
579
580impl Render for AssistantConfiguration {
581    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
582        v_flex()
583            .id("assistant-configuration")
584            .key_context("AgentConfiguration")
585            .track_focus(&self.focus_handle(cx))
586            .relative()
587            .size_full()
588            .pb_8()
589            .bg(cx.theme().colors().panel_background)
590            .child(
591                v_flex()
592                    .id("assistant-configuration-content")
593                    .track_scroll(&self.scroll_handle)
594                    .size_full()
595                    .overflow_y_scroll()
596                    .child(self.render_general_settings_section(cx))
597                    .child(Divider::horizontal().color(DividerColor::Border))
598                    .child(self.render_context_servers_section(window, cx))
599                    .child(Divider::horizontal().color(DividerColor::Border))
600                    .child(self.render_provider_configuration_section(cx)),
601            )
602            .child(
603                div()
604                    .id("assistant-configuration-scrollbar")
605                    .occlude()
606                    .absolute()
607                    .right(px(3.))
608                    .top_0()
609                    .bottom_0()
610                    .pb_6()
611                    .w(px(12.))
612                    .cursor_default()
613                    .on_mouse_move(cx.listener(|_, _, _window, cx| {
614                        cx.notify();
615                        cx.stop_propagation()
616                    }))
617                    .on_hover(|_, _window, cx| {
618                        cx.stop_propagation();
619                    })
620                    .on_any_mouse_down(|_, _window, cx| {
621                        cx.stop_propagation();
622                    })
623                    .on_scroll_wheel(cx.listener(|_, _, _window, cx| {
624                        cx.notify();
625                    }))
626                    .children(Scrollbar::vertical(self.scrollbar_state.clone())),
627            )
628    }
629}