1mod add_context_server_modal;
2mod manage_profiles_modal;
3mod tool_picker;
4
5use std::sync::Arc;
6
7use assistant_tool::{ToolSource, ToolWorkingSet};
8use collections::HashMap;
9use context_server::manager::ContextServerManager;
10use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
11use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
12use ui::{Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, prelude::*};
13use util::ResultExt as _;
14use zed_actions::ExtensionCategoryFilter;
15
16pub(crate) use add_context_server_modal::AddContextServerModal;
17pub(crate) use manage_profiles_modal::ManageProfilesModal;
18
19use crate::AddContextServer;
20
21pub struct AssistantConfiguration {
22 focus_handle: FocusHandle,
23 configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
24 context_server_manager: Entity<ContextServerManager>,
25 expanded_context_server_tools: HashMap<Arc<str>, bool>,
26 tools: Arc<ToolWorkingSet>,
27 _registry_subscription: Subscription,
28}
29
30impl AssistantConfiguration {
31 pub fn new(
32 context_server_manager: Entity<ContextServerManager>,
33 tools: Arc<ToolWorkingSet>,
34 window: &mut Window,
35 cx: &mut Context<Self>,
36 ) -> Self {
37 let focus_handle = cx.focus_handle();
38
39 let registry_subscription = cx.subscribe_in(
40 &LanguageModelRegistry::global(cx),
41 window,
42 |this, _, event: &language_model::Event, window, cx| match event {
43 language_model::Event::AddedProvider(provider_id) => {
44 let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
45 if let Some(provider) = provider {
46 this.add_provider_configuration_view(&provider, window, cx);
47 }
48 }
49 language_model::Event::RemovedProvider(provider_id) => {
50 this.remove_provider_configuration_view(provider_id);
51 }
52 _ => {}
53 },
54 );
55
56 let mut this = Self {
57 focus_handle,
58 configuration_views_by_provider: HashMap::default(),
59 context_server_manager,
60 expanded_context_server_tools: HashMap::default(),
61 tools,
62 _registry_subscription: registry_subscription,
63 };
64 this.build_provider_configuration_views(window, cx);
65 this
66 }
67
68 fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
69 let providers = LanguageModelRegistry::read_global(cx).providers();
70 for provider in providers {
71 self.add_provider_configuration_view(&provider, window, cx);
72 }
73 }
74
75 fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
76 self.configuration_views_by_provider.remove(provider_id);
77 }
78
79 fn add_provider_configuration_view(
80 &mut self,
81 provider: &Arc<dyn LanguageModelProvider>,
82 window: &mut Window,
83 cx: &mut Context<Self>,
84 ) {
85 let configuration_view = provider.configuration_view(window, cx);
86 self.configuration_views_by_provider
87 .insert(provider.id(), configuration_view);
88 }
89}
90
91impl Focusable for AssistantConfiguration {
92 fn focus_handle(&self, _: &App) -> FocusHandle {
93 self.focus_handle.clone()
94 }
95}
96
97pub enum AssistantConfigurationEvent {
98 NewThread(Arc<dyn LanguageModelProvider>),
99}
100
101impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
102
103impl AssistantConfiguration {
104 fn render_provider_configuration(
105 &mut self,
106 provider: &Arc<dyn LanguageModelProvider>,
107 cx: &mut Context<Self>,
108 ) -> impl IntoElement + use<> {
109 let provider_id = provider.id().0.clone();
110 let provider_name = provider.name().0.clone();
111 let configuration_view = self
112 .configuration_views_by_provider
113 .get(&provider.id())
114 .cloned();
115
116 v_flex()
117 .gap_1p5()
118 .child(
119 h_flex()
120 .justify_between()
121 .child(
122 h_flex()
123 .gap_2()
124 .child(
125 Icon::new(provider.icon())
126 .size(IconSize::Small)
127 .color(Color::Muted),
128 )
129 .child(Label::new(provider_name.clone())),
130 )
131 .when(provider.is_authenticated(cx), |parent| {
132 parent.child(
133 Button::new(
134 SharedString::from(format!("new-thread-{provider_id}")),
135 "Start New Thread",
136 )
137 .icon_position(IconPosition::Start)
138 .icon(IconName::Plus)
139 .icon_size(IconSize::Small)
140 .style(ButtonStyle::Filled)
141 .layer(ElevationIndex::ModalSurface)
142 .label_size(LabelSize::Small)
143 .on_click(cx.listener({
144 let provider = provider.clone();
145 move |_this, _event, _window, cx| {
146 cx.emit(AssistantConfigurationEvent::NewThread(
147 provider.clone(),
148 ))
149 }
150 })),
151 )
152 }),
153 )
154 .child(
155 div()
156 .p(DynamicSpacing::Base08.rems(cx))
157 .bg(cx.theme().colors().editor_background)
158 .border_1()
159 .border_color(cx.theme().colors().border_variant)
160 .rounded_sm()
161 .map(|parent| match configuration_view {
162 Some(configuration_view) => parent.child(configuration_view),
163 None => parent.child(div().child(Label::new(format!(
164 "No configuration view for {provider_name}",
165 )))),
166 }),
167 )
168 }
169
170 fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
171 let context_servers = self.context_server_manager.read(cx).all_servers().clone();
172 let tools_by_source = self.tools.tools_by_source(cx);
173 let empty = Vec::new();
174
175 const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
176
177 v_flex()
178 .p(DynamicSpacing::Base16.rems(cx))
179 .gap_2()
180 .flex_1()
181 .child(
182 v_flex()
183 .gap_0p5()
184 .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
185 .child(Label::new(SUBHEADING).color(Color::Muted)),
186 )
187 .children(context_servers.into_iter().map(|context_server| {
188 let is_running = context_server.client().is_some();
189 let are_tools_expanded = self
190 .expanded_context_server_tools
191 .get(&context_server.id())
192 .copied()
193 .unwrap_or_default();
194
195 let tools = tools_by_source
196 .get(&ToolSource::ContextServer {
197 id: context_server.id().into(),
198 })
199 .unwrap_or_else(|| &empty);
200 let tool_count = tools.len();
201
202 v_flex()
203 .id(SharedString::from(context_server.id()))
204 .border_1()
205 .rounded_sm()
206 .border_color(cx.theme().colors().border)
207 .bg(cx.theme().colors().editor_background)
208 .child(
209 h_flex()
210 .justify_between()
211 .px_2()
212 .py_1()
213 .when(are_tools_expanded, |element| {
214 element
215 .border_b_1()
216 .border_color(cx.theme().colors().border)
217 })
218 .child(
219 h_flex()
220 .gap_2()
221 .child(
222 Disclosure::new("tool-list-disclosure", are_tools_expanded)
223 .on_click(cx.listener({
224 let context_server_id = context_server.id();
225 move |this, _event, _window, _cx| {
226 let is_open = this
227 .expanded_context_server_tools
228 .entry(context_server_id.clone())
229 .or_insert(false);
230
231 *is_open = !*is_open;
232 }
233 })),
234 )
235 .child(Indicator::dot().color(if is_running {
236 Color::Success
237 } else {
238 Color::Error
239 }))
240 .child(Label::new(context_server.id()))
241 .child(
242 Label::new(format!("{tool_count} tools"))
243 .color(Color::Muted),
244 ),
245 )
246 .child(h_flex().child(
247 Switch::new("context-server-switch", is_running.into()).on_click({
248 let context_server_manager =
249 self.context_server_manager.clone();
250 let context_server = context_server.clone();
251 move |state, _window, cx| match state {
252 ToggleState::Unselected | ToggleState::Indeterminate => {
253 context_server_manager.update(cx, |this, cx| {
254 this.stop_server(context_server.clone(), cx)
255 .log_err();
256 });
257 }
258 ToggleState::Selected => {
259 cx.spawn({
260 let context_server_manager =
261 context_server_manager.clone();
262 let context_server = context_server.clone();
263 async move |cx| {
264 if let Some(start_server_task) =
265 context_server_manager
266 .update(cx, |this, cx| {
267 this.start_server(
268 context_server,
269 cx,
270 )
271 })
272 .log_err()
273 {
274 start_server_task.await.log_err();
275 }
276 }
277 })
278 .detach();
279 }
280 }
281 }),
282 )),
283 )
284 .map(|parent| {
285 if !are_tools_expanded {
286 return parent;
287 }
288
289 parent.child(v_flex().children(tools.into_iter().enumerate().map(
290 |(ix, tool)| {
291 h_flex()
292 .px_2()
293 .py_1()
294 .when(ix < tool_count - 1, |element| {
295 element
296 .border_b_1()
297 .border_color(cx.theme().colors().border)
298 })
299 .child(Label::new(tool.name()))
300 },
301 )))
302 })
303 }))
304 .child(
305 h_flex()
306 .justify_between()
307 .gap_2()
308 .child(
309 h_flex().w_full().child(
310 Button::new("add-context-server", "Add Context Server")
311 .style(ButtonStyle::Filled)
312 .layer(ElevationIndex::ModalSurface)
313 .full_width()
314 .icon(IconName::Plus)
315 .icon_size(IconSize::Small)
316 .icon_position(IconPosition::Start)
317 .on_click(|_event, window, cx| {
318 window.dispatch_action(AddContextServer.boxed_clone(), cx)
319 }),
320 ),
321 )
322 .child(
323 h_flex().w_full().child(
324 Button::new(
325 "install-context-server-extensions",
326 "Install Context Server Extensions",
327 )
328 .style(ButtonStyle::Filled)
329 .layer(ElevationIndex::ModalSurface)
330 .full_width()
331 .icon(IconName::DatabaseZap)
332 .icon_size(IconSize::Small)
333 .icon_position(IconPosition::Start)
334 .on_click(|_event, window, cx| {
335 window.dispatch_action(
336 zed_actions::Extensions {
337 category_filter: Some(
338 ExtensionCategoryFilter::ContextServers,
339 ),
340 }
341 .boxed_clone(),
342 cx,
343 )
344 }),
345 ),
346 ),
347 )
348 }
349}
350
351impl Render for AssistantConfiguration {
352 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
353 let providers = LanguageModelRegistry::read_global(cx).providers();
354
355 v_flex()
356 .id("assistant-configuration")
357 .track_focus(&self.focus_handle(cx))
358 .bg(cx.theme().colors().panel_background)
359 .size_full()
360 .overflow_y_scroll()
361 .child(self.render_context_servers_section(cx))
362 .child(Divider::horizontal().color(DividerColor::Border))
363 .child(
364 v_flex()
365 .p(DynamicSpacing::Base16.rems(cx))
366 .mt_1()
367 .gap_6()
368 .flex_1()
369 .child(
370 v_flex()
371 .gap_0p5()
372 .child(Headline::new("LLM Providers").size(HeadlineSize::Small))
373 .child(
374 Label::new("Add at least one provider to use AI-powered features.")
375 .color(Color::Muted),
376 ),
377 )
378 .children(
379 providers
380 .into_iter()
381 .map(|provider| self.render_provider_configuration(&provider, cx)),
382 ),
383 )
384 }
385}