assistant_configuration.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::{ToolSource, ToolWorkingSet};
  4use collections::HashMap;
  5use context_server::manager::ContextServerManager;
  6use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
  7use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
  8use ui::{
  9    prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip,
 10};
 11use util::ResultExt as _;
 12use zed_actions::ExtensionCategoryFilter;
 13
 14pub struct AssistantConfiguration {
 15    focus_handle: FocusHandle,
 16    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 17    context_server_manager: Entity<ContextServerManager>,
 18    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 19    tools: Arc<ToolWorkingSet>,
 20    _registry_subscription: Subscription,
 21}
 22
 23impl AssistantConfiguration {
 24    pub fn new(
 25        context_server_manager: Entity<ContextServerManager>,
 26        tools: Arc<ToolWorkingSet>,
 27        window: &mut Window,
 28        cx: &mut Context<Self>,
 29    ) -> Self {
 30        let focus_handle = cx.focus_handle();
 31
 32        let registry_subscription = cx.subscribe_in(
 33            &LanguageModelRegistry::global(cx),
 34            window,
 35            |this, _, event: &language_model::Event, window, cx| match event {
 36                language_model::Event::AddedProvider(provider_id) => {
 37                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 38                    if let Some(provider) = provider {
 39                        this.add_provider_configuration_view(&provider, window, cx);
 40                    }
 41                }
 42                language_model::Event::RemovedProvider(provider_id) => {
 43                    this.remove_provider_configuration_view(provider_id);
 44                }
 45                _ => {}
 46            },
 47        );
 48
 49        let mut this = Self {
 50            focus_handle,
 51            configuration_views_by_provider: HashMap::default(),
 52            context_server_manager,
 53            expanded_context_server_tools: HashMap::default(),
 54            tools,
 55            _registry_subscription: registry_subscription,
 56        };
 57        this.build_provider_configuration_views(window, cx);
 58        this
 59    }
 60
 61    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 62        let providers = LanguageModelRegistry::read_global(cx).providers();
 63        for provider in providers {
 64            self.add_provider_configuration_view(&provider, window, cx);
 65        }
 66    }
 67
 68    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 69        self.configuration_views_by_provider.remove(provider_id);
 70    }
 71
 72    fn add_provider_configuration_view(
 73        &mut self,
 74        provider: &Arc<dyn LanguageModelProvider>,
 75        window: &mut Window,
 76        cx: &mut Context<Self>,
 77    ) {
 78        let configuration_view = provider.configuration_view(window, cx);
 79        self.configuration_views_by_provider
 80            .insert(provider.id(), configuration_view);
 81    }
 82}
 83
 84impl Focusable for AssistantConfiguration {
 85    fn focus_handle(&self, _: &App) -> FocusHandle {
 86        self.focus_handle.clone()
 87    }
 88}
 89
 90pub enum AssistantConfigurationEvent {
 91    NewThread(Arc<dyn LanguageModelProvider>),
 92}
 93
 94impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
 95
 96impl AssistantConfiguration {
 97    fn render_provider_configuration(
 98        &mut self,
 99        provider: &Arc<dyn LanguageModelProvider>,
100        cx: &mut Context<Self>,
101    ) -> impl IntoElement {
102        let provider_id = provider.id().0.clone();
103        let provider_name = provider.name().0.clone();
104        let configuration_view = self
105            .configuration_views_by_provider
106            .get(&provider.id())
107            .cloned();
108
109        v_flex()
110            .gap_1p5()
111            .child(
112                h_flex()
113                    .justify_between()
114                    .child(
115                        h_flex()
116                            .gap_2()
117                            .child(
118                                Icon::new(provider.icon())
119                                    .size(IconSize::Small)
120                                    .color(Color::Muted),
121                            )
122                            .child(Label::new(provider_name.clone())),
123                    )
124                    .when(provider.is_authenticated(cx), |parent| {
125                        parent.child(
126                            Button::new(
127                                SharedString::from(format!("new-thread-{provider_id}")),
128                                "Start New Thread",
129                            )
130                            .icon_position(IconPosition::Start)
131                            .icon(IconName::Plus)
132                            .icon_size(IconSize::Small)
133                            .style(ButtonStyle::Filled)
134                            .layer(ElevationIndex::ModalSurface)
135                            .label_size(LabelSize::Small)
136                            .on_click(cx.listener({
137                                let provider = provider.clone();
138                                move |_this, _event, _window, cx| {
139                                    cx.emit(AssistantConfigurationEvent::NewThread(
140                                        provider.clone(),
141                                    ))
142                                }
143                            })),
144                        )
145                    }),
146            )
147            .child(
148                div()
149                    .p(DynamicSpacing::Base08.rems(cx))
150                    .bg(cx.theme().colors().editor_background)
151                    .border_1()
152                    .border_color(cx.theme().colors().border_variant)
153                    .rounded_sm()
154                    .map(|parent| match configuration_view {
155                        Some(configuration_view) => parent.child(configuration_view),
156                        None => parent.child(div().child(Label::new(format!(
157                            "No configuration view for {provider_name}",
158                        )))),
159                    }),
160            )
161    }
162
163    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
164        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
165        let tools_by_source = self.tools.tools_by_source(cx);
166        let empty = Vec::new();
167
168        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
169
170        v_flex()
171            .p(DynamicSpacing::Base16.rems(cx))
172            .gap_2()
173            .flex_1()
174            .child(
175                v_flex()
176                    .gap_0p5()
177                    .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
178                    .child(Label::new(SUBHEADING).color(Color::Muted)),
179            )
180            .children(context_servers.into_iter().map(|context_server| {
181                let is_running = context_server.client().is_some();
182                let are_tools_expanded = self
183                    .expanded_context_server_tools
184                    .get(&context_server.id())
185                    .copied()
186                    .unwrap_or_default();
187
188                let tools = tools_by_source
189                    .get(&ToolSource::ContextServer {
190                        id: context_server.id().into(),
191                    })
192                    .unwrap_or_else(|| &empty);
193                let tool_count = tools.len();
194
195                v_flex()
196                    .id(SharedString::from(context_server.id()))
197                    .border_1()
198                    .rounded_sm()
199                    .border_color(cx.theme().colors().border)
200                    .bg(cx.theme().colors().editor_background)
201                    .child(
202                        h_flex()
203                            .justify_between()
204                            .px_2()
205                            .py_1()
206                            .when(are_tools_expanded, |element| {
207                                element
208                                    .border_b_1()
209                                    .border_color(cx.theme().colors().border)
210                            })
211                            .child(
212                                h_flex()
213                                    .gap_2()
214                                    .child(
215                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
216                                            .on_click(cx.listener({
217                                                let context_server_id = context_server.id();
218                                                move |this, _event, _window, _cx| {
219                                                    let is_open = this
220                                                        .expanded_context_server_tools
221                                                        .entry(context_server_id.clone())
222                                                        .or_insert(false);
223
224                                                    *is_open = !*is_open;
225                                                }
226                                            })),
227                                    )
228                                    .child(Indicator::dot().color(if is_running {
229                                        Color::Success
230                                    } else {
231                                        Color::Error
232                                    }))
233                                    .child(Label::new(context_server.id()))
234                                    .child(
235                                        Label::new(format!("{tool_count} tools"))
236                                            .color(Color::Muted),
237                                    ),
238                            )
239                            .child(h_flex().child(
240                                Switch::new("context-server-switch", is_running.into()).on_click({
241                                    let context_server_manager =
242                                        self.context_server_manager.clone();
243                                    let context_server = context_server.clone();
244                                    move |state, _window, cx| match state {
245                                        ToggleState::Unselected | ToggleState::Indeterminate => {
246                                            context_server_manager.update(cx, |this, cx| {
247                                                this.stop_server(context_server.clone(), cx)
248                                                    .log_err();
249                                            });
250                                        }
251                                        ToggleState::Selected => {
252                                            cx.spawn({
253                                                let context_server_manager =
254                                                    context_server_manager.clone();
255                                                let context_server = context_server.clone();
256                                                async move |cx| {
257                                                    if let Some(start_server_task) =
258                                                        context_server_manager
259                                                            .update(cx, |this, cx| {
260                                                                this.start_server(
261                                                                    context_server,
262                                                                    cx,
263                                                                )
264                                                            })
265                                                            .log_err()
266                                                    {
267                                                        start_server_task.await.log_err();
268                                                    }
269                                                }
270                                            })
271                                            .detach();
272                                        }
273                                    }
274                                }),
275                            )),
276                    )
277                    .map(|parent| {
278                        if !are_tools_expanded {
279                            return parent;
280                        }
281
282                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
283                            |(ix, tool)| {
284                                h_flex()
285                                    .px_2()
286                                    .py_1()
287                                    .when(ix < tool_count - 1, |element| {
288                                        element
289                                            .border_b_1()
290                                            .border_color(cx.theme().colors().border)
291                                    })
292                                    .child(Label::new(tool.name()))
293                            },
294                        )))
295                    })
296            }))
297            .child(
298                h_flex()
299                    .justify_between()
300                    .gap_2()
301                    .child(
302                        h_flex().w_full().child(
303                            Button::new("add-context-server", "Add Context Server")
304                                .style(ButtonStyle::Filled)
305                                .layer(ElevationIndex::ModalSurface)
306                                .full_width()
307                                .icon(IconName::Plus)
308                                .icon_size(IconSize::Small)
309                                .icon_position(IconPosition::Start)
310                                .disabled(true)
311                                .tooltip(Tooltip::text("Not yet implemented")),
312                        ),
313                    )
314                    .child(
315                        h_flex().w_full().child(
316                            Button::new(
317                                "install-context-server-extensions",
318                                "Install Context Server Extensions",
319                            )
320                            .style(ButtonStyle::Filled)
321                            .layer(ElevationIndex::ModalSurface)
322                            .full_width()
323                            .icon(IconName::DatabaseZap)
324                            .icon_size(IconSize::Small)
325                            .icon_position(IconPosition::Start)
326                            .on_click(|_event, window, cx| {
327                                window.dispatch_action(
328                                    zed_actions::Extensions {
329                                        category_filter: Some(
330                                            ExtensionCategoryFilter::ContextServers,
331                                        ),
332                                    }
333                                    .boxed_clone(),
334                                    cx,
335                                )
336                            }),
337                        ),
338                    ),
339            )
340    }
341}
342
343impl Render for AssistantConfiguration {
344    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
345        let providers = LanguageModelRegistry::read_global(cx).providers();
346
347        v_flex()
348            .id("assistant-configuration")
349            .track_focus(&self.focus_handle(cx))
350            .bg(cx.theme().colors().panel_background)
351            .size_full()
352            .overflow_y_scroll()
353            .child(self.render_context_servers_section(cx))
354            .child(Divider::horizontal().color(DividerColor::Border))
355            .child(
356                v_flex()
357                    .p(DynamicSpacing::Base16.rems(cx))
358                    .mt_1()
359                    .gap_6()
360                    .flex_1()
361                    .child(
362                        v_flex()
363                            .gap_0p5()
364                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
365                            .child(
366                                Label::new("Add at least one provider to use AI-powered features.")
367                                    .color(Color::Muted),
368                            ),
369                    )
370                    .children(
371                        providers
372                            .into_iter()
373                            .map(|provider| self.render_provider_configuration(&provider, cx)),
374                    ),
375            )
376    }
377}