1use std::sync::Arc;
2
3use assistant_tool::{ToolSource, ToolWorkingSet};
4use collections::HashMap;
5use context_server::manager::ContextServerManager;
6use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
7use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
8use ui::{
9 prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip,
10};
11use util::ResultExt as _;
12use zed_actions::assistant::DeployPromptLibrary;
13use zed_actions::ExtensionCategoryFilter;
14
15pub struct AssistantConfiguration {
16 focus_handle: FocusHandle,
17 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
18 context_server_manager: Entity<ContextServerManager>,
19 expanded_context_server_tools: HashMap<Arc<str>, bool>,
20 tools: Arc<ToolWorkingSet>,
21 _registry_subscription: Subscription,
22}
23
24impl AssistantConfiguration {
25 pub fn new(
26 context_server_manager: Entity<ContextServerManager>,
27 tools: Arc<ToolWorkingSet>,
28 window: &mut Window,
29 cx: &mut Context<Self>,
30 ) -> Self {
31 let focus_handle = cx.focus_handle();
32
33 let registry_subscription = cx.subscribe_in(
34 &LanguageModelRegistry::global(cx),
35 window,
36 |this, _, event: &language_model::Event, window, cx| match event {
37 language_model::Event::AddedProvider(provider_id) => {
38 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
39 if let Some(provider) = provider {
40 this.add_provider_configuration_view(&provider, window, cx);
41 }
42 }
43 language_model::Event::RemovedProvider(provider_id) => {
44 this.remove_provider_configuration_view(provider_id);
45 }
46 _ => {}
47 },
48 );
49
50 let mut this = Self {
51 focus_handle,
52 configuration_views_by_provider: HashMap::default(),
53 context_server_manager,
54 expanded_context_server_tools: HashMap::default(),
55 tools,
56 _registry_subscription: registry_subscription,
57 };
58 this.build_provider_configuration_views(window, cx);
59 this
60 }
61
62 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
63 let providers = LanguageModelRegistry::read_global(cx).providers();
64 for provider in providers {
65 self.add_provider_configuration_view(&provider, window, cx);
66 }
67 }
68
69 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
70 self.configuration_views_by_provider.remove(provider_id);
71 }
72
73 fn add_provider_configuration_view(
74 &mut self,
75 provider: &Arc<dyn LanguageModelProvider>,
76 window: &mut Window,
77 cx: &mut Context<Self>,
78 ) {
79 let configuration_view = provider.configuration_view(window, cx);
80 self.configuration_views_by_provider
81 .insert(provider.id(), configuration_view);
82 }
83}
84
85impl Focusable for AssistantConfiguration {
86 fn focus_handle(&self, _: &App) -> FocusHandle {
87 self.focus_handle.clone()
88 }
89}
90
91pub enum AssistantConfigurationEvent {
92 NewThread(Arc<dyn LanguageModelProvider>),
93}
94
95impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
96
97impl AssistantConfiguration {
98 fn render_provider_configuration(
99 &mut self,
100 provider: &Arc<dyn LanguageModelProvider>,
101 cx: &mut Context<Self>,
102 ) -> impl IntoElement {
103 let provider_id = provider.id().0.clone();
104 let provider_name = provider.name().0.clone();
105 let configuration_view = self
106 .configuration_views_by_provider
107 .get(&provider.id())
108 .cloned();
109
110 v_flex()
111 .gap_1p5()
112 .child(
113 h_flex()
114 .justify_between()
115 .child(
116 h_flex()
117 .gap_2()
118 .child(
119 Icon::new(provider.icon())
120 .size(IconSize::Small)
121 .color(Color::Muted),
122 )
123 .child(Label::new(provider_name.clone())),
124 )
125 .when(provider.is_authenticated(cx), |parent| {
126 parent.child(
127 Button::new(
128 SharedString::from(format!("new-thread-{provider_id}")),
129 "Start New Thread",
130 )
131 .icon_position(IconPosition::Start)
132 .icon(IconName::Plus)
133 .icon_size(IconSize::Small)
134 .style(ButtonStyle::Filled)
135 .layer(ElevationIndex::ModalSurface)
136 .label_size(LabelSize::Small)
137 .on_click(cx.listener({
138 let provider = provider.clone();
139 move |_this, _event, _window, cx| {
140 cx.emit(AssistantConfigurationEvent::NewThread(
141 provider.clone(),
142 ))
143 }
144 })),
145 )
146 }),
147 )
148 .child(
149 div()
150 .p(DynamicSpacing::Base08.rems(cx))
151 .bg(cx.theme().colors().editor_background)
152 .border_1()
153 .border_color(cx.theme().colors().border_variant)
154 .rounded_sm()
155 .map(|parent| match configuration_view {
156 Some(configuration_view) => parent.child(configuration_view),
157 None => parent.child(div().child(Label::new(format!(
158 "No configuration view for {provider_name}",
159 )))),
160 }),
161 )
162 }
163
164 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
165 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
166 let tools_by_source = self.tools.tools_by_source(cx);
167 let empty = Vec::new();
168
169 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
170
171 v_flex()
172 .p(DynamicSpacing::Base16.rems(cx))
173 .mt_1()
174 .gap_2()
175 .flex_1()
176 .child(
177 v_flex()
178 .gap_0p5()
179 .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
180 .child(Label::new(SUBHEADING).color(Color::Muted)),
181 )
182 .children(context_servers.into_iter().map(|context_server| {
183 let is_running = context_server.client().is_some();
184 let are_tools_expanded = self
185 .expanded_context_server_tools
186 .get(&context_server.id())
187 .copied()
188 .unwrap_or_default();
189
190 let tools = tools_by_source
191 .get(&ToolSource::ContextServer {
192 id: context_server.id().into(),
193 })
194 .unwrap_or_else(|| &empty);
195 let tool_count = tools.len();
196
197 v_flex()
198 .id(SharedString::from(context_server.id()))
199 .border_1()
200 .rounded_sm()
201 .border_color(cx.theme().colors().border)
202 .bg(cx.theme().colors().editor_background)
203 .child(
204 h_flex()
205 .justify_between()
206 .px_2()
207 .py_1()
208 .when(are_tools_expanded, |element| {
209 element
210 .border_b_1()
211 .border_color(cx.theme().colors().border)
212 })
213 .child(
214 h_flex()
215 .gap_2()
216 .child(
217 Disclosure::new("tool-list-disclosure", are_tools_expanded)
218 .on_click(cx.listener({
219 let context_server_id = context_server.id();
220 move |this, _event, _window, _cx| {
221 let is_open = this
222 .expanded_context_server_tools
223 .entry(context_server_id.clone())
224 .or_insert(false);
225
226 *is_open = !*is_open;
227 }
228 })),
229 )
230 .child(Indicator::dot().color(if is_running {
231 Color::Success
232 } else {
233 Color::Error
234 }))
235 .child(Label::new(context_server.id()))
236 .child(
237 Label::new(format!("{tool_count} tools"))
238 .color(Color::Muted),
239 ),
240 )
241 .child(h_flex().child(
242 Switch::new("context-server-switch", is_running.into()).on_click({
243 let context_server_manager =
244 self.context_server_manager.clone();
245 let context_server = context_server.clone();
246 move |state, _window, cx| match state {
247 ToggleState::Unselected | ToggleState::Indeterminate => {
248 context_server_manager.update(cx, |this, cx| {
249 this.stop_server(context_server.clone(), cx)
250 .log_err();
251 });
252 }
253 ToggleState::Selected => {
254 cx.spawn({
255 let context_server_manager =
256 context_server_manager.clone();
257 let context_server = context_server.clone();
258 async move |cx| {
259 if let Some(start_server_task) =
260 context_server_manager
261 .update(cx, |this, cx| {
262 this.start_server(
263 context_server,
264 cx,
265 )
266 })
267 .log_err()
268 {
269 start_server_task.await.log_err();
270 }
271 }
272 })
273 .detach();
274 }
275 }
276 }),
277 )),
278 )
279 .map(|parent| {
280 if !are_tools_expanded {
281 return parent;
282 }
283
284 parent.child(v_flex().children(tools.into_iter().enumerate().map(
285 |(ix, tool)| {
286 h_flex()
287 .px_2()
288 .py_1()
289 .when(ix < tool_count - 1, |element| {
290 element
291 .border_b_1()
292 .border_color(cx.theme().colors().border)
293 })
294 .child(Label::new(tool.name()))
295 },
296 )))
297 })
298 }))
299 .child(
300 h_flex()
301 .justify_between()
302 .gap_2()
303 .child(
304 h_flex().w_full().child(
305 Button::new("add-context-server", "Add Context Server")
306 .style(ButtonStyle::Filled)
307 .layer(ElevationIndex::ModalSurface)
308 .full_width()
309 .icon(IconName::Plus)
310 .icon_size(IconSize::Small)
311 .icon_position(IconPosition::Start)
312 .disabled(true)
313 .tooltip(Tooltip::text("Not yet implemented")),
314 ),
315 )
316 .child(
317 h_flex().w_full().child(
318 Button::new(
319 "install-context-server-extensions",
320 "Install Context Server Extensions",
321 )
322 .style(ButtonStyle::Filled)
323 .layer(ElevationIndex::ModalSurface)
324 .full_width()
325 .icon(IconName::DatabaseZap)
326 .icon_size(IconSize::Small)
327 .icon_position(IconPosition::Start)
328 .on_click(|_event, window, cx| {
329 window.dispatch_action(
330 zed_actions::Extensions {
331 category_filter: Some(
332 ExtensionCategoryFilter::ContextServers,
333 ),
334 }
335 .boxed_clone(),
336 cx,
337 )
338 }),
339 ),
340 ),
341 )
342 }
343}
344
345impl Render for AssistantConfiguration {
346 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
347 let providers = LanguageModelRegistry::read_global(cx).providers();
348
349 v_flex()
350 .id("assistant-configuration")
351 .track_focus(&self.focus_handle(cx))
352 .bg(cx.theme().colors().panel_background)
353 .size_full()
354 .overflow_y_scroll()
355 .child(
356 v_flex()
357 .p(DynamicSpacing::Base16.rems(cx))
358 .gap_2()
359 .child(
360 v_flex()
361 .gap_0p5()
362 .child(Headline::new("Prompt Library").size(HeadlineSize::Small))
363 .child(
364 Label::new("Create reusable prompts and tag which ones you want sent in every LLM interaction.")
365 .color(Color::Muted),
366 ),
367 )
368 .child(
369 Button::new("open-prompt-library", "Open Prompt Library")
370 .style(ButtonStyle::Filled)
371 .layer(ElevationIndex::ModalSurface)
372 .full_width()
373 .icon(IconName::Book)
374 .icon_size(IconSize::Small)
375 .icon_position(IconPosition::Start)
376 .on_click(|_event, window, cx| {
377 window.dispatch_action(DeployPromptLibrary.boxed_clone(), cx)
378 }),
379 ),
380 )
381 .child(Divider::horizontal().color(DividerColor::Border))
382 .child(self.render_context_servers_section(cx))
383 .child(Divider::horizontal().color(DividerColor::Border))
384 .child(
385 v_flex()
386 .p(DynamicSpacing::Base16.rems(cx))
387 .mt_1()
388 .gap_6()
389 .flex_1()
390 .child(
391 v_flex()
392 .gap_0p5()
393 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
394 .child(
395 Label::new("Add at least one provider to use AI-powered features.")
396 .color(Color::Muted),
397 ),
398 )
399 .children(
400 providers
401 .into_iter()
402 .map(|provider| self.render_provider_configuration(&provider, cx)),
403 ),
404 )
405 }
406}