1mod add_context_server_modal;
2mod manage_profiles_modal;
3mod profile_picker;
4mod tool_picker;
5
6use std::sync::Arc;
7
8use assistant_tool::{ToolSource, ToolWorkingSet};
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
12use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
13use ui::{prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch};
14use util::ResultExt as _;
15use zed_actions::ExtensionCategoryFilter;
16
17pub(crate) use add_context_server_modal::AddContextServerModal;
18pub(crate) use manage_profiles_modal::ManageProfilesModal;
19
20use crate::AddContextServer;
21
22pub struct AssistantConfiguration {
23 focus_handle: FocusHandle,
24 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
25 context_server_manager: Entity<ContextServerManager>,
26 expanded_context_server_tools: HashMap<Arc<str>, bool>,
27 tools: Arc<ToolWorkingSet>,
28 _registry_subscription: Subscription,
29}
30
31impl AssistantConfiguration {
32 pub fn new(
33 context_server_manager: Entity<ContextServerManager>,
34 tools: Arc<ToolWorkingSet>,
35 window: &mut Window,
36 cx: &mut Context<Self>,
37 ) -> Self {
38 let focus_handle = cx.focus_handle();
39
40 let registry_subscription = cx.subscribe_in(
41 &LanguageModelRegistry::global(cx),
42 window,
43 |this, _, event: &language_model::Event, window, cx| match event {
44 language_model::Event::AddedProvider(provider_id) => {
45 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
46 if let Some(provider) = provider {
47 this.add_provider_configuration_view(&provider, window, cx);
48 }
49 }
50 language_model::Event::RemovedProvider(provider_id) => {
51 this.remove_provider_configuration_view(provider_id);
52 }
53 _ => {}
54 },
55 );
56
57 let mut this = Self {
58 focus_handle,
59 configuration_views_by_provider: HashMap::default(),
60 context_server_manager,
61 expanded_context_server_tools: HashMap::default(),
62 tools,
63 _registry_subscription: registry_subscription,
64 };
65 this.build_provider_configuration_views(window, cx);
66 this
67 }
68
69 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
70 let providers = LanguageModelRegistry::read_global(cx).providers();
71 for provider in providers {
72 self.add_provider_configuration_view(&provider, window, cx);
73 }
74 }
75
76 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
77 self.configuration_views_by_provider.remove(provider_id);
78 }
79
80 fn add_provider_configuration_view(
81 &mut self,
82 provider: &Arc<dyn LanguageModelProvider>,
83 window: &mut Window,
84 cx: &mut Context<Self>,
85 ) {
86 let configuration_view = provider.configuration_view(window, cx);
87 self.configuration_views_by_provider
88 .insert(provider.id(), configuration_view);
89 }
90}
91
92impl Focusable for AssistantConfiguration {
93 fn focus_handle(&self, _: &App) -> FocusHandle {
94 self.focus_handle.clone()
95 }
96}
97
98pub enum AssistantConfigurationEvent {
99 NewThread(Arc<dyn LanguageModelProvider>),
100}
101
102impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
103
104impl AssistantConfiguration {
105 fn render_provider_configuration(
106 &mut self,
107 provider: &Arc<dyn LanguageModelProvider>,
108 cx: &mut Context<Self>,
109 ) -> impl IntoElement {
110 let provider_id = provider.id().0.clone();
111 let provider_name = provider.name().0.clone();
112 let configuration_view = self
113 .configuration_views_by_provider
114 .get(&provider.id())
115 .cloned();
116
117 v_flex()
118 .gap_1p5()
119 .child(
120 h_flex()
121 .justify_between()
122 .child(
123 h_flex()
124 .gap_2()
125 .child(
126 Icon::new(provider.icon())
127 .size(IconSize::Small)
128 .color(Color::Muted),
129 )
130 .child(Label::new(provider_name.clone())),
131 )
132 .when(provider.is_authenticated(cx), |parent| {
133 parent.child(
134 Button::new(
135 SharedString::from(format!("new-thread-{provider_id}")),
136 "Start New Thread",
137 )
138 .icon_position(IconPosition::Start)
139 .icon(IconName::Plus)
140 .icon_size(IconSize::Small)
141 .style(ButtonStyle::Filled)
142 .layer(ElevationIndex::ModalSurface)
143 .label_size(LabelSize::Small)
144 .on_click(cx.listener({
145 let provider = provider.clone();
146 move |_this, _event, _window, cx| {
147 cx.emit(AssistantConfigurationEvent::NewThread(
148 provider.clone(),
149 ))
150 }
151 })),
152 )
153 }),
154 )
155 .child(
156 div()
157 .p(DynamicSpacing::Base08.rems(cx))
158 .bg(cx.theme().colors().editor_background)
159 .border_1()
160 .border_color(cx.theme().colors().border_variant)
161 .rounded_sm()
162 .map(|parent| match configuration_view {
163 Some(configuration_view) => parent.child(configuration_view),
164 None => parent.child(div().child(Label::new(format!(
165 "No configuration view for {provider_name}",
166 )))),
167 }),
168 )
169 }
170
171 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
172 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
173 let tools_by_source = self.tools.tools_by_source(cx);
174 let empty = Vec::new();
175
176 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
177
178 v_flex()
179 .p(DynamicSpacing::Base16.rems(cx))
180 .gap_2()
181 .flex_1()
182 .child(
183 v_flex()
184 .gap_0p5()
185 .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
186 .child(Label::new(SUBHEADING).color(Color::Muted)),
187 )
188 .children(context_servers.into_iter().map(|context_server| {
189 let is_running = context_server.client().is_some();
190 let are_tools_expanded = self
191 .expanded_context_server_tools
192 .get(&context_server.id())
193 .copied()
194 .unwrap_or_default();
195
196 let tools = tools_by_source
197 .get(&ToolSource::ContextServer {
198 id: context_server.id().into(),
199 })
200 .unwrap_or_else(|| &empty);
201 let tool_count = tools.len();
202
203 v_flex()
204 .id(SharedString::from(context_server.id()))
205 .border_1()
206 .rounded_sm()
207 .border_color(cx.theme().colors().border)
208 .bg(cx.theme().colors().editor_background)
209 .child(
210 h_flex()
211 .justify_between()
212 .px_2()
213 .py_1()
214 .when(are_tools_expanded, |element| {
215 element
216 .border_b_1()
217 .border_color(cx.theme().colors().border)
218 })
219 .child(
220 h_flex()
221 .gap_2()
222 .child(
223 Disclosure::new("tool-list-disclosure", are_tools_expanded)
224 .on_click(cx.listener({
225 let context_server_id = context_server.id();
226 move |this, _event, _window, _cx| {
227 let is_open = this
228 .expanded_context_server_tools
229 .entry(context_server_id.clone())
230 .or_insert(false);
231
232 *is_open = !*is_open;
233 }
234 })),
235 )
236 .child(Indicator::dot().color(if is_running {
237 Color::Success
238 } else {
239 Color::Error
240 }))
241 .child(Label::new(context_server.id()))
242 .child(
243 Label::new(format!("{tool_count} tools"))
244 .color(Color::Muted),
245 ),
246 )
247 .child(h_flex().child(
248 Switch::new("context-server-switch", is_running.into()).on_click({
249 let context_server_manager =
250 self.context_server_manager.clone();
251 let context_server = context_server.clone();
252 move |state, _window, cx| match state {
253 ToggleState::Unselected | ToggleState::Indeterminate => {
254 context_server_manager.update(cx, |this, cx| {
255 this.stop_server(context_server.clone(), cx)
256 .log_err();
257 });
258 }
259 ToggleState::Selected => {
260 cx.spawn({
261 let context_server_manager =
262 context_server_manager.clone();
263 let context_server = context_server.clone();
264 async move |cx| {
265 if let Some(start_server_task) =
266 context_server_manager
267 .update(cx, |this, cx| {
268 this.start_server(
269 context_server,
270 cx,
271 )
272 })
273 .log_err()
274 {
275 start_server_task.await.log_err();
276 }
277 }
278 })
279 .detach();
280 }
281 }
282 }),
283 )),
284 )
285 .map(|parent| {
286 if !are_tools_expanded {
287 return parent;
288 }
289
290 parent.child(v_flex().children(tools.into_iter().enumerate().map(
291 |(ix, tool)| {
292 h_flex()
293 .px_2()
294 .py_1()
295 .when(ix < tool_count - 1, |element| {
296 element
297 .border_b_1()
298 .border_color(cx.theme().colors().border)
299 })
300 .child(Label::new(tool.name()))
301 },
302 )))
303 })
304 }))
305 .child(
306 h_flex()
307 .justify_between()
308 .gap_2()
309 .child(
310 h_flex().w_full().child(
311 Button::new("add-context-server", "Add Context Server")
312 .style(ButtonStyle::Filled)
313 .layer(ElevationIndex::ModalSurface)
314 .full_width()
315 .icon(IconName::Plus)
316 .icon_size(IconSize::Small)
317 .icon_position(IconPosition::Start)
318 .on_click(|_event, window, cx| {
319 window.dispatch_action(AddContextServer.boxed_clone(), cx)
320 }),
321 ),
322 )
323 .child(
324 h_flex().w_full().child(
325 Button::new(
326 "install-context-server-extensions",
327 "Install Context Server Extensions",
328 )
329 .style(ButtonStyle::Filled)
330 .layer(ElevationIndex::ModalSurface)
331 .full_width()
332 .icon(IconName::DatabaseZap)
333 .icon_size(IconSize::Small)
334 .icon_position(IconPosition::Start)
335 .on_click(|_event, window, cx| {
336 window.dispatch_action(
337 zed_actions::Extensions {
338 category_filter: Some(
339 ExtensionCategoryFilter::ContextServers,
340 ),
341 }
342 .boxed_clone(),
343 cx,
344 )
345 }),
346 ),
347 ),
348 )
349 }
350}
351
352impl Render for AssistantConfiguration {
353 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
354 let providers = LanguageModelRegistry::read_global(cx).providers();
355
356 v_flex()
357 .id("assistant-configuration")
358 .track_focus(&self.focus_handle(cx))
359 .bg(cx.theme().colors().panel_background)
360 .size_full()
361 .overflow_y_scroll()
362 .child(self.render_context_servers_section(cx))
363 .child(Divider::horizontal().color(DividerColor::Border))
364 .child(
365 v_flex()
366 .p(DynamicSpacing::Base16.rems(cx))
367 .mt_1()
368 .gap_6()
369 .flex_1()
370 .child(
371 v_flex()
372 .gap_0p5()
373 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
374 .child(
375 Label::new("Add at least one provider to use AI-powered features.")
376 .color(Color::Muted),
377 ),
378 )
379 .children(
380 providers
381 .into_iter()
382 .map(|provider| self.render_provider_configuration(&provider, cx)),
383 ),
384 )
385 }
386}