assistant_configuration.rs

  1mod add_context_server_modal;
  2
  3use std::sync::Arc;
  4
  5use assistant_tool::{ToolSource, ToolWorkingSet};
  6use collections::HashMap;
  7use context_server::manager::ContextServerManager;
  8use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
  9use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 10use ui::{prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch};
 11use util::ResultExt as _;
 12use zed_actions::ExtensionCategoryFilter;
 13
 14pub(crate) use add_context_server_modal::AddContextServerModal;
 15
 16use crate::AddContextServer;
 17
 18pub struct AssistantConfiguration {
 19    focus_handle: FocusHandle,
 20    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 21    context_server_manager: Entity<ContextServerManager>,
 22    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 23    tools: Arc<ToolWorkingSet>,
 24    _registry_subscription: Subscription,
 25}
 26
 27impl AssistantConfiguration {
 28    pub fn new(
 29        context_server_manager: Entity<ContextServerManager>,
 30        tools: Arc<ToolWorkingSet>,
 31        window: &mut Window,
 32        cx: &mut Context<Self>,
 33    ) -> Self {
 34        let focus_handle = cx.focus_handle();
 35
 36        let registry_subscription = cx.subscribe_in(
 37            &LanguageModelRegistry::global(cx),
 38            window,
 39            |this, _, event: &language_model::Event, window, cx| match event {
 40                language_model::Event::AddedProvider(provider_id) => {
 41                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 42                    if let Some(provider) = provider {
 43                        this.add_provider_configuration_view(&provider, window, cx);
 44                    }
 45                }
 46                language_model::Event::RemovedProvider(provider_id) => {
 47                    this.remove_provider_configuration_view(provider_id);
 48                }
 49                _ => {}
 50            },
 51        );
 52
 53        let mut this = Self {
 54            focus_handle,
 55            configuration_views_by_provider: HashMap::default(),
 56            context_server_manager,
 57            expanded_context_server_tools: HashMap::default(),
 58            tools,
 59            _registry_subscription: registry_subscription,
 60        };
 61        this.build_provider_configuration_views(window, cx);
 62        this
 63    }
 64
 65    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 66        let providers = LanguageModelRegistry::read_global(cx).providers();
 67        for provider in providers {
 68            self.add_provider_configuration_view(&provider, window, cx);
 69        }
 70    }
 71
 72    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 73        self.configuration_views_by_provider.remove(provider_id);
 74    }
 75
 76    fn add_provider_configuration_view(
 77        &mut self,
 78        provider: &Arc<dyn LanguageModelProvider>,
 79        window: &mut Window,
 80        cx: &mut Context<Self>,
 81    ) {
 82        let configuration_view = provider.configuration_view(window, cx);
 83        self.configuration_views_by_provider
 84            .insert(provider.id(), configuration_view);
 85    }
 86}
 87
 88impl Focusable for AssistantConfiguration {
 89    fn focus_handle(&self, _: &App) -> FocusHandle {
 90        self.focus_handle.clone()
 91    }
 92}
 93
 94pub enum AssistantConfigurationEvent {
 95    NewThread(Arc<dyn LanguageModelProvider>),
 96}
 97
 98impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
 99
