1mod add_context_server_modal;
2mod manage_profiles_modal;
3mod tool_picker;
4
5use std::sync::Arc;
6
7use assistant_settings::AssistantSettings;
8use assistant_tool::{ToolSource, ToolWorkingSet};
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use fs::Fs;
12use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
13use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
14use settings::{Settings, update_settings_file};
15use ui::{
16 Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip, prelude::*,
17};
18use util::ResultExt as _;
19use zed_actions::ExtensionCategoryFilter;
20
21pub(crate) use add_context_server_modal::AddContextServerModal;
22pub(crate) use manage_profiles_modal::ManageProfilesModal;
23
24use crate::AddContextServer;
25
26pub struct AssistantConfiguration {
27 fs: Arc<dyn Fs>,
28 focus_handle: FocusHandle,
29 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
30 context_server_manager: Entity<ContextServerManager>,
31 expanded_context_server_tools: HashMap<Arc<str>, bool>,
32 tools: Entity<ToolWorkingSet>,
33 _registry_subscription: Subscription,
34}
35
36impl AssistantConfiguration {
37 pub fn new(
38 fs: Arc<dyn Fs>,
39 context_server_manager: Entity<ContextServerManager>,
40 tools: Entity<ToolWorkingSet>,
41 window: &mut Window,
42 cx: &mut Context<Self>,
43 ) -> Self {
44 let focus_handle = cx.focus_handle();
45
46 let registry_subscription = cx.subscribe_in(
47 &LanguageModelRegistry::global(cx),
48 window,
49 |this, _, event: &language_model::Event, window, cx| match event {
50 language_model::Event::AddedProvider(provider_id) => {
51 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
52 if let Some(provider) = provider {
53 this.add_provider_configuration_view(&provider, window, cx);
54 }
55 }
56 language_model::Event::RemovedProvider(provider_id) => {
57 this.remove_provider_configuration_view(provider_id);
58 }
59 _ => {}
60 },
61 );
62
63 let mut this = Self {
64 fs,
65 focus_handle,
66 configuration_views_by_provider: HashMap::default(),
67 context_server_manager,
68 expanded_context_server_tools: HashMap::default(),
69 tools,
70 _registry_subscription: registry_subscription,
71 };
72 this.build_provider_configuration_views(window, cx);
73 this
74 }
75
76 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
77 let providers = LanguageModelRegistry::read_global(cx).providers();
78 for provider in providers {
79 self.add_provider_configuration_view(&provider, window, cx);
80 }
81 }
82
83 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
84 self.configuration_views_by_provider.remove(provider_id);
85 }
86
87 fn add_provider_configuration_view(
88 &mut self,
89 provider: &Arc<dyn LanguageModelProvider>,
90 window: &mut Window,
91 cx: &mut Context<Self>,
92 ) {
93 let configuration_view = provider.configuration_view(window, cx);
94 self.configuration_views_by_provider
95 .insert(provider.id(), configuration_view);
96 }
97}
98
99impl Focusable for AssistantConfiguration {
100 fn focus_handle(&self, _: &App) -> FocusHandle {
101 self.focus_handle.clone()
102 }
103}
104
105pub enum AssistantConfigurationEvent {
106 NewThread(Arc<dyn LanguageModelProvider>),
107}
108
109impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
110
111impl AssistantConfiguration {
112 fn render_provider_configuration(
113 &mut self,
114 provider: &Arc<dyn LanguageModelProvider>,
115 cx: &mut Context<Self>,
116 ) -> impl IntoElement + use<> {
117 let provider_id = provider.id().0.clone();
118 let provider_name = provider.name().0.clone();
119 let configuration_view = self
120 .configuration_views_by_provider
121 .get(&provider.id())
122 .cloned();
123
124 v_flex()
125 .gap_1p5()
126 .child(
127 h_flex()
128 .justify_between()
129 .child(
130 h_flex()
131 .gap_2()
132 .child(
133 Icon::new(provider.icon())
134 .size(IconSize::Small)
135 .color(Color::Muted),
136 )
137 .child(Label::new(provider_name.clone())),
138 )
139 .when(provider.is_authenticated(cx), |parent| {
140 parent.child(
141 Button::new(
142 SharedString::from(format!("new-thread-{provider_id}")),
143 "Start New Thread",
144 )
145 .icon_position(IconPosition::Start)
146 .icon(IconName::Plus)
147 .icon_size(IconSize::Small)
148 .style(ButtonStyle::Filled)
149 .layer(ElevationIndex::ModalSurface)
150 .label_size(LabelSize::Small)
151 .on_click(cx.listener({
152 let provider = provider.clone();
153 move |_this, _event, _window, cx| {
154 cx.emit(AssistantConfigurationEvent::NewThread(
155 provider.clone(),
156 ))
157 }
158 })),
159 )
160 }),
161 )
162 .child(
163 div()
164 .p(DynamicSpacing::Base08.rems(cx))
165 .bg(cx.theme().colors().editor_background)
166 .border_1()
167 .border_color(cx.theme().colors().border_variant)
168 .rounded_sm()
169 .map(|parent| match configuration_view {
170 Some(configuration_view) => parent.child(configuration_view),
171 None => parent.child(div().child(Label::new(format!(
172 "No configuration view for {provider_name}",
173 )))),
174 }),
175 )
176 }
177
178 fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
179 let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
180
181 const HEADING: &str = "Allow running tools without asking for confirmation";
182
183 v_flex()
184 .p(DynamicSpacing::Base16.rems(cx))
185 .gap_2()
186 .flex_1()
187 .child(Headline::new("General Settings").size(HeadlineSize::Small))
188 .child(
189 h_flex()
190 .p_2p5()
191 .rounded_sm()
192 .bg(cx.theme().colors().editor_background)
193 .border_1()
194 .border_color(cx.theme().colors().border)
195 .gap_4()
196 .justify_between()
197 .flex_wrap()
198 .child(
199 v_flex()
200 .gap_0p5()
201 .max_w_5_6()
202 .child(Label::new(HEADING))
203 .child(Label::new("When enabled, the agent can perform potentially destructive actions without asking for your confirmation.").color(Color::Muted)),
204 )
205 .child(
206 Switch::new(
207 "always-allow-tool-actions-switch",
208 always_allow_tool_actions.into(),
209 )
210 .on_click({
211 let fs = self.fs.clone();
212 move |state, _window, cx| {
213 let allow = state == &ToggleState::Selected;
214 update_settings_file::<AssistantSettings>(
215 fs.clone(),
216 cx,
217 move |settings, _| {
218 settings.set_always_allow_tool_actions(allow);
219 },
220 );
221 }
222 }),
223 ),
224 )
225 }
226
227 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
228 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
229 let tools_by_source = self.tools.read(cx).tools_by_source(cx);
230 let empty = Vec::new();
231
232 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
233
234 v_flex()
235 .p(DynamicSpacing::Base16.rems(cx))
236 .gap_2()
237 .flex_1()
238 .child(
239 v_flex()
240 .gap_0p5()
241 .child(
242 Headline::new("Model Context Protocol (MCP) Servers")
243 .size(HeadlineSize::Small),
244 )
245 .child(Label::new(SUBHEADING).color(Color::Muted)),
246 )
247 .children(context_servers.into_iter().map(|context_server| {
248 let is_running = context_server.client().is_some();
249 let are_tools_expanded = self
250 .expanded_context_server_tools
251 .get(&context_server.id())
252 .copied()
253 .unwrap_or_default();
254
255 let tools = tools_by_source
256 .get(&ToolSource::ContextServer {
257 id: context_server.id().into(),
258 })
259 .unwrap_or_else(|| &empty);
260 let tool_count = tools.len();
261
262 v_flex()
263 .id(SharedString::from(context_server.id()))
264 .border_1()
265 .rounded_sm()
266 .border_color(cx.theme().colors().border)
267 .bg(cx.theme().colors().editor_background)
268 .child(
269 h_flex()
270 .p_1()
271 .justify_between()
272 .when(are_tools_expanded && tool_count > 1, |element| {
273 element
274 .border_b_1()
275 .border_color(cx.theme().colors().border)
276 })
277 .child(
278 h_flex()
279 .gap_2()
280 .child(
281 Disclosure::new("tool-list-disclosure", are_tools_expanded)
282 .disabled(tool_count == 0)
283 .on_click(cx.listener({
284 let context_server_id = context_server.id();
285 move |this, _event, _window, _cx| {
286 let is_open = this
287 .expanded_context_server_tools
288 .entry(context_server_id.clone())
289 .or_insert(false);
290
291 *is_open = !*is_open;
292 }
293 })),
294 )
295 .child(Indicator::dot().color(if is_running {
296 Color::Success
297 } else {
298 Color::Error
299 }))
300 .child(Label::new(context_server.id()))
301 .child(
302 Label::new(format!("{tool_count} tools"))
303 .color(Color::Muted)
304 .size(LabelSize::Small),
305 ),
306 )
307 .child(
308 Switch::new("context-server-switch", is_running.into()).on_click({
309 let context_server_manager =
310 self.context_server_manager.clone();
311 let context_server = context_server.clone();
312 move |state, _window, cx| match state {
313 ToggleState::Unselected | ToggleState::Indeterminate => {
314 context_server_manager.update(cx, |this, cx| {
315 this.stop_server(context_server.clone(), cx)
316 .log_err();
317 });
318 }
319 ToggleState::Selected => {
320 cx.spawn({
321 let context_server_manager =
322 context_server_manager.clone();
323 let context_server = context_server.clone();
324 async move |cx| {
325 if let Some(start_server_task) =
326 context_server_manager
327 .update(cx, |this, cx| {
328 this.start_server(
329 context_server,
330 cx,
331 )
332 })
333 .log_err()
334 {
335 start_server_task.await.log_err();
336 }
337 }
338 })
339 .detach();
340 }
341 }
342 }),
343 ),
344 )
345 .map(|parent| {
346 if !are_tools_expanded {
347 return parent;
348 }
349
350 parent.child(v_flex().children(tools.into_iter().enumerate().map(
351 |(ix, tool)| {
352 h_flex()
353 .id("tool-item")
354 .pl_2()
355 .pr_1()
356 .py_1()
357 .gap_2()
358 .justify_between()
359 .when(ix < tool_count - 1, |element| {
360 element
361 .border_b_1()
362 .border_color(cx.theme().colors().border_variant)
363 })
364 .child(
365 Label::new(tool.name())
366 .buffer_font(cx)
367 .size(LabelSize::Small),
368 )
369 .child(
370 IconButton::new(("tool-description", ix), IconName::Info)
371 .shape(ui::IconButtonShape::Square)
372 .icon_size(IconSize::Small)
373 .icon_color(Color::Ignored)
374 .tooltip(Tooltip::text(tool.description())),
375 )
376 },
377 )))
378 })
379 }))
380 .child(
381 h_flex()
382 .justify_between()
383 .gap_2()
384 .child(
385 h_flex().w_full().child(
386 Button::new("add-context-server", "Add MCPs Directly")
387 .style(ButtonStyle::Filled)
388 .layer(ElevationIndex::ModalSurface)
389 .full_width()
390 .icon(IconName::Plus)
391 .icon_size(IconSize::Small)
392 .icon_position(IconPosition::Start)
393 .on_click(|_event, window, cx| {
394 window.dispatch_action(AddContextServer.boxed_clone(), cx)
395 }),
396 ),
397 )
398 .child(
399 h_flex().w_full().child(
400 Button::new(
401 "install-context-server-extensions",
402 "Install MCP Extensions",
403 )
404 .style(ButtonStyle::Filled)
405 .layer(ElevationIndex::ModalSurface)
406 .full_width()
407 .icon(IconName::DatabaseZap)
408 .icon_size(IconSize::Small)
409 .icon_position(IconPosition::Start)
410 .on_click(|_event, window, cx| {
411 window.dispatch_action(
412 zed_actions::Extensions {
413 category_filter: Some(
414 ExtensionCategoryFilter::ContextServers,
415 ),
416 }
417 .boxed_clone(),
418 cx,
419 )
420 }),
421 ),
422 ),
423 )
424 }
425}
426
427impl Render for AssistantConfiguration {
428 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
429 let providers = LanguageModelRegistry::read_global(cx).providers();
430
431 v_flex()
432 .id("assistant-configuration")
433 .key_context("AgentConfiguration")
434 .track_focus(&self.focus_handle(cx))
435 .bg(cx.theme().colors().panel_background)
436 .size_full()
437 .overflow_y_scroll()
438 .child(self.render_command_permission(cx))
439 .child(Divider::horizontal().color(DividerColor::Border))
440 .child(self.render_context_servers_section(cx))
441 .child(Divider::horizontal().color(DividerColor::Border))
442 .child(
443 v_flex()
444 .p(DynamicSpacing::Base16.rems(cx))
445 .mt_1()
446 .gap_6()
447 .flex_1()
448 .child(
449 v_flex()
450 .gap_0p5()
451 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
452 .child(
453 Label::new("Add at least one provider to use AI-powered features.")
454 .color(Color::Muted),
455 ),
456 )
457 .children(
458 providers
459 .into_iter()
460 .map(|provider| self.render_provider_configuration(&provider, cx)),
461 ),
462 )
463 }
464}