1use std::sync::Arc;
2
3use assistant_tool::{ToolSource, ToolWorkingSet};
4use collections::HashMap;
5use context_server::manager::ContextServerManager;
6use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
7use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
8use ui::{prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch};
9use util::ResultExt as _;
10use zed_actions::assistant::DeployPromptLibrary;
11
12pub struct AssistantConfiguration {
13 focus_handle: FocusHandle,
14 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
15 context_server_manager: Entity<ContextServerManager>,
16 expanded_context_server_tools: HashMap<Arc<str>, bool>,
17 tools: Arc<ToolWorkingSet>,
18 _registry_subscription: Subscription,
19}
20
21impl AssistantConfiguration {
22 pub fn new(
23 context_server_manager: Entity<ContextServerManager>,
24 tools: Arc<ToolWorkingSet>,
25 window: &mut Window,
26 cx: &mut Context<Self>,
27 ) -> Self {
28 let focus_handle = cx.focus_handle();
29
30 let registry_subscription = cx.subscribe_in(
31 &LanguageModelRegistry::global(cx),
32 window,
33 |this, _, event: &language_model::Event, window, cx| match event {
34 language_model::Event::AddedProvider(provider_id) => {
35 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
36 if let Some(provider) = provider {
37 this.add_provider_configuration_view(&provider, window, cx);
38 }
39 }
40 language_model::Event::RemovedProvider(provider_id) => {
41 this.remove_provider_configuration_view(provider_id);
42 }
43 _ => {}
44 },
45 );
46
47 let mut this = Self {
48 focus_handle,
49 configuration_views_by_provider: HashMap::default(),
50 context_server_manager,
51 expanded_context_server_tools: HashMap::default(),
52 tools,
53 _registry_subscription: registry_subscription,
54 };
55 this.build_provider_configuration_views(window, cx);
56 this
57 }
58
59 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
60 let providers = LanguageModelRegistry::read_global(cx).providers();
61 for provider in providers {
62 self.add_provider_configuration_view(&provider, window, cx);
63 }
64 }
65
66 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
67 self.configuration_views_by_provider.remove(provider_id);
68 }
69
70 fn add_provider_configuration_view(
71 &mut self,
72 provider: &Arc<dyn LanguageModelProvider>,
73 window: &mut Window,
74 cx: &mut Context<Self>,
75 ) {
76 let configuration_view = provider.configuration_view(window, cx);
77 self.configuration_views_by_provider
78 .insert(provider.id(), configuration_view);
79 }
80}
81
82impl Focusable for AssistantConfiguration {
83 fn focus_handle(&self, _: &App) -> FocusHandle {
84 self.focus_handle.clone()
85 }
86}
87
88pub enum AssistantConfigurationEvent {
89 NewThread(Arc<dyn LanguageModelProvider>),
90}
91
92impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
93
94impl AssistantConfiguration {
95 fn render_provider_configuration(
96 &mut self,
97 provider: &Arc<dyn LanguageModelProvider>,
98 cx: &mut Context<Self>,
99 ) -> impl IntoElement {
100 let provider_id = provider.id().0.clone();
101 let provider_name = provider.name().0.clone();
102 let configuration_view = self
103 .configuration_views_by_provider
104 .get(&provider.id())
105 .cloned();
106
107 v_flex()
108 .gap_1p5()
109 .child(
110 h_flex()
111 .justify_between()
112 .child(
113 h_flex()
114 .gap_2()
115 .child(
116 Icon::new(provider.icon())
117 .size(IconSize::Small)
118 .color(Color::Muted),
119 )
120 .child(Label::new(provider_name.clone())),
121 )
122 .when(provider.is_authenticated(cx), |parent| {
123 parent.child(
124 Button::new(
125 SharedString::from(format!("new-thread-{provider_id}")),
126 "Start New Thread",
127 )
128 .icon_position(IconPosition::Start)
129 .icon(IconName::Plus)
130 .icon_size(IconSize::Small)
131 .style(ButtonStyle::Filled)
132 .layer(ElevationIndex::ModalSurface)
133 .label_size(LabelSize::Small)
134 .on_click(cx.listener({
135 let provider = provider.clone();
136 move |_this, _event, _window, cx| {
137 cx.emit(AssistantConfigurationEvent::NewThread(
138 provider.clone(),
139 ))
140 }
141 })),
142 )
143 }),
144 )
145 .child(
146 div()
147 .p(DynamicSpacing::Base08.rems(cx))
148 .bg(cx.theme().colors().editor_background)
149 .border_1()
150 .border_color(cx.theme().colors().border_variant)
151 .rounded_sm()
152 .map(|parent| match configuration_view {
153 Some(configuration_view) => parent.child(configuration_view),
154 None => parent.child(div().child(Label::new(format!(
155 "No configuration view for {provider_name}",
156 )))),
157 }),
158 )
159 }
160
161 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
162 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
163 let tools_by_source = self.tools.tools_by_source(cx);
164 let empty = Vec::new();
165
166 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
167
168 v_flex()
169 .p(DynamicSpacing::Base16.rems(cx))
170 .mt_1()
171 .gap_6()
172 .flex_1()
173 .child(
174 v_flex()
175 .gap_0p5()
176 .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
177 .child(Label::new(SUBHEADING).color(Color::Muted)),
178 )
179 .children(context_servers.into_iter().map(|context_server| {
180 let is_running = context_server.client().is_some();
181 let are_tools_expanded = self
182 .expanded_context_server_tools
183 .get(&context_server.id())
184 .copied()
185 .unwrap_or_default();
186
187 let tools = tools_by_source
188 .get(&ToolSource::ContextServer {
189 id: context_server.id().into(),
190 })
191 .unwrap_or_else(|| &empty);
192 let tool_count = tools.len();
193
194 v_flex()
195 .border_1()
196 .rounded_sm()
197 .border_color(cx.theme().colors().border)
198 .bg(cx.theme().colors().editor_background)
199 .child(
200 h_flex()
201 .justify_between()
202 .px_2()
203 .py_1()
204 .when(are_tools_expanded, |element| {
205 element
206 .border_b_1()
207 .border_color(cx.theme().colors().border)
208 })
209 .child(
210 h_flex()
211 .gap_2()
212 .child(
213 Disclosure::new("tool-list-disclosure", are_tools_expanded)
214 .on_click(cx.listener({
215 let context_server_id = context_server.id();
216 move |this, _event, _window, _cx| {
217 let is_open = this
218 .expanded_context_server_tools
219 .entry(context_server_id.clone())
220 .or_insert(false);
221
222 *is_open = !*is_open;
223 }
224 })),
225 )
226 .child(Indicator::dot().color(if is_running {
227 Color::Success
228 } else {
229 Color::Error
230 }))
231 .child(Label::new(context_server.id()))
232 .child(
233 Label::new(format!("{tool_count} tools"))
234 .color(Color::Muted),
235 ),
236 )
237 .child(h_flex().child(
238 Switch::new("context-server-switch", is_running.into()).on_click({
239 let context_server_manager =
240 self.context_server_manager.clone();
241 let context_server = context_server.clone();
242 move |state, _window, cx| match state {
243 ToggleState::Unselected | ToggleState::Indeterminate => {
244 context_server_manager.update(cx, |this, cx| {
245 this.stop_server(context_server.clone(), cx)
246 .log_err();
247 });
248 }
249 ToggleState::Selected => {
250 cx.spawn({
251 let context_server_manager =
252 context_server_manager.clone();
253 let context_server = context_server.clone();
254 async move |cx| {
255 if let Some(start_server_task) =
256 context_server_manager
257 .update(cx, |this, cx| {
258 this.start_server(
259 context_server,
260 cx,
261 )
262 })
263 .log_err()
264 {
265 start_server_task.await.log_err();
266 }
267 }
268 })
269 .detach();
270 }
271 }
272 }),
273 )),
274 )
275 .map(|parent| {
276 if !are_tools_expanded {
277 return parent;
278 }
279
280 parent.child(v_flex().children(tools.into_iter().enumerate().map(
281 |(ix, tool)| {
282 h_flex()
283 .px_2()
284 .py_1()
285 .when(ix < tool_count - 1, |element| {
286 element
287 .border_b_1()
288 .border_color(cx.theme().colors().border)
289 })
290 .child(Label::new(tool.name()))
291 },
292 )))
293 })
294 }))
295 }
296}
297
298impl Render for AssistantConfiguration {
299 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
300 let providers = LanguageModelRegistry::read_global(cx).providers();
301
302 v_flex()
303 .id("assistant-configuration")
304 .track_focus(&self.focus_handle(cx))
305 .bg(cx.theme().colors().panel_background)
306 .size_full()
307 .overflow_y_scroll()
308 .child(
309 v_flex()
310 .p(DynamicSpacing::Base16.rems(cx))
311 .gap_2()
312 .child(
313 v_flex()
314 .gap_0p5()
315 .child(Headline::new("Prompt Library").size(HeadlineSize::Small))
316 .child(
317 Label::new("Create reusable prompts and tag which ones you want sent in every LLM interaction.")
318 .color(Color::Muted),
319 ),
320 )
321 .child(
322 Button::new("open-prompt-library", "Open Prompt Library")
323 .style(ButtonStyle::Filled)
324 .layer(ElevationIndex::ModalSurface)
325 .full_width()
326 .icon(IconName::Book)
327 .icon_size(IconSize::Small)
328 .icon_position(IconPosition::Start)
329 .on_click(|_event, window, cx| {
330 window.dispatch_action(DeployPromptLibrary.boxed_clone(), cx)
331 }),
332 ),
333 )
334 .child(Divider::horizontal().color(DividerColor::Border))
335 .child(self.render_context_servers_section(cx))
336 .child(Divider::horizontal().color(DividerColor::Border))
337 .child(
338 v_flex()
339 .p(DynamicSpacing::Base16.rems(cx))
340 .mt_1()
341 .gap_6()
342 .flex_1()
343 .child(
344 v_flex()
345 .gap_0p5()
346 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
347 .child(
348 Label::new("Add at least one provider to use AI-powered features.")
349 .color(Color::Muted),
350 ),
351 )
352 .children(
353 providers
354 .into_iter()
355 .map(|provider| self.render_provider_configuration(&provider, cx)),
356 ),
357 )
358 }
359}