assistant_configuration.rs

  1mod add_context_server_modal;
  2mod manage_profiles_modal;
  3mod tool_picker;
  4
  5use std::sync::Arc;
  6
  7use assistant_tool::{ToolSource, ToolWorkingSet};
  8use collections::HashMap;
  9use context_server::manager::ContextServerManager;
 10use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
 11use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 12use ui::{prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch};
 13use util::ResultExt as _;
 14use zed_actions::ExtensionCategoryFilter;
 15
 16pub(crate) use add_context_server_modal::AddContextServerModal;
 17pub(crate) use manage_profiles_modal::ManageProfilesModal;
 18
 19use crate::AddContextServer;
 20
 21pub struct AssistantConfiguration {
 22    focus_handle: FocusHandle,
 23    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 24    context_server_manager: Entity<ContextServerManager>,
 25    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 26    tools: Arc<ToolWorkingSet>,
 27    _registry_subscription: Subscription,
 28}
 29
 30impl AssistantConfiguration {
 31    pub fn new(
 32        context_server_manager: Entity<ContextServerManager>,
 33        tools: Arc<ToolWorkingSet>,
 34        window: &mut Window,
 35        cx: &mut Context<Self>,
 36    ) -> Self {
 37        let focus_handle = cx.focus_handle();
 38
 39        let registry_subscription = cx.subscribe_in(
 40            &LanguageModelRegistry::global(cx),
 41            window,
 42            |this, _, event: &language_model::Event, window, cx| match event {
 43                language_model::Event::AddedProvider(provider_id) => {
 44                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 45                    if let Some(provider) = provider {
 46                        this.add_provider_configuration_view(&provider, window, cx);
 47                    }
 48                }
 49                language_model::Event::RemovedProvider(provider_id) => {
 50                    this.remove_provider_configuration_view(provider_id);
 51                }
 52                _ => {}
 53            },
 54        );
 55
 56        let mut this = Self {
 57            focus_handle,
 58            configuration_views_by_provider: HashMap::default(),
 59            context_server_manager,
 60            expanded_context_server_tools: HashMap::default(),
 61            tools,
 62            _registry_subscription: registry_subscription,
 63        };
 64        this.build_provider_configuration_views(window, cx);
 65        this
 66    }
 67
 68    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 69        let providers = LanguageModelRegistry::read_global(cx).providers();
 70        for provider in providers {
 71            self.add_provider_configuration_view(&provider, window, cx);
 72        }
 73    }
 74
 75    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 76        self.configuration_views_by_provider.remove(provider_id);
 77    }
 78
 79    fn add_provider_configuration_view(
 80        &mut self,
 81        provider: &Arc<dyn LanguageModelProvider>,
 82        window: &mut Window,
 83        cx: &mut Context<Self>,
 84    ) {
 85        let configuration_view = provider.configuration_view(window, cx);
 86        self.configuration_views_by_provider
 87            .insert(provider.id(), configuration_view);
 88    }
 89}
 90
 91impl Focusable for AssistantConfiguration {
 92    fn focus_handle(&self, _: &App) -> FocusHandle {
 93        self.focus_handle.clone()
 94    }
 95}
 96
 97pub enum AssistantConfigurationEvent {
 98    NewThread(Arc<dyn LanguageModelProvider>),
 99}
