1mod add_context_server_modal;
2
3use std::sync::Arc;
4
5use assistant_tool::{ToolSource, ToolWorkingSet};
6use collections::HashMap;
7use context_server::manager::ContextServerManager;
8use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
9use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
10use ui::{prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch};
11use util::ResultExt as _;
12use zed_actions::ExtensionCategoryFilter;
13
14pub(crate) use add_context_server_modal::AddContextServerModal;
15
16use crate::AddContextServer;
17
18pub struct AssistantConfiguration {
19 focus_handle: FocusHandle,
20 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
21 context_server_manager: Entity<ContextServerManager>,
22 expanded_context_server_tools: HashMap<Arc<str>, bool>,
23 tools: Arc<ToolWorkingSet>,
24 _registry_subscription: Subscription,
25}
26
27impl AssistantConfiguration {
28 pub fn new(
29 context_server_manager: Entity<ContextServerManager>,
30 tools: Arc<ToolWorkingSet>,
31 window: &mut Window,
32 cx: &mut Context<Self>,
33 ) -> Self {
34 let focus_handle = cx.focus_handle();
35
36 let registry_subscription = cx.subscribe_in(
37 &LanguageModelRegistry::global(cx),
38 window,
39 |this, _, event: &language_model::Event, window, cx| match event {
40 language_model::Event::AddedProvider(provider_id) => {
41 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
42 if let Some(provider) = provider {
43 this.add_provider_configuration_view(&provider, window, cx);
44 }
45 }
46 language_model::Event::RemovedProvider(provider_id) => {
47 this.remove_provider_configuration_view(provider_id);
48 }
49 _ => {}
50 },
51 );
52
53 let mut this = Self {
54 focus_handle,
55 configuration_views_by_provider: HashMap::default(),
56 context_server_manager,
57 expanded_context_server_tools: HashMap::default(),
58 tools,
59 _registry_subscription: registry_subscription,
60 };
61 this.build_provider_configuration_views(window, cx);
62 this
63 }
64
65 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
66 let providers = LanguageModelRegistry::read_global(cx).providers();
67 for provider in providers {
68 self.add_provider_configuration_view(&provider, window, cx);
69 }
70 }
71
72 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
73 self.configuration_views_by_provider.remove(provider_id);
74 }
75
76 fn add_provider_configuration_view(
77 &mut self,
78 provider: &Arc<dyn LanguageModelProvider>,
79 window: &mut Window,
80 cx: &mut Context<Self>,
81 ) {
82 let configuration_view = provider.configuration_view(window, cx);
83 self.configuration_views_by_provider
84 .insert(provider.id(), configuration_view);
85 }
86}
87
88impl Focusable for AssistantConfiguration {
89 fn focus_handle(&self, _: &App) -> FocusHandle {
90 self.focus_handle.clone()
91 }
92}
93
94pub enum AssistantConfigurationEvent {
95 NewThread(Arc<dyn LanguageModelProvider>),
96}
97
98impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
99
100impl AssistantConfiguration {
101 fn render_provider_configuration(
102 &mut self,
103 provider: &Arc<dyn LanguageModelProvider>,
104 cx: &mut Context<Self>,
105 ) -> impl IntoElement {
106 let provider_id = provider.id().0.clone();
107 let provider_name = provider.name().0.clone();
108 let configuration_view = self
109 .configuration_views_by_provider
110 .get(&provider.id())
111 .cloned();
112
113 v_flex()
114 .gap_1p5()
115 .child(
116 h_flex()
117 .justify_between()
118 .child(
119 h_flex()
120 .gap_2()
121 .child(
122 Icon::new(provider.icon())
123 .size(IconSize::Small)
124 .color(Color::Muted),
125 )
126 .child(Label::new(provider_name.clone())),
127 )
128 .when(provider.is_authenticated(cx), |parent| {
129 parent.child(
130 Button::new(
131 SharedString::from(format!("new-thread-{provider_id}")),
132 "Start New Thread",
133 )
134 .icon_position(IconPosition::Start)
135 .icon(IconName::Plus)
136 .icon_size(IconSize::Small)
137 .style(ButtonStyle::Filled)
138 .layer(ElevationIndex::ModalSurface)
139 .label_size(LabelSize::Small)
140 .on_click(cx.listener({
141 let provider = provider.clone();
142 move |_this, _event, _window, cx| {
143 cx.emit(AssistantConfigurationEvent::NewThread(
144 provider.clone(),
145 ))
146 }
147 })),
148 )
149 }),
150 )
151 .child(
152 div()
153 .p(DynamicSpacing::Base08.rems(cx))
154 .bg(cx.theme().colors().editor_background)
155 .border_1()
156 .border_color(cx.theme().colors().border_variant)
157 .rounded_sm()
158 .map(|parent| match configuration_view {
159 Some(configuration_view) => parent.child(configuration_view),
160 None => parent.child(div().child(Label::new(format!(
161 "No configuration view for {provider_name}",
162 )))),
163 }),
164 )
165 }
166
167 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
168 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
169 let tools_by_source = self.tools.tools_by_source(cx);
170 let empty = Vec::new();
171
172 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
173
174 v_flex()
175 .p(DynamicSpacing::Base16.rems(cx))
176 .gap_2()
177 .flex_1()
178 .child(
179 v_flex()
180 .gap_0p5()
181 .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
182 .child(Label::new(SUBHEADING).color(Color::Muted)),
183 )
184 .children(context_servers.into_iter().map(|context_server| {
185 let is_running = context_server.client().is_some();
186 let are_tools_expanded = self
187 .expanded_context_server_tools
188 .get(&context_server.id())
189 .copied()
190 .unwrap_or_default();
191
192 let tools = tools_by_source
193 .get(&ToolSource::ContextServer {
194 id: context_server.id().into(),
195 })
196 .unwrap_or_else(|| &empty);
197 let tool_count = tools.len();
198
199 v_flex()
200 .id(SharedString::from(context_server.id()))
201 .border_1()
202 .rounded_sm()
203 .border_color(cx.theme().colors().border)
204 .bg(cx.theme().colors().editor_background)
205 .child(
206 h_flex()
207 .justify_between()
208 .px_2()
209 .py_1()
210 .when(are_tools_expanded, |element| {
211 element
212 .border_b_1()
213 .border_color(cx.theme().colors().border)
214 })
215 .child(
216 h_flex()
217 .gap_2()
218 .child(
219 Disclosure::new("tool-list-disclosure", are_tools_expanded)
220 .on_click(cx.listener({
221 let context_server_id = context_server.id();
222 move |this, _event, _window, _cx| {
223 let is_open = this
224 .expanded_context_server_tools
225 .entry(context_server_id.clone())
226 .or_insert(false);
227
228 *is_open = !*is_open;
229 }
230 })),
231 )
232 .child(Indicator::dot().color(if is_running {
233 Color::Success
234 } else {
235 Color::Error
236 }))
237 .child(Label::new(context_server.id()))
238 .child(
239 Label::new(format!("{tool_count} tools"))
240 .color(Color::Muted),
241 ),
242 )
243 .child(h_flex().child(
244 Switch::new("context-server-switch", is_running.into()).on_click({
245 let context_server_manager =
246 self.context_server_manager.clone();
247 let context_server = context_server.clone();
248 move |state, _window, cx| match state {
249 ToggleState::Unselected | ToggleState::Indeterminate => {
250 context_server_manager.update(cx, |this, cx| {
251 this.stop_server(context_server.clone(), cx)
252 .log_err();
253 });
254 }
255 ToggleState::Selected => {
256 cx.spawn({
257 let context_server_manager =
258 context_server_manager.clone();
259 let context_server = context_server.clone();
260 async move |cx| {
261 if let Some(start_server_task) =
262 context_server_manager
263 .update(cx, |this, cx| {
264 this.start_server(
265 context_server,
266 cx,
267 )
268 })
269 .log_err()
270 {
271 start_server_task.await.log_err();
272 }
273 }
274 })
275 .detach();
276 }
277 }
278 }),
279 )),
280 )
281 .map(|parent| {
282 if !are_tools_expanded {
283 return parent;
284 }
285
286 parent.child(v_flex().children(tools.into_iter().enumerate().map(
287 |(ix, tool)| {
288 h_flex()
289 .px_2()
290 .py_1()
291 .when(ix < tool_count - 1, |element| {
292 element
293 .border_b_1()
294 .border_color(cx.theme().colors().border)
295 })
296 .child(Label::new(tool.name()))
297 },
298 )))
299 })
300 }))
301 .child(
302 h_flex()
303 .justify_between()
304 .gap_2()
305 .child(
306 h_flex().w_full().child(
307 Button::new("add-context-server", "Add Context Server")
308 .style(ButtonStyle::Filled)
309 .layer(ElevationIndex::ModalSurface)
310 .full_width()
311 .icon(IconName::Plus)
312 .icon_size(IconSize::Small)
313 .icon_position(IconPosition::Start)
314 .on_click(|_event, window, cx| {
315 window.dispatch_action(AddContextServer.boxed_clone(), cx)
316 }),
317 ),
318 )
319 .child(
320 h_flex().w_full().child(
321 Button::new(
322 "install-context-server-extensions",
323 "Install Context Server Extensions",
324 )
325 .style(ButtonStyle::Filled)
326 .layer(ElevationIndex::ModalSurface)
327 .full_width()
328 .icon(IconName::DatabaseZap)
329 .icon_size(IconSize::Small)
330 .icon_position(IconPosition::Start)
331 .on_click(|_event, window, cx| {
332 window.dispatch_action(
333 zed_actions::Extensions {
334 category_filter: Some(
335 ExtensionCategoryFilter::ContextServers,
336 ),
337 }
338 .boxed_clone(),
339 cx,
340 )
341 }),
342 ),
343 ),
344 )
345 }
346}
347
348impl Render for AssistantConfiguration {
349 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
350 let providers = LanguageModelRegistry::read_global(cx).providers();
351
352 v_flex()
353 .id("assistant-configuration")
354 .track_focus(&self.focus_handle(cx))
355 .bg(cx.theme().colors().panel_background)
356 .size_full()
357 .overflow_y_scroll()
358 .child(self.render_context_servers_section(cx))
359 .child(Divider::horizontal().color(DividerColor::Border))
360 .child(
361 v_flex()
362 .p(DynamicSpacing::Base16.rems(cx))
363 .mt_1()
364 .gap_6()
365 .flex_1()
366 .child(
367 v_flex()
368 .gap_0p5()
369 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
370 .child(
371 Label::new("Add at least one provider to use AI-powered features.")
372 .color(Color::Muted),
373 ),
374 )
375 .children(
376 providers
377 .into_iter()
378 .map(|provider| self.render_provider_configuration(&provider, cx)),
379 ),
380 )
381 }
382}