1use std::sync::Arc;
2
3use assistant_tool::{ToolSource, ToolWorkingSet};
4use collections::HashMap;
5use context_server::manager::ContextServerManager;
6use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
7use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
8use ui::{
9 prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip,
10};
11use util::ResultExt as _;
12use zed_actions::assistant::DeployPromptLibrary;
13use zed_actions::ExtensionCategoryFilter;
14
15pub struct AssistantConfiguration {
16 focus_handle: FocusHandle,
17 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
18 context_server_manager: Entity<ContextServerManager>,
19 expanded_context_server_tools: HashMap<Arc<str>, bool>,
20 tools: Arc<ToolWorkingSet>,
21 _registry_subscription: Subscription,
22}
23
24impl AssistantConfiguration {
25 pub fn new(
26 context_server_manager: Entity<ContextServerManager>,
27 tools: Arc<ToolWorkingSet>,
28 window: &mut Window,
29 cx: &mut Context<Self>,
30 ) -> Self {
31 let focus_handle = cx.focus_handle();
32
33 let registry_subscription = cx.subscribe_in(
34 &LanguageModelRegistry::global(cx),
35 window,
36 |this, _, event: &language_model::Event, window, cx| match event {
37 language_model::Event::AddedProvider(provider_id) => {
38 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
39 if let Some(provider) = provider {
40 this.add_provider_configuration_view(&provider, window, cx);
41 }
42 }
43 language_model::Event::RemovedProvider(provider_id) => {
44 this.remove_provider_configuration_view(provider_id);
45 }
46 _ => {}
47 },
48 );
49
50 let mut this = Self {
51 focus_handle,
52 configuration_views_by_provider: HashMap::default(),
53 context_server_manager,
54 expanded_context_server_tools: HashMap::default(),
55 tools,
56 _registry_subscription: registry_subscription,
57 };
58 this.build_provider_configuration_views(window, cx);
59 this
60 }
61
62 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
63 let providers = LanguageModelRegistry::read_global(cx).providers();
64 for provider in providers {
65 self.add_provider_configuration_view(&provider, window, cx);
66 }
67 }
68
69 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
70 self.configuration_views_by_provider.remove(provider_id);
71 }
72
73 fn add_provider_configuration_view(
74 &mut self,
75 provider: &Arc<dyn LanguageModelProvider>,
76 window: &mut Window,
77 cx: &mut Context<Self>,
78 ) {
79 let configuration_view = provider.configuration_view(window, cx);
80 self.configuration_views_by_provider
81 .insert(provider.id(), configuration_view);
82 }
83}
84
85impl Focusable for AssistantConfiguration {
86 fn focus_handle(&self, _: &App) -> FocusHandle {
87 self.focus_handle.clone()
88 }
89}
90
91pub enum AssistantConfigurationEvent {
92 NewThread(Arc<dyn LanguageModelProvider>),
93}
94
95impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
96
97impl AssistantConfiguration {
98 fn render_provider_configuration(
99 &mut self,
100 provider: &Arc<dyn LanguageModelProvider>,
101 cx: &mut Context<Self>,
102 ) -> impl IntoElement {
103 let provider_id = provider.id().0.clone();
104 let provider_name = provider.name().0.clone();
105 let configuration_view = self
106 .configuration_views_by_provider
107 .get(&provider.id())
108 .cloned();
109
110 v_flex()
111 .gap_1p5()
112 .child(
113 h_flex()
114 .justify_between()
115 .child(
116 h_flex()
117 .gap_2()
118 .child(
119 Icon::new(provider.icon())
120 .size(IconSize::Small)
121 .color(Color::Muted),
122 )
123 .child(Label::new(provider_name.clone())),
124 )
125 .when(provider.is_authenticated(cx), |parent| {
126 parent.child(
127 Button::new(
128 SharedString::from(format!("new-thread-{provider_id}")),
129 "Start New Thread",
130 )
131 .icon_position(IconPosition::Start)
132 .icon(IconName::Plus)
133 .icon_size(IconSize::Small)
134 .style(ButtonStyle::Filled)
135 .layer(ElevationIndex::ModalSurface)
136 .label_size(LabelSize::Small)
137 .on_click(cx.listener({
138 let provider = provider.clone();
139 move |_this, _event, _window, cx| {
140 cx.emit(AssistantConfigurationEvent::NewThread(
141 provider.clone(),
142 ))
143 }
144 })),
145 )
146 }),
147 )
148 .child(
149 div()
150 .p(DynamicSpacing::Base08.rems(cx))
151 .bg(cx.theme().colors().editor_background)
152 .border_1()
153 .border_color(cx.theme().colors().border_variant)
154 .rounded_sm()
155 .map(|parent| match configuration_view {
156 Some(configuration_view) => parent.child(configuration_view),
157 None => parent.child(div().child(Label::new(format!(
158 "No configuration view for {provider_name}",
159 )))),
160 }),
161 )
162 }
163
164 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
165 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
166 let tools_by_source = self.tools.tools_by_source(cx);
167 let empty = Vec::new();
168
169 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
170
171 v_flex()
172 .p(DynamicSpacing::Base16.rems(cx))
173 .mt_1()
174 .gap_2()
175 .flex_1()
176 .child(
177 v_flex()
178 .gap_0p5()
179 .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
180 .child(Label::new(SUBHEADING).color(Color::Muted)),
181 )
182 .children(context_servers.into_iter().map(|context_server| {
183 let is_running = context_server.client().is_some();
184 let are_tools_expanded = self
185 .expanded_context_server_tools
186 .get(&context_server.id())
187 .copied()
188 .unwrap_or_default();
189
190 let tools = tools_by_source
191 .get(&ToolSource::ContextServer {
192 id: context_server.id().into(),
193 })
194 .unwrap_or_else(|| &empty);
195 let tool_count = tools.len();
196
197 v_flex()
198 .border_1()
199 .rounded_sm()
200 .border_color(cx.theme().colors().border)
201 .bg(cx.theme().colors().editor_background)
202 .child(
203 h_flex()
204 .justify_between()
205 .px_2()
206 .py_1()
207 .when(are_tools_expanded, |element| {
208 element
209 .border_b_1()
210 .border_color(cx.theme().colors().border)
211 })
212 .child(
213 h_flex()
214 .gap_2()
215 .child(
216 Disclosure::new("tool-list-disclosure", are_tools_expanded)
217 .on_click(cx.listener({
218 let context_server_id = context_server.id();
219 move |this, _event, _window, _cx| {
220 let is_open = this
221 .expanded_context_server_tools
222 .entry(context_server_id.clone())
223 .or_insert(false);
224
225 *is_open = !*is_open;
226 }
227 })),
228 )
229 .child(Indicator::dot().color(if is_running {
230 Color::Success
231 } else {
232 Color::Error
233 }))
234 .child(Label::new(context_server.id()))
235 .child(
236 Label::new(format!("{tool_count} tools"))
237 .color(Color::Muted),
238 ),
239 )
240 .child(h_flex().child(
241 Switch::new("context-server-switch", is_running.into()).on_click({
242 let context_server_manager =
243 self.context_server_manager.clone();
244 let context_server = context_server.clone();
245 move |state, _window, cx| match state {
246 ToggleState::Unselected | ToggleState::Indeterminate => {
247 context_server_manager.update(cx, |this, cx| {
248 this.stop_server(context_server.clone(), cx)
249 .log_err();
250 });
251 }
252 ToggleState::Selected => {
253 cx.spawn({
254 let context_server_manager =
255 context_server_manager.clone();
256 let context_server = context_server.clone();
257 async move |cx| {
258 if let Some(start_server_task) =
259 context_server_manager
260 .update(cx, |this, cx| {
261 this.start_server(
262 context_server,
263 cx,
264 )
265 })
266 .log_err()
267 {
268 start_server_task.await.log_err();
269 }
270 }
271 })
272 .detach();
273 }
274 }
275 }),
276 )),
277 )
278 .map(|parent| {
279 if !are_tools_expanded {
280 return parent;
281 }
282
283 parent.child(v_flex().children(tools.into_iter().enumerate().map(
284 |(ix, tool)| {
285 h_flex()
286 .px_2()
287 .py_1()
288 .when(ix < tool_count - 1, |element| {
289 element
290 .border_b_1()
291 .border_color(cx.theme().colors().border)
292 })
293 .child(Label::new(tool.name()))
294 },
295 )))
296 })
297 }))
298 .child(
299 h_flex()
300 .justify_between()
301 .gap_2()
302 .child(
303 h_flex().w_full().child(
304 Button::new("add-context-server", "Add Context Server")
305 .style(ButtonStyle::Filled)
306 .layer(ElevationIndex::ModalSurface)
307 .full_width()
308 .icon(IconName::Plus)
309 .icon_size(IconSize::Small)
310 .icon_position(IconPosition::Start)
311 .disabled(true)
312 .tooltip(Tooltip::text("Not yet implemented")),
313 ),
314 )
315 .child(
316 h_flex().w_full().child(
317 Button::new(
318 "install-context-server-extensions",
319 "Install Context Server Extensions",
320 )
321 .style(ButtonStyle::Filled)
322 .layer(ElevationIndex::ModalSurface)
323 .full_width()
324 .icon(IconName::DatabaseZap)
325 .icon_size(IconSize::Small)
326 .icon_position(IconPosition::Start)
327 .on_click(|_event, window, cx| {
328 window.dispatch_action(
329 zed_actions::Extensions {
330 category_filter: Some(
331 ExtensionCategoryFilter::ContextServers,
332 ),
333 }
334 .boxed_clone(),
335 cx,
336 )
337 }),
338 ),
339 ),
340 )
341 }
342}
343
344impl Render for AssistantConfiguration {
345 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
346 let providers = LanguageModelRegistry::read_global(cx).providers();
347
348 v_flex()
349 .id("assistant-configuration")
350 .track_focus(&self.focus_handle(cx))
351 .bg(cx.theme().colors().panel_background)
352 .size_full()
353 .overflow_y_scroll()
354 .child(
355 v_flex()
356 .p(DynamicSpacing::Base16.rems(cx))
357 .gap_2()
358 .child(
359 v_flex()
360 .gap_0p5()
361 .child(Headline::new("Prompt Library").size(HeadlineSize::Small))
362 .child(
363 Label::new("Create reusable prompts and tag which ones you want sent in every LLM interaction.")
364 .color(Color::Muted),
365 ),
366 )
367 .child(
368 Button::new("open-prompt-library", "Open Prompt Library")
369 .style(ButtonStyle::Filled)
370 .layer(ElevationIndex::ModalSurface)
371 .full_width()
372 .icon(IconName::Book)
373 .icon_size(IconSize::Small)
374 .icon_position(IconPosition::Start)
375 .on_click(|_event, window, cx| {
376 window.dispatch_action(DeployPromptLibrary.boxed_clone(), cx)
377 }),
378 ),
379 )
380 .child(Divider::horizontal().color(DividerColor::Border))
381 .child(self.render_context_servers_section(cx))
382 .child(Divider::horizontal().color(DividerColor::Border))
383 .child(
384 v_flex()
385 .p(DynamicSpacing::Base16.rems(cx))
386 .mt_1()
387 .gap_6()
388 .flex_1()
389 .child(
390 v_flex()
391 .gap_0p5()
392 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
393 .child(
394 Label::new("Add at least one provider to use AI-powered features.")
395 .color(Color::Muted),
396 ),
397 )
398 .children(
399 providers
400 .into_iter()
401 .map(|provider| self.render_provider_configuration(&provider, cx)),
402 ),
403 )
404 }
405}