assistant_configuration.rs

  1mod add_context_server_modal;
  2mod manage_profiles_modal;
  3mod profile_picker;
  4mod tool_picker;
  5
  6use std::sync::Arc;
  7
  8use assistant_tool::{ToolSource, ToolWorkingSet};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
 12use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
 13use ui::{prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch};
 14use util::ResultExt as _;
 15use zed_actions::ExtensionCategoryFilter;
 16
 17pub(crate) use add_context_server_modal::AddContextServerModal;
 18pub(crate) use manage_profiles_modal::ManageProfilesModal;
 19
 20use crate::AddContextServer;
 21
 22pub struct AssistantConfiguration {
 23    focus_handle: FocusHandle,
 24    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 25    context_server_manager: Entity<ContextServerManager>,
 26    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 27    tools: Arc<ToolWorkingSet>,
 28    _registry_subscription: Subscription,
 29}
 30
 31impl AssistantConfiguration {
 32    pub fn new(
 33        context_server_manager: Entity<ContextServerManager>,
 34        tools: Arc<ToolWorkingSet>,
 35        window: &mut Window,
 36        cx: &mut Context<Self>,
 37    ) -> Self {
 38        let focus_handle = cx.focus_handle();
 39
 40        let registry_subscription = cx.subscribe_in(
 41            &LanguageModelRegistry::global(cx),
 42            window,
 43            |this, _, event: &language_model::Event, window, cx| match event {
 44                language_model::Event::AddedProvider(provider_id) => {
 45                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 46                    if let Some(provider) = provider {
 47                        this.add_provider_configuration_view(&provider, window, cx);
 48                    }
 49                }
 50                language_model::Event::RemovedProvider(provider_id) => {
 51                    this.remove_provider_configuration_view(provider_id);
 52                }
 53                _ => {}
 54            },
 55        );
 56
 57        let mut this = Self {
 58            focus_handle,
 59            configuration_views_by_provider: HashMap::default(),
 60            context_server_manager,
 61            expanded_context_server_tools: HashMap::default(),
 62            tools,
 63            _registry_subscription: registry_subscription,
 64        };
 65        this.build_provider_configuration_views(window, cx);
 66        this
 67    }
 68
 69    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 70        let providers = LanguageModelRegistry::read_global(cx).providers();
 71        for provider in providers {
 72            self.add_provider_configuration_view(&provider, window, cx);
 73        }
 74    }
 75
 76    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 77        self.configuration_views_by_provider.remove(provider_id);
 78    }
 79
 80    fn add_provider_configuration_view(
 81        &mut self,
 82        provider: &Arc<dyn LanguageModelProvider>,
 83        window: &mut Window,
 84        cx: &mut Context<Self>,
 85    ) {
 86        let configuration_view = provider.configuration_view(window, cx);
 87        self.configuration_views_by_provider
 88            .insert(provider.id(), configuration_view);
 89    }
 90}
 91
 92impl Focusable for AssistantConfiguration {
 93    fn focus_handle(&self, _: &App) -> FocusHandle {
 94        self.focus_handle.clone()
 95    }
 96}
 97
 98pub enum AssistantConfigurationEvent {
 99    NewThread(Arc<dyn LanguageModelProvider>),
100}
101
102impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
103
104impl AssistantConfiguration {
105    fn render_provider_configuration(
106        &mut self,
107        provider: &Arc<dyn LanguageModelProvider>,
108        cx: &mut Context<Self>,
109    ) -> impl IntoElement {
110        let provider_id = provider.id().0.clone();
111        let provider_name = provider.name().0.clone();
112        let configuration_view = self
113            .configuration_views_by_provider
114            .get(&provider.id())
115            .cloned();
116
117        v_flex()
118            .gap_1p5()
119            .child(
120                h_flex()
121                    .justify_between()
122                    .child(
123                        h_flex()
124                            .gap_2()
125                            .child(
126                                Icon::new(provider.icon())
127                                    .size(IconSize::Small)
128                                    .color(Color::Muted),
129                            )
130                            .child(Label::new(provider_name.clone())),
131                    )
132                    .when(provider.is_authenticated(cx), |parent| {
133                        parent.child(
134                            Button::new(
135                                SharedString::from(format!("new-thread-{provider_id}")),
136                                "Start New Thread",
137                            )
138                            .icon_position(IconPosition::Start)
139                            .icon(IconName::Plus)
140                            .icon_size(IconSize::Small)
141                            .style(ButtonStyle::Filled)
142                            .layer(ElevationIndex::ModalSurface)
143                            .label_size(LabelSize::Small)
144                            .on_click(cx.listener({
145                                let provider = provider.clone();
146                                move |_this, _event, _window, cx| {
147                                    cx.emit(AssistantConfigurationEvent::NewThread(
148                                        provider.clone(),
149                                    ))
150                                }
151                            })),
152                        )
153                    }),
154            )
155            .child(
156                div()
157                    .p(DynamicSpacing::Base08.rems(cx))
158                    .bg(cx.theme().colors().editor_background)
159                    .border_1()
160                    .border_color(cx.theme().colors().border_variant)
161                    .rounded_sm()
162                    .map(|parent| match configuration_view {
163                        Some(configuration_view) => parent.child(configuration_view),
164                        None => parent.child(div().child(Label::new(format!(
165                            "No configuration view for {provider_name}",
166                        )))),
167                    }),
168            )
169    }
170
171    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
172        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
173        let tools_by_source = self.tools.tools_by_source(cx);
174        let empty = Vec::new();
175
176        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
177
178        v_flex()
179            .p(DynamicSpacing::Base16.rems(cx))
180            .gap_2()
181            .flex_1()
182            .child(
183                v_flex()
184                    .gap_0p5()
185                    .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
186                    .child(Label::new(SUBHEADING).color(Color::Muted)),
187            )
188            .children(context_servers.into_iter().map(|context_server| {
189                let is_running = context_server.client().is_some();
190                let are_tools_expanded = self
191                    .expanded_context_server_tools
192                    .get(&context_server.id())
193                    .copied()
194                    .unwrap_or_default();
195
196                let tools = tools_by_source
197                    .get(&ToolSource::ContextServer {
198                        id: context_server.id().into(),
199                    })
200                    .unwrap_or_else(|| &empty);
201                let tool_count = tools.len();
202
203                v_flex()
204                    .id(SharedString::from(context_server.id()))
205                    .border_1()
206                    .rounded_sm()
207                    .border_color(cx.theme().colors().border)
208                    .bg(cx.theme().colors().editor_background)
209                    .child(
210                        h_flex()
211                            .justify_between()
212                            .px_2()
213                            .py_1()
214                            .when(are_tools_expanded, |element| {
215                                element
216                                    .border_b_1()
217                                    .border_color(cx.theme().colors().border)
218                            })
219                            .child(
220                                h_flex()
221                                    .gap_2()
222                                    .child(
223                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
224                                            .on_click(cx.listener({
225                                                let context_server_id = context_server.id();
226                                                move |this, _event, _window, _cx| {
227                                                    let is_open = this
228                                                        .expanded_context_server_tools
229                                                        .entry(context_server_id.clone())
230                                                        .or_insert(false);
231
232                                                    *is_open = !*is_open;
233                                                }
234                                            })),
235                                    )
236                                    .child(Indicator::dot().color(if is_running {
237                                        Color::Success
238                                    } else {
239                                        Color::Error
240                                    }))
241                                    .child(Label::new(context_server.id()))
242                                    .child(
243                                        Label::new(format!("{tool_count} tools"))
244                                            .color(Color::Muted),
245                                    ),
246                            )
247                            .child(h_flex().child(
248                                Switch::new("context-server-switch", is_running.into()).on_click({
249                                    let context_server_manager =
250                                        self.context_server_manager.clone();
251                                    let context_server = context_server.clone();
252                                    move |state, _window, cx| match state {
253                                        ToggleState::Unselected | ToggleState::Indeterminate => {
254                                            context_server_manager.update(cx, |this, cx| {
255                                                this.stop_server(context_server.clone(), cx)
256                                                    .log_err();
257                                            });
258                                        }
259                                        ToggleState::Selected => {
260                                            cx.spawn({
261                                                let context_server_manager =
262                                                    context_server_manager.clone();
263                                                let context_server = context_server.clone();
264                                                async move |cx| {
265                                                    if let Some(start_server_task) =
266                                                        context_server_manager
267                                                            .update(cx, |this, cx| {
268                                                                this.start_server(
269                                                                    context_server,
270                                                                    cx,
271                                                                )
272                                                            })
273                                                            .log_err()
274                                                    {
275                                                        start_server_task.await.log_err();
276                                                    }
277                                                }
278                                            })
279                                            .detach();
280                                        }
281                                    }
282                                }),
283                            )),
284                    )
285                    .map(|parent| {
286                        if !are_tools_expanded {
287                            return parent;
288                        }
289
290                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
291                            |(ix, tool)| {
292                                h_flex()
293                                    .px_2()
294                                    .py_1()
295                                    .when(ix < tool_count - 1, |element| {
296                                        element
297                                            .border_b_1()
298                                            .border_color(cx.theme().colors().border)
299                                    })
300                                    .child(Label::new(tool.name()))
301                            },
302                        )))
303                    })
304            }))
305            .child(
306                h_flex()
307                    .justify_between()
308                    .gap_2()
309                    .child(
310                        h_flex().w_full().child(
311                            Button::new("add-context-server", "Add Context Server")
312                                .style(ButtonStyle::Filled)
313                                .layer(ElevationIndex::ModalSurface)
314                                .full_width()
315                                .icon(IconName::Plus)
316                                .icon_size(IconSize::Small)
317                                .icon_position(IconPosition::Start)
318                                .on_click(|_event, window, cx| {
319                                    window.dispatch_action(AddContextServer.boxed_clone(), cx)
320                                }),
321                        ),
322                    )
323                    .child(
324                        h_flex().w_full().child(
325                            Button::new(
326                                "install-context-server-extensions",
327                                "Install Context Server Extensions",
328                            )
329                            .style(ButtonStyle::Filled)
330                            .layer(ElevationIndex::ModalSurface)
331                            .full_width()
332                            .icon(IconName::DatabaseZap)
333                            .icon_size(IconSize::Small)
334                            .icon_position(IconPosition::Start)
335                            .on_click(|_event, window, cx| {
336                                window.dispatch_action(
337                                    zed_actions::Extensions {
338                                        category_filter: Some(
339                                            ExtensionCategoryFilter::ContextServers,
340                                        ),
341                                    }
342                                    .boxed_clone(),
343                                    cx,
344                                )
345                            }),
346                        ),
347                    ),
348            )
349    }
350}
351
352impl Render for AssistantConfiguration {
353    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
354        let providers = LanguageModelRegistry::read_global(cx).providers();
355
356        v_flex()
357            .id("assistant-configuration")
358            .track_focus(&self.focus_handle(cx))
359            .bg(cx.theme().colors().panel_background)
360            .size_full()
361            .overflow_y_scroll()
362            .child(self.render_context_servers_section(cx))
363            .child(Divider::horizontal().color(DividerColor::Border))
364            .child(
365                v_flex()
366                    .p(DynamicSpacing::Base16.rems(cx))
367                    .mt_1()
368                    .gap_6()
369                    .flex_1()
370                    .child(
371                        v_flex()
372                            .gap_0p5()
373                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
374                            .child(
375                                Label::new("Add at least one provider to use AI-powered features.")
376                                    .color(Color::Muted),
377                            ),
378                    )
379                    .children(
380                        providers
381                            .into_iter()
382                            .map(|provider| self.render_provider_configuration(&provider, cx)),
383                    ),
384            )
385    }
386}