1pub mod extension;
2pub mod registry;
3
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use collections::{HashMap, HashSet};
10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
11use context_server::transport::{HttpTransport, TransportError};
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use credentials_provider::CredentialsProvider;
14use futures::future::Either;
15use futures::{FutureExt as _, StreamExt as _, future::join_all};
16use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
17use http_client::HttpClient;
18use itertools::Itertools;
19use rand::Rng as _;
20use registry::ContextServerDescriptorRegistry;
21use remote::RemoteClient;
22use rpc::{AnyProtoClient, TypedEnvelope, proto};
23use settings::{Settings as _, SettingsStore};
24use util::{ResultExt as _, rel_path::RelPath};
25
26use crate::{
27 DisableAiSettings, Project,
28 project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings},
29 worktree_store::WorktreeStore,
30};
31
32/// Maximum timeout for context server requests
33/// Prevents extremely large timeout values from tying up resources indefinitely.
34const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
35
36pub fn init(cx: &mut App) {
37 extension::init(cx);
38}
39
40actions!(
41 context_server,
42 [
43 /// Restarts the context server.
44 Restart
45 ]
46);
47
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub enum ContextServerStatus {
50 Starting,
51 Running,
52 Stopped,
53 Error(Arc<str>),
54 /// The server returned 401 and OAuth authorization is needed. The UI
55 /// should show an "Authenticate" button.
56 AuthRequired,
57 /// The server has a pre-registered OAuth client_id, but a client_secret
58 /// is needed and not available in settings or the keychain. The UI should
59 /// show a text input to collect it.
60 ClientSecretRequired,
61 /// The OAuth browser flow is in progress — the user has been redirected
62 /// to the authorization server and we're waiting for the callback.
63 Authenticating,
64}
65
66impl ContextServerStatus {
67 fn from_state(state: &ContextServerState) -> Self {
68 match state {
69 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
70 ContextServerState::Running { .. } => ContextServerStatus::Running,
71 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
72 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
73 ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
74 ContextServerState::ClientSecretRequired { .. } => {
75 ContextServerStatus::ClientSecretRequired
76 }
77 ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
78 }
79 }
80}
81
82enum ContextServerState {
83 Starting {
84 server: Arc<ContextServer>,
85 configuration: Arc<ContextServerConfiguration>,
86 _task: Task<()>,
87 },
88 Running {
89 server: Arc<ContextServer>,
90 configuration: Arc<ContextServerConfiguration>,
91 },
92 Stopped {
93 server: Arc<ContextServer>,
94 configuration: Arc<ContextServerConfiguration>,
95 },
96 Error {
97 server: Arc<ContextServer>,
98 configuration: Arc<ContextServerConfiguration>,
99 error: Arc<str>,
100 },
101 /// The server requires OAuth authorization before it can be used. The
102 /// `OAuthDiscovery` holds everything needed to start the browser flow.
103 AuthRequired {
104 server: Arc<ContextServer>,
105 configuration: Arc<ContextServerConfiguration>,
106 discovery: Arc<OAuthDiscovery>,
107 },
108 /// A pre-registered client_id is configured but no client_secret was found
109 /// in settings or the keychain. The user needs to provide it interactively.
110 ClientSecretRequired {
111 server: Arc<ContextServer>,
112 configuration: Arc<ContextServerConfiguration>,
113 discovery: Arc<OAuthDiscovery>,
114 },
115 /// The OAuth browser flow is in progress. The user has been redirected
116 /// to the authorization server and we're waiting for the callback.
117 Authenticating {
118 server: Arc<ContextServer>,
119 configuration: Arc<ContextServerConfiguration>,
120 _task: Task<()>,
121 },
122}
123
124impl ContextServerState {
125 pub fn server(&self) -> Arc<ContextServer> {
126 match self {
127 ContextServerState::Starting { server, .. }
128 | ContextServerState::Running { server, .. }
129 | ContextServerState::Stopped { server, .. }
130 | ContextServerState::Error { server, .. }
131 | ContextServerState::AuthRequired { server, .. }
132 | ContextServerState::ClientSecretRequired { server, .. }
133 | ContextServerState::Authenticating { server, .. } => server.clone(),
134 }
135 }
136
137 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
138 match self {
139 ContextServerState::Starting { configuration, .. }
140 | ContextServerState::Running { configuration, .. }
141 | ContextServerState::Stopped { configuration, .. }
142 | ContextServerState::Error { configuration, .. }
143 | ContextServerState::AuthRequired { configuration, .. }
144 | ContextServerState::ClientSecretRequired { configuration, .. }
145 | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
146 }
147 }
148}
149
150#[derive(Debug, PartialEq, Eq)]
151pub enum ContextServerConfiguration {
152 Custom {
153 command: ContextServerCommand,
154 remote: bool,
155 },
156 Extension {
157 command: ContextServerCommand,
158 settings: serde_json::Value,
159 remote: bool,
160 },
161 Http {
162 url: url::Url,
163 headers: HashMap<String, String>,
164 timeout: Option<u64>,
165 oauth: Option<OAuthClientSettings>,
166 },
167}
168
169impl ContextServerConfiguration {
170 pub fn command(&self) -> Option<&ContextServerCommand> {
171 match self {
172 ContextServerConfiguration::Custom { command, .. } => Some(command),
173 ContextServerConfiguration::Extension { command, .. } => Some(command),
174 ContextServerConfiguration::Http { .. } => None,
175 }
176 }
177
178 pub fn has_static_auth_header(&self) -> bool {
179 match self {
180 ContextServerConfiguration::Http { headers, .. } => headers
181 .keys()
182 .any(|k| k.eq_ignore_ascii_case("authorization")),
183 _ => false,
184 }
185 }
186
187 pub fn remote(&self) -> bool {
188 match self {
189 ContextServerConfiguration::Custom { remote, .. } => *remote,
190 ContextServerConfiguration::Extension { remote, .. } => *remote,
191 ContextServerConfiguration::Http { .. } => false,
192 }
193 }
194
195 pub async fn from_settings(
196 settings: ContextServerSettings,
197 id: ContextServerId,
198 registry: Entity<ContextServerDescriptorRegistry>,
199 worktree_store: Entity<WorktreeStore>,
200 cx: &AsyncApp,
201 ) -> Option<Self> {
202 const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
203
204 match settings {
205 ContextServerSettings::Stdio {
206 enabled: _,
207 command,
208 remote,
209 } => Some(ContextServerConfiguration::Custom { command, remote }),
210 ContextServerSettings::Extension {
211 enabled: _,
212 settings,
213 remote,
214 } => {
215 let descriptor =
216 cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
217
218 let command_future = descriptor.command(worktree_store, cx);
219 let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
220
221 match futures::future::select(command_future, timeout_future).await {
222 Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
223 command,
224 settings,
225 remote,
226 }),
227 Either::Left((Err(e), _)) => {
228 log::error!(
229 "Failed to create context server configuration from settings: {e:#}"
230 );
231 None
232 }
233 Either::Right(_) => {
234 log::error!(
235 "Timed out resolving command for extension context server {id}"
236 );
237 None
238 }
239 }
240 }
241 ContextServerSettings::Http {
242 enabled: _,
243 url,
244 headers: auth,
245 timeout,
246 oauth,
247 } => {
248 let url = url::Url::parse(&url).log_err()?;
249 Some(ContextServerConfiguration::Http {
250 url,
251 headers: auth,
252 timeout,
253 oauth,
254 })
255 }
256 }
257 }
258}
259
260pub type ContextServerFactory =
261 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
262
263enum ContextServerStoreState {
264 Local {
265 downstream_client: Option<(u64, AnyProtoClient)>,
266 is_headless: bool,
267 },
268 Remote {
269 project_id: u64,
270 upstream_client: Entity<RemoteClient>,
271 },
272}
273
274pub struct ContextServerStore {
275 state: ContextServerStoreState,
276 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
277 servers: HashMap<ContextServerId, ContextServerState>,
278 server_ids: Vec<ContextServerId>,
279 worktree_store: Entity<WorktreeStore>,
280 project: Option<WeakEntity<Project>>,
281 registry: Entity<ContextServerDescriptorRegistry>,
282 update_servers_task: Option<Task<Result<()>>>,
283 context_server_factory: Option<ContextServerFactory>,
284 needs_server_update: bool,
285 ai_disabled: bool,
286 _subscriptions: Vec<Subscription>,
287}
288
289pub struct ServerStatusChangedEvent {
290 pub server_id: ContextServerId,
291 pub status: ContextServerStatus,
292}
293
294impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
295
296impl ContextServerStore {
297 pub fn local(
298 worktree_store: Entity<WorktreeStore>,
299 weak_project: Option<WeakEntity<Project>>,
300 headless: bool,
301 cx: &mut Context<Self>,
302 ) -> Self {
303 Self::new_internal(
304 !headless,
305 None,
306 ContextServerDescriptorRegistry::default_global(cx),
307 worktree_store,
308 weak_project,
309 ContextServerStoreState::Local {
310 downstream_client: None,
311 is_headless: headless,
312 },
313 cx,
314 )
315 }
316
317 pub fn remote(
318 project_id: u64,
319 upstream_client: Entity<RemoteClient>,
320 worktree_store: Entity<WorktreeStore>,
321 weak_project: Option<WeakEntity<Project>>,
322 cx: &mut Context<Self>,
323 ) -> Self {
324 Self::new_internal(
325 true,
326 None,
327 ContextServerDescriptorRegistry::default_global(cx),
328 worktree_store,
329 weak_project,
330 ContextServerStoreState::Remote {
331 project_id,
332 upstream_client,
333 },
334 cx,
335 )
336 }
337
338 pub fn init_headless(session: &AnyProtoClient) {
339 session.add_entity_request_handler(Self::handle_get_context_server_command);
340 }
341
342 pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
343 if let ContextServerStoreState::Local {
344 downstream_client, ..
345 } = &mut self.state
346 {
347 *downstream_client = Some((project_id, client));
348 }
349 }
350
351 pub fn is_remote_project(&self) -> bool {
352 matches!(self.state, ContextServerStoreState::Remote { .. })
353 }
354
355 /// Returns all configured context server ids, excluding the ones that are disabled
356 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
357 self.context_server_settings
358 .iter()
359 .filter(|(_, settings)| settings.enabled())
360 .map(|(id, _)| ContextServerId(id.clone()))
361 .collect()
362 }
363
364 #[cfg(feature = "test-support")]
365 pub fn test(
366 registry: Entity<ContextServerDescriptorRegistry>,
367 worktree_store: Entity<WorktreeStore>,
368 weak_project: Option<WeakEntity<Project>>,
369 cx: &mut Context<Self>,
370 ) -> Self {
371 Self::new_internal(
372 false,
373 None,
374 registry,
375 worktree_store,
376 weak_project,
377 ContextServerStoreState::Local {
378 downstream_client: None,
379 is_headless: false,
380 },
381 cx,
382 )
383 }
384
385 #[cfg(feature = "test-support")]
386 pub fn test_maintain_server_loop(
387 context_server_factory: Option<ContextServerFactory>,
388 registry: Entity<ContextServerDescriptorRegistry>,
389 worktree_store: Entity<WorktreeStore>,
390 weak_project: Option<WeakEntity<Project>>,
391 cx: &mut Context<Self>,
392 ) -> Self {
393 Self::new_internal(
394 true,
395 context_server_factory,
396 registry,
397 worktree_store,
398 weak_project,
399 ContextServerStoreState::Local {
400 downstream_client: None,
401 is_headless: false,
402 },
403 cx,
404 )
405 }
406
407 #[cfg(feature = "test-support")]
408 pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
409 self.context_server_factory = Some(factory);
410 }
411
412 #[cfg(feature = "test-support")]
413 pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
414 &self.registry
415 }
416
417 #[cfg(feature = "test-support")]
418 pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
419 let configuration = Arc::new(ContextServerConfiguration::Custom {
420 command: ContextServerCommand {
421 path: "test".into(),
422 args: vec![],
423 env: None,
424 timeout: None,
425 },
426 remote: false,
427 });
428 self.run_server(server, configuration, cx);
429 }
430
431 fn new_internal(
432 maintain_server_loop: bool,
433 context_server_factory: Option<ContextServerFactory>,
434 registry: Entity<ContextServerDescriptorRegistry>,
435 worktree_store: Entity<WorktreeStore>,
436 weak_project: Option<WeakEntity<Project>>,
437 state: ContextServerStoreState,
438 cx: &mut Context<Self>,
439 ) -> Self {
440 let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
441 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
442 let ai_was_disabled = this.ai_disabled;
443 this.ai_disabled = ai_disabled;
444
445 let settings =
446 &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
447 let settings_changed = &this.context_server_settings != settings;
448
449 if settings_changed {
450 this.context_server_settings = settings.clone();
451 }
452
453 // When AI is disabled, stop all running servers
454 if ai_disabled {
455 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
456 for id in server_ids {
457 this.stop_server(&id, cx).log_err();
458 }
459 return;
460 }
461
462 // Trigger updates if AI was re-enabled or settings changed
463 if maintain_server_loop && (ai_was_disabled || settings_changed) {
464 this.available_context_servers_changed(cx);
465 }
466 })];
467
468 if maintain_server_loop {
469 subscriptions.push(cx.observe(®istry, |this, _registry, cx| {
470 if !DisableAiSettings::get_global(cx).disable_ai {
471 this.available_context_servers_changed(cx);
472 }
473 }));
474 }
475
476 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
477 let mut this = Self {
478 state,
479 _subscriptions: subscriptions,
480 context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
481 .context_servers
482 .clone(),
483 worktree_store,
484 project: weak_project,
485 registry,
486 needs_server_update: false,
487 ai_disabled,
488 servers: HashMap::default(),
489 server_ids: Default::default(),
490 update_servers_task: None,
491 context_server_factory,
492 };
493 if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
494 this.available_context_servers_changed(cx);
495 }
496 this
497 }
498
499 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
500 self.servers.get(id).map(|state| state.server())
501 }
502
503 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
504 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
505 Some(server.clone())
506 } else {
507 None
508 }
509 }
510
511 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
512 self.servers.get(id).map(ContextServerStatus::from_state)
513 }
514
515 pub fn configuration_for_server(
516 &self,
517 id: &ContextServerId,
518 ) -> Option<Arc<ContextServerConfiguration>> {
519 self.servers.get(id).map(|state| state.configuration())
520 }
521
522 /// Returns a sorted slice of available unique context server IDs. Within the
523 /// slice, context servers which have `mcp-server-` as a prefix in their ID will
524 /// appear after servers that do not have this prefix in their ID.
525 pub fn server_ids(&self) -> &[ContextServerId] {
526 self.server_ids.as_slice()
527 }
528
529 fn populate_server_ids(&mut self, cx: &App) {
530 self.server_ids = self
531 .servers
532 .keys()
533 .cloned()
534 .chain(
535 self.registry
536 .read(cx)
537 .context_server_descriptors()
538 .into_iter()
539 .map(|(id, _)| ContextServerId(id)),
540 )
541 .chain(
542 self.context_server_settings
543 .keys()
544 .map(|id| ContextServerId(id.clone())),
545 )
546 .unique()
547 .sorted_unstable_by(
548 // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
549 |a, b| {
550 const MCP_PREFIX: &str = "mcp-server-";
551 match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
552 // If one has mcp-server- prefix and other doesn't, non-mcp comes first
553 (Some(_), None) => std::cmp::Ordering::Greater,
554 (None, Some(_)) => std::cmp::Ordering::Less,
555 // If both have same prefix status, sort by appropriate key
556 (Some(a), Some(b)) => a.cmp(b),
557 (None, None) => a.0.cmp(&b.0),
558 }
559 },
560 )
561 .collect();
562 }
563
564 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
565 self.servers
566 .values()
567 .filter_map(|state| {
568 if let ContextServerState::Running { server, .. } = state {
569 Some(server.clone())
570 } else {
571 None
572 }
573 })
574 .collect()
575 }
576
577 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
578 cx.spawn(async move |this, cx| {
579 let this = this.upgrade().context("Context server store dropped")?;
580 let id = server.id();
581 let settings = this
582 .update(cx, |this, _| {
583 this.context_server_settings.get(&id.0).cloned()
584 })
585 .context("Failed to get context server settings")?;
586
587 if !settings.enabled() {
588 return anyhow::Ok(());
589 }
590
591 let (registry, worktree_store) = this.update(cx, |this, _| {
592 (this.registry.clone(), this.worktree_store.clone())
593 });
594 let configuration = ContextServerConfiguration::from_settings(
595 settings,
596 id.clone(),
597 registry,
598 worktree_store,
599 cx,
600 )
601 .await
602 .context("Failed to create context server configuration")?;
603
604 this.update(cx, |this, cx| {
605 this.run_server(server, Arc::new(configuration), cx)
606 });
607 Ok(())
608 })
609 .detach_and_log_err(cx);
610 }
611
612 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
613 if matches!(
614 self.servers.get(id),
615 Some(ContextServerState::Stopped { .. })
616 ) {
617 return Ok(());
618 }
619
620 let state = self
621 .servers
622 .remove(id)
623 .context("Context server not found")?;
624
625 let server = state.server();
626 let configuration = state.configuration();
627 let mut result = Ok(());
628 if let ContextServerState::Running { server, .. } = &state {
629 result = server.stop();
630 }
631 drop(state);
632
633 self.update_server_state(
634 id.clone(),
635 ContextServerState::Stopped {
636 configuration,
637 server,
638 },
639 cx,
640 );
641
642 result
643 }
644
645 fn run_server(
646 &mut self,
647 server: Arc<ContextServer>,
648 configuration: Arc<ContextServerConfiguration>,
649 cx: &mut Context<Self>,
650 ) {
651 let id = server.id();
652 if matches!(
653 self.servers.get(&id),
654 Some(
655 ContextServerState::Starting { .. }
656 | ContextServerState::Running { .. }
657 | ContextServerState::Authenticating { .. },
658 )
659 ) {
660 self.stop_server(&id, cx).log_err();
661 }
662 let task = cx.spawn({
663 let id = server.id();
664 let server = server.clone();
665 let configuration = configuration.clone();
666
667 async move |this, cx| {
668 let new_state = match server.clone().start(cx).await {
669 Ok(_) => {
670 debug_assert!(server.client().is_some());
671 ContextServerState::Running {
672 server,
673 configuration,
674 }
675 }
676 Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
677 };
678 this.update(cx, |this, cx| {
679 this.update_server_state(id.clone(), new_state, cx)
680 })
681 .log_err();
682 }
683 });
684
685 self.update_server_state(
686 id.clone(),
687 ContextServerState::Starting {
688 configuration,
689 _task: task,
690 server,
691 },
692 cx,
693 );
694 }
695
696 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
697 let state = self
698 .servers
699 .remove(id)
700 .context("Context server not found")?;
701
702 if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
703 let server_url = url.clone();
704 let id = id.clone();
705 cx.spawn(async move |_this, cx| {
706 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
707 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
708 {
709 log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
710 }
711 })
712 .detach();
713 }
714
715 drop(state);
716 cx.emit(ServerStatusChangedEvent {
717 server_id: id.clone(),
718 status: ContextServerStatus::Stopped,
719 });
720 Ok(())
721 }
722
723 pub async fn create_context_server(
724 this: WeakEntity<Self>,
725 id: ContextServerId,
726 configuration: Arc<ContextServerConfiguration>,
727 cx: &mut AsyncApp,
728 ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
729 let remote = configuration.remote();
730 let needs_remote_command = match configuration.as_ref() {
731 ContextServerConfiguration::Custom { .. }
732 | ContextServerConfiguration::Extension { .. } => remote,
733 ContextServerConfiguration::Http { .. } => false,
734 };
735
736 let (remote_state, is_remote_project) = this.update(cx, |this, _| {
737 let remote_state = match &this.state {
738 ContextServerStoreState::Remote {
739 project_id,
740 upstream_client,
741 } if needs_remote_command => Some((*project_id, upstream_client.clone())),
742 _ => None,
743 };
744 (remote_state, this.is_remote_project())
745 })?;
746
747 let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
748 this.project
749 .as_ref()
750 .and_then(|project| {
751 project
752 .read_with(cx, |project, cx| project.active_project_directory(cx))
753 .ok()
754 .flatten()
755 })
756 .or_else(|| {
757 this.worktree_store.read_with(cx, |store, cx| {
758 store.visible_worktrees(cx).fold(None, |acc, item| {
759 if acc.is_none() {
760 item.read(cx).root_dir()
761 } else {
762 acc
763 }
764 })
765 })
766 })
767 })?;
768
769 let configuration = if let Some((project_id, upstream_client)) = remote_state {
770 let root_dir = root_path.as_ref().map(|p| p.display().to_string());
771
772 let response = upstream_client
773 .update(cx, |client, _| {
774 client
775 .proto_client()
776 .request(proto::GetContextServerCommand {
777 project_id,
778 server_id: id.0.to_string(),
779 root_dir: root_dir.clone(),
780 })
781 })
782 .await?;
783
784 let remote_command = upstream_client.update(cx, |client, _| {
785 client.build_command(
786 Some(response.path),
787 &response.args,
788 &response.env.into_iter().collect(),
789 root_dir,
790 None,
791 )
792 })?;
793
794 let command = ContextServerCommand {
795 path: remote_command.program.into(),
796 args: remote_command.args,
797 env: Some(remote_command.env.into_iter().collect()),
798 timeout: None,
799 };
800
801 Arc::new(ContextServerConfiguration::Custom { command, remote })
802 } else {
803 configuration
804 };
805
806 if let Some(server) = this.update(cx, |this, _| {
807 this.context_server_factory
808 .as_ref()
809 .map(|factory| factory(id.clone(), configuration.clone()))
810 })? {
811 return Ok((server, configuration));
812 }
813
814 let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
815 if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
816 if configuration.has_static_auth_header() {
817 None
818 } else {
819 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
820 let http_client = cx.update(|cx| cx.http_client());
821
822 match Self::load_session(&credentials_provider, url, &cx).await {
823 Ok(Some(session)) => {
824 log::info!("{} loaded cached OAuth session from keychain", id);
825 Some(Self::create_oauth_token_provider(
826 &id,
827 url,
828 session,
829 http_client,
830 credentials_provider,
831 cx,
832 ))
833 }
834 Ok(None) => None,
835 Err(err) => {
836 log::warn!("{} failed to load cached OAuth session: {}", id, err);
837 None
838 }
839 }
840 }
841 } else {
842 None
843 };
844
845 let server: Arc<ContextServer> = this.update(cx, |this, cx| {
846 let global_timeout =
847 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
848
849 match configuration.as_ref() {
850 ContextServerConfiguration::Http {
851 url,
852 headers,
853 timeout,
854 oauth: _,
855 } => {
856 let transport = HttpTransport::new_with_token_provider(
857 cx.http_client(),
858 url.to_string(),
859 headers.clone(),
860 cx.background_executor().clone(),
861 cached_token_provider.clone(),
862 );
863 anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
864 id,
865 Arc::new(transport),
866 Some(Duration::from_secs(
867 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
868 )),
869 )))
870 }
871 _ => {
872 let mut command = configuration
873 .command()
874 .context("Missing command configuration for stdio context server")?
875 .clone();
876 command.timeout = Some(
877 command
878 .timeout
879 .unwrap_or(global_timeout)
880 .min(MAX_TIMEOUT_SECS),
881 );
882
883 // Don't pass remote paths as working directory for locally-spawned processes
884 let working_directory = if is_remote_project { None } else { root_path };
885 anyhow::Ok(Arc::new(ContextServer::stdio(
886 id,
887 command,
888 working_directory,
889 )))
890 }
891 }
892 })??;
893
894 Ok((server, configuration))
895 }
896
897 async fn handle_get_context_server_command(
898 this: Entity<Self>,
899 envelope: TypedEnvelope<proto::GetContextServerCommand>,
900 mut cx: AsyncApp,
901 ) -> Result<proto::ContextServerCommand> {
902 let server_id = ContextServerId(envelope.payload.server_id.into());
903
904 let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
905 let ContextServerStoreState::Local {
906 is_headless: true, ..
907 } = &this.state
908 else {
909 anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
910 };
911
912 let settings = this
913 .context_server_settings
914 .get(&server_id.0)
915 .cloned()
916 .or_else(|| {
917 this.registry
918 .read(inner_cx)
919 .context_server_descriptor(&server_id.0)
920 .map(|_| ContextServerSettings::default_extension())
921 })
922 .with_context(|| format!("context server `{}` not found", server_id))?;
923
924 anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
925 })?;
926
927 let configuration = ContextServerConfiguration::from_settings(
928 settings,
929 server_id.clone(),
930 registry,
931 worktree_store,
932 &cx,
933 )
934 .await
935 .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
936
937 let command = configuration
938 .command()
939 .context("context server has no command (HTTP servers don't need RPC)")?;
940
941 Ok(proto::ContextServerCommand {
942 path: command.path.display().to_string(),
943 args: command.args.clone(),
944 env: command
945 .env
946 .clone()
947 .map(|env| env.into_iter().collect())
948 .unwrap_or_default(),
949 })
950 }
951
952 fn resolve_project_settings<'a>(
953 worktree_store: &'a Entity<WorktreeStore>,
954 cx: &'a App,
955 ) -> &'a ProjectSettings {
956 let location = worktree_store
957 .read(cx)
958 .visible_worktrees(cx)
959 .next()
960 .map(|worktree| settings::SettingsLocation {
961 worktree_id: worktree.read(cx).id(),
962 path: RelPath::empty(),
963 });
964 ProjectSettings::get(location, cx)
965 }
966
967 fn create_oauth_token_provider(
968 id: &ContextServerId,
969 server_url: &url::Url,
970 session: OAuthSession,
971 http_client: Arc<dyn HttpClient>,
972 credentials_provider: Arc<dyn CredentialsProvider>,
973 cx: &mut AsyncApp,
974 ) -> Arc<dyn oauth::OAuthTokenProvider> {
975 let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
976 let id = id.clone();
977 let server_url = server_url.clone();
978
979 cx.spawn(async move |cx| {
980 while let Some(refreshed_session) = token_refresh_rx.next().await {
981 if let Err(err) =
982 Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
983 .await
984 {
985 log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
986 }
987 }
988 log::debug!("{} OAuth session persistence task ended", id);
989 })
990 .detach();
991
992 Arc::new(McpOAuthTokenProvider::new(
993 session,
994 http_client,
995 Some(token_refresh_tx),
996 ))
997 }
998
999 /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
1000 ///
1001 /// This starts a loopback HTTP callback server on an ephemeral port, builds
1002 /// the authorization URL, opens the user's browser, waits for the callback,
1003 /// exchanges the code for tokens, persists them in the keychain, and restarts
1004 /// the server with the new token provider.
1005 pub fn authenticate_server(
1006 &mut self,
1007 id: &ContextServerId,
1008 cx: &mut Context<Self>,
1009 ) -> Result<()> {
1010 let state = self.servers.get(id).context("Context server not found")?;
1011
1012 let (discovery, server, configuration) = match state {
1013 ContextServerState::AuthRequired {
1014 discovery,
1015 server,
1016 configuration,
1017 } => (discovery.clone(), server.clone(), configuration.clone()),
1018 _ => anyhow::bail!("Server is not in AuthRequired state"),
1019 };
1020
1021 // Check if the configuration has pre-registered OAuth credentials that
1022 // need a client_secret we don't have yet.
1023 let needs_secret_prompt = match configuration.as_ref() {
1024 ContextServerConfiguration::Http {
1025 url,
1026 oauth: Some(oauth_settings),
1027 ..
1028 } if oauth_settings.client_secret.is_none() => Some(url.clone()),
1029 _ => None,
1030 };
1031
1032 let id = id.clone();
1033
1034 if let Some(server_url) = needs_secret_prompt {
1035 // Check keychain for the secret asynchronously.
1036 let task = cx.spawn({
1037 let id = id.clone();
1038 let server = server.clone();
1039 let configuration = configuration.clone();
1040 async move |this, cx| {
1041 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1042 let keychain_secret =
1043 Self::load_client_secret(&credentials_provider, &server_url, cx)
1044 .await
1045 .ok()
1046 .flatten();
1047
1048 if keychain_secret.is_some() {
1049 // Secret found in keychain, proceed with OAuth flow.
1050 let result = Self::run_oauth_flow(
1051 this.clone(),
1052 id.clone(),
1053 discovery.clone(),
1054 configuration.clone(),
1055 cx,
1056 )
1057 .await;
1058
1059 if let Err(err) = &result {
1060 log::error!("{} OAuth authentication failed: {:?}", id, err);
1061 this.update(cx, |this, cx| {
1062 this.update_server_state(
1063 id.clone(),
1064 ContextServerState::Error {
1065 server,
1066 configuration,
1067 error: format!("{err:#}").into(),
1068 },
1069 cx,
1070 )
1071 })
1072 .log_err();
1073 }
1074 } else {
1075 // No secret anywhere — prompt the user.
1076 this.update(cx, |this, cx| {
1077 this.update_server_state(
1078 id.clone(),
1079 ContextServerState::ClientSecretRequired {
1080 server,
1081 configuration,
1082 discovery,
1083 },
1084 cx,
1085 );
1086 })
1087 .log_err();
1088 }
1089 }
1090 });
1091
1092 self.update_server_state(
1093 id,
1094 ContextServerState::Authenticating {
1095 server,
1096 configuration,
1097 _task: task,
1098 },
1099 cx,
1100 );
1101 } else {
1102 // No pre-registration, or secret already in settings — proceed directly.
1103 let task = cx.spawn({
1104 let id = id.clone();
1105 let server = server.clone();
1106 let configuration = configuration.clone();
1107 async move |this, cx| {
1108 let result = Self::run_oauth_flow(
1109 this.clone(),
1110 id.clone(),
1111 discovery.clone(),
1112 configuration.clone(),
1113 cx,
1114 )
1115 .await;
1116
1117 if let Err(err) = &result {
1118 log::error!("{} OAuth authentication failed: {:?}", id, err);
1119 this.update(cx, |this, cx| {
1120 this.update_server_state(
1121 id.clone(),
1122 ContextServerState::Error {
1123 server,
1124 configuration,
1125 error: format!("{err:#}").into(),
1126 },
1127 cx,
1128 )
1129 })
1130 .log_err();
1131 }
1132 }
1133 });
1134
1135 self.update_server_state(
1136 id,
1137 ContextServerState::Authenticating {
1138 server,
1139 configuration,
1140 _task: task,
1141 },
1142 cx,
1143 );
1144 }
1145
1146 Ok(())
1147 }
1148
1149 /// Store an interactively-provided client secret and proceed with authentication.
1150 pub fn submit_client_secret(
1151 &mut self,
1152 id: &ContextServerId,
1153 secret: String,
1154 cx: &mut Context<Self>,
1155 ) -> Result<()> {
1156 let state = self.servers.get(id).context("Context server not found")?;
1157
1158 let (server, configuration, discovery) = match state {
1159 ContextServerState::ClientSecretRequired {
1160 server,
1161 configuration,
1162 discovery,
1163 } => (server.clone(), configuration.clone(), discovery.clone()),
1164 _ => anyhow::bail!("Server is not in ClientSecretRequired state"),
1165 };
1166
1167 let server_url = match configuration.as_ref() {
1168 ContextServerConfiguration::Http { url, .. } => url.clone(),
1169 _ => anyhow::bail!("OAuth only supported for HTTP servers"),
1170 };
1171
1172 let id = id.clone();
1173
1174 let task = cx.spawn({
1175 let id = id.clone();
1176 let server = server.clone();
1177 let configuration = configuration.clone();
1178 async move |this, cx| {
1179 // Store the secret if non-empty (empty means public client / skip).
1180 if !secret.is_empty() {
1181 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1182 if let Err(err) =
1183 Self::store_client_secret(&credentials_provider, &server_url, &secret, cx)
1184 .await
1185 {
1186 log::error!(
1187 "{} failed to store client secret in keychain: {:?}",
1188 id,
1189 err
1190 );
1191 }
1192 }
1193
1194 let result = Self::run_oauth_flow(
1195 this.clone(),
1196 id.clone(),
1197 discovery.clone(),
1198 configuration.clone(),
1199 cx,
1200 )
1201 .await;
1202
1203 if let Err(err) = &result {
1204 log::error!("{} OAuth authentication failed: {:?}", id, err);
1205 this.update(cx, |this, cx| {
1206 this.update_server_state(
1207 id.clone(),
1208 ContextServerState::Error {
1209 server,
1210 configuration,
1211 error: format!("{err:#}").into(),
1212 },
1213 cx,
1214 )
1215 })
1216 .log_err();
1217 }
1218 }
1219 });
1220
1221 self.update_server_state(
1222 id,
1223 ContextServerState::Authenticating {
1224 server,
1225 configuration,
1226 _task: task,
1227 },
1228 cx,
1229 );
1230
1231 Ok(())
1232 }
1233
1234 async fn run_oauth_flow(
1235 this: WeakEntity<Self>,
1236 id: ContextServerId,
1237 discovery: Arc<OAuthDiscovery>,
1238 configuration: Arc<ContextServerConfiguration>,
1239 cx: &mut AsyncApp,
1240 ) -> Result<()> {
1241 let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1242 let pkce = oauth::generate_pkce_challenge();
1243
1244 let mut state_bytes = [0u8; 32];
1245 rand::rng().fill(&mut state_bytes);
1246 let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1247
1248 // Start a loopback HTTP server on an ephemeral port. The redirect URI
1249 // includes this port so the browser sends the callback directly to our
1250 // process.
1251 let (redirect_uri, callback_rx) = oauth::start_callback_server()
1252 .await
1253 .context("Failed to start OAuth callback server")?;
1254
1255 let http_client = cx.update(|cx| cx.http_client());
1256 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1257 let server_url = match configuration.as_ref() {
1258 ContextServerConfiguration::Http { url, .. } => url.clone(),
1259 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1260 };
1261
1262 let client_registration = match configuration.as_ref() {
1263 ContextServerConfiguration::Http {
1264 url,
1265 oauth: Some(oauth_settings),
1266 ..
1267 } => {
1268 // Pre-registered client. Resolve the secret from settings, then keychain.
1269 let client_secret = if oauth_settings.client_secret.is_some() {
1270 oauth_settings.client_secret.clone()
1271 } else {
1272 Self::load_client_secret(&credentials_provider, url, cx)
1273 .await
1274 .ok()
1275 .flatten()
1276 };
1277 oauth::OAuthClientRegistration {
1278 client_id: oauth_settings.client_id.clone(),
1279 client_secret,
1280 }
1281 }
1282 _ => oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1283 .await
1284 .context("Failed to resolve OAuth client registration")?,
1285 };
1286
1287 let auth_url = oauth::build_authorization_url(
1288 &discovery.auth_server_metadata,
1289 &client_registration.client_id,
1290 &redirect_uri,
1291 &discovery.scopes,
1292 &resource,
1293 &pkce,
1294 &state_param,
1295 );
1296
1297 cx.update(|cx| cx.open_url(auth_url.as_str()));
1298
1299 let callback = callback_rx
1300 .await
1301 .map_err(|_| {
1302 anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1303 })?
1304 .context("OAuth callback server received an invalid request")?;
1305
1306 if callback.state != state_param {
1307 anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1308 }
1309
1310 let tokens = oauth::exchange_code(
1311 &http_client,
1312 &discovery.auth_server_metadata,
1313 &callback.code,
1314 &client_registration.client_id,
1315 &redirect_uri,
1316 &pkce.verifier,
1317 &resource,
1318 client_registration.client_secret.as_deref(),
1319 )
1320 .await
1321 .context("Failed to exchange authorization code for tokens")?;
1322
1323 let session = OAuthSession {
1324 token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1325 resource: discovery.resource_metadata.resource.clone(),
1326 client_registration,
1327 tokens,
1328 };
1329
1330 Self::store_session(&credentials_provider, &server_url, &session, cx)
1331 .await
1332 .context("Failed to persist OAuth session in keychain")?;
1333
1334 let token_provider = Self::create_oauth_token_provider(
1335 &id,
1336 &server_url,
1337 session,
1338 http_client.clone(),
1339 credentials_provider,
1340 cx,
1341 );
1342
1343 let new_server = this.update(cx, |this, cx| {
1344 let global_timeout =
1345 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1346
1347 match configuration.as_ref() {
1348 ContextServerConfiguration::Http {
1349 url,
1350 headers,
1351 timeout,
1352 oauth: _,
1353 } => {
1354 let transport = HttpTransport::new_with_token_provider(
1355 http_client.clone(),
1356 url.to_string(),
1357 headers.clone(),
1358 cx.background_executor().clone(),
1359 Some(token_provider.clone()),
1360 );
1361 Ok(Arc::new(ContextServer::new_with_timeout(
1362 id.clone(),
1363 Arc::new(transport),
1364 Some(Duration::from_secs(
1365 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1366 )),
1367 )))
1368 }
1369 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1370 }
1371 })??;
1372
1373 this.update(cx, |this, cx| {
1374 this.run_server(new_server, configuration, cx);
1375 })?;
1376
1377 Ok(())
1378 }
1379
1380 /// Store the full OAuth session in the system keychain, keyed by the
1381 /// server's canonical URI.
1382 async fn store_session(
1383 credentials_provider: &Arc<dyn CredentialsProvider>,
1384 server_url: &url::Url,
1385 session: &OAuthSession,
1386 cx: &AsyncApp,
1387 ) -> Result<()> {
1388 let key = Self::keychain_key(server_url);
1389 let json = serde_json::to_string(session)?;
1390 credentials_provider
1391 .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1392 .await
1393 }
1394
1395 /// Load the full OAuth session from the system keychain for the given
1396 /// server URL.
1397 async fn load_session(
1398 credentials_provider: &Arc<dyn CredentialsProvider>,
1399 server_url: &url::Url,
1400 cx: &AsyncApp,
1401 ) -> Result<Option<OAuthSession>> {
1402 let key = Self::keychain_key(server_url);
1403 match credentials_provider.read_credentials(&key, cx).await? {
1404 Some((_username, password_bytes)) => {
1405 let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1406 Ok(Some(session))
1407 }
1408 None => Ok(None),
1409 }
1410 }
1411
1412 /// Clear the stored OAuth session from the system keychain.
1413 async fn clear_session(
1414 credentials_provider: &Arc<dyn CredentialsProvider>,
1415 server_url: &url::Url,
1416 cx: &AsyncApp,
1417 ) -> Result<()> {
1418 let key = Self::keychain_key(server_url);
1419 credentials_provider.delete_credentials(&key, cx).await
1420 }
1421
1422 fn keychain_key(server_url: &url::Url) -> String {
1423 format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1424 }
1425
1426 fn client_secret_keychain_key(server_url: &url::Url) -> String {
1427 format!(
1428 "mcp-oauth-client-secret:{}",
1429 oauth::canonical_server_uri(server_url)
1430 )
1431 }
1432
1433 async fn load_client_secret(
1434 credentials_provider: &Arc<dyn CredentialsProvider>,
1435 server_url: &url::Url,
1436 cx: &AsyncApp,
1437 ) -> Result<Option<String>> {
1438 let key = Self::client_secret_keychain_key(server_url);
1439 match credentials_provider.read_credentials(&key, cx).await? {
1440 Some((_username, secret_bytes)) => Ok(Some(String::from_utf8(secret_bytes)?)),
1441 None => Ok(None),
1442 }
1443 }
1444
1445 pub async fn store_client_secret(
1446 credentials_provider: &Arc<dyn CredentialsProvider>,
1447 server_url: &url::Url,
1448 secret: &str,
1449 cx: &AsyncApp,
1450 ) -> Result<()> {
1451 let key = Self::client_secret_keychain_key(server_url);
1452 credentials_provider
1453 .write_credentials(&key, "mcp-oauth-client-secret", secret.as_bytes(), cx)
1454 .await
1455 }
1456
1457 async fn clear_client_secret(
1458 credentials_provider: &Arc<dyn CredentialsProvider>,
1459 server_url: &url::Url,
1460 cx: &AsyncApp,
1461 ) -> Result<()> {
1462 let key = Self::client_secret_keychain_key(server_url);
1463 credentials_provider.delete_credentials(&key, cx).await
1464 }
1465
1466 /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1467 /// session from the keychain and stop the server.
1468 pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1469 let state = self.servers.get(id).context("Context server not found")?;
1470 let configuration = state.configuration();
1471
1472 let server_url = match configuration.as_ref() {
1473 ContextServerConfiguration::Http { url, .. } => url.clone(),
1474 _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1475 };
1476
1477 let id = id.clone();
1478 self.stop_server(&id, cx)?;
1479
1480 cx.spawn(async move |this, cx| {
1481 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1482 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1483 log::error!("{} failed to clear OAuth session: {}", id, err);
1484 }
1485 // Also clear any interactively-provided client secret so the user
1486 // gets a fresh prompt on the next authentication attempt.
1487 Self::clear_client_secret(&credentials_provider, &server_url, &cx)
1488 .await
1489 .log_err();
1490 // Trigger server recreation so the next start uses a fresh
1491 // transport without the old (now-invalidated) token provider.
1492 this.update(cx, |this, cx| {
1493 this.available_context_servers_changed(cx);
1494 })
1495 .log_err();
1496 })
1497 .detach();
1498
1499 Ok(())
1500 }
1501
1502 fn update_server_state(
1503 &mut self,
1504 id: ContextServerId,
1505 state: ContextServerState,
1506 cx: &mut Context<Self>,
1507 ) {
1508 let status = ContextServerStatus::from_state(&state);
1509 self.servers.insert(id.clone(), state);
1510 cx.emit(ServerStatusChangedEvent {
1511 server_id: id,
1512 status,
1513 });
1514 }
1515
1516 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1517 if self.update_servers_task.is_some() {
1518 self.needs_server_update = true;
1519 } else {
1520 self.needs_server_update = false;
1521 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1522 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1523 log::error!("Error maintaining context servers: {}", err);
1524 }
1525
1526 this.update(cx, |this, cx| {
1527 this.populate_server_ids(cx);
1528 cx.notify();
1529 this.update_servers_task.take();
1530 if this.needs_server_update {
1531 this.available_context_servers_changed(cx);
1532 }
1533 })?;
1534
1535 Ok(())
1536 }));
1537 }
1538 }
1539
1540 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1541 // Don't start context servers if AI is disabled
1542 let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1543 if ai_disabled {
1544 // Stop all running servers when AI is disabled
1545 this.update(cx, |this, cx| {
1546 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1547 for id in server_ids {
1548 let _ = this.stop_server(&id, cx);
1549 }
1550 })?;
1551 return Ok(());
1552 }
1553
1554 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1555 (
1556 this.context_server_settings.clone(),
1557 this.registry.clone(),
1558 this.worktree_store.clone(),
1559 )
1560 })?;
1561
1562 for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1563 configured_servers
1564 .entry(id)
1565 .or_insert(ContextServerSettings::default_extension());
1566 }
1567
1568 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1569 configured_servers
1570 .into_iter()
1571 .partition(|(_, settings)| settings.enabled());
1572
1573 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1574 let id = ContextServerId(id);
1575 ContextServerConfiguration::from_settings(
1576 settings,
1577 id.clone(),
1578 registry.clone(),
1579 worktree_store.clone(),
1580 cx,
1581 )
1582 .map(move |config| (id, config))
1583 }))
1584 .await
1585 .into_iter()
1586 .filter_map(|(id, config)| config.map(|config| (id, config)))
1587 .collect::<HashMap<_, _>>();
1588
1589 let mut servers_to_start = Vec::new();
1590 let mut servers_to_remove = HashSet::default();
1591 let mut servers_to_stop = HashSet::default();
1592
1593 this.update(cx, |this, _cx| {
1594 for server_id in this.servers.keys() {
1595 // All servers that are not in desired_servers should be removed from the store.
1596 // This can happen if the user removed a server from the context server settings.
1597 if !configured_servers.contains_key(server_id) {
1598 if disabled_servers.contains_key(&server_id.0) {
1599 servers_to_stop.insert(server_id.clone());
1600 } else {
1601 servers_to_remove.insert(server_id.clone());
1602 }
1603 }
1604 }
1605
1606 for (id, config) in configured_servers {
1607 let state = this.servers.get(&id);
1608 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1609 let existing_config = state.as_ref().map(|state| state.configuration());
1610 if existing_config.as_deref() != Some(&config) || is_stopped {
1611 let config = Arc::new(config);
1612 servers_to_start.push((id.clone(), config));
1613 if this.servers.contains_key(&id) {
1614 servers_to_stop.insert(id);
1615 }
1616 }
1617 }
1618
1619 anyhow::Ok(())
1620 })??;
1621
1622 this.update(cx, |this, inner_cx| {
1623 for id in servers_to_stop {
1624 this.stop_server(&id, inner_cx)?;
1625 }
1626 for id in servers_to_remove {
1627 this.remove_server(&id, inner_cx)?;
1628 }
1629 anyhow::Ok(())
1630 })??;
1631
1632 for (id, config) in servers_to_start {
1633 match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1634 Ok((server, config)) => {
1635 this.update(cx, |this, cx| {
1636 this.run_server(server, config, cx);
1637 })?;
1638 }
1639 Err(err) => {
1640 log::error!("{id} context server failed to create: {err:#}");
1641 this.update(cx, |_this, cx| {
1642 cx.emit(ServerStatusChangedEvent {
1643 server_id: id,
1644 status: ContextServerStatus::Error(err.to_string().into()),
1645 });
1646 cx.notify();
1647 })?;
1648 }
1649 }
1650 }
1651
1652 Ok(())
1653 }
1654}
1655
1656/// Determines the appropriate server state after a start attempt fails.
1657///
1658/// When the error is an HTTP 401 with no static auth header configured,
1659/// attempts OAuth discovery so the UI can offer an authentication flow.
1660async fn resolve_start_failure(
1661 id: &ContextServerId,
1662 err: anyhow::Error,
1663 server: Arc<ContextServer>,
1664 configuration: Arc<ContextServerConfiguration>,
1665 cx: &AsyncApp,
1666) -> ContextServerState {
1667 let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1668 TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1669 });
1670
1671 if www_authenticate.is_some() && configuration.has_static_auth_header() {
1672 log::warn!("{id} received 401 with a static Authorization header configured");
1673 return ContextServerState::Error {
1674 configuration,
1675 server,
1676 error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1677 .into(),
1678 };
1679 }
1680
1681 let server_url = match configuration.as_ref() {
1682 ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1683 url.clone()
1684 }
1685 _ => {
1686 if www_authenticate.is_some() {
1687 log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1688 } else {
1689 log::error!("{id} context server failed to start: {err}");
1690 }
1691 return ContextServerState::Error {
1692 configuration,
1693 server,
1694 error: err.to_string().into(),
1695 };
1696 }
1697 };
1698
1699 // When the error is NOT a 401 but there is a cached OAuth session in the
1700 // keychain, the session is likely stale/expired and caused the failure
1701 // (e.g. timeout because the server rejected the token silently). Clear it
1702 // so the next start attempt can get a clean 401 and trigger the auth flow.
1703 if www_authenticate.is_none() {
1704 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1705 match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1706 Ok(Some(_)) => {
1707 log::info!("{id} start failed with a cached OAuth session present; clearing it");
1708 ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1709 .await
1710 .log_err();
1711 }
1712 _ => {
1713 log::error!("{id} context server failed to start: {err}");
1714 return ContextServerState::Error {
1715 configuration,
1716 server,
1717 error: err.to_string().into(),
1718 };
1719 }
1720 }
1721 }
1722
1723 let default_www_authenticate = oauth::WwwAuthenticate {
1724 resource_metadata: None,
1725 scope: None,
1726 error: None,
1727 error_description: None,
1728 };
1729 let www_authenticate = www_authenticate
1730 .as_ref()
1731 .unwrap_or(&default_www_authenticate);
1732 let http_client = cx.update(|cx| cx.http_client());
1733
1734 match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1735 Ok(discovery) => {
1736 use context_server::oauth::{
1737 ClientRegistrationStrategy, determine_registration_strategy,
1738 };
1739
1740 let has_preregistered_client_id = matches!(
1741 configuration.as_ref(),
1742 ContextServerConfiguration::Http { oauth: Some(_), .. }
1743 );
1744
1745 let strategy = determine_registration_strategy(&discovery.auth_server_metadata);
1746
1747 if matches!(strategy, ClientRegistrationStrategy::Unavailable)
1748 && !has_preregistered_client_id
1749 {
1750 log::error!(
1751 "{id} authorization server supports neither CIMD nor DCR, \
1752 and no pre-registered client_id is configured"
1753 );
1754 return ContextServerState::Error {
1755 configuration,
1756 server,
1757 error: "Authorization server supports neither CIMD nor DCR. \
1758 Configure a pre-registered client_id in your settings \
1759 under the \"oauth\" key."
1760 .into(),
1761 };
1762 }
1763
1764 log::info!(
1765 "{id} requires OAuth authorization (auth server: {})",
1766 discovery.auth_server_metadata.issuer,
1767 );
1768 ContextServerState::AuthRequired {
1769 server,
1770 configuration,
1771 discovery: Arc::new(discovery),
1772 }
1773 }
1774 Err(discovery_err) => {
1775 log::error!("{id} OAuth discovery failed: {discovery_err}");
1776 ContextServerState::Error {
1777 configuration,
1778 server,
1779 error: format!("OAuth discovery failed: {discovery_err}").into(),
1780 }
1781 }
1782 }
1783}