1use std::sync::Arc;
2
3use assistant_tool::{ToolSource, ToolWorkingSet};
4use collections::HashMap;
5use context_server::manager::ContextServerManager;
6use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
7use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
8use ui::{
9 prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip,
10};
11use util::ResultExt as _;
12use zed_actions::ExtensionCategoryFilter;
13
14pub struct AssistantConfiguration {
15 focus_handle: FocusHandle,
16 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
17 context_server_manager: Entity<ContextServerManager>,
18 expanded_context_server_tools: HashMap<Arc<str>, bool>,
19 tools: Arc<ToolWorkingSet>,
20 _registry_subscription: Subscription,
21}
22
23impl AssistantConfiguration {
24 pub fn new(
25 context_server_manager: Entity<ContextServerManager>,
26 tools: Arc<ToolWorkingSet>,
27 window: &mut Window,
28 cx: &mut Context<Self>,
29 ) -> Self {
30 let focus_handle = cx.focus_handle();
31
32 let registry_subscription = cx.subscribe_in(
33 &LanguageModelRegistry::global(cx),
34 window,
35 |this, _, event: &language_model::Event, window, cx| match event {
36 language_model::Event::AddedProvider(provider_id) => {
37 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
38 if let Some(provider) = provider {
39 this.add_provider_configuration_view(&provider, window, cx);
40 }
41 }
42 language_model::Event::RemovedProvider(provider_id) => {
43 this.remove_provider_configuration_view(provider_id);
44 }
45 _ => {}
46 },
47 );
48
49 let mut this = Self {
50 focus_handle,
51 configuration_views_by_provider: HashMap::default(),
52 context_server_manager,
53 expanded_context_server_tools: HashMap::default(),
54 tools,
55 _registry_subscription: registry_subscription,
56 };
57 this.build_provider_configuration_views(window, cx);
58 this
59 }
60
61 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
62 let providers = LanguageModelRegistry::read_global(cx).providers();
63 for provider in providers {
64 self.add_provider_configuration_view(&provider, window, cx);
65 }
66 }
67
68 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
69 self.configuration_views_by_provider.remove(provider_id);
70 }
71
72 fn add_provider_configuration_view(
73 &mut self,
74 provider: &Arc<dyn LanguageModelProvider>,
75 window: &mut Window,
76 cx: &mut Context<Self>,
77 ) {
78 let configuration_view = provider.configuration_view(window, cx);
79 self.configuration_views_by_provider
80 .insert(provider.id(), configuration_view);
81 }
82}
83
84impl Focusable for AssistantConfiguration {
85 fn focus_handle(&self, _: &App) -> FocusHandle {
86 self.focus_handle.clone()
87 }
88}
89
90pub enum AssistantConfigurationEvent {
91 NewThread(Arc<dyn LanguageModelProvider>),
92}
93
94impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
95
96impl AssistantConfiguration {
97 fn render_provider_configuration(
98 &mut self,
99 provider: &Arc<dyn LanguageModelProvider>,
100 cx: &mut Context<Self>,
101 ) -> impl IntoElement {
102 let provider_id = provider.id().0.clone();
103 let provider_name = provider.name().0.clone();
104 let configuration_view = self
105 .configuration_views_by_provider
106 .get(&provider.id())
107 .cloned();
108
109 v_flex()
110 .gap_1p5()
111 .child(
112 h_flex()
113 .justify_between()
114 .child(
115 h_flex()
116 .gap_2()
117 .child(
118 Icon::new(provider.icon())
119 .size(IconSize::Small)
120 .color(Color::Muted),
121 )
122 .child(Label::new(provider_name.clone())),
123 )
124 .when(provider.is_authenticated(cx), |parent| {
125 parent.child(
126 Button::new(
127 SharedString::from(format!("new-thread-{provider_id}")),
128 "Start New Thread",
129 )
130 .icon_position(IconPosition::Start)
131 .icon(IconName::Plus)
132 .icon_size(IconSize::Small)
133 .style(ButtonStyle::Filled)
134 .layer(ElevationIndex::ModalSurface)
135 .label_size(LabelSize::Small)
136 .on_click(cx.listener({
137 let provider = provider.clone();
138 move |_this, _event, _window, cx| {
139 cx.emit(AssistantConfigurationEvent::NewThread(
140 provider.clone(),
141 ))
142 }
143 })),
144 )
145 }),
146 )
147 .child(
148 div()
149 .p(DynamicSpacing::Base08.rems(cx))
150 .bg(cx.theme().colors().editor_background)
151 .border_1()
152 .border_color(cx.theme().colors().border_variant)
153 .rounded_sm()
154 .map(|parent| match configuration_view {
155 Some(configuration_view) => parent.child(configuration_view),
156 None => parent.child(div().child(Label::new(format!(
157 "No configuration view for {provider_name}",
158 )))),
159 }),
160 )
161 }
162
163 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
164 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
165 let tools_by_source = self.tools.tools_by_source(cx);
166 let empty = Vec::new();
167
168 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
169
170 v_flex()
171 .p(DynamicSpacing::Base16.rems(cx))
172 .gap_2()
173 .flex_1()
174 .child(
175 v_flex()
176 .gap_0p5()
177 .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
178 .child(Label::new(SUBHEADING).color(Color::Muted)),
179 )
180 .children(context_servers.into_iter().map(|context_server| {
181 let is_running = context_server.client().is_some();
182 let are_tools_expanded = self
183 .expanded_context_server_tools
184 .get(&context_server.id())
185 .copied()
186 .unwrap_or_default();
187
188 let tools = tools_by_source
189 .get(&ToolSource::ContextServer {
190 id: context_server.id().into(),
191 })
192 .unwrap_or_else(|| &empty);
193 let tool_count = tools.len();
194
195 v_flex()
196 .id(SharedString::from(context_server.id()))
197 .border_1()
198 .rounded_sm()
199 .border_color(cx.theme().colors().border)
200 .bg(cx.theme().colors().editor_background)
201 .child(
202 h_flex()
203 .justify_between()
204 .px_2()
205 .py_1()
206 .when(are_tools_expanded, |element| {
207 element
208 .border_b_1()
209 .border_color(cx.theme().colors().border)
210 })
211 .child(
212 h_flex()
213 .gap_2()
214 .child(
215 Disclosure::new("tool-list-disclosure", are_tools_expanded)
216 .on_click(cx.listener({
217 let context_server_id = context_server.id();
218 move |this, _event, _window, _cx| {
219 let is_open = this
220 .expanded_context_server_tools
221 .entry(context_server_id.clone())
222 .or_insert(false);
223
224 *is_open = !*is_open;
225 }
226 })),
227 )
228 .child(Indicator::dot().color(if is_running {
229 Color::Success
230 } else {
231 Color::Error
232 }))
233 .child(Label::new(context_server.id()))
234 .child(
235 Label::new(format!("{tool_count} tools"))
236 .color(Color::Muted),
237 ),
238 )
239 .child(h_flex().child(
240 Switch::new("context-server-switch", is_running.into()).on_click({
241 let context_server_manager =
242 self.context_server_manager.clone();
243 let context_server = context_server.clone();
244 move |state, _window, cx| match state {
245 ToggleState::Unselected | ToggleState::Indeterminate => {
246 context_server_manager.update(cx, |this, cx| {
247 this.stop_server(context_server.clone(), cx)
248 .log_err();
249 });
250 }
251 ToggleState::Selected => {
252 cx.spawn({
253 let context_server_manager =
254 context_server_manager.clone();
255 let context_server = context_server.clone();
256 async move |cx| {
257 if let Some(start_server_task) =
258 context_server_manager
259 .update(cx, |this, cx| {
260 this.start_server(
261 context_server,
262 cx,
263 )
264 })
265 .log_err()
266 {
267 start_server_task.await.log_err();
268 }
269 }
270 })
271 .detach();
272 }
273 }
274 }),
275 )),
276 )
277 .map(|parent| {
278 if !are_tools_expanded {
279 return parent;
280 }
281
282 parent.child(v_flex().children(tools.into_iter().enumerate().map(
283 |(ix, tool)| {
284 h_flex()
285 .px_2()
286 .py_1()
287 .when(ix < tool_count - 1, |element| {
288 element
289 .border_b_1()
290 .border_color(cx.theme().colors().border)
291 })
292 .child(Label::new(tool.name()))
293 },
294 )))
295 })
296 }))
297 .child(
298 h_flex()
299 .justify_between()
300 .gap_2()
301 .child(
302 h_flex().w_full().child(
303 Button::new("add-context-server", "Add Context Server")
304 .style(ButtonStyle::Filled)
305 .layer(ElevationIndex::ModalSurface)
306 .full_width()
307 .icon(IconName::Plus)
308 .icon_size(IconSize::Small)
309 .icon_position(IconPosition::Start)
310 .disabled(true)
311 .tooltip(Tooltip::text("Not yet implemented")),
312 ),
313 )
314 .child(
315 h_flex().w_full().child(
316 Button::new(
317 "install-context-server-extensions",
318 "Install Context Server Extensions",
319 )
320 .style(ButtonStyle::Filled)
321 .layer(ElevationIndex::ModalSurface)
322 .full_width()
323 .icon(IconName::DatabaseZap)
324 .icon_size(IconSize::Small)
325 .icon_position(IconPosition::Start)
326 .on_click(|_event, window, cx| {
327 window.dispatch_action(
328 zed_actions::Extensions {
329 category_filter: Some(
330 ExtensionCategoryFilter::ContextServers,
331 ),
332 }
333 .boxed_clone(),
334 cx,
335 )
336 }),
337 ),
338 ),
339 )
340 }
341}
342
343impl Render for AssistantConfiguration {
344 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
345 let providers = LanguageModelRegistry::read_global(cx).providers();
346
347 v_flex()
348 .id("assistant-configuration")
349 .track_focus(&self.focus_handle(cx))
350 .bg(cx.theme().colors().panel_background)
351 .size_full()
352 .overflow_y_scroll()
353 .child(self.render_context_servers_section(cx))
354 .child(Divider::horizontal().color(DividerColor::Border))
355 .child(
356 v_flex()
357 .p(DynamicSpacing::Base16.rems(cx))
358 .mt_1()
359 .gap_6()
360 .flex_1()
361 .child(
362 v_flex()
363 .gap_0p5()
364 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
365 .child(
366 Label::new("Add at least one provider to use AI-powered features.")
367 .color(Color::Muted),
368 ),
369 )
370 .children(
371 providers
372 .into_iter()
373 .map(|provider| self.render_provider_configuration(&provider, cx)),
374 ),
375 )
376 }
377}