100
101impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
102
103impl AssistantConfiguration {
104    fn render_provider_configuration(
105        &mut self,
106        provider: &Arc<dyn LanguageModelProvider>,
107        cx: &mut Context<Self>,
108    ) -> impl IntoElement + use<> {
109        let provider_id = provider.id().0.clone();
110        let provider_name = provider.name().0.clone();
111        let configuration_view = self
112            .configuration_views_by_provider
113            .get(&provider.id())
114            .cloned();
115
116        v_flex()
117            .gap_1p5()
118            .child(
119                h_flex()
120                    .justify_between()
121                    .child(
122                        h_flex()
123                            .gap_2()
124                            .child(
125                                Icon::new(provider.icon())
126                                    .size(IconSize::Small)
127                                    .color(Color::Muted),
128                            )
129                            .child(Label::new(provider_name.clone())),
130                    )
131                    .when(provider.is_authenticated(cx), |parent| {
132                        parent.child(
133                            Button::new(
134                                SharedString::from(format!("new-thread-{provider_id}")),
135                                "Start New Thread",
136                            )
137                            .icon_position(IconPosition::Start)
138                            .icon(IconName::Plus)
139                            .icon_size(IconSize::Small)
140                            .style(ButtonStyle::Filled)
141                            .layer(ElevationIndex::ModalSurface)
142                            .label_size(LabelSize::Small)
143                            .on_click(cx.listener({
144                                let provider = provider.clone();
145                                move |_this, _event, _window, cx| {
146                                    cx.emit(AssistantConfigurationEvent::NewThread(
147                                        provider.clone(),
148                                    ))
149                                }
150                            })),
151                        )
152                    }),
153            )
154            .child(
155                div()
156                    .p(DynamicSpacing::Base08.rems(cx))
157                    .bg(cx.theme().colors().editor_background)
158                    .border_1()
159                    .border_color(cx.theme().colors().border_variant)
160                    .rounded_sm()
161                    .map(|parent| match configuration_view {
162                        Some(configuration_view) => parent.child(configuration_view),
163                        None => parent.child(div().child(Label::new(format!(
164                            "No configuration view for {provider_name}",
165                        )))),
166                    }),
167            )
168    }
169
170    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
171        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
172        let tools_by_source = self.tools.tools_by_source(cx);
173        let empty = Vec::new();
174
175        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
176
177        v_flex()
178            .p(DynamicSpacing::Base16.rems(cx))
179            .gap_2()
180            .flex_1()
181            .child(
182                v_flex()
183                    .gap_0p5()
184                    .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
185                    .child(Label::new(SUBHEADING).color(Color::Muted)),
186            )
187            .children(context_servers.into_iter().map(|context_server| {
188                let is_running = context_server.client().is_some();
189                let are_tools_expanded = self
190                    .expanded_context_server_tools
191                    .get(&context_server.id())
192                    .copied()
193                    .unwrap_or_default();
194
195                let tools = tools_by_source
196                    .get(&ToolSource::ContextServer {
197                        id: context_server.id().into(),
198                    })
199                    .unwrap_or_else(|| &empty);
200                let tool_count = tools.len();
201
202                v_flex()
203                    .id(SharedString::from(context_server.id()))
204                    .border_1()
205                    .rounded_sm()
206                    .border_color(cx.theme().colors().border)
207                    .bg(cx.theme().colors().editor_background)
208                    .child(
209                        h_flex()
210                            .justify_between()
211                            .px_2()
212                            .py_1()
213                            .when(are_tools_expanded, |element| {
214                                element
215                                    .border_b_1()
216                                    .border_color(cx.theme().colors().border)
217                            })
218                            .child(
219                                h_flex()
220                                    .gap_2()
221                                    .child(
222                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
223                                            .on_click(cx.listener({
224                                                let context_server_id = context_server.id();
225                                                move |this, _event, _window, _cx| {
226                                                    let is_open = this
227                                                        .expanded_context_server_tools
228                                                        .entry(context_server_id.clone())
229                                                        .or_insert(false);
230
231                                                    *is_open = !*is_open;
232                                                }
233                                            })),
234                                    )
235                                    .child(Indicator::dot().color(if is_running {
236                                        Color::Success
237                                    } else {
238                                        Color::Error
239                                    }))
240                                    .child(Label::new(context_server.id()))
241                                    .child(
242                                        Label::new(format!("{tool_count} tools"))
243                                            .color(Color::Muted),
244                                    ),
245                            )
246                            .child(h_flex().child(
247                                Switch::new("context-server-switch", is_running.into()).on_click({
248                                    let context_server_manager =
249                                        self.context_server_manager.clone();
250                                    let context_server = context_server.clone();
251                                    move |state, _window, cx| match state {
252                                        ToggleState::Unselected | ToggleState::Indeterminate => {
253                                            context_server_manager.update(cx, |this, cx| {
254                                                this.stop_server(context_server.clone(), cx)
255                                                    .log_err();
256                                            });
257                                        }
258                                        ToggleState::Selected => {
259                                            cx.spawn({
260                                                let context_server_manager =
261                                                    context_server_manager.clone();
262                                                let context_server = context_server.clone();
263                                                async move |cx| {
264                                                    if let Some(start_server_task) =
265                                                        context_server_manager
266                                                            .update(cx, |this, cx| {
267                                                                this.start_server(
268                                                                    context_server,
269                                                                    cx,
270                                                                )
271                                                            })
272                                                            .log_err()
273                                                    {
274                                                        start_server_task.await.log_err();
275                                                    }
276                                                }
277                                            })
278                                            .detach();
279                                        }
280                                    }
281                                }),
282                            )),
283                    )
284                    .map(|parent| {
285                        if !are_tools_expanded {
286                            return parent;
287                        }
288
289                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
290                            |(ix, tool)| {
291                                h_flex()
292                                    .px_2()
293                                    .py_1()
294                                    .when(ix < tool_count - 1, |element| {
295                                        element
296                                            .border_b_1()
297                                            .border_color(cx.theme().colors().border)
298                                    })
299                                    .child(Label::new(tool.name()))
300                            },
301                        )))
302                    })
303            }))
304            .child(
305                h_flex()
306                    .justify_between()
307                    .gap_2()
308                    .child(
309                        h_flex().w_full().child(
310                            Button::new("add-context-server", "Add Context Server")
311                                .style(ButtonStyle::Filled)
312                                .layer(ElevationIndex::ModalSurface)
313                                .full_width()
314                                .icon(IconName::Plus)
315                                .icon_size(IconSize::Small)
316                                .icon_position(IconPosition::Start)
317                                .on_click(|_event, window, cx| {
318                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
319                                }),
320                        ),
321                    )
322                    .child(
323                        h_flex().w_full().child(
324                            Button::new(
325                                "install-context-server-extensions",
326                                "Install Context Server Extensions",
327                            )
328                            .style(ButtonStyle::Filled)
329                            .layer(ElevationIndex::ModalSurface)
330                            .full_width()
331                            .icon(IconName::DatabaseZap)
332                            .icon_size(IconSize::Small)
333                            .icon_position(IconPosition::Start)
334                            .on_click(|_event, window, cx| {
335                                window.dispatch_action(
336                                    zed_actions::Extensions {
337                                        category_filter: Some(
338                                            ExtensionCategoryFilter::ContextServers,
339                                        ),
340                                    }
341                                    .boxed_clone(),
342                                    cx,
343                                )
344                            }),
345                        ),
346                    ),
347            )
348    }
349}
350
351impl Render for AssistantConfiguration {
352    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
353        let providers = LanguageModelRegistry::read_global(cx).providers();
354
355        v_flex()
356            .id("assistant-configuration")
357            .track_focus(&self.focus_handle(cx))
358            .bg(cx.theme().colors().panel_background)
359            .size_full()
360            .overflow_y_scroll()
361            .child(self.render_context_servers_section(cx))
362            .child(Divider::horizontal().color(DividerColor::Border))
363            .child(
364                v_flex()
365                    .p(DynamicSpacing::Base16.rems(cx))
366                    .mt_1()
367                    .gap_6()
368                    .flex_1()
369                    .child(
370                        v_flex()
371                            .gap_0p5()
372                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
373                            .child(
374                                Label::new("Add at least one provider to use AI-powered features.")
375                                    .color(Color::Muted),
376                            ),
377                    )
378                    .children(
379                        providers
380                            .into_iter()
381                            .map(|provider| self.render_provider_configuration(&provider, cx)),
382                    ),
383            )
384    }
385}