assistant_configuration.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::{ToolSource, ToolWorkingSet};
  4use collections::HashMap;
  5use context_server::manager::ContextServerManager;
  6use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
  7use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
  8use ui::{prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator};
  9use zed_actions::assistant::DeployPromptLibrary;
 10
 11pub struct AssistantConfiguration {
 12    focus_handle: FocusHandle,
 13    configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
 14    context_server_manager: Entity<ContextServerManager>,
 15    expanded_context_server_tools: HashMap<Arc<str>, bool>,
 16    tools: Arc<ToolWorkingSet>,
 17    _registry_subscription: Subscription,
 18}
 19
 20impl AssistantConfiguration {
 21    pub fn new(
 22        context_server_manager: Entity<ContextServerManager>,
 23        tools: Arc<ToolWorkingSet>,
 24        window: &mut Window,
 25        cx: &mut Context<Self>,
 26    ) -> Self {
 27        let focus_handle = cx.focus_handle();
 28
 29        let registry_subscription = cx.subscribe_in(
 30            &LanguageModelRegistry::global(cx),
 31            window,
 32            |this, _, event: &language_model::Event, window, cx| match event {
 33                language_model::Event::AddedProvider(provider_id) => {
 34                    let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
 35                    if let Some(provider) = provider {
 36                        this.add_provider_configuration_view(&provider, window, cx);
 37                    }
 38                }
 39                language_model::Event::RemovedProvider(provider_id) => {
 40                    this.remove_provider_configuration_view(provider_id);
 41                }
 42                _ => {}
 43            },
 44        );
 45
 46        let mut this = Self {
 47            focus_handle,
 48            configuration_views_by_provider: HashMap::default(),
 49            context_server_manager,
 50            expanded_context_server_tools: HashMap::default(),
 51            tools,
 52            _registry_subscription: registry_subscription,
 53        };
 54        this.build_provider_configuration_views(window, cx);
 55        this
 56    }
 57
 58    fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 59        let providers = LanguageModelRegistry::read_global(cx).providers();
 60        for provider in providers {
 61            self.add_provider_configuration_view(&provider, window, cx);
 62        }
 63    }
 64
 65    fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
 66        self.configuration_views_by_provider.remove(provider_id);
 67    }
 68
 69    fn add_provider_configuration_view(
 70        &mut self,
 71        provider: &Arc<dyn LanguageModelProvider>,
 72        window: &mut Window,
 73        cx: &mut Context<Self>,
 74    ) {
 75        let configuration_view = provider.configuration_view(window, cx);
 76        self.configuration_views_by_provider
 77            .insert(provider.id(), configuration_view);
 78    }
 79}
 80
 81impl Focusable for AssistantConfiguration {
 82    fn focus_handle(&self, _: &App) -> FocusHandle {
 83        self.focus_handle.clone()
 84    }
 85}
 86
 87pub enum AssistantConfigurationEvent {
 88    NewThread(Arc<dyn LanguageModelProvider>),
 89}
 90
 91impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
 92
 93impl AssistantConfiguration {
 94    fn render_provider_configuration(
 95        &mut self,
 96        provider: &Arc<dyn LanguageModelProvider>,
 97        cx: &mut Context<Self>,
 98    ) -> impl IntoElement {
 99        let provider_id = provider.id().0.clone();
100        let provider_name = provider.name().0.clone();
101        let configuration_view = self
102            .configuration_views_by_provider
103            .get(&provider.id())
104            .cloned();
105
106        v_flex()
107            .gap_1p5()
108            .child(
109                h_flex()
110                    .justify_between()
111                    .child(
112                        h_flex()
113                            .gap_2()
114                            .child(
115                                Icon::new(provider.icon())
116                                    .size(IconSize::Small)
117                                    .color(Color::Muted),
118                            )
119                            .child(Label::new(provider_name.clone())),
120                    )
121                    .when(provider.is_authenticated(cx), |parent| {
122                        parent.child(
123                            Button::new(
124                                SharedString::from(format!("new-thread-{provider_id}")),
125                                "Start New Thread",
126                            )
127                            .icon_position(IconPosition::Start)
128                            .icon(IconName::Plus)
129                            .icon_size(IconSize::Small)
130                            .style(ButtonStyle::Filled)
131                            .layer(ElevationIndex::ModalSurface)
132                            .label_size(LabelSize::Small)
133                            .on_click(cx.listener({
134                                let provider = provider.clone();
135                                move |_this, _event, _window, cx| {
136                                    cx.emit(AssistantConfigurationEvent::NewThread(
137                                        provider.clone(),
138                                    ))
139                                }
140                            })),
141                        )
142                    }),
143            )
144            .child(
145                div()
146                    .p(DynamicSpacing::Base08.rems(cx))
147                    .bg(cx.theme().colors().editor_background)
148                    .border_1()
149                    .border_color(cx.theme().colors().border_variant)
150                    .rounded_sm()
151                    .map(|parent| match configuration_view {
152                        Some(configuration_view) => parent.child(configuration_view),
153                        None => parent.child(div().child(Label::new(format!(
154                            "No configuration view for {provider_name}",
155                        )))),
156                    }),
157            )
158    }
159
160    fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
161        let context_servers = self.context_server_manager.read(cx).servers().clone();
162        let tools_by_source = self.tools.tools_by_source(cx);
163        let empty = Vec::new();
164
165        const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
166
167        v_flex()
168            .p(DynamicSpacing::Base16.rems(cx))
169            .mt_1()
170            .gap_6()
171            .flex_1()
172            .child(
173                v_flex()
174                    .gap_0p5()
175                    .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
176                    .child(Label::new(SUBHEADING).color(Color::Muted)),
177            )
178            .children(context_servers.into_iter().map(|context_server| {
179                let is_running = context_server.client().is_some();
180                let are_tools_expanded = self
181                    .expanded_context_server_tools
182                    .get(&context_server.id())
183                    .copied()
184                    .unwrap_or_default();
185
186                let tools = tools_by_source
187                    .get(&ToolSource::ContextServer {
188                        id: context_server.id().into(),
189                    })
190                    .unwrap_or_else(|| &empty);
191                let tool_count = tools.len();
192
193                v_flex()
194                    .border_1()
195                    .rounded_sm()
196                    .border_color(cx.theme().colors().border)
197                    .bg(cx.theme().colors().editor_background)
198                    .child(
199                        h_flex()
200                            .gap_2()
201                            .px_2()
202                            .py_1()
203                            .when(are_tools_expanded, |element| {
204                                element
205                                    .border_b_1()
206                                    .border_color(cx.theme().colors().border)
207                            })
208                            .child(
209                                Disclosure::new("tool-list-disclosure", are_tools_expanded)
210                                    .on_click(cx.listener({
211                                        let context_server_id = context_server.id();
212                                        move |this, _event, _window, _cx| {
213                                            let is_open = this
214                                                .expanded_context_server_tools
215                                                .entry(context_server_id.clone())
216                                                .or_insert(false);
217
218                                            *is_open = !*is_open;
219                                        }
220                                    })),
221                            )
222                            .child(Indicator::dot().color(if is_running {
223                                Color::Success
224                            } else {
225                                Color::Error
226                            }))
227                            .child(Label::new(context_server.id()))
228                            .child(Label::new(format!("{tool_count} tools")).color(Color::Muted)),
229                    )
230                    .map(|parent| {
231                        if !are_tools_expanded {
232                            return parent;
233                        }
234
235                        parent.child(v_flex().children(tools.into_iter().enumerate().map(
236                            |(ix, tool)| {
237                                h_flex()
238                                    .px_2()
239                                    .py_1()
240                                    .when(ix < tool_count - 1, |element| {
241                                        element
242                                            .border_b_1()
243                                            .border_color(cx.theme().colors().border)
244                                    })
245                                    .child(Label::new(tool.name()))
246                            },
247                        )))
248                    })
249            }))
250    }
251}
252
253impl Render for AssistantConfiguration {
254    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
255        let providers = LanguageModelRegistry::read_global(cx).providers();
256
257        v_flex()
258            .id("assistant-configuration")
259            .track_focus(&self.focus_handle(cx))
260            .bg(cx.theme().colors().panel_background)
261            .size_full()
262            .overflow_y_scroll()
263            .child(
264                v_flex()
265                    .p(DynamicSpacing::Base16.rems(cx))
266                    .gap_2()
267                    .child(
268                        v_flex()
269                            .gap_0p5()
270                            .child(Headline::new("Prompt Library").size(HeadlineSize::Small))
271                            .child(
272                                Label::new("Create reusable prompts and tag which ones you want sent in every LLM interaction.")
273                                    .color(Color::Muted),
274                            ),
275                    )
276                    .child(
277                        Button::new("open-prompt-library", "Open Prompt Library")
278                            .style(ButtonStyle::Filled)
279                            .layer(ElevationIndex::ModalSurface)
280                            .full_width()
281                            .icon(IconName::Book)
282                            .icon_size(IconSize::Small)
283                            .icon_position(IconPosition::Start)
284                            .on_click(|_event, window, cx| {
285                                window.dispatch_action(DeployPromptLibrary.boxed_clone(), cx)
286                            }),
287                    ),
288            )
289            .child(Divider::horizontal().color(DividerColor::Border))
290            .child(self.render_context_servers_section(cx))
291            .child(Divider::horizontal().color(DividerColor::Border))
292            .child(
293                v_flex()
294                    .p(DynamicSpacing::Base16.rems(cx))
295                    .mt_1()
296                    .gap_6()
297                    .flex_1()
298                    .child(
299                        v_flex()
300                            .gap_0p5()
301                            .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
302                            .child(
303                                Label::new("Add at least one provider to use AI-powered features.")
304                                    .color(Color::Muted),
305                            ),
306                    )
307                    .children(
308                        providers
309                            .into_iter()
310                            .map(|provider| self.render_provider_configuration(&provider, cx)),
311                    ),
312            )
313    }
314}