assistant_configuration.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::{ToolSource, ToolWorkingSet};
  4use collections::HashMap;
  5use context_server::manager::ContextServerManager;
  6use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
  7use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
  8use ui::{prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch};
  9use util::ResultExt as _;
 10use zed_actions::assistant::DeployPromptLibrary;
 11
 12pub struct AssistantConfiguration {
 13    focus_handle: FocusHandle,
 14    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 15    context_server_manager: Entity<ContextServerManager>,
 16    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 17    tools: Arc<ToolWorkingSet>,
 18    _registry_subscription: Subscription,
 19}
 20
 21impl AssistantConfiguration {
 22    pub fn new(
 23        context_server_manager: Entity<ContextServerManager>,
 24        tools: Arc<ToolWorkingSet>,
 25        window: &mut Window,
 26        cx: &mut Context<Self>,
 27    ) -> Self {
 28        let focus_handle = cx.focus_handle();
 29
 30        let registry_subscription = cx.subscribe_in(
 31            &LanguageModelRegistry::global(cx),
 32            window,
 33            |this, _, event: &language_model::Event, window, cx| match event {
 34                language_model::Event::AddedProvider(provider_id) => {
 35                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 36                    if let Some(provider) = provider {
 37                        this.add_provider_configuration_view(&provider, window, cx);
 38                    }
 39                }
 40                language_model::Event::RemovedProvider(provider_id) => {
 41                    this.remove_provider_configuration_view(provider_id);
 42                }
 43                _ => {}
 44            },
 45        );
 46
 47        let mut this = Self {
 48            focus_handle,
 49            configuration_views_by_provider: HashMap::default(),
 50            context_server_manager,
 51            expanded_context_server_tools: HashMap::default(),
 52            tools,
 53            _registry_subscription: registry_subscription,
 54        };
 55        this.build_provider_configuration_views(window, cx);
 56        this
 57    }
 58
 59    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 60        let providers = LanguageModelRegistry::read_global(cx).providers();
 61        for provider in providers {
 62            self.add_provider_configuration_view(&provider, window, cx);
 63        }
 64    }
 65
 66    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 67        self.configuration_views_by_provider.remove(provider_id);
 68    }
 69
 70    fn add_provider_configuration_view(
 71        &mut self,
 72        provider: &Arc<dyn LanguageModelProvider>,
 73        window: &mut Window,
 74        cx: &mut Context<Self>,
 75    ) {
 76        let configuration_view = provider.configuration_view(window, cx);
 77        self.configuration_views_by_provider
 78            .insert(provider.id(), configuration_view);
 79    }
 80}
 81
 82impl Focusable for AssistantConfiguration {
 83    fn focus_handle(&self, _: &App) -> FocusHandle {
 84        self.focus_handle.clone()
 85    }
 86}
 87
 88pub enum AssistantConfigurationEvent {
 89    NewThread(Arc<dyn LanguageModelProvider>),
 90}
 91
 92impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
 93
 94impl AssistantConfiguration {
 95    fn render_provider_configuration(
 96        &mut self,
 97        provider: &Arc<dyn LanguageModelProvider>,
 98        cx: &mut Context<Self>,
 99    ) -> impl IntoElement {
100        let provider_id = provider.id().0.clone();
101        let provider_name = provider.name().0.clone();
102        let configuration_view = self
103            .configuration_views_by_provider
104            .get(&provider.id())
105            .cloned();
106
107        v_flex()
108            .gap_1p5()
109            .child(
110                h_flex()
111                    .justify_between()
112                    .child(
113                        h_flex()
114                            .gap_2()
115                            .child(
116                                Icon::new(provider.icon())
117                                    .size(IconSize::Small)
118                                    .color(Color::Muted),
119                            )
120                            .child(Label::new(provider_name.clone())),
121                    )
122                    .when(provider.is_authenticated(cx), |parent| {
123                        parent.child(
124                            Button::new(
125                                SharedString::from(format!("new-thread-{provider_id}")),
126                                "Start New Thread",
127                            )
128                            .icon_position(IconPosition::Start)
129                            .icon(IconName::Plus)
130                            .icon_size(IconSize::Small)
131                            .style(ButtonStyle::Filled)
132                            .layer(ElevationIndex::ModalSurface)
133                            .label_size(LabelSize::Small)
134                            .on_click(cx.listener({
135                                let provider = provider.clone();
136                                move |_this, _event, _window, cx| {
137                                    cx.emit(AssistantConfigurationEvent::NewThread(
138                                        provider.clone(),
139                                    ))
140                                }
141                            })),
142                        )
143                    }),
144            )
145            .child(
146                div()
147                    .p(DynamicSpacing::Base08.rems(cx))
148                    .bg(cx.theme().colors().editor_background)
149                    .border_1()
150                    .border_color(cx.theme().colors().border_variant)
151                    .rounded_sm()
152                    .map(|parent| match configuration_view {
153                        Some(configuration_view) => parent.child(configuration_view),
154                        None => parent.child(div().child(Label::new(format!(
155                            "No configuration view for {provider_name}",
156                        )))),
157                    }),
158            )
159    }
160
161    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
162        let context_servers = self.context_server_manager.read(cx).all_servers().clone();
163        let tools_by_source = self.tools.tools_by_source(cx);
164        let empty = Vec::new();
165
166        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
167
168        v_flex()
169            .p(DynamicSpacing::Base16.rems(cx))
170            .mt_1()
171            .gap_6()
172            .flex_1()
173            .child(
174                v_flex()
175                    .gap_0p5()
176                    .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
177                    .child(Label::new(SUBHEADING).color(Color::Muted)),
178            )
179            .children(context_servers.into_iter().map(|context_server| {
180                let is_running = context_server.client().is_some();
181                let are_tools_expanded = self
182                    .expanded_context_server_tools
183                    .get(&context_server.id())
184                    .copied()
185                    .unwrap_or_default();
186
187                let tools = tools_by_source
188                    .get(&ToolSource::ContextServer {
189                        id: context_server.id().into(),
190                    })
191                    .unwrap_or_else(|| &empty);
192                let tool_count = tools.len();
193
194                v_flex()
195                    .border_1()
196                    .rounded_sm()
197                    .border_color(cx.theme().colors().border)
198                    .bg(cx.theme().colors().editor_background)
199                    .child(
200                        h_flex()
201                            .justify_between()
202                            .px_2()
203                            .py_1()
204                            .when(are_tools_expanded, |element| {
205                                element
206                                    .border_b_1()
207                                    .border_color(cx.theme().colors().border)
208                            })
209                            .child(
210                                h_flex()
211                                    .gap_2()
212                                    .child(
213                                        Disclosure::new("tool-list-disclosure", are_tools_expanded)
214                                            .on_click(cx.listener({
215                                                let context_server_id = context_server.id();
216                                                move |this, _event, _window, _cx| {
217                                                    let is_open = this
218                                                        .expanded_context_server_tools
219                                                        .entry(context_server_id.clone())
220                                                        .or_insert(false);
221
222                                                    *is_open = !*is_open;
223                                                }
224                                            })),
225                                    )
226                                    .child(Indicator::dot().color(if is_running {
227                                        Color::Success
228                                    } else {
229                                        Color::Error
230                                    }))
231                                    .child(Label::new(context_server.id()))
232                                    .child(
233                                        Label::new(format!("{tool_count} tools"))
234                                            .color(Color::Muted),
235                                    ),
236                            )
237                            .child(h_flex().child(
238                                Switch::new("context-server-switch", is_running.into()).on_click({
239                                    let context_server_manager =
240                                        self.context_server_manager.clone();
241                                    let context_server = context_server.clone();
242                                    move |state, _window, cx| match state {
243                                        ToggleState::Unselected | ToggleState::Indeterminate => {
244                                            context_server_manager.update(cx, |this, cx| {
245                                                this.stop_server(context_server.clone(), cx)
246                                                    .log_err();
247                                            });
248                                        }
249                                        ToggleState::Selected => {
250                                            cx.spawn({
251                                                let context_server_manager =
252                                                    context_server_manager.clone();
253                                                let context_server = context_server.clone();
254                                                async move |cx| {
255                                                    if let Some(start_server_task) =
256                                                        context_server_manager
257                                                            .update(cx, |this, cx| {
258                                                                this.start_server(
259                                                                    context_server,
260                                                                    cx,
261                                                                )
262                                                            })
263                                                            .log_err()
264                                                    {
265                                                        start_server_task.await.log_err();
266                                                    }
267                                                }
268                                            })
269                                            .detach();
270                                        }
271                                    }
272                                }),
273                            )),
274                    )
275                    .map(|parent| {
276                        if !are_tools_expanded {
277                            return parent;
278                        }
279
280                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
281                            |(ix, tool)| {
282                                h_flex()
283                                    .px_2()
284                                    .py_1()
285                                    .when(ix < tool_count - 1, |element| {
286                                        element
287                                            .border_b_1()
288                                            .border_color(cx.theme().colors().border)
289                                    })
290                                    .child(Label::new(tool.name()))
291                            },
292                        )))
293                    })
294            }))
295    }
296}
297
298impl Render for AssistantConfiguration {
299    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
300        let providers = LanguageModelRegistry::read_global(cx).providers();
301
302        v_flex()
303            .id("assistant-configuration")
304            .track_focus(&self.focus_handle(cx))
305            .bg(cx.theme().colors().panel_background)
306            .size_full()
307            .overflow_y_scroll()
308            .child(
309                v_flex()
310                    .p(DynamicSpacing::Base16.rems(cx))
311                    .gap_2()
312                    .child(
313                        v_flex()
314                            .gap_0p5()
315                            .child(Headline::new("Prompt Library").size(HeadlineSize::Small))
316                            .child(
317                                Label::new("Create reusable prompts and tag which ones you want sent in every LLM interaction.")
318                                    .color(Color::Muted),
319                            ),
320                    )
321                    .child(
322                        Button::new("open-prompt-library", "Open Prompt Library")
323                            .style(ButtonStyle::Filled)
324                            .layer(ElevationIndex::ModalSurface)
325                            .full_width()
326                            .icon(IconName::Book)
327                            .icon_size(IconSize::Small)
328                            .icon_position(IconPosition::Start)
329                            .on_click(|_event, window, cx| {
330                                window.dispatch_action(DeployPromptLibrary.boxed_clone(), cx)
331                            }),
332                    ),
333            )
334            .child(Divider::horizontal().color(DividerColor::Border))
335            .child(self.render_context_servers_section(cx))
336            .child(Divider::horizontal().color(DividerColor::Border))
337            .child(
338                v_flex()
339                    .p(DynamicSpacing::Base16.rems(cx))
340                    .mt_1()
341                    .gap_6()
342                    .flex_1()
343                    .child(
344                        v_flex()
345                            .gap_0p5()
346                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
347                            .child(
348                                Label::new("Add at least one provider to use AI-powered features.")
349                                    .color(Color::Muted),
350                            ),
351                    )
352                    .children(
353                        providers
354                            .into_iter()
355                            .map(|provider| self.render_provider_configuration(&provider, cx)),
356                    ),
357            )
358    }
359}