assistant_configuration.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::{ToolSource, ToolWorkingSet};
  4use collections::HashMap;
  5use context_server::manager::ContextServerManager;
  6use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
  7use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
  8use ui::{
  9    prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip,
 10};
 11use util::ResultExt as _;
 12use zed_actions::assistant::OpenPromptLibrary;
 13use zed_actions::ExtensionCategoryFilter;
 14
 15pub struct AssistantConfiguration {
 16    focus_handle: FocusHandle,
 17    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 18    context_server_manager: Entity<ContextServerManager>,
 19    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 20    tools: Arc<ToolWorkingSet>,
 21    _registry_subscription: Subscription,
 22}
 23
 24impl AssistantConfiguration {
 25    pub fn new(
 26        context_server_manager: Entity<ContextServerManager>,
 27        tools: Arc<ToolWorkingSet>,
 28        window: &mut Window,
 29        cx: &mut Context<Self>,
 30    ) -> Self {
 31        let focus_handle = cx.focus_handle();
 32
 33        let registry_subscription = cx.subscribe_in(
 34            &LanguageModelRegistry::global(cx),
 35            window,
 36            |this, _, event: &language_model::Event, window, cx| match event {
 37                language_model::Event::AddedProvider(provider_id) => {
 38                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 39                    if let Some(provider) = provider {
 40                        this.add_provider_configuration_view(&provider, window, cx);
 41                    }
 42                }
 43                language_model::Event::RemovedProvider(provider_id) => {
 44                    this.remove_provider_configuration_view(provider_id);
 45                }
 46                _ => {}
 47            },
 48        );
 49
 50        let mut this = Self {
 51            focus_handle,
 52            configuration_views_by_provider: HashMap::default(),
 53            context_server_manager,
 54            expanded_context_server_tools: HashMap::default(),
 55            tools,
 56            _registry_subscription: registry_subscription,
 57        };
 58        this.build_provider_configuration_views(window, cx);
 59        this
 60    }
 61
 62    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 63        let providers = LanguageModelRegistry::read_global(cx).providers();
 64        for provider in providers {
 65            self.add_provider_configuration_view(&provider, window, cx);
 66        }
 67    }
 68
 69    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 70        self.configuration_views_by_provider.remove(provider_id);
 71    }
 72
 73    fn add_provider_configuration_view(
 74        &mut self,
 75        provider: &Arc<dyn LanguageModelProvider>,
 76        window: &mut Window,
 77        cx: &mut Context<Self>,
 78    ) {
 79        let configuration_view = provider.configuration_view(window, cx);
 80        self.configuration_views_by_provider
 81            .insert(provider.id(), configuration_view);
 82    }
 83}
 84
 85impl Focusable for AssistantConfiguration {
 86    fn focus_handle(&self, _: &App) -> FocusHandle {
 87        self.focus_handle.clone()
 88    }
 89}
 90
 91pub enum AssistantConfigurationEvent {
 92    NewThread(Arc<dyn LanguageModelProvider>),
 93}
 94
 95impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
 96
 97impl AssistantConfiguration {
 98    fn render_provider_configuration(
 99        &mut self,
100        provider: &Arc<dyn LanguageModelProvider>,
101        cx: &mut Context<Self>,
102    ) -> impl IntoElement {
103        let provider_id = provider.id().0.clone();
104        let provider_name = provider.name().0.clone();
105        let configuration_view = self
106            .configuration_views_by_provider
107            .get(&provider.id())
108            .cloned();
109
110        v_flex()
111            .gap_1p5()
112            .child(
113                h_flex()
114                    .justify_between()
115                    .child(
116                        h_flex()
117                            .gap_2()
118                            .child(
119                                Icon::new(provider.icon())
120                                    .size(IconSize::Small)
121                                    .color(Color::Muted),
122                            )
123                            .child(Label::new(provider_name.clone())),
124                    )
125                    .when(provider.is_authenticated(cx), |parent| {
126                        parent.child(
127                            Button::new(
128                                SharedString::from(format!("new-thread-{provider_id}")),
129                                "Start New Thread",
130                            )
131                            .icon_position(IconPosition::Start)
132                            .icon(IconName::Plus)
133                            .icon_size(IconSize::Small)
134                            .style(ButtonStyle::Filled)
135                            .layer(ElevationIndex::ModalSurface)
136                            .label_size(LabelSize::Small)
137                            .on_click(cx.listener({
138                                let provider = provider.clone();
139                                move |_this, _event, _window, cx| {
140                                    cx.emit(AssistantConfigurationEvent::NewThread(
141                                        provider.clone(),
142                                    ))
143                                }
144                            })),
145                        )
146                    }),
147            )
148            .child(
149                div()
150                    .p(DynamicSpacing::Base08.rems(cx))
151                    .bg(cx.theme().colors().editor_background)
152                    .border_1()
153                    .border_color(cx.theme().colors().border_variant)
154                    .rounded_sm()
155                    .map(|parent| match configuration_view {
156                        Some(configuration_view) => parent.child(configuration_view),
157                        None => parent.child(div().child(Label::new(format!(
158                            "No configuration view for {provider_name}",
159                        )))),
160                    }),
161            )
162    }
163
164    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
165        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
166        let tools_by_source = self.tools.tools_by_source(cx);
167        let empty = Vec::new();
168
169        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
170
171        v_flex()
172            .p(DynamicSpacing::Base16.rems(cx))
173            .mt_1()
174            .gap_2()
175            .flex_1()
176            .child(
177                v_flex()
178                    .gap_0p5()
179                    .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
180                    .child(Label::new(SUBHEADING).color(Color::Muted)),
181            )
182            .children(context_servers.into_iter().map(|context_server| {
183                let is_running = context_server.client().is_some();
184                let are_tools_expanded = self
185                    .expanded_context_server_tools
186                    .get(&context_server.id())
187                    .copied()
188                    .unwrap_or_default();
189
190                let tools = tools_by_source
191                    .get(&ToolSource::ContextServer {
192                        id: context_server.id().into(),
193                    })
194                    .unwrap_or_else(|| &empty);
195                let tool_count = tools.len();
196
197                v_flex()
198                    .id(SharedString::from(context_server.id()))
199                    .border_1()
200                    .rounded_sm()
201                    .border_color(cx.theme().colors().border)
202                    .bg(cx.theme().colors().editor_background)
203                    .child(
204                        h_flex()
205                            .justify_between()
206                            .px_2()
207                            .py_1()
208                            .when(are_tools_expanded, |element| {
209                                element
210                                    .border_b_1()
211                                    .border_color(cx.theme().colors().border)
212                            })
213                            .child(
214                                h_flex()
215                                    .gap_2()
216                                    .child(
217                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
218                                            .on_click(cx.listener({
219                                                let context_server_id = context_server.id();
220                                                move |this, _event, _window, _cx| {
221                                                    let is_open = this
222                                                        .expanded_context_server_tools
223                                                        .entry(context_server_id.clone())
224                                                        .or_insert(false);
225
226                                                    *is_open = !*is_open;
227                                                }
228                                            })),
229                                    )
230                                    .child(Indicator::dot().color(if is_running {
231                                        Color::Success
232                                    } else {
233                                        Color::Error
234                                    }))
235                                    .child(Label::new(context_server.id()))
236                                    .child(
237                                        Label::new(format!("{tool_count} tools"))
238                                            .color(Color::Muted),
239                                    ),
240                            )
241                            .child(h_flex().child(
242                                Switch::new("context-server-switch", is_running.into()).on_click({
243                                    let context_server_manager =
244                                        self.context_server_manager.clone();
245                                    let context_server = context_server.clone();
246                                    move |state, _window, cx| match state {
247                                        ToggleState::Unselected | ToggleState::Indeterminate => {
248                                            context_server_manager.update(cx, |this, cx| {
249                                                this.stop_server(context_server.clone(), cx)
250                                                    .log_err();
251                                            });
252                                        }
253                                        ToggleState::Selected => {
254                                            cx.spawn({
255                                                let context_server_manager =
256                                                    context_server_manager.clone();
257                                                let context_server = context_server.clone();
258                                                async move |cx| {
259                                                    if let Some(start_server_task) =
260                                                        context_server_manager
261                                                            .update(cx, |this, cx| {
262                                                                this.start_server(
263                                                                    context_server,
264                                                                    cx,
265                                                                )
266                                                            })
267                                                            .log_err()
268                                                    {
269                                                        start_server_task.await.log_err();
270                                                    }
271                                                }
272                                            })
273                                            .detach();
274                                        }
275                                    }
276                                }),
277                            )),
278                    )
279                    .map(|parent| {
280                        if !are_tools_expanded {
281                            return parent;
282                        }
283
284                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
285                            |(ix, tool)| {
286                                h_flex()
287                                    .px_2()
288                                    .py_1()
289                                    .when(ix < tool_count - 1, |element| {
290                                        element
291                                            .border_b_1()
292                                            .border_color(cx.theme().colors().border)
293                                    })
294                                    .child(Label::new(tool.name()))
295                            },
296                        )))
297                    })
298            }))
299            .child(
300                h_flex()
301                    .justify_between()
302                    .gap_2()
303                    .child(
304                        h_flex().w_full().child(
305                            Button::new("add-context-server", "Add Context Server")
306                                .style(ButtonStyle::Filled)
307                                .layer(ElevationIndex::ModalSurface)
308                                .full_width()
309                                .icon(IconName::Plus)
310                                .icon_size(IconSize::Small)
311                                .icon_position(IconPosition::Start)
312                                .disabled(true)
313                                .tooltip(Tooltip::text("Not yet implemented")),
314                        ),
315                    )
316                    .child(
317                        h_flex().w_full().child(
318                            Button::new(
319                                "install-context-server-extensions",
320                                "Install Context Server Extensions",
321                            )
322                            .style(ButtonStyle::Filled)
323                            .layer(ElevationIndex::ModalSurface)
324                            .full_width()
325                            .icon(IconName::DatabaseZap)
326                            .icon_size(IconSize::Small)
327                            .icon_position(IconPosition::Start)
328                            .on_click(|_event, window, cx| {
329                                window.dispatch_action(
330                                    zed_actions::Extensions {
331                                        category_filter: Some(
332                                            ExtensionCategoryFilter::ContextServers,
333                                        ),
334                                    }
335                                    .boxed_clone(),
336                                    cx,
337                                )
338                            }),
339                        ),
340                    ),
341            )
342    }
343}
344
345impl Render for AssistantConfiguration {
346    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
347        let providers = LanguageModelRegistry::read_global(cx).providers();
348
349        v_flex()
350            .id("assistant-configuration")
351            .track_focus(&self.focus_handle(cx))
352            .bg(cx.theme().colors().panel_background)
353            .size_full()
354            .overflow_y_scroll()
355            .child(
356                v_flex()
357                    .p(DynamicSpacing::Base16.rems(cx))
358                    .gap_2()
359                    .child(
360                        v_flex()
361                            .gap_0p5()
362                            .child(Headline::new("Prompt Library").size(HeadlineSize::Small))
363                            .child(
364                                Label::new("Create reusable prompts and tag which ones you want sent in every LLM interaction.")
365                                    .color(Color::Muted),
366                            ),
367                    )
368                    .child(
369                        Button::new("open-prompt-library", "Open Prompt Library")
370                            .style(ButtonStyle::Filled)
371                            .layer(ElevationIndex::ModalSurface)
372                            .full_width()
373                            .icon(IconName::Book)
374                            .icon_size(IconSize::Small)
375                            .icon_position(IconPosition::Start)
376                            .on_click(|_event, window, cx| {
377                                window.dispatch_action(OpenPromptLibrary.boxed_clone(), cx)
378                            }),
379                    ),
380            )
381            .child(Divider::horizontal().color(DividerColor::Border))
382            .child(self.render_context_servers_section(cx))
383            .child(Divider::horizontal().color(DividerColor::Border))
384            .child(
385                v_flex()
386                    .p(DynamicSpacing::Base16.rems(cx))
387                    .mt_1()
388                    .gap_6()
389                    .flex_1()
390                    .child(
391                        v_flex()
392                            .gap_0p5()
393                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
394                            .child(
395                                Label::new("Add at least one provider to use AI-powered features.")
396                                    .color(Color::Muted),
397                            ),
398                    )
399                    .children(
400                        providers
401                            .into_iter()
402                            .map(|provider| self.render_provider_configuration(&provider, cx)),
403                    ),
404            )
405    }
406}