1pub mod extension;
2pub mod registry;
3
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use collections::{HashMap, HashSet};
10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
11use context_server::transport::{HttpTransport, TransportError};
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use credentials_provider::CredentialsProvider;
14use futures::future::Either;
15use futures::{FutureExt as _, StreamExt as _, future::join_all};
16use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
17use http_client::HttpClient;
18use itertools::Itertools;
19use rand::Rng as _;
20use registry::ContextServerDescriptorRegistry;
21use remote::RemoteClient;
22use rpc::{AnyProtoClient, TypedEnvelope, proto};
23use settings::{Settings as _, SettingsStore};
24use util::{ResultExt as _, rel_path::RelPath};
25
26use crate::{
27 DisableAiSettings, Project,
28 project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings},
29 worktree_store::WorktreeStore,
30};
31
32/// Maximum timeout for context server requests
33/// Prevents extremely large timeout values from tying up resources indefinitely.
34const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
35
36pub fn init(cx: &mut App) {
37 extension::init(cx);
38}
39
40actions!(
41 context_server,
42 [
43 /// Restarts the context server.
44 Restart
45 ]
46);
47
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub enum ContextServerStatus {
50 Starting,
51 Running,
52 Stopped,
53 Error(Arc<str>),
54 /// The server returned 401 and OAuth authorization is needed. The UI
55 /// should show an "Authenticate" button.
56 AuthRequired,
57 /// The server has a pre-registered OAuth client_id, but a client_secret
58 /// is needed and not available in settings or the keychain.
59 ClientSecretRequired {
60 error: Option<Arc<str>>,
61 },
62 /// The OAuth browser flow is in progress — the user has been redirected
63 /// to the authorization server and we're waiting for the callback.
64 Authenticating,
65}
66
67impl ContextServerStatus {
68 fn from_state(state: &ContextServerState) -> Self {
69 match state {
70 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
71 ContextServerState::Running { .. } => ContextServerStatus::Running,
72 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
73 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
74 ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
75 ContextServerState::ClientSecretRequired { error, .. } => {
76 ContextServerStatus::ClientSecretRequired {
77 error: error.clone(),
78 }
79 }
80 ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
81 }
82 }
83}
84
85enum ContextServerState {
86 Starting {
87 server: Arc<ContextServer>,
88 configuration: Arc<ContextServerConfiguration>,
89 _task: Task<()>,
90 },
91 Running {
92 server: Arc<ContextServer>,
93 configuration: Arc<ContextServerConfiguration>,
94 },
95 Stopped {
96 server: Arc<ContextServer>,
97 configuration: Arc<ContextServerConfiguration>,
98 },
99 Error {
100 server: Arc<ContextServer>,
101 configuration: Arc<ContextServerConfiguration>,
102 error: Arc<str>,
103 },
104 /// The server requires OAuth authorization before it can be used. The
105 /// `OAuthDiscovery` holds everything needed to start the browser flow.
106 AuthRequired {
107 server: Arc<ContextServer>,
108 configuration: Arc<ContextServerConfiguration>,
109 discovery: Arc<OAuthDiscovery>,
110 },
111 /// A pre-registered client_id is configured but no client_secret was found
112 /// in settings or the keychain.
113 ClientSecretRequired {
114 server: Arc<ContextServer>,
115 configuration: Arc<ContextServerConfiguration>,
116 discovery: Arc<OAuthDiscovery>,
117 error: Option<Arc<str>>,
118 },
119 /// The OAuth browser flow is in progress. The user has been redirected
120 /// to the authorization server and we're waiting for the callback.
121 Authenticating {
122 server: Arc<ContextServer>,
123 configuration: Arc<ContextServerConfiguration>,
124 _task: Task<()>,
125 },
126}
127
128impl ContextServerState {
129 pub fn server(&self) -> Arc<ContextServer> {
130 match self {
131 ContextServerState::Starting { server, .. }
132 | ContextServerState::Running { server, .. }
133 | ContextServerState::Stopped { server, .. }
134 | ContextServerState::Error { server, .. }
135 | ContextServerState::AuthRequired { server, .. }
136 | ContextServerState::ClientSecretRequired { server, .. }
137 | ContextServerState::Authenticating { server, .. } => server.clone(),
138 }
139 }
140
141 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
142 match self {
143 ContextServerState::Starting { configuration, .. }
144 | ContextServerState::Running { configuration, .. }
145 | ContextServerState::Stopped { configuration, .. }
146 | ContextServerState::Error { configuration, .. }
147 | ContextServerState::AuthRequired { configuration, .. }
148 | ContextServerState::ClientSecretRequired { configuration, .. }
149 | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
150 }
151 }
152}
153
154#[derive(Debug, PartialEq, Eq)]
155pub enum ContextServerConfiguration {
156 Custom {
157 command: ContextServerCommand,
158 remote: bool,
159 },
160 Extension {
161 command: ContextServerCommand,
162 settings: serde_json::Value,
163 remote: bool,
164 },
165 Http {
166 url: url::Url,
167 headers: HashMap<String, String>,
168 timeout: Option<u64>,
169 oauth: Option<OAuthClientSettings>,
170 },
171}
172
173impl ContextServerConfiguration {
174 pub fn command(&self) -> Option<&ContextServerCommand> {
175 match self {
176 ContextServerConfiguration::Custom { command, .. } => Some(command),
177 ContextServerConfiguration::Extension { command, .. } => Some(command),
178 ContextServerConfiguration::Http { .. } => None,
179 }
180 }
181
182 pub fn has_static_auth_header(&self) -> bool {
183 match self {
184 ContextServerConfiguration::Http { headers, .. } => headers
185 .keys()
186 .any(|k| k.eq_ignore_ascii_case("authorization")),
187 _ => false,
188 }
189 }
190
191 pub fn remote(&self) -> bool {
192 match self {
193 ContextServerConfiguration::Custom { remote, .. } => *remote,
194 ContextServerConfiguration::Extension { remote, .. } => *remote,
195 ContextServerConfiguration::Http { .. } => false,
196 }
197 }
198
199 pub async fn from_settings(
200 settings: ContextServerSettings,
201 id: ContextServerId,
202 registry: Entity<ContextServerDescriptorRegistry>,
203 worktree_store: Entity<WorktreeStore>,
204 cx: &AsyncApp,
205 ) -> Option<Self> {
206 const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
207
208 match settings {
209 ContextServerSettings::Stdio {
210 enabled: _,
211 command,
212 remote,
213 } => Some(ContextServerConfiguration::Custom { command, remote }),
214 ContextServerSettings::Extension {
215 enabled: _,
216 settings,
217 remote,
218 } => {
219 let descriptor =
220 cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
221
222 let command_future = descriptor.command(worktree_store, cx);
223 let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
224
225 match futures::future::select(command_future, timeout_future).await {
226 Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
227 command,
228 settings,
229 remote,
230 }),
231 Either::Left((Err(e), _)) => {
232 log::error!(
233 "Failed to create context server configuration from settings: {e:#}"
234 );
235 None
236 }
237 Either::Right(_) => {
238 log::error!(
239 "Timed out resolving command for extension context server {id}"
240 );
241 None
242 }
243 }
244 }
245 ContextServerSettings::Http {
246 enabled: _,
247 url,
248 headers: auth,
249 timeout,
250 oauth,
251 } => {
252 let url = url::Url::parse(&url).log_err()?;
253 Some(ContextServerConfiguration::Http {
254 url,
255 headers: auth,
256 timeout,
257 oauth,
258 })
259 }
260 }
261 }
262}
263
264pub type ContextServerFactory =
265 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
266
267enum ContextServerStoreState {
268 Local {
269 downstream_client: Option<(u64, AnyProtoClient)>,
270 is_headless: bool,
271 },
272 Remote {
273 project_id: u64,
274 upstream_client: Entity<RemoteClient>,
275 },
276}
277
278pub struct ContextServerStore {
279 state: ContextServerStoreState,
280 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
281 servers: HashMap<ContextServerId, ContextServerState>,
282 server_ids: Vec<ContextServerId>,
283 worktree_store: Entity<WorktreeStore>,
284 project: Option<WeakEntity<Project>>,
285 registry: Entity<ContextServerDescriptorRegistry>,
286 update_servers_task: Option<Task<Result<()>>>,
287 context_server_factory: Option<ContextServerFactory>,
288 needs_server_update: bool,
289 ai_disabled: bool,
290 _subscriptions: Vec<Subscription>,
291}
292
293pub struct ServerStatusChangedEvent {
294 pub server_id: ContextServerId,
295 pub status: ContextServerStatus,
296}
297
298impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
299
300impl ContextServerStore {
301 pub fn local(
302 worktree_store: Entity<WorktreeStore>,
303 weak_project: Option<WeakEntity<Project>>,
304 headless: bool,
305 cx: &mut Context<Self>,
306 ) -> Self {
307 Self::new_internal(
308 !headless,
309 None,
310 ContextServerDescriptorRegistry::default_global(cx),
311 worktree_store,
312 weak_project,
313 ContextServerStoreState::Local {
314 downstream_client: None,
315 is_headless: headless,
316 },
317 cx,
318 )
319 }
320
321 pub fn remote(
322 project_id: u64,
323 upstream_client: Entity<RemoteClient>,
324 worktree_store: Entity<WorktreeStore>,
325 weak_project: Option<WeakEntity<Project>>,
326 cx: &mut Context<Self>,
327 ) -> Self {
328 Self::new_internal(
329 true,
330 None,
331 ContextServerDescriptorRegistry::default_global(cx),
332 worktree_store,
333 weak_project,
334 ContextServerStoreState::Remote {
335 project_id,
336 upstream_client,
337 },
338 cx,
339 )
340 }
341
342 pub fn init_headless(session: &AnyProtoClient) {
343 session.add_entity_request_handler(Self::handle_get_context_server_command);
344 }
345
346 pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
347 if let ContextServerStoreState::Local {
348 downstream_client, ..
349 } = &mut self.state
350 {
351 *downstream_client = Some((project_id, client));
352 }
353 }
354
355 pub fn is_remote_project(&self) -> bool {
356 matches!(self.state, ContextServerStoreState::Remote { .. })
357 }
358
359 /// Returns all configured context server ids, excluding the ones that are disabled
360 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
361 self.context_server_settings
362 .iter()
363 .filter(|(_, settings)| settings.enabled())
364 .map(|(id, _)| ContextServerId(id.clone()))
365 .collect()
366 }
367
368 #[cfg(feature = "test-support")]
369 pub fn test(
370 registry: Entity<ContextServerDescriptorRegistry>,
371 worktree_store: Entity<WorktreeStore>,
372 weak_project: Option<WeakEntity<Project>>,
373 cx: &mut Context<Self>,
374 ) -> Self {
375 Self::new_internal(
376 false,
377 None,
378 registry,
379 worktree_store,
380 weak_project,
381 ContextServerStoreState::Local {
382 downstream_client: None,
383 is_headless: false,
384 },
385 cx,
386 )
387 }
388
389 #[cfg(feature = "test-support")]
390 pub fn test_maintain_server_loop(
391 context_server_factory: Option<ContextServerFactory>,
392 registry: Entity<ContextServerDescriptorRegistry>,
393 worktree_store: Entity<WorktreeStore>,
394 weak_project: Option<WeakEntity<Project>>,
395 cx: &mut Context<Self>,
396 ) -> Self {
397 Self::new_internal(
398 true,
399 context_server_factory,
400 registry,
401 worktree_store,
402 weak_project,
403 ContextServerStoreState::Local {
404 downstream_client: None,
405 is_headless: false,
406 },
407 cx,
408 )
409 }
410
411 #[cfg(feature = "test-support")]
412 pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
413 self.context_server_factory = Some(factory);
414 }
415
416 #[cfg(feature = "test-support")]
417 pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
418 &self.registry
419 }
420
421 #[cfg(feature = "test-support")]
422 pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
423 let configuration = Arc::new(ContextServerConfiguration::Custom {
424 command: ContextServerCommand {
425 path: "test".into(),
426 args: vec![],
427 env: None,
428 timeout: None,
429 },
430 remote: false,
431 });
432 self.run_server(server, configuration, cx);
433 }
434
435 fn new_internal(
436 maintain_server_loop: bool,
437 context_server_factory: Option<ContextServerFactory>,
438 registry: Entity<ContextServerDescriptorRegistry>,
439 worktree_store: Entity<WorktreeStore>,
440 weak_project: Option<WeakEntity<Project>>,
441 state: ContextServerStoreState,
442 cx: &mut Context<Self>,
443 ) -> Self {
444 let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
445 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
446 let ai_was_disabled = this.ai_disabled;
447 this.ai_disabled = ai_disabled;
448
449 let settings =
450 &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
451 let settings_changed = &this.context_server_settings != settings;
452
453 if settings_changed {
454 this.context_server_settings = settings.clone();
455 }
456
457 // When AI is disabled, stop all running servers
458 if ai_disabled {
459 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
460 for id in server_ids {
461 this.stop_server(&id, cx).log_err();
462 }
463 return;
464 }
465
466 // Trigger updates if AI was re-enabled or settings changed
467 if maintain_server_loop && (ai_was_disabled || settings_changed) {
468 this.available_context_servers_changed(cx);
469 }
470 })];
471
472 if maintain_server_loop {
473 subscriptions.push(cx.observe(®istry, |this, _registry, cx| {
474 if !DisableAiSettings::get_global(cx).disable_ai {
475 this.available_context_servers_changed(cx);
476 }
477 }));
478 }
479
480 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
481 let mut this = Self {
482 state,
483 _subscriptions: subscriptions,
484 context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
485 .context_servers
486 .clone(),
487 worktree_store,
488 project: weak_project,
489 registry,
490 needs_server_update: false,
491 ai_disabled,
492 servers: HashMap::default(),
493 server_ids: Default::default(),
494 update_servers_task: None,
495 context_server_factory,
496 };
497 if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
498 this.available_context_servers_changed(cx);
499 }
500 this
501 }
502
503 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
504 self.servers.get(id).map(|state| state.server())
505 }
506
507 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
508 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
509 Some(server.clone())
510 } else {
511 None
512 }
513 }
514
515 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
516 self.servers.get(id).map(ContextServerStatus::from_state)
517 }
518
519 pub fn configuration_for_server(
520 &self,
521 id: &ContextServerId,
522 ) -> Option<Arc<ContextServerConfiguration>> {
523 self.servers.get(id).map(|state| state.configuration())
524 }
525
526 /// Returns a sorted slice of available unique context server IDs. Within the
527 /// slice, context servers which have `mcp-server-` as a prefix in their ID will
528 /// appear after servers that do not have this prefix in their ID.
529 pub fn server_ids(&self) -> &[ContextServerId] {
530 self.server_ids.as_slice()
531 }
532
533 fn populate_server_ids(&mut self, cx: &App) {
534 self.server_ids = self
535 .servers
536 .keys()
537 .cloned()
538 .chain(
539 self.registry
540 .read(cx)
541 .context_server_descriptors()
542 .into_iter()
543 .map(|(id, _)| ContextServerId(id)),
544 )
545 .chain(
546 self.context_server_settings
547 .keys()
548 .map(|id| ContextServerId(id.clone())),
549 )
550 .unique()
551 .sorted_unstable_by(
552 // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
553 |a, b| {
554 const MCP_PREFIX: &str = "mcp-server-";
555 match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
556 // If one has mcp-server- prefix and other doesn't, non-mcp comes first
557 (Some(_), None) => std::cmp::Ordering::Greater,
558 (None, Some(_)) => std::cmp::Ordering::Less,
559 // If both have same prefix status, sort by appropriate key
560 (Some(a), Some(b)) => a.cmp(b),
561 (None, None) => a.0.cmp(&b.0),
562 }
563 },
564 )
565 .collect();
566 }
567
568 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
569 self.servers
570 .values()
571 .filter_map(|state| {
572 if let ContextServerState::Running { server, .. } = state {
573 Some(server.clone())
574 } else {
575 None
576 }
577 })
578 .collect()
579 }
580
581 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
582 cx.spawn(async move |this, cx| {
583 let this = this.upgrade().context("Context server store dropped")?;
584 let id = server.id();
585 let settings = this
586 .update(cx, |this, _| {
587 this.context_server_settings.get(&id.0).cloned()
588 })
589 .context("Failed to get context server settings")?;
590
591 if !settings.enabled() {
592 return anyhow::Ok(());
593 }
594
595 let (registry, worktree_store) = this.update(cx, |this, _| {
596 (this.registry.clone(), this.worktree_store.clone())
597 });
598 let configuration = ContextServerConfiguration::from_settings(
599 settings,
600 id.clone(),
601 registry,
602 worktree_store,
603 cx,
604 )
605 .await
606 .context("Failed to create context server configuration")?;
607
608 this.update(cx, |this, cx| {
609 this.run_server(server, Arc::new(configuration), cx)
610 });
611 Ok(())
612 })
613 .detach_and_log_err(cx);
614 }
615
616 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
617 if matches!(
618 self.servers.get(id),
619 Some(ContextServerState::Stopped { .. })
620 ) {
621 return Ok(());
622 }
623
624 let state = self
625 .servers
626 .remove(id)
627 .context("Context server not found")?;
628
629 let server = state.server();
630 let configuration = state.configuration();
631 let mut result = Ok(());
632 if let ContextServerState::Running { server, .. } = &state {
633 result = server.stop();
634 }
635 drop(state);
636
637 self.update_server_state(
638 id.clone(),
639 ContextServerState::Stopped {
640 configuration,
641 server,
642 },
643 cx,
644 );
645
646 result
647 }
648
649 fn run_server(
650 &mut self,
651 server: Arc<ContextServer>,
652 configuration: Arc<ContextServerConfiguration>,
653 cx: &mut Context<Self>,
654 ) {
655 let id = server.id();
656 if matches!(
657 self.servers.get(&id),
658 Some(
659 ContextServerState::Starting { .. }
660 | ContextServerState::Running { .. }
661 | ContextServerState::Authenticating { .. },
662 )
663 ) {
664 self.stop_server(&id, cx).log_err();
665 }
666 let task = cx.spawn({
667 let id = server.id();
668 let server = server.clone();
669 let configuration = configuration.clone();
670
671 async move |this, cx| {
672 let new_state = match server.clone().start(cx).await {
673 Ok(_) => {
674 debug_assert!(server.client().is_some());
675 ContextServerState::Running {
676 server,
677 configuration,
678 }
679 }
680 Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
681 };
682 this.update(cx, |this, cx| {
683 this.update_server_state(id.clone(), new_state, cx)
684 })
685 .log_err();
686 }
687 });
688
689 self.update_server_state(
690 id.clone(),
691 ContextServerState::Starting {
692 configuration,
693 _task: task,
694 server,
695 },
696 cx,
697 );
698 }
699
700 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
701 let state = self
702 .servers
703 .remove(id)
704 .context("Context server not found")?;
705
706 if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
707 let server_url = url.clone();
708 let id = id.clone();
709 cx.spawn(async move |_this, cx| {
710 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
711 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
712 {
713 log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
714 }
715 })
716 .detach();
717 }
718
719 drop(state);
720 cx.emit(ServerStatusChangedEvent {
721 server_id: id.clone(),
722 status: ContextServerStatus::Stopped,
723 });
724 Ok(())
725 }
726
727 pub async fn create_context_server(
728 this: WeakEntity<Self>,
729 id: ContextServerId,
730 configuration: Arc<ContextServerConfiguration>,
731 cx: &mut AsyncApp,
732 ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
733 let remote = configuration.remote();
734 let needs_remote_command = match configuration.as_ref() {
735 ContextServerConfiguration::Custom { .. }
736 | ContextServerConfiguration::Extension { .. } => remote,
737 ContextServerConfiguration::Http { .. } => false,
738 };
739
740 let (remote_state, is_remote_project) = this.update(cx, |this, _| {
741 let remote_state = match &this.state {
742 ContextServerStoreState::Remote {
743 project_id,
744 upstream_client,
745 } if needs_remote_command => Some((*project_id, upstream_client.clone())),
746 _ => None,
747 };
748 (remote_state, this.is_remote_project())
749 })?;
750
751 let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
752 this.project
753 .as_ref()
754 .and_then(|project| {
755 project
756 .read_with(cx, |project, cx| project.active_project_directory(cx))
757 .ok()
758 .flatten()
759 })
760 .or_else(|| {
761 this.worktree_store.read_with(cx, |store, cx| {
762 store.visible_worktrees(cx).fold(None, |acc, item| {
763 if acc.is_none() {
764 item.read(cx).root_dir()
765 } else {
766 acc
767 }
768 })
769 })
770 })
771 })?;
772
773 let configuration = if let Some((project_id, upstream_client)) = remote_state {
774 let root_dir = root_path.as_ref().map(|p| p.display().to_string());
775
776 let response = upstream_client
777 .update(cx, |client, _| {
778 client
779 .proto_client()
780 .request(proto::GetContextServerCommand {
781 project_id,
782 server_id: id.0.to_string(),
783 root_dir: root_dir.clone(),
784 })
785 })
786 .await?;
787
788 let remote_command = upstream_client.update(cx, |client, _| {
789 client.build_command(
790 Some(response.path),
791 &response.args,
792 &response.env.into_iter().collect(),
793 root_dir,
794 None,
795 )
796 })?;
797
798 let command = ContextServerCommand {
799 path: remote_command.program.into(),
800 args: remote_command.args,
801 env: Some(remote_command.env.into_iter().collect()),
802 timeout: None,
803 };
804
805 Arc::new(ContextServerConfiguration::Custom { command, remote })
806 } else {
807 configuration
808 };
809
810 if let Some(server) = this.update(cx, |this, _| {
811 this.context_server_factory
812 .as_ref()
813 .map(|factory| factory(id.clone(), configuration.clone()))
814 })? {
815 return Ok((server, configuration));
816 }
817
818 let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
819 if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
820 if configuration.has_static_auth_header() {
821 None
822 } else {
823 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
824 let http_client = cx.update(|cx| cx.http_client());
825
826 match Self::load_session(&credentials_provider, url, &cx).await {
827 Ok(Some(session)) => {
828 log::info!("{} loaded cached OAuth session from keychain", id);
829 Some(Self::create_oauth_token_provider(
830 &id,
831 url,
832 session,
833 http_client,
834 credentials_provider,
835 cx,
836 ))
837 }
838 Ok(None) => None,
839 Err(err) => {
840 log::warn!("{} failed to load cached OAuth session: {}", id, err);
841 None
842 }
843 }
844 }
845 } else {
846 None
847 };
848
849 let server: Arc<ContextServer> = this.update(cx, |this, cx| {
850 let global_timeout =
851 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
852
853 match configuration.as_ref() {
854 ContextServerConfiguration::Http {
855 url,
856 headers,
857 timeout,
858 oauth: _,
859 } => {
860 let transport = HttpTransport::new_with_token_provider(
861 cx.http_client(),
862 url.to_string(),
863 headers.clone(),
864 cx.background_executor().clone(),
865 cached_token_provider.clone(),
866 );
867 anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
868 id,
869 Arc::new(transport),
870 Some(Duration::from_secs(
871 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
872 )),
873 )))
874 }
875 _ => {
876 let mut command = configuration
877 .command()
878 .context("Missing command configuration for stdio context server")?
879 .clone();
880 command.timeout = Some(
881 command
882 .timeout
883 .unwrap_or(global_timeout)
884 .min(MAX_TIMEOUT_SECS),
885 );
886
887 // Don't pass remote paths as working directory for locally-spawned processes
888 let working_directory = if is_remote_project { None } else { root_path };
889 anyhow::Ok(Arc::new(ContextServer::stdio(
890 id,
891 command,
892 working_directory,
893 )))
894 }
895 }
896 })??;
897
898 Ok((server, configuration))
899 }
900
901 async fn handle_get_context_server_command(
902 this: Entity<Self>,
903 envelope: TypedEnvelope<proto::GetContextServerCommand>,
904 mut cx: AsyncApp,
905 ) -> Result<proto::ContextServerCommand> {
906 let server_id = ContextServerId(envelope.payload.server_id.into());
907
908 let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
909 let ContextServerStoreState::Local {
910 is_headless: true, ..
911 } = &this.state
912 else {
913 anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
914 };
915
916 let settings = this
917 .context_server_settings
918 .get(&server_id.0)
919 .cloned()
920 .or_else(|| {
921 this.registry
922 .read(inner_cx)
923 .context_server_descriptor(&server_id.0)
924 .map(|_| ContextServerSettings::default_extension())
925 })
926 .with_context(|| format!("context server `{}` not found", server_id))?;
927
928 anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
929 })?;
930
931 let configuration = ContextServerConfiguration::from_settings(
932 settings,
933 server_id.clone(),
934 registry,
935 worktree_store,
936 &cx,
937 )
938 .await
939 .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
940
941 let command = configuration
942 .command()
943 .context("context server has no command (HTTP servers don't need RPC)")?;
944
945 Ok(proto::ContextServerCommand {
946 path: command.path.display().to_string(),
947 args: command.args.clone(),
948 env: command
949 .env
950 .clone()
951 .map(|env| env.into_iter().collect())
952 .unwrap_or_default(),
953 })
954 }
955
956 fn resolve_project_settings<'a>(
957 worktree_store: &'a Entity<WorktreeStore>,
958 cx: &'a App,
959 ) -> &'a ProjectSettings {
960 let location = worktree_store
961 .read(cx)
962 .visible_worktrees(cx)
963 .next()
964 .map(|worktree| settings::SettingsLocation {
965 worktree_id: worktree.read(cx).id(),
966 path: RelPath::empty(),
967 });
968 ProjectSettings::get(location, cx)
969 }
970
971 fn create_oauth_token_provider(
972 id: &ContextServerId,
973 server_url: &url::Url,
974 session: OAuthSession,
975 http_client: Arc<dyn HttpClient>,
976 credentials_provider: Arc<dyn CredentialsProvider>,
977 cx: &mut AsyncApp,
978 ) -> Arc<dyn oauth::OAuthTokenProvider> {
979 let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
980 let id = id.clone();
981 let server_url = server_url.clone();
982
983 cx.spawn(async move |cx| {
984 while let Some(refreshed_session) = token_refresh_rx.next().await {
985 if let Err(err) =
986 Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
987 .await
988 {
989 log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
990 }
991 }
992 log::debug!("{} OAuth session persistence task ended", id);
993 })
994 .detach();
995
996 Arc::new(McpOAuthTokenProvider::new(
997 session,
998 http_client,
999 Some(token_refresh_tx),
1000 ))
1001 }
1002
1003 /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
1004 ///
1005 /// This starts a loopback HTTP callback server on an ephemeral port, builds
1006 /// the authorization URL, opens the user's browser, waits for the callback,
1007 /// exchanges the code for tokens, persists them in the keychain, and restarts
1008 /// the server with the new token provider.
1009 pub fn authenticate_server(
1010 &mut self,
1011 id: &ContextServerId,
1012 cx: &mut Context<Self>,
1013 ) -> Result<()> {
1014 let state = self.servers.get(id).context("Context server not found")?;
1015
1016 let (discovery, server, configuration) = match state {
1017 ContextServerState::AuthRequired {
1018 discovery,
1019 server,
1020 configuration,
1021 } => (discovery.clone(), server.clone(), configuration.clone()),
1022 _ => anyhow::bail!("Server is not in AuthRequired state"),
1023 };
1024
1025 let needs_keychain_check = match configuration.as_ref() {
1026 ContextServerConfiguration::Http {
1027 url,
1028 oauth: Some(oauth_settings),
1029 ..
1030 } if oauth_settings.client_secret.is_none() => Some(url.clone()),
1031 _ => None,
1032 };
1033
1034 let id = id.clone();
1035
1036 let task = cx.spawn({
1037 let id = id.clone();
1038 let server = server.clone();
1039 let configuration = configuration.clone();
1040 async move |this, cx| {
1041 if let Some(server_url) = needs_keychain_check {
1042 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1043 let has_keychain_secret =
1044 Self::load_client_secret(&credentials_provider, &server_url, cx)
1045 .await
1046 .ok()
1047 .flatten()
1048 .is_some();
1049
1050 if !has_keychain_secret {
1051 this.update(cx, |this, cx| {
1052 this.update_server_state(
1053 id.clone(),
1054 ContextServerState::ClientSecretRequired {
1055 server,
1056 configuration,
1057 discovery,
1058 error: None,
1059 },
1060 cx,
1061 );
1062 })
1063 .log_err();
1064 return;
1065 }
1066 }
1067
1068 let result = Self::run_oauth_flow(
1069 this.clone(),
1070 id.clone(),
1071 discovery.clone(),
1072 configuration.clone(),
1073 cx,
1074 )
1075 .await;
1076
1077 if let Err(err) = &result {
1078 log::error!("{} OAuth authentication failed: {:?}", id, err);
1079 this.update(cx, |this, cx| {
1080 this.update_server_state(
1081 id.clone(),
1082 ContextServerState::Error {
1083 server,
1084 configuration,
1085 error: format!("{err:#}").into(),
1086 },
1087 cx,
1088 )
1089 })
1090 .log_err();
1091 }
1092 }
1093 });
1094
1095 self.update_server_state(
1096 id,
1097 ContextServerState::Authenticating {
1098 server,
1099 configuration,
1100 _task: task,
1101 },
1102 cx,
1103 );
1104
1105 Ok(())
1106 }
1107
1108 /// Store the client secret and proceed with authentication.
1109 pub fn submit_client_secret(
1110 &mut self,
1111 id: &ContextServerId,
1112 secret: String,
1113 cx: &mut Context<Self>,
1114 ) -> Result<()> {
1115 let state = self.servers.get(id).context("Context server not found")?;
1116
1117 let (server, configuration, discovery) = match state {
1118 ContextServerState::ClientSecretRequired {
1119 server,
1120 configuration,
1121 discovery,
1122 ..
1123 } => (server.clone(), configuration.clone(), discovery.clone()),
1124 _ => anyhow::bail!("Server is not in ClientSecretRequired state"),
1125 };
1126
1127 let server_url = match configuration.as_ref() {
1128 ContextServerConfiguration::Http { url, .. } => url.clone(),
1129 _ => anyhow::bail!("OAuth only supported for HTTP servers"),
1130 };
1131
1132 let id = id.clone();
1133
1134 let task = cx.spawn({
1135 let id = id.clone();
1136 let server = server.clone();
1137 let configuration = configuration.clone();
1138 async move |this, cx| {
1139 // Store the secret if non-empty (empty means public client / skip).
1140 if !secret.is_empty() {
1141 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1142 if let Err(err) =
1143 Self::store_client_secret(&credentials_provider, &server_url, &secret, cx)
1144 .await
1145 {
1146 log::error!(
1147 "{} failed to store client secret in keychain: {:?}",
1148 id,
1149 err
1150 );
1151 }
1152 }
1153
1154 let result = Self::run_oauth_flow(
1155 this.clone(),
1156 id.clone(),
1157 discovery.clone(),
1158 configuration.clone(),
1159 cx,
1160 )
1161 .await;
1162
1163 if let Err(err) = &result {
1164 log::error!("{} OAuth authentication failed: {:?}", id, err);
1165
1166 let is_bad_client_credentials = err
1167 .downcast_ref::<oauth::OAuthTokenError>()
1168 .is_some_and(|e| e.error == "unauthorized_client");
1169
1170 if is_bad_client_credentials {
1171 // Clear the bad secret from the keychain so the user
1172 // gets a fresh prompt.
1173 let credentials_provider =
1174 cx.update(|cx| zed_credentials_provider::global(cx));
1175 Self::clear_client_secret(&credentials_provider, &server_url, cx)
1176 .await
1177 .log_err();
1178
1179 this.update(cx, |this, cx| {
1180 this.update_server_state(
1181 id.clone(),
1182 ContextServerState::ClientSecretRequired {
1183 server,
1184 configuration,
1185 discovery,
1186 error: Some(format!("{err:#}").into()),
1187 },
1188 cx,
1189 );
1190 })
1191 .log_err();
1192 } else {
1193 this.update(cx, |this, cx| {
1194 this.update_server_state(
1195 id.clone(),
1196 ContextServerState::Error {
1197 server,
1198 configuration,
1199 error: format!("{err:#}").into(),
1200 },
1201 cx,
1202 )
1203 })
1204 .log_err();
1205 }
1206 }
1207 }
1208 });
1209
1210 self.update_server_state(
1211 id,
1212 ContextServerState::Authenticating {
1213 server,
1214 configuration,
1215 _task: task,
1216 },
1217 cx,
1218 );
1219
1220 Ok(())
1221 }
1222
1223 async fn run_oauth_flow(
1224 this: WeakEntity<Self>,
1225 id: ContextServerId,
1226 discovery: Arc<OAuthDiscovery>,
1227 configuration: Arc<ContextServerConfiguration>,
1228 cx: &mut AsyncApp,
1229 ) -> Result<()> {
1230 let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1231 let pkce = oauth::generate_pkce_challenge();
1232
1233 let mut state_bytes = [0u8; 32];
1234 rand::rng().fill(&mut state_bytes);
1235 let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1236
1237 // Start a loopback HTTP server on an ephemeral port. The redirect URI
1238 // includes this port so the browser sends the callback directly to our
1239 // process.
1240 let (redirect_uri, callback_rx) = oauth::start_callback_server()
1241 .await
1242 .context("Failed to start OAuth callback server")?;
1243
1244 let http_client = cx.update(|cx| cx.http_client());
1245 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1246 let server_url = match configuration.as_ref() {
1247 ContextServerConfiguration::Http { url, .. } => url.clone(),
1248 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1249 };
1250
1251 let client_registration = match configuration.as_ref() {
1252 ContextServerConfiguration::Http {
1253 url,
1254 oauth: Some(oauth_settings),
1255 ..
1256 } => {
1257 // Pre-registered client. Resolve the secret from settings, then keychain.
1258 let client_secret = if oauth_settings.client_secret.is_some() {
1259 oauth_settings.client_secret.clone()
1260 } else {
1261 Self::load_client_secret(&credentials_provider, url, cx)
1262 .await
1263 .ok()
1264 .flatten()
1265 };
1266 oauth::OAuthClientRegistration {
1267 client_id: oauth_settings.client_id.clone(),
1268 client_secret,
1269 }
1270 }
1271 _ => oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1272 .await
1273 .context("Failed to resolve OAuth client registration")?,
1274 };
1275
1276 let auth_url = oauth::build_authorization_url(
1277 &discovery.auth_server_metadata,
1278 &client_registration.client_id,
1279 &redirect_uri,
1280 &discovery.scopes,
1281 &resource,
1282 &pkce,
1283 &state_param,
1284 );
1285
1286 cx.update(|cx| cx.open_url(auth_url.as_str()));
1287
1288 let callback = callback_rx
1289 .await
1290 .map_err(|_| {
1291 anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1292 })?
1293 .context("OAuth callback server received an invalid request")?;
1294
1295 if callback.state != state_param {
1296 anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1297 }
1298
1299 let tokens = oauth::exchange_code(
1300 &http_client,
1301 &discovery.auth_server_metadata,
1302 &callback.code,
1303 &client_registration.client_id,
1304 &redirect_uri,
1305 &pkce.verifier,
1306 &resource,
1307 client_registration.client_secret.as_deref(),
1308 )
1309 .await
1310 .context("Failed to exchange authorization code for tokens")?;
1311
1312 let session = OAuthSession {
1313 token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1314 resource: discovery.resource_metadata.resource.clone(),
1315 client_registration,
1316 tokens,
1317 };
1318
1319 Self::store_session(&credentials_provider, &server_url, &session, cx)
1320 .await
1321 .context("Failed to persist OAuth session in keychain")?;
1322
1323 let token_provider = Self::create_oauth_token_provider(
1324 &id,
1325 &server_url,
1326 session,
1327 http_client.clone(),
1328 credentials_provider,
1329 cx,
1330 );
1331
1332 let new_server = this.update(cx, |this, cx| {
1333 let global_timeout =
1334 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1335
1336 match configuration.as_ref() {
1337 ContextServerConfiguration::Http {
1338 url,
1339 headers,
1340 timeout,
1341 oauth: _,
1342 } => {
1343 let transport = HttpTransport::new_with_token_provider(
1344 http_client.clone(),
1345 url.to_string(),
1346 headers.clone(),
1347 cx.background_executor().clone(),
1348 Some(token_provider.clone()),
1349 );
1350 Ok(Arc::new(ContextServer::new_with_timeout(
1351 id.clone(),
1352 Arc::new(transport),
1353 Some(Duration::from_secs(
1354 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1355 )),
1356 )))
1357 }
1358 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1359 }
1360 })??;
1361
1362 this.update(cx, |this, cx| {
1363 this.run_server(new_server, configuration, cx);
1364 })?;
1365
1366 Ok(())
1367 }
1368
1369 /// Store the full OAuth session in the system keychain, keyed by the
1370 /// server's canonical URI.
1371 async fn store_session(
1372 credentials_provider: &Arc<dyn CredentialsProvider>,
1373 server_url: &url::Url,
1374 session: &OAuthSession,
1375 cx: &AsyncApp,
1376 ) -> Result<()> {
1377 let key = Self::keychain_key(server_url);
1378 let json = serde_json::to_string(session)?;
1379 credentials_provider
1380 .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1381 .await
1382 }
1383
1384 /// Load the full OAuth session from the system keychain for the given
1385 /// server URL.
1386 async fn load_session(
1387 credentials_provider: &Arc<dyn CredentialsProvider>,
1388 server_url: &url::Url,
1389 cx: &AsyncApp,
1390 ) -> Result<Option<OAuthSession>> {
1391 let key = Self::keychain_key(server_url);
1392 match credentials_provider.read_credentials(&key, cx).await? {
1393 Some((_username, password_bytes)) => {
1394 let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1395 Ok(Some(session))
1396 }
1397 None => Ok(None),
1398 }
1399 }
1400
1401 /// Clear the stored OAuth session from the system keychain.
1402 async fn clear_session(
1403 credentials_provider: &Arc<dyn CredentialsProvider>,
1404 server_url: &url::Url,
1405 cx: &AsyncApp,
1406 ) -> Result<()> {
1407 let key = Self::keychain_key(server_url);
1408 credentials_provider.delete_credentials(&key, cx).await
1409 }
1410
1411 fn keychain_key(server_url: &url::Url) -> String {
1412 format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1413 }
1414
1415 fn client_secret_keychain_key(server_url: &url::Url) -> String {
1416 format!(
1417 "mcp-oauth-client-secret:{}",
1418 oauth::canonical_server_uri(server_url)
1419 )
1420 }
1421
1422 async fn load_client_secret(
1423 credentials_provider: &Arc<dyn CredentialsProvider>,
1424 server_url: &url::Url,
1425 cx: &AsyncApp,
1426 ) -> Result<Option<String>> {
1427 let key = Self::client_secret_keychain_key(server_url);
1428 match credentials_provider.read_credentials(&key, cx).await? {
1429 Some((_username, secret_bytes)) => Ok(Some(String::from_utf8(secret_bytes)?)),
1430 None => Ok(None),
1431 }
1432 }
1433
1434 pub async fn store_client_secret(
1435 credentials_provider: &Arc<dyn CredentialsProvider>,
1436 server_url: &url::Url,
1437 secret: &str,
1438 cx: &AsyncApp,
1439 ) -> Result<()> {
1440 let key = Self::client_secret_keychain_key(server_url);
1441 credentials_provider
1442 .write_credentials(&key, "mcp-oauth-client-secret", secret.as_bytes(), cx)
1443 .await
1444 }
1445
1446 async fn clear_client_secret(
1447 credentials_provider: &Arc<dyn CredentialsProvider>,
1448 server_url: &url::Url,
1449 cx: &AsyncApp,
1450 ) -> Result<()> {
1451 let key = Self::client_secret_keychain_key(server_url);
1452 credentials_provider.delete_credentials(&key, cx).await
1453 }
1454
1455 /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1456 /// session from the keychain and stop the server.
1457 pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1458 let state = self.servers.get(id).context("Context server not found")?;
1459 let configuration = state.configuration();
1460
1461 let server_url = match configuration.as_ref() {
1462 ContextServerConfiguration::Http { url, .. } => url.clone(),
1463 _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1464 };
1465
1466 let id = id.clone();
1467 self.stop_server(&id, cx)?;
1468
1469 cx.spawn(async move |this, cx| {
1470 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1471 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1472 log::error!("{} failed to clear OAuth session: {}", id, err);
1473 }
1474 // Also clear any client secret so the user gets a fresh prompt on
1475 // the next authentication attempt.
1476 Self::clear_client_secret(&credentials_provider, &server_url, &cx)
1477 .await
1478 .log_err();
1479 // Trigger server recreation so the next start uses a fresh
1480 // transport without the old (now-invalidated) token provider.
1481 this.update(cx, |this, cx| {
1482 this.available_context_servers_changed(cx);
1483 })
1484 .log_err();
1485 })
1486 .detach();
1487
1488 Ok(())
1489 }
1490
1491 fn update_server_state(
1492 &mut self,
1493 id: ContextServerId,
1494 state: ContextServerState,
1495 cx: &mut Context<Self>,
1496 ) {
1497 let status = ContextServerStatus::from_state(&state);
1498 self.servers.insert(id.clone(), state);
1499 cx.emit(ServerStatusChangedEvent {
1500 server_id: id,
1501 status,
1502 });
1503 }
1504
1505 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1506 if self.update_servers_task.is_some() {
1507 self.needs_server_update = true;
1508 } else {
1509 self.needs_server_update = false;
1510 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1511 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1512 log::error!("Error maintaining context servers: {}", err);
1513 }
1514
1515 this.update(cx, |this, cx| {
1516 this.populate_server_ids(cx);
1517 cx.notify();
1518 this.update_servers_task.take();
1519 if this.needs_server_update {
1520 this.available_context_servers_changed(cx);
1521 }
1522 })?;
1523
1524 Ok(())
1525 }));
1526 }
1527 }
1528
1529 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1530 // Don't start context servers if AI is disabled
1531 let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1532 if ai_disabled {
1533 // Stop all running servers when AI is disabled
1534 this.update(cx, |this, cx| {
1535 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1536 for id in server_ids {
1537 let _ = this.stop_server(&id, cx);
1538 }
1539 })?;
1540 return Ok(());
1541 }
1542
1543 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1544 (
1545 this.context_server_settings.clone(),
1546 this.registry.clone(),
1547 this.worktree_store.clone(),
1548 )
1549 })?;
1550
1551 for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1552 configured_servers
1553 .entry(id)
1554 .or_insert(ContextServerSettings::default_extension());
1555 }
1556
1557 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1558 configured_servers
1559 .into_iter()
1560 .partition(|(_, settings)| settings.enabled());
1561
1562 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1563 let id = ContextServerId(id);
1564 ContextServerConfiguration::from_settings(
1565 settings,
1566 id.clone(),
1567 registry.clone(),
1568 worktree_store.clone(),
1569 cx,
1570 )
1571 .map(move |config| (id, config))
1572 }))
1573 .await
1574 .into_iter()
1575 .filter_map(|(id, config)| config.map(|config| (id, config)))
1576 .collect::<HashMap<_, _>>();
1577
1578 let mut servers_to_start = Vec::new();
1579 let mut servers_to_remove = HashSet::default();
1580 let mut servers_to_stop = HashSet::default();
1581
1582 this.update(cx, |this, _cx| {
1583 for server_id in this.servers.keys() {
1584 // All servers that are not in desired_servers should be removed from the store.
1585 // This can happen if the user removed a server from the context server settings.
1586 if !configured_servers.contains_key(server_id) {
1587 if disabled_servers.contains_key(&server_id.0) {
1588 servers_to_stop.insert(server_id.clone());
1589 } else {
1590 servers_to_remove.insert(server_id.clone());
1591 }
1592 }
1593 }
1594
1595 for (id, config) in configured_servers {
1596 let state = this.servers.get(&id);
1597 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1598 let existing_config = state.as_ref().map(|state| state.configuration());
1599 if existing_config.as_deref() != Some(&config) || is_stopped {
1600 let config = Arc::new(config);
1601 servers_to_start.push((id.clone(), config));
1602 if this.servers.contains_key(&id) {
1603 servers_to_stop.insert(id);
1604 }
1605 }
1606 }
1607
1608 anyhow::Ok(())
1609 })??;
1610
1611 this.update(cx, |this, inner_cx| {
1612 for id in servers_to_stop {
1613 this.stop_server(&id, inner_cx)?;
1614 }
1615 for id in servers_to_remove {
1616 this.remove_server(&id, inner_cx)?;
1617 }
1618 anyhow::Ok(())
1619 })??;
1620
1621 for (id, config) in servers_to_start {
1622 match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1623 Ok((server, config)) => {
1624 this.update(cx, |this, cx| {
1625 this.run_server(server, config, cx);
1626 })?;
1627 }
1628 Err(err) => {
1629 log::error!("{id} context server failed to create: {err:#}");
1630 this.update(cx, |_this, cx| {
1631 cx.emit(ServerStatusChangedEvent {
1632 server_id: id,
1633 status: ContextServerStatus::Error(err.to_string().into()),
1634 });
1635 cx.notify();
1636 })?;
1637 }
1638 }
1639 }
1640
1641 Ok(())
1642 }
1643}
1644
1645/// Determines the appropriate server state after a start attempt fails.
1646///
1647/// When the error is an HTTP 401 with no static auth header configured,
1648/// attempts OAuth discovery so the UI can offer an authentication flow.
1649async fn resolve_start_failure(
1650 id: &ContextServerId,
1651 err: anyhow::Error,
1652 server: Arc<ContextServer>,
1653 configuration: Arc<ContextServerConfiguration>,
1654 cx: &AsyncApp,
1655) -> ContextServerState {
1656 let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1657 TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1658 });
1659
1660 if www_authenticate.is_some() && configuration.has_static_auth_header() {
1661 log::warn!("{id} received 401 with a static Authorization header configured");
1662 return ContextServerState::Error {
1663 configuration,
1664 server,
1665 error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1666 .into(),
1667 };
1668 }
1669
1670 let server_url = match configuration.as_ref() {
1671 ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1672 url.clone()
1673 }
1674 _ => {
1675 if www_authenticate.is_some() {
1676 log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1677 } else {
1678 log::error!("{id} context server failed to start: {err}");
1679 }
1680 return ContextServerState::Error {
1681 configuration,
1682 server,
1683 error: err.to_string().into(),
1684 };
1685 }
1686 };
1687
1688 // When the error is NOT a 401 but there is a cached OAuth session in the
1689 // keychain, the session is likely stale/expired and caused the failure
1690 // (e.g. timeout because the server rejected the token silently). Clear it
1691 // so the next start attempt can get a clean 401 and trigger the auth flow.
1692 if www_authenticate.is_none() {
1693 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1694 match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1695 Ok(Some(_)) => {
1696 log::info!("{id} start failed with a cached OAuth session present; clearing it");
1697 ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1698 .await
1699 .log_err();
1700 }
1701 _ => {
1702 log::error!("{id} context server failed to start: {err}");
1703 return ContextServerState::Error {
1704 configuration,
1705 server,
1706 error: err.to_string().into(),
1707 };
1708 }
1709 }
1710 }
1711
1712 let default_www_authenticate = oauth::WwwAuthenticate {
1713 resource_metadata: None,
1714 scope: None,
1715 error: None,
1716 error_description: None,
1717 };
1718 let www_authenticate = www_authenticate
1719 .as_ref()
1720 .unwrap_or(&default_www_authenticate);
1721 let http_client = cx.update(|cx| cx.http_client());
1722
1723 match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1724 Ok(discovery) => {
1725 use context_server::oauth::{
1726 ClientRegistrationStrategy, determine_registration_strategy,
1727 };
1728
1729 let has_preregistered_client_id = matches!(
1730 configuration.as_ref(),
1731 ContextServerConfiguration::Http { oauth: Some(_), .. }
1732 );
1733
1734 let strategy = determine_registration_strategy(&discovery.auth_server_metadata);
1735
1736 if matches!(strategy, ClientRegistrationStrategy::Unavailable)
1737 && !has_preregistered_client_id
1738 {
1739 log::error!(
1740 "{id} authorization server supports neither CIMD nor DCR, \
1741 and no pre-registered client_id is configured"
1742 );
1743 return ContextServerState::Error {
1744 configuration,
1745 server,
1746 error: "Authorization server supports neither CIMD nor DCR. \
1747 Configure a pre-registered client_id in your settings \
1748 under the \"oauth\" key."
1749 .into(),
1750 };
1751 }
1752
1753 log::info!(
1754 "{id} requires OAuth authorization (auth server: {})",
1755 discovery.auth_server_metadata.issuer,
1756 );
1757 ContextServerState::AuthRequired {
1758 server,
1759 configuration,
1760 discovery: Arc::new(discovery),
1761 }
1762 }
1763 Err(discovery_err) => {
1764 log::error!("{id} OAuth discovery failed: {discovery_err}");
1765 ContextServerState::Error {
1766 configuration,
1767 server,
1768 error: format!("OAuth discovery failed: {discovery_err}").into(),
1769 }
1770 }
1771 }
1772}