1pub mod extension;
2pub mod registry;
3
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use collections::{HashMap, HashSet};
10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
11use context_server::transport::{HttpTransport, TransportError};
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use credentials_provider::CredentialsProvider;
14use futures::future::Either;
15use futures::{FutureExt as _, StreamExt as _, future::join_all};
16use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
17use http_client::HttpClient;
18use itertools::Itertools;
19use rand::Rng as _;
20use registry::ContextServerDescriptorRegistry;
21use remote::RemoteClient;
22use rpc::{AnyProtoClient, TypedEnvelope, proto};
23use settings::{Settings as _, SettingsStore};
24use util::{ResultExt as _, rel_path::RelPath};
25
26use crate::{
27 DisableAiSettings, Project,
28 project_settings::{ContextServerSettings, ProjectSettings},
29 worktree_store::WorktreeStore,
30};
31
32/// Maximum timeout for context server requests
33/// Prevents extremely large timeout values from tying up resources indefinitely.
34const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
35
36pub fn init(cx: &mut App) {
37 extension::init(cx);
38}
39
40actions!(
41 context_server,
42 [
43 /// Restarts the context server.
44 Restart
45 ]
46);
47
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub enum ContextServerStatus {
50 Starting,
51 Running,
52 Stopped,
53 Error(Arc<str>),
54 /// The server returned 401 and OAuth authorization is needed. The UI
55 /// should show an "Authenticate" button.
56 AuthRequired,
57 /// The OAuth browser flow is in progress — the user has been redirected
58 /// to the authorization server and we're waiting for the callback.
59 Authenticating,
60}
61
62impl ContextServerStatus {
63 fn from_state(state: &ContextServerState) -> Self {
64 match state {
65 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
66 ContextServerState::Running { .. } => ContextServerStatus::Running,
67 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
68 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
69 ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
70 ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
71 }
72 }
73}
74
75enum ContextServerState {
76 Starting {
77 server: Arc<ContextServer>,
78 configuration: Arc<ContextServerConfiguration>,
79 _task: Task<()>,
80 },
81 Running {
82 server: Arc<ContextServer>,
83 configuration: Arc<ContextServerConfiguration>,
84 },
85 Stopped {
86 server: Arc<ContextServer>,
87 configuration: Arc<ContextServerConfiguration>,
88 },
89 Error {
90 server: Arc<ContextServer>,
91 configuration: Arc<ContextServerConfiguration>,
92 error: Arc<str>,
93 },
94 /// The server requires OAuth authorization before it can be used. The
95 /// `OAuthDiscovery` holds everything needed to start the browser flow.
96 AuthRequired {
97 server: Arc<ContextServer>,
98 configuration: Arc<ContextServerConfiguration>,
99 discovery: Arc<OAuthDiscovery>,
100 },
101 /// The OAuth browser flow is in progress. The user has been redirected
102 /// to the authorization server and we're waiting for the callback.
103 Authenticating {
104 server: Arc<ContextServer>,
105 configuration: Arc<ContextServerConfiguration>,
106 _task: Task<()>,
107 },
108}
109
110impl ContextServerState {
111 pub fn server(&self) -> Arc<ContextServer> {
112 match self {
113 ContextServerState::Starting { server, .. }
114 | ContextServerState::Running { server, .. }
115 | ContextServerState::Stopped { server, .. }
116 | ContextServerState::Error { server, .. }
117 | ContextServerState::AuthRequired { server, .. }
118 | ContextServerState::Authenticating { server, .. } => server.clone(),
119 }
120 }
121
122 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
123 match self {
124 ContextServerState::Starting { configuration, .. }
125 | ContextServerState::Running { configuration, .. }
126 | ContextServerState::Stopped { configuration, .. }
127 | ContextServerState::Error { configuration, .. }
128 | ContextServerState::AuthRequired { configuration, .. }
129 | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
130 }
131 }
132}
133
134#[derive(Debug, PartialEq, Eq)]
135pub enum ContextServerConfiguration {
136 Custom {
137 command: ContextServerCommand,
138 remote: bool,
139 },
140 Extension {
141 command: ContextServerCommand,
142 settings: serde_json::Value,
143 remote: bool,
144 },
145 Http {
146 url: url::Url,
147 headers: HashMap<String, String>,
148 timeout: Option<u64>,
149 },
150}
151
152impl ContextServerConfiguration {
153 pub fn command(&self) -> Option<&ContextServerCommand> {
154 match self {
155 ContextServerConfiguration::Custom { command, .. } => Some(command),
156 ContextServerConfiguration::Extension { command, .. } => Some(command),
157 ContextServerConfiguration::Http { .. } => None,
158 }
159 }
160
161 pub fn has_static_auth_header(&self) -> bool {
162 match self {
163 ContextServerConfiguration::Http { headers, .. } => headers
164 .keys()
165 .any(|k| k.eq_ignore_ascii_case("authorization")),
166 _ => false,
167 }
168 }
169
170 pub fn remote(&self) -> bool {
171 match self {
172 ContextServerConfiguration::Custom { remote, .. } => *remote,
173 ContextServerConfiguration::Extension { remote, .. } => *remote,
174 ContextServerConfiguration::Http { .. } => false,
175 }
176 }
177
178 pub async fn from_settings(
179 settings: ContextServerSettings,
180 id: ContextServerId,
181 registry: Entity<ContextServerDescriptorRegistry>,
182 worktree_store: Entity<WorktreeStore>,
183 cx: &AsyncApp,
184 ) -> Option<Self> {
185 const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
186
187 match settings {
188 ContextServerSettings::Stdio {
189 enabled: _,
190 command,
191 remote,
192 } => Some(ContextServerConfiguration::Custom { command, remote }),
193 ContextServerSettings::Extension {
194 enabled: _,
195 settings,
196 remote,
197 } => {
198 let descriptor =
199 cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
200
201 let command_future = descriptor.command(worktree_store, cx);
202 let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
203
204 match futures::future::select(command_future, timeout_future).await {
205 Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
206 command,
207 settings,
208 remote,
209 }),
210 Either::Left((Err(e), _)) => {
211 log::error!(
212 "Failed to create context server configuration from settings: {e:#}"
213 );
214 None
215 }
216 Either::Right(_) => {
217 log::error!(
218 "Timed out resolving command for extension context server {id}"
219 );
220 None
221 }
222 }
223 }
224 ContextServerSettings::Http {
225 enabled: _,
226 url,
227 headers: auth,
228 timeout,
229 } => {
230 let url = url::Url::parse(&url).log_err()?;
231 Some(ContextServerConfiguration::Http {
232 url,
233 headers: auth,
234 timeout,
235 })
236 }
237 }
238 }
239}
240
241pub type ContextServerFactory =
242 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
243
244enum ContextServerStoreState {
245 Local {
246 downstream_client: Option<(u64, AnyProtoClient)>,
247 is_headless: bool,
248 },
249 Remote {
250 project_id: u64,
251 upstream_client: Entity<RemoteClient>,
252 },
253}
254
255pub struct ContextServerStore {
256 state: ContextServerStoreState,
257 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
258 servers: HashMap<ContextServerId, ContextServerState>,
259 server_ids: Vec<ContextServerId>,
260 worktree_store: Entity<WorktreeStore>,
261 project: Option<WeakEntity<Project>>,
262 registry: Entity<ContextServerDescriptorRegistry>,
263 update_servers_task: Option<Task<Result<()>>>,
264 context_server_factory: Option<ContextServerFactory>,
265 needs_server_update: bool,
266 ai_disabled: bool,
267 _subscriptions: Vec<Subscription>,
268}
269
270pub struct ServerStatusChangedEvent {
271 pub server_id: ContextServerId,
272 pub status: ContextServerStatus,
273}
274
275impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
276
277impl ContextServerStore {
278 pub fn local(
279 worktree_store: Entity<WorktreeStore>,
280 weak_project: Option<WeakEntity<Project>>,
281 headless: bool,
282 cx: &mut Context<Self>,
283 ) -> Self {
284 Self::new_internal(
285 !headless,
286 None,
287 ContextServerDescriptorRegistry::default_global(cx),
288 worktree_store,
289 weak_project,
290 ContextServerStoreState::Local {
291 downstream_client: None,
292 is_headless: headless,
293 },
294 cx,
295 )
296 }
297
298 pub fn remote(
299 project_id: u64,
300 upstream_client: Entity<RemoteClient>,
301 worktree_store: Entity<WorktreeStore>,
302 weak_project: Option<WeakEntity<Project>>,
303 cx: &mut Context<Self>,
304 ) -> Self {
305 Self::new_internal(
306 true,
307 None,
308 ContextServerDescriptorRegistry::default_global(cx),
309 worktree_store,
310 weak_project,
311 ContextServerStoreState::Remote {
312 project_id,
313 upstream_client,
314 },
315 cx,
316 )
317 }
318
319 pub fn init_headless(session: &AnyProtoClient) {
320 session.add_entity_request_handler(Self::handle_get_context_server_command);
321 }
322
323 pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
324 if let ContextServerStoreState::Local {
325 downstream_client, ..
326 } = &mut self.state
327 {
328 *downstream_client = Some((project_id, client));
329 }
330 }
331
332 pub fn is_remote_project(&self) -> bool {
333 matches!(self.state, ContextServerStoreState::Remote { .. })
334 }
335
336 /// Returns all configured context server ids, excluding the ones that are disabled
337 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
338 self.context_server_settings
339 .iter()
340 .filter(|(_, settings)| settings.enabled())
341 .map(|(id, _)| ContextServerId(id.clone()))
342 .collect()
343 }
344
345 #[cfg(feature = "test-support")]
346 pub fn test(
347 registry: Entity<ContextServerDescriptorRegistry>,
348 worktree_store: Entity<WorktreeStore>,
349 weak_project: Option<WeakEntity<Project>>,
350 cx: &mut Context<Self>,
351 ) -> Self {
352 Self::new_internal(
353 false,
354 None,
355 registry,
356 worktree_store,
357 weak_project,
358 ContextServerStoreState::Local {
359 downstream_client: None,
360 is_headless: false,
361 },
362 cx,
363 )
364 }
365
366 #[cfg(feature = "test-support")]
367 pub fn test_maintain_server_loop(
368 context_server_factory: Option<ContextServerFactory>,
369 registry: Entity<ContextServerDescriptorRegistry>,
370 worktree_store: Entity<WorktreeStore>,
371 weak_project: Option<WeakEntity<Project>>,
372 cx: &mut Context<Self>,
373 ) -> Self {
374 Self::new_internal(
375 true,
376 context_server_factory,
377 registry,
378 worktree_store,
379 weak_project,
380 ContextServerStoreState::Local {
381 downstream_client: None,
382 is_headless: false,
383 },
384 cx,
385 )
386 }
387
388 #[cfg(feature = "test-support")]
389 pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
390 self.context_server_factory = Some(factory);
391 }
392
393 #[cfg(feature = "test-support")]
394 pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
395 &self.registry
396 }
397
398 #[cfg(feature = "test-support")]
399 pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
400 let configuration = Arc::new(ContextServerConfiguration::Custom {
401 command: ContextServerCommand {
402 path: "test".into(),
403 args: vec![],
404 env: None,
405 timeout: None,
406 },
407 remote: false,
408 });
409 self.run_server(server, configuration, cx);
410 }
411
412 fn new_internal(
413 maintain_server_loop: bool,
414 context_server_factory: Option<ContextServerFactory>,
415 registry: Entity<ContextServerDescriptorRegistry>,
416 worktree_store: Entity<WorktreeStore>,
417 weak_project: Option<WeakEntity<Project>>,
418 state: ContextServerStoreState,
419 cx: &mut Context<Self>,
420 ) -> Self {
421 let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
422 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
423 let ai_was_disabled = this.ai_disabled;
424 this.ai_disabled = ai_disabled;
425
426 let settings =
427 &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
428 let settings_changed = &this.context_server_settings != settings;
429
430 if settings_changed {
431 this.context_server_settings = settings.clone();
432 }
433
434 // When AI is disabled, stop all running servers
435 if ai_disabled {
436 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
437 for id in server_ids {
438 this.stop_server(&id, cx).log_err();
439 }
440 return;
441 }
442
443 // Trigger updates if AI was re-enabled or settings changed
444 if maintain_server_loop && (ai_was_disabled || settings_changed) {
445 this.available_context_servers_changed(cx);
446 }
447 })];
448
449 if maintain_server_loop {
450 subscriptions.push(cx.observe(®istry, |this, _registry, cx| {
451 if !DisableAiSettings::get_global(cx).disable_ai {
452 this.available_context_servers_changed(cx);
453 }
454 }));
455 }
456
457 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
458 let mut this = Self {
459 state,
460 _subscriptions: subscriptions,
461 context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
462 .context_servers
463 .clone(),
464 worktree_store,
465 project: weak_project,
466 registry,
467 needs_server_update: false,
468 ai_disabled,
469 servers: HashMap::default(),
470 server_ids: Default::default(),
471 update_servers_task: None,
472 context_server_factory,
473 };
474 if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
475 this.available_context_servers_changed(cx);
476 }
477 this
478 }
479
480 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
481 self.servers.get(id).map(|state| state.server())
482 }
483
484 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
485 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
486 Some(server.clone())
487 } else {
488 None
489 }
490 }
491
492 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
493 self.servers.get(id).map(ContextServerStatus::from_state)
494 }
495
496 pub fn configuration_for_server(
497 &self,
498 id: &ContextServerId,
499 ) -> Option<Arc<ContextServerConfiguration>> {
500 self.servers.get(id).map(|state| state.configuration())
501 }
502
503 /// Returns a sorted slice of available unique context server IDs. Within the
504 /// slice, context servers which have `mcp-server-` as a prefix in their ID will
505 /// appear after servers that do not have this prefix in their ID.
506 pub fn server_ids(&self) -> &[ContextServerId] {
507 self.server_ids.as_slice()
508 }
509
510 fn populate_server_ids(&mut self, cx: &App) {
511 self.server_ids = self
512 .servers
513 .keys()
514 .cloned()
515 .chain(
516 self.registry
517 .read(cx)
518 .context_server_descriptors()
519 .into_iter()
520 .map(|(id, _)| ContextServerId(id)),
521 )
522 .chain(
523 self.context_server_settings
524 .keys()
525 .map(|id| ContextServerId(id.clone())),
526 )
527 .unique()
528 .sorted_unstable_by(
529 // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
530 |a, b| {
531 const MCP_PREFIX: &str = "mcp-server-";
532 match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
533 // If one has mcp-server- prefix and other doesn't, non-mcp comes first
534 (Some(_), None) => std::cmp::Ordering::Greater,
535 (None, Some(_)) => std::cmp::Ordering::Less,
536 // If both have same prefix status, sort by appropriate key
537 (Some(a), Some(b)) => a.cmp(b),
538 (None, None) => a.0.cmp(&b.0),
539 }
540 },
541 )
542 .collect();
543 }
544
545 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
546 self.servers
547 .values()
548 .filter_map(|state| {
549 if let ContextServerState::Running { server, .. } = state {
550 Some(server.clone())
551 } else {
552 None
553 }
554 })
555 .collect()
556 }
557
558 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
559 cx.spawn(async move |this, cx| {
560 let this = this.upgrade().context("Context server store dropped")?;
561 let id = server.id();
562 let settings = this
563 .update(cx, |this, _| {
564 this.context_server_settings.get(&id.0).cloned()
565 })
566 .context("Failed to get context server settings")?;
567
568 if !settings.enabled() {
569 return anyhow::Ok(());
570 }
571
572 let (registry, worktree_store) = this.update(cx, |this, _| {
573 (this.registry.clone(), this.worktree_store.clone())
574 });
575 let configuration = ContextServerConfiguration::from_settings(
576 settings,
577 id.clone(),
578 registry,
579 worktree_store,
580 cx,
581 )
582 .await
583 .context("Failed to create context server configuration")?;
584
585 this.update(cx, |this, cx| {
586 this.run_server(server, Arc::new(configuration), cx)
587 });
588 Ok(())
589 })
590 .detach_and_log_err(cx);
591 }
592
593 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
594 if matches!(
595 self.servers.get(id),
596 Some(ContextServerState::Stopped { .. })
597 ) {
598 return Ok(());
599 }
600
601 let state = self
602 .servers
603 .remove(id)
604 .context("Context server not found")?;
605
606 let server = state.server();
607 let configuration = state.configuration();
608 let mut result = Ok(());
609 if let ContextServerState::Running { server, .. } = &state {
610 result = server.stop();
611 }
612 drop(state);
613
614 self.update_server_state(
615 id.clone(),
616 ContextServerState::Stopped {
617 configuration,
618 server,
619 },
620 cx,
621 );
622
623 result
624 }
625
626 fn run_server(
627 &mut self,
628 server: Arc<ContextServer>,
629 configuration: Arc<ContextServerConfiguration>,
630 cx: &mut Context<Self>,
631 ) {
632 let id = server.id();
633 if matches!(
634 self.servers.get(&id),
635 Some(
636 ContextServerState::Starting { .. }
637 | ContextServerState::Running { .. }
638 | ContextServerState::Authenticating { .. },
639 )
640 ) {
641 self.stop_server(&id, cx).log_err();
642 }
643 let task = cx.spawn({
644 let id = server.id();
645 let server = server.clone();
646 let configuration = configuration.clone();
647
648 async move |this, cx| {
649 let new_state = match server.clone().start(cx).await {
650 Ok(_) => {
651 debug_assert!(server.client().is_some());
652 ContextServerState::Running {
653 server,
654 configuration,
655 }
656 }
657 Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
658 };
659 this.update(cx, |this, cx| {
660 this.update_server_state(id.clone(), new_state, cx)
661 })
662 .log_err();
663 }
664 });
665
666 self.update_server_state(
667 id.clone(),
668 ContextServerState::Starting {
669 configuration,
670 _task: task,
671 server,
672 },
673 cx,
674 );
675 }
676
677 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
678 let state = self
679 .servers
680 .remove(id)
681 .context("Context server not found")?;
682
683 if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
684 let server_url = url.clone();
685 let id = id.clone();
686 cx.spawn(async move |_this, cx| {
687 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
688 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
689 {
690 log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
691 }
692 })
693 .detach();
694 }
695
696 drop(state);
697 cx.emit(ServerStatusChangedEvent {
698 server_id: id.clone(),
699 status: ContextServerStatus::Stopped,
700 });
701 Ok(())
702 }
703
704 pub async fn create_context_server(
705 this: WeakEntity<Self>,
706 id: ContextServerId,
707 configuration: Arc<ContextServerConfiguration>,
708 cx: &mut AsyncApp,
709 ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
710 let remote = configuration.remote();
711 let needs_remote_command = match configuration.as_ref() {
712 ContextServerConfiguration::Custom { .. }
713 | ContextServerConfiguration::Extension { .. } => remote,
714 ContextServerConfiguration::Http { .. } => false,
715 };
716
717 let (remote_state, is_remote_project) = this.update(cx, |this, _| {
718 let remote_state = match &this.state {
719 ContextServerStoreState::Remote {
720 project_id,
721 upstream_client,
722 } if needs_remote_command => Some((*project_id, upstream_client.clone())),
723 _ => None,
724 };
725 (remote_state, this.is_remote_project())
726 })?;
727
728 let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
729 this.project
730 .as_ref()
731 .and_then(|project| {
732 project
733 .read_with(cx, |project, cx| project.active_project_directory(cx))
734 .ok()
735 .flatten()
736 })
737 .or_else(|| {
738 this.worktree_store.read_with(cx, |store, cx| {
739 store.visible_worktrees(cx).fold(None, |acc, item| {
740 if acc.is_none() {
741 item.read(cx).root_dir()
742 } else {
743 acc
744 }
745 })
746 })
747 })
748 })?;
749
750 let configuration = if let Some((project_id, upstream_client)) = remote_state {
751 let root_dir = root_path.as_ref().map(|p| p.display().to_string());
752
753 let response = upstream_client
754 .update(cx, |client, _| {
755 client
756 .proto_client()
757 .request(proto::GetContextServerCommand {
758 project_id,
759 server_id: id.0.to_string(),
760 root_dir: root_dir.clone(),
761 })
762 })
763 .await?;
764
765 let remote_command = upstream_client.update(cx, |client, _| {
766 client.build_command(
767 Some(response.path),
768 &response.args,
769 &response.env.into_iter().collect(),
770 root_dir,
771 None,
772 )
773 })?;
774
775 let command = ContextServerCommand {
776 path: remote_command.program.into(),
777 args: remote_command.args,
778 env: Some(remote_command.env.into_iter().collect()),
779 timeout: None,
780 };
781
782 Arc::new(ContextServerConfiguration::Custom { command, remote })
783 } else {
784 configuration
785 };
786
787 if let Some(server) = this.update(cx, |this, _| {
788 this.context_server_factory
789 .as_ref()
790 .map(|factory| factory(id.clone(), configuration.clone()))
791 })? {
792 return Ok((server, configuration));
793 }
794
795 let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
796 if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
797 if configuration.has_static_auth_header() {
798 None
799 } else {
800 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
801 let http_client = cx.update(|cx| cx.http_client());
802
803 match Self::load_session(&credentials_provider, url, &cx).await {
804 Ok(Some(session)) => {
805 log::info!("{} loaded cached OAuth session from keychain", id);
806 Some(Self::create_oauth_token_provider(
807 &id,
808 url,
809 session,
810 http_client,
811 credentials_provider,
812 cx,
813 ))
814 }
815 Ok(None) => None,
816 Err(err) => {
817 log::warn!("{} failed to load cached OAuth session: {}", id, err);
818 None
819 }
820 }
821 }
822 } else {
823 None
824 };
825
826 let server: Arc<ContextServer> = this.update(cx, |this, cx| {
827 let global_timeout =
828 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
829
830 match configuration.as_ref() {
831 ContextServerConfiguration::Http {
832 url,
833 headers,
834 timeout,
835 } => {
836 let transport = HttpTransport::new_with_token_provider(
837 cx.http_client(),
838 url.to_string(),
839 headers.clone(),
840 cx.background_executor().clone(),
841 cached_token_provider.clone(),
842 );
843 anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
844 id,
845 Arc::new(transport),
846 Some(Duration::from_secs(
847 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
848 )),
849 )))
850 }
851 _ => {
852 let mut command = configuration
853 .command()
854 .context("Missing command configuration for stdio context server")?
855 .clone();
856 command.timeout = Some(
857 command
858 .timeout
859 .unwrap_or(global_timeout)
860 .min(MAX_TIMEOUT_SECS),
861 );
862
863 // Don't pass remote paths as working directory for locally-spawned processes
864 let working_directory = if is_remote_project { None } else { root_path };
865 anyhow::Ok(Arc::new(ContextServer::stdio(
866 id,
867 command,
868 working_directory,
869 )))
870 }
871 }
872 })??;
873
874 Ok((server, configuration))
875 }
876
877 async fn handle_get_context_server_command(
878 this: Entity<Self>,
879 envelope: TypedEnvelope<proto::GetContextServerCommand>,
880 mut cx: AsyncApp,
881 ) -> Result<proto::ContextServerCommand> {
882 let server_id = ContextServerId(envelope.payload.server_id.into());
883
884 let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
885 let ContextServerStoreState::Local {
886 is_headless: true, ..
887 } = &this.state
888 else {
889 anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
890 };
891
892 let settings = this
893 .context_server_settings
894 .get(&server_id.0)
895 .cloned()
896 .or_else(|| {
897 this.registry
898 .read(inner_cx)
899 .context_server_descriptor(&server_id.0)
900 .map(|_| ContextServerSettings::default_extension())
901 })
902 .with_context(|| format!("context server `{}` not found", server_id))?;
903
904 anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
905 })?;
906
907 let configuration = ContextServerConfiguration::from_settings(
908 settings,
909 server_id.clone(),
910 registry,
911 worktree_store,
912 &cx,
913 )
914 .await
915 .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
916
917 let command = configuration
918 .command()
919 .context("context server has no command (HTTP servers don't need RPC)")?;
920
921 Ok(proto::ContextServerCommand {
922 path: command.path.display().to_string(),
923 args: command.args.clone(),
924 env: command
925 .env
926 .clone()
927 .map(|env| env.into_iter().collect())
928 .unwrap_or_default(),
929 })
930 }
931
932 fn resolve_project_settings<'a>(
933 worktree_store: &'a Entity<WorktreeStore>,
934 cx: &'a App,
935 ) -> &'a ProjectSettings {
936 let location = worktree_store
937 .read(cx)
938 .visible_worktrees(cx)
939 .next()
940 .map(|worktree| settings::SettingsLocation {
941 worktree_id: worktree.read(cx).id(),
942 path: RelPath::empty(),
943 });
944 ProjectSettings::get(location, cx)
945 }
946
947 fn create_oauth_token_provider(
948 id: &ContextServerId,
949 server_url: &url::Url,
950 session: OAuthSession,
951 http_client: Arc<dyn HttpClient>,
952 credentials_provider: Arc<dyn CredentialsProvider>,
953 cx: &mut AsyncApp,
954 ) -> Arc<dyn oauth::OAuthTokenProvider> {
955 let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
956 let id = id.clone();
957 let server_url = server_url.clone();
958
959 cx.spawn(async move |cx| {
960 while let Some(refreshed_session) = token_refresh_rx.next().await {
961 if let Err(err) =
962 Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
963 .await
964 {
965 log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
966 }
967 }
968 log::debug!("{} OAuth session persistence task ended", id);
969 })
970 .detach();
971
972 Arc::new(McpOAuthTokenProvider::new(
973 session,
974 http_client,
975 Some(token_refresh_tx),
976 ))
977 }
978
979 /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
980 ///
981 /// This starts a loopback HTTP callback server on an ephemeral port, builds
982 /// the authorization URL, opens the user's browser, waits for the callback,
983 /// exchanges the code for tokens, persists them in the keychain, and restarts
984 /// the server with the new token provider.
985 pub fn authenticate_server(
986 &mut self,
987 id: &ContextServerId,
988 cx: &mut Context<Self>,
989 ) -> Result<()> {
990 let state = self.servers.get(id).context("Context server not found")?;
991
992 let (discovery, server, configuration) = match state {
993 ContextServerState::AuthRequired {
994 discovery,
995 server,
996 configuration,
997 } => (discovery.clone(), server.clone(), configuration.clone()),
998 _ => anyhow::bail!("Server is not in AuthRequired state"),
999 };
1000
1001 let id = id.clone();
1002
1003 let task = cx.spawn({
1004 let id = id.clone();
1005 let server = server.clone();
1006 let configuration = configuration.clone();
1007 async move |this, cx| {
1008 let result = Self::run_oauth_flow(
1009 this.clone(),
1010 id.clone(),
1011 discovery.clone(),
1012 configuration.clone(),
1013 cx,
1014 )
1015 .await;
1016
1017 if let Err(err) = &result {
1018 log::error!("{} OAuth authentication failed: {:?}", id, err);
1019 // Transition back to AuthRequired so the user can retry
1020 // rather than landing in a terminal Error state.
1021 this.update(cx, |this, cx| {
1022 this.update_server_state(
1023 id.clone(),
1024 ContextServerState::AuthRequired {
1025 server,
1026 configuration,
1027 discovery,
1028 },
1029 cx,
1030 )
1031 })
1032 .log_err();
1033 }
1034 }
1035 });
1036
1037 self.update_server_state(
1038 id,
1039 ContextServerState::Authenticating {
1040 server,
1041 configuration,
1042 _task: task,
1043 },
1044 cx,
1045 );
1046
1047 Ok(())
1048 }
1049
1050 async fn run_oauth_flow(
1051 this: WeakEntity<Self>,
1052 id: ContextServerId,
1053 discovery: Arc<OAuthDiscovery>,
1054 configuration: Arc<ContextServerConfiguration>,
1055 cx: &mut AsyncApp,
1056 ) -> Result<()> {
1057 let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1058 let pkce = oauth::generate_pkce_challenge();
1059
1060 let mut state_bytes = [0u8; 32];
1061 rand::rng().fill(&mut state_bytes);
1062 let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1063
1064 // Start a loopback HTTP server on an ephemeral port. The redirect URI
1065 // includes this port so the browser sends the callback directly to our
1066 // process.
1067 let (redirect_uri, callback_rx) = oauth::start_callback_server()
1068 .await
1069 .context("Failed to start OAuth callback server")?;
1070
1071 let http_client = cx.update(|cx| cx.http_client());
1072 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1073 let server_url = match configuration.as_ref() {
1074 ContextServerConfiguration::Http { url, .. } => url.clone(),
1075 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1076 };
1077
1078 let client_registration =
1079 oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1080 .await
1081 .context("Failed to resolve OAuth client registration")?;
1082
1083 let auth_url = oauth::build_authorization_url(
1084 &discovery.auth_server_metadata,
1085 &client_registration.client_id,
1086 &redirect_uri,
1087 &discovery.scopes,
1088 &resource,
1089 &pkce,
1090 &state_param,
1091 );
1092
1093 cx.update(|cx| cx.open_url(auth_url.as_str()));
1094
1095 let callback = callback_rx
1096 .await
1097 .map_err(|_| {
1098 anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1099 })?
1100 .context("OAuth callback server received an invalid request")?;
1101
1102 if callback.state != state_param {
1103 anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1104 }
1105
1106 let tokens = oauth::exchange_code(
1107 &http_client,
1108 &discovery.auth_server_metadata,
1109 &callback.code,
1110 &client_registration.client_id,
1111 &redirect_uri,
1112 &pkce.verifier,
1113 &resource,
1114 )
1115 .await
1116 .context("Failed to exchange authorization code for tokens")?;
1117
1118 let session = OAuthSession {
1119 token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1120 resource: discovery.resource_metadata.resource.clone(),
1121 client_registration,
1122 tokens,
1123 };
1124
1125 Self::store_session(&credentials_provider, &server_url, &session, cx)
1126 .await
1127 .context("Failed to persist OAuth session in keychain")?;
1128
1129 let token_provider = Self::create_oauth_token_provider(
1130 &id,
1131 &server_url,
1132 session,
1133 http_client.clone(),
1134 credentials_provider,
1135 cx,
1136 );
1137
1138 let new_server = this.update(cx, |this, cx| {
1139 let global_timeout =
1140 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1141
1142 match configuration.as_ref() {
1143 ContextServerConfiguration::Http {
1144 url,
1145 headers,
1146 timeout,
1147 } => {
1148 let transport = HttpTransport::new_with_token_provider(
1149 http_client.clone(),
1150 url.to_string(),
1151 headers.clone(),
1152 cx.background_executor().clone(),
1153 Some(token_provider.clone()),
1154 );
1155 Ok(Arc::new(ContextServer::new_with_timeout(
1156 id.clone(),
1157 Arc::new(transport),
1158 Some(Duration::from_secs(
1159 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1160 )),
1161 )))
1162 }
1163 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1164 }
1165 })??;
1166
1167 this.update(cx, |this, cx| {
1168 this.run_server(new_server, configuration, cx);
1169 })?;
1170
1171 Ok(())
1172 }
1173
1174 /// Store the full OAuth session in the system keychain, keyed by the
1175 /// server's canonical URI.
1176 async fn store_session(
1177 credentials_provider: &Arc<dyn CredentialsProvider>,
1178 server_url: &url::Url,
1179 session: &OAuthSession,
1180 cx: &AsyncApp,
1181 ) -> Result<()> {
1182 let key = Self::keychain_key(server_url);
1183 let json = serde_json::to_string(session)?;
1184 credentials_provider
1185 .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1186 .await
1187 }
1188
1189 /// Load the full OAuth session from the system keychain for the given
1190 /// server URL.
1191 async fn load_session(
1192 credentials_provider: &Arc<dyn CredentialsProvider>,
1193 server_url: &url::Url,
1194 cx: &AsyncApp,
1195 ) -> Result<Option<OAuthSession>> {
1196 let key = Self::keychain_key(server_url);
1197 match credentials_provider.read_credentials(&key, cx).await? {
1198 Some((_username, password_bytes)) => {
1199 let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1200 Ok(Some(session))
1201 }
1202 None => Ok(None),
1203 }
1204 }
1205
1206 /// Clear the stored OAuth session from the system keychain.
1207 async fn clear_session(
1208 credentials_provider: &Arc<dyn CredentialsProvider>,
1209 server_url: &url::Url,
1210 cx: &AsyncApp,
1211 ) -> Result<()> {
1212 let key = Self::keychain_key(server_url);
1213 credentials_provider.delete_credentials(&key, cx).await
1214 }
1215
1216 fn keychain_key(server_url: &url::Url) -> String {
1217 format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1218 }
1219
1220 /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1221 /// session from the keychain and stop the server.
1222 pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1223 let state = self.servers.get(id).context("Context server not found")?;
1224 let configuration = state.configuration();
1225
1226 let server_url = match configuration.as_ref() {
1227 ContextServerConfiguration::Http { url, .. } => url.clone(),
1228 _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1229 };
1230
1231 let id = id.clone();
1232 self.stop_server(&id, cx)?;
1233
1234 cx.spawn(async move |this, cx| {
1235 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1236 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1237 log::error!("{} failed to clear OAuth session: {}", id, err);
1238 }
1239 // Trigger server recreation so the next start uses a fresh
1240 // transport without the old (now-invalidated) token provider.
1241 this.update(cx, |this, cx| {
1242 this.available_context_servers_changed(cx);
1243 })
1244 .log_err();
1245 })
1246 .detach();
1247
1248 Ok(())
1249 }
1250
1251 fn update_server_state(
1252 &mut self,
1253 id: ContextServerId,
1254 state: ContextServerState,
1255 cx: &mut Context<Self>,
1256 ) {
1257 let status = ContextServerStatus::from_state(&state);
1258 self.servers.insert(id.clone(), state);
1259 cx.emit(ServerStatusChangedEvent {
1260 server_id: id,
1261 status,
1262 });
1263 }
1264
1265 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1266 if self.update_servers_task.is_some() {
1267 self.needs_server_update = true;
1268 } else {
1269 self.needs_server_update = false;
1270 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1271 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1272 log::error!("Error maintaining context servers: {}", err);
1273 }
1274
1275 this.update(cx, |this, cx| {
1276 this.populate_server_ids(cx);
1277 cx.notify();
1278 this.update_servers_task.take();
1279 if this.needs_server_update {
1280 this.available_context_servers_changed(cx);
1281 }
1282 })?;
1283
1284 Ok(())
1285 }));
1286 }
1287 }
1288
1289 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1290 // Don't start context servers if AI is disabled
1291 let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1292 if ai_disabled {
1293 // Stop all running servers when AI is disabled
1294 this.update(cx, |this, cx| {
1295 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1296 for id in server_ids {
1297 let _ = this.stop_server(&id, cx);
1298 }
1299 })?;
1300 return Ok(());
1301 }
1302
1303 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1304 (
1305 this.context_server_settings.clone(),
1306 this.registry.clone(),
1307 this.worktree_store.clone(),
1308 )
1309 })?;
1310
1311 for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1312 configured_servers
1313 .entry(id)
1314 .or_insert(ContextServerSettings::default_extension());
1315 }
1316
1317 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1318 configured_servers
1319 .into_iter()
1320 .partition(|(_, settings)| settings.enabled());
1321
1322 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1323 let id = ContextServerId(id);
1324 ContextServerConfiguration::from_settings(
1325 settings,
1326 id.clone(),
1327 registry.clone(),
1328 worktree_store.clone(),
1329 cx,
1330 )
1331 .map(move |config| (id, config))
1332 }))
1333 .await
1334 .into_iter()
1335 .filter_map(|(id, config)| config.map(|config| (id, config)))
1336 .collect::<HashMap<_, _>>();
1337
1338 let mut servers_to_start = Vec::new();
1339 let mut servers_to_remove = HashSet::default();
1340 let mut servers_to_stop = HashSet::default();
1341
1342 this.update(cx, |this, _cx| {
1343 for server_id in this.servers.keys() {
1344 // All servers that are not in desired_servers should be removed from the store.
1345 // This can happen if the user removed a server from the context server settings.
1346 if !configured_servers.contains_key(server_id) {
1347 if disabled_servers.contains_key(&server_id.0) {
1348 servers_to_stop.insert(server_id.clone());
1349 } else {
1350 servers_to_remove.insert(server_id.clone());
1351 }
1352 }
1353 }
1354
1355 for (id, config) in configured_servers {
1356 let state = this.servers.get(&id);
1357 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1358 let existing_config = state.as_ref().map(|state| state.configuration());
1359 if existing_config.as_deref() != Some(&config) || is_stopped {
1360 let config = Arc::new(config);
1361 servers_to_start.push((id.clone(), config));
1362 if this.servers.contains_key(&id) {
1363 servers_to_stop.insert(id);
1364 }
1365 }
1366 }
1367
1368 anyhow::Ok(())
1369 })??;
1370
1371 this.update(cx, |this, inner_cx| {
1372 for id in servers_to_stop {
1373 this.stop_server(&id, inner_cx)?;
1374 }
1375 for id in servers_to_remove {
1376 this.remove_server(&id, inner_cx)?;
1377 }
1378 anyhow::Ok(())
1379 })??;
1380
1381 for (id, config) in servers_to_start {
1382 match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1383 Ok((server, config)) => {
1384 this.update(cx, |this, cx| {
1385 this.run_server(server, config, cx);
1386 })?;
1387 }
1388 Err(err) => {
1389 log::error!("{id} context server failed to create: {err:#}");
1390 this.update(cx, |_this, cx| {
1391 cx.emit(ServerStatusChangedEvent {
1392 server_id: id,
1393 status: ContextServerStatus::Error(err.to_string().into()),
1394 });
1395 cx.notify();
1396 })?;
1397 }
1398 }
1399 }
1400
1401 Ok(())
1402 }
1403}
1404
1405/// Determines the appropriate server state after a start attempt fails.
1406///
1407/// When the error is an HTTP 401 with no static auth header configured,
1408/// attempts OAuth discovery so the UI can offer an authentication flow.
1409async fn resolve_start_failure(
1410 id: &ContextServerId,
1411 err: anyhow::Error,
1412 server: Arc<ContextServer>,
1413 configuration: Arc<ContextServerConfiguration>,
1414 cx: &AsyncApp,
1415) -> ContextServerState {
1416 let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1417 TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1418 });
1419
1420 if www_authenticate.is_some() && configuration.has_static_auth_header() {
1421 log::warn!("{id} received 401 with a static Authorization header configured");
1422 return ContextServerState::Error {
1423 configuration,
1424 server,
1425 error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1426 .into(),
1427 };
1428 }
1429
1430 let server_url = match configuration.as_ref() {
1431 ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1432 url.clone()
1433 }
1434 _ => {
1435 if www_authenticate.is_some() {
1436 log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1437 } else {
1438 log::error!("{id} context server failed to start: {err}");
1439 }
1440 return ContextServerState::Error {
1441 configuration,
1442 server,
1443 error: err.to_string().into(),
1444 };
1445 }
1446 };
1447
1448 // When the error is NOT a 401 but there is a cached OAuth session in the
1449 // keychain, the session is likely stale/expired and caused the failure
1450 // (e.g. timeout because the server rejected the token silently). Clear it
1451 // so the next start attempt can get a clean 401 and trigger the auth flow.
1452 if www_authenticate.is_none() {
1453 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1454 match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1455 Ok(Some(_)) => {
1456 log::info!("{id} start failed with a cached OAuth session present; clearing it");
1457 ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1458 .await
1459 .log_err();
1460 }
1461 _ => {
1462 log::error!("{id} context server failed to start: {err}");
1463 return ContextServerState::Error {
1464 configuration,
1465 server,
1466 error: err.to_string().into(),
1467 };
1468 }
1469 }
1470 }
1471
1472 let default_www_authenticate = oauth::WwwAuthenticate {
1473 resource_metadata: None,
1474 scope: None,
1475 error: None,
1476 error_description: None,
1477 };
1478 let www_authenticate = www_authenticate
1479 .as_ref()
1480 .unwrap_or(&default_www_authenticate);
1481 let http_client = cx.update(|cx| cx.http_client());
1482
1483 match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1484 Ok(discovery) => {
1485 log::info!(
1486 "{id} requires OAuth authorization (auth server: {})",
1487 discovery.auth_server_metadata.issuer,
1488 );
1489 ContextServerState::AuthRequired {
1490 server,
1491 configuration,
1492 discovery: Arc::new(discovery),
1493 }
1494 }
1495 Err(discovery_err) => {
1496 log::error!("{id} OAuth discovery failed: {discovery_err}");
1497 ContextServerState::Error {
1498 configuration,
1499 server,
1500 error: format!("OAuth discovery failed: {discovery_err}").into(),
1501 }
1502 }
1503 }
1504}