1mod add_context_server_modal;
2mod manage_profiles_modal;
3mod tool_picker;
4
5use std::sync::Arc;
6
7use assistant_settings::AssistantSettings;
8use assistant_tool::{ToolSource, ToolWorkingSet};
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use fs::Fs;
12use gpui::{
13 Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
14};
15use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
16use settings::{Settings, update_settings_file};
17use ui::{
18 Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
19 Switch, Tooltip, prelude::*,
20};
21use util::ResultExt as _;
22use zed_actions::ExtensionCategoryFilter;
23
24pub(crate) use add_context_server_modal::AddContextServerModal;
25pub(crate) use manage_profiles_modal::ManageProfilesModal;
26
27use crate::AddContextServer;
28
29pub struct AssistantConfiguration {
30 fs: Arc<dyn Fs>,
31 focus_handle: FocusHandle,
32 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
33 context_server_manager: Entity<ContextServerManager>,
34 expanded_context_server_tools: HashMap<Arc<str>, bool>,
35 tools: Entity<ToolWorkingSet>,
36 _registry_subscription: Subscription,
37 scroll_handle: ScrollHandle,
38 scrollbar_state: ScrollbarState,
39}
40
41impl AssistantConfiguration {
42 pub fn new(
43 fs: Arc<dyn Fs>,
44 context_server_manager: Entity<ContextServerManager>,
45 tools: Entity<ToolWorkingSet>,
46 window: &mut Window,
47 cx: &mut Context<Self>,
48 ) -> Self {
49 let focus_handle = cx.focus_handle();
50
51 let registry_subscription = cx.subscribe_in(
52 &LanguageModelRegistry::global(cx),
53 window,
54 |this, _, event: &language_model::Event, window, cx| match event {
55 language_model::Event::AddedProvider(provider_id) => {
56 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
57 if let Some(provider) = provider {
58 this.add_provider_configuration_view(&provider, window, cx);
59 }
60 }
61 language_model::Event::RemovedProvider(provider_id) => {
62 this.remove_provider_configuration_view(provider_id);
63 }
64 _ => {}
65 },
66 );
67
68 let scroll_handle = ScrollHandle::new();
69 let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
70
71 let mut this = Self {
72 fs,
73 focus_handle,
74 configuration_views_by_provider: HashMap::default(),
75 context_server_manager,
76 expanded_context_server_tools: HashMap::default(),
77 tools,
78 _registry_subscription: registry_subscription,
79 scroll_handle,
80 scrollbar_state,
81 };
82 this.build_provider_configuration_views(window, cx);
83 this
84 }
85
86 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
87 let providers = LanguageModelRegistry::read_global(cx).providers();
88 for provider in providers {
89 self.add_provider_configuration_view(&provider, window, cx);
90 }
91 }
92
93 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
94 self.configuration_views_by_provider.remove(provider_id);
95 }
96
97 fn add_provider_configuration_view(
98 &mut self,
99 provider: &Arc<dyn LanguageModelProvider>,
100 window: &mut Window,
101 cx: &mut Context<Self>,
102 ) {
103 let configuration_view = provider.configuration_view(window, cx);
104 self.configuration_views_by_provider
105 .insert(provider.id(), configuration_view);
106 }
107}
108
109impl Focusable for AssistantConfiguration {
110 fn focus_handle(&self, _: &App) -> FocusHandle {
111 self.focus_handle.clone()
112 }
113}
114
115pub enum AssistantConfigurationEvent {
116 NewThread(Arc<dyn LanguageModelProvider>),
117}
118
119impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
120
121impl AssistantConfiguration {
122 fn render_provider_configuration_block(
123 &mut self,
124 provider: &Arc<dyn LanguageModelProvider>,
125 cx: &mut Context<Self>,
126 ) -> impl IntoElement + use<> {
127 let provider_id = provider.id().0.clone();
128 let provider_name = provider.name().0.clone();
129 let configuration_view = self
130 .configuration_views_by_provider
131 .get(&provider.id())
132 .cloned();
133
134 v_flex()
135 .gap_1p5()
136 .child(
137 h_flex()
138 .justify_between()
139 .child(
140 h_flex()
141 .gap_2()
142 .child(
143 Icon::new(provider.icon())
144 .size(IconSize::Small)
145 .color(Color::Muted),
146 )
147 .child(Label::new(provider_name.clone())),
148 )
149 .when(provider.is_authenticated(cx), |parent| {
150 parent.child(
151 Button::new(
152 SharedString::from(format!("new-thread-{provider_id}")),
153 "Start New Thread",
154 )
155 .icon_position(IconPosition::Start)
156 .icon(IconName::Plus)
157 .icon_size(IconSize::Small)
158 .style(ButtonStyle::Filled)
159 .layer(ElevationIndex::ModalSurface)
160 .label_size(LabelSize::Small)
161 .on_click(cx.listener({
162 let provider = provider.clone();
163 move |_this, _event, _window, cx| {
164 cx.emit(AssistantConfigurationEvent::NewThread(
165 provider.clone(),
166 ))
167 }
168 })),
169 )
170 }),
171 )
172 .child(
173 div()
174 .p(DynamicSpacing::Base08.rems(cx))
175 .bg(cx.theme().colors().editor_background)
176 .border_1()
177 .border_color(cx.theme().colors().border)
178 .rounded_sm()
179 .map(|parent| match configuration_view {
180 Some(configuration_view) => parent.child(configuration_view),
181 None => parent.child(div().child(Label::new(format!(
182 "No configuration view for {provider_name}",
183 )))),
184 }),
185 )
186 }
187
188 fn render_provider_configuration_section(
189 &mut self,
190 cx: &mut Context<Self>,
191 ) -> impl IntoElement {
192 let providers = LanguageModelRegistry::read_global(cx).providers();
193
194 v_flex()
195 .p(DynamicSpacing::Base16.rems(cx))
196 .pr(DynamicSpacing::Base20.rems(cx))
197 .gap_4()
198 .flex_1()
199 .child(
200 v_flex()
201 .gap_0p5()
202 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
203 .child(
204 Label::new("Add at least one provider to use AI-powered features.")
205 .color(Color::Muted),
206 ),
207 )
208 .children(
209 providers
210 .into_iter()
211 .map(|provider| self.render_provider_configuration_block(&provider, cx)),
212 )
213 }
214
215 fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
216 let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
217
218 const HEADING: &str = "Allow running tools without asking for confirmation";
219
220 v_flex()
221 .p(DynamicSpacing::Base16.rems(cx))
222 .pr(DynamicSpacing::Base20.rems(cx))
223 .gap_2()
224 .flex_1()
225 .child(Headline::new("General Settings").size(HeadlineSize::Small))
226 .child(
227 h_flex()
228 .p_2p5()
229 .rounded_sm()
230 .bg(cx.theme().colors().editor_background)
231 .border_1()
232 .border_color(cx.theme().colors().border)
233 .gap_4()
234 .justify_between()
235 .flex_wrap()
236 .child(
237 v_flex()
238 .gap_0p5()
239 .max_w_5_6()
240 .child(Label::new(HEADING))
241 .child(Label::new("When enabled, the agent can perform potentially destructive actions without asking for your confirmation.").color(Color::Muted)),
242 )
243 .child(
244 Switch::new(
245 "always-allow-tool-actions-switch",
246 always_allow_tool_actions.into(),
247 )
248 .on_click({
249 let fs = self.fs.clone();
250 move |state, _window, cx| {
251 let allow = state == &ToggleState::Selected;
252 update_settings_file::<AssistantSettings>(
253 fs.clone(),
254 cx,
255 move |settings, _| {
256 settings.set_always_allow_tool_actions(allow);
257 },
258 );
259 }
260 }),
261 ),
262 )
263 }
264
265 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
266 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
267 let tools_by_source = self.tools.read(cx).tools_by_source(cx);
268 let empty = Vec::new();
269
270 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
271
272 v_flex()
273 .p(DynamicSpacing::Base16.rems(cx))
274 .pr(DynamicSpacing::Base20.rems(cx))
275 .gap_2()
276 .flex_1()
277 .child(
278 v_flex()
279 .gap_0p5()
280 .child(
281 Headline::new("Model Context Protocol (MCP) Servers")
282 .size(HeadlineSize::Small),
283 )
284 .child(Label::new(SUBHEADING).color(Color::Muted)),
285 )
286 .children(context_servers.into_iter().map(|context_server| {
287 let is_running = context_server.client().is_some();
288 let are_tools_expanded = self
289 .expanded_context_server_tools
290 .get(&context_server.id())
291 .copied()
292 .unwrap_or_default();
293
294 let tools = tools_by_source
295 .get(&ToolSource::ContextServer {
296 id: context_server.id().into(),
297 })
298 .unwrap_or_else(|| &empty);
299 let tool_count = tools.len();
300
301 v_flex()
302 .id(SharedString::from(context_server.id()))
303 .border_1()
304 .rounded_sm()
305 .border_color(cx.theme().colors().border)
306 .bg(cx.theme().colors().editor_background)
307 .child(
308 h_flex()
309 .p_1()
310 .justify_between()
311 .when(are_tools_expanded && tool_count > 1, |element| {
312 element
313 .border_b_1()
314 .border_color(cx.theme().colors().border)
315 })
316 .child(
317 h_flex()
318 .gap_2()
319 .child(
320 Disclosure::new("tool-list-disclosure", are_tools_expanded)
321 .disabled(tool_count == 0)
322 .on_click(cx.listener({
323 let context_server_id = context_server.id();
324 move |this, _event, _window, _cx| {
325 let is_open = this
326 .expanded_context_server_tools
327 .entry(context_server_id.clone())
328 .or_insert(false);
329
330 *is_open = !*is_open;
331 }
332 })),
333 )
334 .child(Indicator::dot().color(if is_running {
335 Color::Success
336 } else {
337 Color::Error
338 }))
339 .child(Label::new(context_server.id()))
340 .child(
341 Label::new(format!("{tool_count} tools"))
342 .color(Color::Muted)
343 .size(LabelSize::Small),
344 ),
345 )
346 .child(
347 Switch::new("context-server-switch", is_running.into()).on_click({
348 let context_server_manager =
349 self.context_server_manager.clone();
350 let context_server = context_server.clone();
351 move |state, _window, cx| match state {
352 ToggleState::Unselected | ToggleState::Indeterminate => {
353 context_server_manager.update(cx, |this, cx| {
354 this.stop_server(context_server.clone(), cx)
355 .log_err();
356 });
357 }
358 ToggleState::Selected => {
359 cx.spawn({
360 let context_server_manager =
361 context_server_manager.clone();
362 let context_server = context_server.clone();
363 async move |cx| {
364 if let Some(start_server_task) =
365 context_server_manager
366 .update(cx, |this, cx| {
367 this.start_server(
368 context_server,
369 cx,
370 )
371 })
372 .log_err()
373 {
374 start_server_task.await.log_err();
375 }
376 }
377 })
378 .detach();
379 }
380 }
381 }),
382 ),
383 )
384 .map(|parent| {
385 if !are_tools_expanded {
386 return parent;
387 }
388
389 parent.child(v_flex().children(tools.into_iter().enumerate().map(
390 |(ix, tool)| {
391 h_flex()
392 .id("tool-item")
393 .pl_2()
394 .pr_1()
395 .py_1()
396 .gap_2()
397 .justify_between()
398 .when(ix < tool_count - 1, |element| {
399 element
400 .border_b_1()
401 .border_color(cx.theme().colors().border_variant)
402 })
403 .child(
404 Label::new(tool.name())
405 .buffer_font(cx)
406 .size(LabelSize::Small),
407 )
408 .child(
409 IconButton::new(("tool-description", ix), IconName::Info)
410 .shape(ui::IconButtonShape::Square)
411 .icon_size(IconSize::Small)
412 .icon_color(Color::Ignored)
413 .tooltip(Tooltip::text(tool.description())),
414 )
415 },
416 )))
417 })
418 }))
419 .child(
420 h_flex()
421 .justify_between()
422 .gap_2()
423 .child(
424 h_flex().w_full().child(
425 Button::new("add-context-server", "Add MCPs Directly")
426 .style(ButtonStyle::Filled)
427 .layer(ElevationIndex::ModalSurface)
428 .full_width()
429 .icon(IconName::Plus)
430 .icon_size(IconSize::Small)
431 .icon_position(IconPosition::Start)
432 .on_click(|_event, window, cx| {
433 window.dispatch_action(AddContextServer.boxed_clone(), cx)
434 }),
435 ),
436 )
437 .child(
438 h_flex().w_full().child(
439 Button::new(
440 "install-context-server-extensions",
441 "Install MCP Extensions",
442 )
443 .style(ButtonStyle::Filled)
444 .layer(ElevationIndex::ModalSurface)
445 .full_width()
446 .icon(IconName::DatabaseZap)
447 .icon_size(IconSize::Small)
448 .icon_position(IconPosition::Start)
449 .on_click(|_event, window, cx| {
450 window.dispatch_action(
451 zed_actions::Extensions {
452 category_filter: Some(
453 ExtensionCategoryFilter::ContextServers,
454 ),
455 }
456 .boxed_clone(),
457 cx,
458 )
459 }),
460 ),
461 ),
462 )
463 }
464}
465
466impl Render for AssistantConfiguration {
467 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
468 v_flex()
469 .id("assistant-configuration")
470 .key_context("AgentConfiguration")
471 .track_focus(&self.focus_handle(cx))
472 .relative()
473 .size_full()
474 .pb_8()
475 .bg(cx.theme().colors().panel_background)
476 .child(
477 v_flex()
478 .id("assistant-configuration-content")
479 .track_scroll(&self.scroll_handle)
480 .size_full()
481 .overflow_y_scroll()
482 .child(self.render_command_permission(cx))
483 .child(Divider::horizontal().color(DividerColor::Border))
484 .child(self.render_context_servers_section(cx))
485 .child(Divider::horizontal().color(DividerColor::Border))
486 .child(self.render_provider_configuration_section(cx)),
487 )
488 .child(
489 div()
490 .id("assistant-configuration-scrollbar")
491 .occlude()
492 .absolute()
493 .right(px(3.))
494 .top_0()
495 .bottom_0()
496 .pb_6()
497 .w(px(12.))
498 .cursor_default()
499 .on_mouse_move(cx.listener(|_, _, _window, cx| {
500 cx.notify();
501 cx.stop_propagation()
502 }))
503 .on_hover(|_, _window, cx| {
504 cx.stop_propagation();
505 })
506 .on_any_mouse_down(|_, _window, cx| {
507 cx.stop_propagation();
508 })
509 .on_scroll_wheel(cx.listener(|_, _, _window, cx| {
510 cx.notify();
511 }))
512 .children(Scrollbar::vertical(self.scrollbar_state.clone())),
513 )
514 }
515}