100impl AssistantConfiguration {
101    fn render_provider_configuration(
102        &mut self,
103        provider: &Arc<dyn LanguageModelProvider>,
104        cx: &mut Context<Self>,
105    ) -> impl IntoElement {
106        let provider_id = provider.id().0.clone();
107        let provider_name = provider.name().0.clone();
108        let configuration_view = self
109            .configuration_views_by_provider
110            .get(&provider.id())
111            .cloned();
112
113        v_flex()
114            .gap_1p5()
115            .child(
116                h_flex()
117                    .justify_between()
118                    .child(
119                        h_flex()
120                            .gap_2()
121                            .child(
122                                Icon::new(provider.icon())
123                                    .size(IconSize::Small)
124                                    .color(Color::Muted),
125                            )
126                            .child(Label::new(provider_name.clone())),
127                    )
128                    .when(provider.is_authenticated(cx), |parent| {
129                        parent.child(
130                            Button::new(
131                                SharedString::from(format!("new-thread-{provider_id}")),
132                                "Start New Thread",
133                            )
134                            .icon_position(IconPosition::Start)
135                            .icon(IconName::Plus)
136                            .icon_size(IconSize::Small)
137                            .style(ButtonStyle::Filled)
138                            .layer(ElevationIndex::ModalSurface)
139                            .label_size(LabelSize::Small)
140                            .on_click(cx.listener({
141                                let provider = provider.clone();
142                                move |_this, _event, _window, cx| {
143                                    cx.emit(AssistantConfigurationEvent::NewThread(
144                                        provider.clone(),
145                                    ))
146                                }
147                            })),
148                        )
149                    }),
150            )
151            .child(
152                div()
153                    .p(DynamicSpacing::Base08.rems(cx))
154                    .bg(cx.theme().colors().editor_background)
155                    .border_1()
156                    .border_color(cx.theme().colors().border_variant)
157                    .rounded_sm()
158                    .map(|parent| match configuration_view {
159                        Some(configuration_view) => parent.child(configuration_view),
160                        None => parent.child(div().child(Label::new(format!(
161                            "No configuration view for {provider_name}",
162                        )))),
163                    }),
164            )
165    }
166
167    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
168        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
169        let tools_by_source = self.tools.tools_by_source(cx);
170        let empty = Vec::new();
171
172        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
173
174        v_flex()
175            .p(DynamicSpacing::Base16.rems(cx))
176            .gap_2()
177            .flex_1()
178            .child(
179                v_flex()
180                    .gap_0p5()
181                    .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
182                    .child(Label::new(SUBHEADING).color(Color::Muted)),
183            )
184            .children(context_servers.into_iter().map(|context_server| {
185                let is_running = context_server.client().is_some();
186                let are_tools_expanded = self
187                    .expanded_context_server_tools
188                    .get(&context_server.id())
189                    .copied()
190                    .unwrap_or_default();
191
192                let tools = tools_by_source
193                    .get(&ToolSource::ContextServer {
194                        id: context_server.id().into(),
195                    })
196                    .unwrap_or_else(|| &empty);
197                let tool_count = tools.len();
198
199                v_flex()
200                    .id(SharedString::from(context_server.id()))
201                    .border_1()
202                    .rounded_sm()
203                    .border_color(cx.theme().colors().border)
204                    .bg(cx.theme().colors().editor_background)
205                    .child(
206                        h_flex()
207                            .justify_between()
208                            .px_2()
209                            .py_1()
210                            .when(are_tools_expanded, |element| {
211                                element
212                                    .border_b_1()
213                                    .border_color(cx.theme().colors().border)
214                            })
215                            .child(
216                                h_flex()
217                                    .gap_2()
218                                    .child(
219                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
220                                            .on_click(cx.listener({
221                                                let context_server_id = context_server.id();
222                                                move |this, _event, _window, _cx| {
223                                                    let is_open = this
224                                                        .expanded_context_server_tools
225                                                        .entry(context_server_id.clone())
226                                                        .or_insert(false);
227
228                                                    *is_open = !*is_open;
229                                                }
230                                            })),
231                                    )
232                                    .child(Indicator::dot().color(if is_running {
233                                        Color::Success
234                                    } else {
235                                        Color::Error
236                                    }))
237                                    .child(Label::new(context_server.id()))
238                                    .child(
239                                        Label::new(format!("{tool_count} tools"))
240                                            .color(Color::Muted),
241                                    ),
242                            )
243                            .child(h_flex().child(
244                                Switch::new("context-server-switch", is_running.into()).on_click({
245                                    let context_server_manager =
246                                        self.context_server_manager.clone();
247                                    let context_server = context_server.clone();
248                                    move |state, _window, cx| match state {
249                                        ToggleState::Unselected | ToggleState::Indeterminate => {
250                                            context_server_manager.update(cx, |this, cx| {
251                                                this.stop_server(context_server.clone(), cx)
252                                                    .log_err();
253                                            });
254                                        }
255                                        ToggleState::Selected => {
256                                            cx.spawn({
257                                                let context_server_manager =
258                                                    context_server_manager.clone();
259                                                let context_server = context_server.clone();
260                                                async move |cx| {
261                                                    if let Some(start_server_task) =
262                                                        context_server_manager
263                                                            .update(cx, |this, cx| {
264                                                                this.start_server(
265                                                                    context_server,
266                                                                    cx,
267                                                                )
268                                                            })
269                                                            .log_err()
270                                                    {
271                                                        start_server_task.await.log_err();
272                                                    }
273                                                }
274                                            })
275                                            .detach();
276                                        }
277                                    }
278                                }),
279                            )),
280                    )
281                    .map(|parent| {
282                        if !are_tools_expanded {
283                            return parent;
284                        }
285
286                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
287                            |(ix, tool)| {
288                                h_flex()
289                                    .px_2()
290                                    .py_1()
291                                    .when(ix < tool_count - 1, |element| {
292                                        element
293                                            .border_b_1()
294                                            .border_color(cx.theme().colors().border)
295                                    })
296                                    .child(Label::new(tool.name()))
297                            },
298                        )))
299                    })
300            }))
301            .child(
302                h_flex()
303                    .justify_between()
304                    .gap_2()
305                    .child(
306                        h_flex().w_full().child(
307                            Button::new("add-context-server", "Add Context Server")
308                                .style(ButtonStyle::Filled)
309                                .layer(ElevationIndex::ModalSurface)
310                                .full_width()
311                                .icon(IconName::Plus)
312                                .icon_size(IconSize::Small)
313                                .icon_position(IconPosition::Start)
314                                .on_click(|_event, window, cx| {
315                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
316                                }),
317                        ),
318                    )
319                    .child(
320                        h_flex().w_full().child(
321                            Button::new(
322                                "install-context-server-extensions",
323                                "Install Context Server Extensions",
324                            )
325                            .style(ButtonStyle::Filled)
326                            .layer(ElevationIndex::ModalSurface)
327                            .full_width()
328                            .icon(IconName::DatabaseZap)
329                            .icon_size(IconSize::Small)
330                            .icon_position(IconPosition::Start)
331                            .on_click(|_event, window, cx| {
332                                window.dispatch_action(
333                                    zed_actions::Extensions {
334                                        category_filter: Some(
335                                            ExtensionCategoryFilter::ContextServers,
336                                        ),
337                                    }
338                                    .boxed_clone(),
339                                    cx,
340                                )
341                            }),
342                        ),
343                    ),
344            )
345    }
346}
347
348impl Render for AssistantConfiguration {
349    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
350        let providers = LanguageModelRegistry::read_global(cx).providers();
351
352        v_flex()
353            .id("assistant-configuration")
354            .track_focus(&self.focus_handle(cx))
355            .bg(cx.theme().colors().panel_background)
356            .size_full()
357            .overflow_y_scroll()
358            .child(self.render_context_servers_section(cx))
359            .child(Divider::horizontal().color(DividerColor::Border))
360            .child(
361                v_flex()
362                    .p(DynamicSpacing::Base16.rems(cx))
363                    .mt_1()
364                    .gap_6()
365                    .flex_1()
366                    .child(
367                        v_flex()
368                            .gap_0p5()
369                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
370                            .child(
371                                Label::new("Add at least one provider to use AI-powered features.")
372                                    .color(Color::Muted),
373                            ),
374                    )
375                    .children(
376                        providers
377                            .into_iter()
378                            .map(|provider| self.render_provider_configuration(&provider, cx)),
379                    ),
380            )
381    }
382}