1pub mod extension;
2pub mod registry;
3
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use collections::{HashMap, HashSet};
10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
11use context_server::transport::{HttpTransport, TransportError};
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use credentials_provider::CredentialsProvider;
14use futures::future::Either;
15use futures::{FutureExt as _, StreamExt as _, future::join_all};
16use gpui::{
17 App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, TaskExt, WeakEntity, actions,
18};
19use http_client::HttpClient;
20use itertools::Itertools;
21use rand::Rng as _;
22use registry::ContextServerDescriptorRegistry;
23use remote::RemoteClient;
24use rpc::{AnyProtoClient, TypedEnvelope, proto};
25use settings::{Settings as _, SettingsStore};
26use util::{ResultExt as _, rel_path::RelPath};
27
28use crate::{
29 DisableAiSettings, Project,
30 project_settings::{ContextServerSettings, ProjectSettings},
31 worktree_store::WorktreeStore,
32};
33
34/// Maximum timeout for context server requests
35/// Prevents extremely large timeout values from tying up resources indefinitely.
36const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
37
38pub fn init(cx: &mut App) {
39 extension::init(cx);
40}
41
42actions!(
43 context_server,
44 [
45 /// Restarts the context server.
46 Restart
47 ]
48);
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
51pub enum ContextServerStatus {
52 Starting,
53 Running,
54 Stopped,
55 Error(Arc<str>),
56 /// The server returned 401 and OAuth authorization is needed. The UI
57 /// should show an "Authenticate" button.
58 AuthRequired,
59 /// The OAuth browser flow is in progress — the user has been redirected
60 /// to the authorization server and we're waiting for the callback.
61 Authenticating,
62}
63
64impl ContextServerStatus {
65 fn from_state(state: &ContextServerState) -> Self {
66 match state {
67 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
68 ContextServerState::Running { .. } => ContextServerStatus::Running,
69 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
70 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
71 ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
72 ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
73 }
74 }
75}
76
77enum ContextServerState {
78 Starting {
79 server: Arc<ContextServer>,
80 configuration: Arc<ContextServerConfiguration>,
81 _task: Task<()>,
82 },
83 Running {
84 server: Arc<ContextServer>,
85 configuration: Arc<ContextServerConfiguration>,
86 },
87 Stopped {
88 server: Arc<ContextServer>,
89 configuration: Arc<ContextServerConfiguration>,
90 },
91 Error {
92 server: Arc<ContextServer>,
93 configuration: Arc<ContextServerConfiguration>,
94 error: Arc<str>,
95 },
96 /// The server requires OAuth authorization before it can be used. The
97 /// `OAuthDiscovery` holds everything needed to start the browser flow.
98 AuthRequired {
99 server: Arc<ContextServer>,
100 configuration: Arc<ContextServerConfiguration>,
101 discovery: Arc<OAuthDiscovery>,
102 },
103 /// The OAuth browser flow is in progress. The user has been redirected
104 /// to the authorization server and we're waiting for the callback.
105 Authenticating {
106 server: Arc<ContextServer>,
107 configuration: Arc<ContextServerConfiguration>,
108 _task: Task<()>,
109 },
110}
111
112impl ContextServerState {
113 pub fn server(&self) -> Arc<ContextServer> {
114 match self {
115 ContextServerState::Starting { server, .. }
116 | ContextServerState::Running { server, .. }
117 | ContextServerState::Stopped { server, .. }
118 | ContextServerState::Error { server, .. }
119 | ContextServerState::AuthRequired { server, .. }
120 | ContextServerState::Authenticating { server, .. } => server.clone(),
121 }
122 }
123
124 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
125 match self {
126 ContextServerState::Starting { configuration, .. }
127 | ContextServerState::Running { configuration, .. }
128 | ContextServerState::Stopped { configuration, .. }
129 | ContextServerState::Error { configuration, .. }
130 | ContextServerState::AuthRequired { configuration, .. }
131 | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
132 }
133 }
134}
135
136#[derive(Debug, PartialEq, Eq)]
137pub enum ContextServerConfiguration {
138 Custom {
139 command: ContextServerCommand,
140 remote: bool,
141 },
142 Extension {
143 command: ContextServerCommand,
144 settings: serde_json::Value,
145 remote: bool,
146 },
147 Http {
148 url: url::Url,
149 headers: HashMap<String, String>,
150 timeout: Option<u64>,
151 },
152}
153
154impl ContextServerConfiguration {
155 pub fn command(&self) -> Option<&ContextServerCommand> {
156 match self {
157 ContextServerConfiguration::Custom { command, .. } => Some(command),
158 ContextServerConfiguration::Extension { command, .. } => Some(command),
159 ContextServerConfiguration::Http { .. } => None,
160 }
161 }
162
163 pub fn has_static_auth_header(&self) -> bool {
164 match self {
165 ContextServerConfiguration::Http { headers, .. } => headers
166 .keys()
167 .any(|k| k.eq_ignore_ascii_case("authorization")),
168 _ => false,
169 }
170 }
171
172 pub fn remote(&self) -> bool {
173 match self {
174 ContextServerConfiguration::Custom { remote, .. } => *remote,
175 ContextServerConfiguration::Extension { remote, .. } => *remote,
176 ContextServerConfiguration::Http { .. } => false,
177 }
178 }
179
180 pub async fn from_settings(
181 settings: ContextServerSettings,
182 id: ContextServerId,
183 registry: Entity<ContextServerDescriptorRegistry>,
184 worktree_store: Entity<WorktreeStore>,
185 cx: &AsyncApp,
186 ) -> Option<Self> {
187 const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
188
189 match settings {
190 ContextServerSettings::Stdio {
191 enabled: _,
192 command,
193 remote,
194 } => Some(ContextServerConfiguration::Custom { command, remote }),
195 ContextServerSettings::Extension {
196 enabled: _,
197 settings,
198 remote,
199 } => {
200 let descriptor =
201 cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
202
203 let command_future = descriptor.command(worktree_store, cx);
204 let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
205
206 match futures::future::select(command_future, timeout_future).await {
207 Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
208 command,
209 settings,
210 remote,
211 }),
212 Either::Left((Err(e), _)) => {
213 log::error!(
214 "Failed to create context server configuration from settings: {e:#}"
215 );
216 None
217 }
218 Either::Right(_) => {
219 log::error!(
220 "Timed out resolving command for extension context server {id}"
221 );
222 None
223 }
224 }
225 }
226 ContextServerSettings::Http {
227 enabled: _,
228 url,
229 headers: auth,
230 timeout,
231 } => {
232 let url = url::Url::parse(&url).log_err()?;
233 Some(ContextServerConfiguration::Http {
234 url,
235 headers: auth,
236 timeout,
237 })
238 }
239 }
240 }
241}
242
243pub type ContextServerFactory =
244 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
245
246enum ContextServerStoreState {
247 Local {
248 downstream_client: Option<(u64, AnyProtoClient)>,
249 is_headless: bool,
250 },
251 Remote {
252 project_id: u64,
253 upstream_client: Entity<RemoteClient>,
254 },
255}
256
257pub struct ContextServerStore {
258 state: ContextServerStoreState,
259 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
260 servers: HashMap<ContextServerId, ContextServerState>,
261 server_ids: Vec<ContextServerId>,
262 worktree_store: Entity<WorktreeStore>,
263 project: Option<WeakEntity<Project>>,
264 registry: Entity<ContextServerDescriptorRegistry>,
265 update_servers_task: Option<Task<Result<()>>>,
266 context_server_factory: Option<ContextServerFactory>,
267 needs_server_update: bool,
268 ai_disabled: bool,
269 _subscriptions: Vec<Subscription>,
270}
271
272pub struct ServerStatusChangedEvent {
273 pub server_id: ContextServerId,
274 pub status: ContextServerStatus,
275}
276
277impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
278
279impl ContextServerStore {
280 pub fn local(
281 worktree_store: Entity<WorktreeStore>,
282 weak_project: Option<WeakEntity<Project>>,
283 headless: bool,
284 cx: &mut Context<Self>,
285 ) -> Self {
286 Self::new_internal(
287 !headless,
288 None,
289 ContextServerDescriptorRegistry::default_global(cx),
290 worktree_store,
291 weak_project,
292 ContextServerStoreState::Local {
293 downstream_client: None,
294 is_headless: headless,
295 },
296 cx,
297 )
298 }
299
300 pub fn remote(
301 project_id: u64,
302 upstream_client: Entity<RemoteClient>,
303 worktree_store: Entity<WorktreeStore>,
304 weak_project: Option<WeakEntity<Project>>,
305 cx: &mut Context<Self>,
306 ) -> Self {
307 Self::new_internal(
308 true,
309 None,
310 ContextServerDescriptorRegistry::default_global(cx),
311 worktree_store,
312 weak_project,
313 ContextServerStoreState::Remote {
314 project_id,
315 upstream_client,
316 },
317 cx,
318 )
319 }
320
321 pub fn init_headless(session: &AnyProtoClient) {
322 session.add_entity_request_handler(Self::handle_get_context_server_command);
323 }
324
325 pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
326 if let ContextServerStoreState::Local {
327 downstream_client, ..
328 } = &mut self.state
329 {
330 *downstream_client = Some((project_id, client));
331 }
332 }
333
334 pub fn is_remote_project(&self) -> bool {
335 matches!(self.state, ContextServerStoreState::Remote { .. })
336 }
337
338 /// Returns all configured context server ids, excluding the ones that are disabled
339 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
340 self.context_server_settings
341 .iter()
342 .filter(|(_, settings)| settings.enabled())
343 .map(|(id, _)| ContextServerId(id.clone()))
344 .collect()
345 }
346
347 #[cfg(feature = "test-support")]
348 pub fn test(
349 registry: Entity<ContextServerDescriptorRegistry>,
350 worktree_store: Entity<WorktreeStore>,
351 weak_project: Option<WeakEntity<Project>>,
352 cx: &mut Context<Self>,
353 ) -> Self {
354 Self::new_internal(
355 false,
356 None,
357 registry,
358 worktree_store,
359 weak_project,
360 ContextServerStoreState::Local {
361 downstream_client: None,
362 is_headless: false,
363 },
364 cx,
365 )
366 }
367
368 #[cfg(feature = "test-support")]
369 pub fn test_maintain_server_loop(
370 context_server_factory: Option<ContextServerFactory>,
371 registry: Entity<ContextServerDescriptorRegistry>,
372 worktree_store: Entity<WorktreeStore>,
373 weak_project: Option<WeakEntity<Project>>,
374 cx: &mut Context<Self>,
375 ) -> Self {
376 Self::new_internal(
377 true,
378 context_server_factory,
379 registry,
380 worktree_store,
381 weak_project,
382 ContextServerStoreState::Local {
383 downstream_client: None,
384 is_headless: false,
385 },
386 cx,
387 )
388 }
389
390 #[cfg(feature = "test-support")]
391 pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
392 self.context_server_factory = Some(factory);
393 }
394
395 #[cfg(feature = "test-support")]
396 pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
397 &self.registry
398 }
399
400 #[cfg(feature = "test-support")]
401 pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
402 let configuration = Arc::new(ContextServerConfiguration::Custom {
403 command: ContextServerCommand {
404 path: "test".into(),
405 args: vec![],
406 env: None,
407 timeout: None,
408 },
409 remote: false,
410 });
411 self.run_server(server, configuration, cx);
412 }
413
414 fn new_internal(
415 maintain_server_loop: bool,
416 context_server_factory: Option<ContextServerFactory>,
417 registry: Entity<ContextServerDescriptorRegistry>,
418 worktree_store: Entity<WorktreeStore>,
419 weak_project: Option<WeakEntity<Project>>,
420 state: ContextServerStoreState,
421 cx: &mut Context<Self>,
422 ) -> Self {
423 let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
424 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
425 let ai_was_disabled = this.ai_disabled;
426 this.ai_disabled = ai_disabled;
427
428 let settings =
429 &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
430 let settings_changed = &this.context_server_settings != settings;
431
432 if settings_changed {
433 this.context_server_settings = settings.clone();
434 }
435
436 // When AI is disabled, stop all running servers
437 if ai_disabled {
438 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
439 for id in server_ids {
440 this.stop_server(&id, cx).log_err();
441 }
442 return;
443 }
444
445 // Trigger updates if AI was re-enabled or settings changed
446 if maintain_server_loop && (ai_was_disabled || settings_changed) {
447 this.available_context_servers_changed(cx);
448 }
449 })];
450
451 if maintain_server_loop {
452 subscriptions.push(cx.observe(®istry, |this, _registry, cx| {
453 if !DisableAiSettings::get_global(cx).disable_ai {
454 this.available_context_servers_changed(cx);
455 }
456 }));
457 }
458
459 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
460 let mut this = Self {
461 state,
462 _subscriptions: subscriptions,
463 context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
464 .context_servers
465 .clone(),
466 worktree_store,
467 project: weak_project,
468 registry,
469 needs_server_update: false,
470 ai_disabled,
471 servers: HashMap::default(),
472 server_ids: Default::default(),
473 update_servers_task: None,
474 context_server_factory,
475 };
476 if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
477 this.available_context_servers_changed(cx);
478 }
479 this
480 }
481
482 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
483 self.servers.get(id).map(|state| state.server())
484 }
485
486 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
487 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
488 Some(server.clone())
489 } else {
490 None
491 }
492 }
493
494 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
495 self.servers.get(id).map(ContextServerStatus::from_state)
496 }
497
498 pub fn configuration_for_server(
499 &self,
500 id: &ContextServerId,
501 ) -> Option<Arc<ContextServerConfiguration>> {
502 self.servers.get(id).map(|state| state.configuration())
503 }
504
505 /// Returns a sorted slice of available unique context server IDs. Within the
506 /// slice, context servers which have `mcp-server-` as a prefix in their ID will
507 /// appear after servers that do not have this prefix in their ID.
508 pub fn server_ids(&self) -> &[ContextServerId] {
509 self.server_ids.as_slice()
510 }
511
512 fn populate_server_ids(&mut self, cx: &App) {
513 self.server_ids = self
514 .servers
515 .keys()
516 .cloned()
517 .chain(
518 self.registry
519 .read(cx)
520 .context_server_descriptors()
521 .into_iter()
522 .map(|(id, _)| ContextServerId(id)),
523 )
524 .chain(
525 self.context_server_settings
526 .keys()
527 .map(|id| ContextServerId(id.clone())),
528 )
529 .unique()
530 .sorted_unstable_by(
531 // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
532 |a, b| {
533 const MCP_PREFIX: &str = "mcp-server-";
534 match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
535 // If one has mcp-server- prefix and other doesn't, non-mcp comes first
536 (Some(_), None) => std::cmp::Ordering::Greater,
537 (None, Some(_)) => std::cmp::Ordering::Less,
538 // If both have same prefix status, sort by appropriate key
539 (Some(a), Some(b)) => a.cmp(b),
540 (None, None) => a.0.cmp(&b.0),
541 }
542 },
543 )
544 .collect();
545 }
546
547 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
548 self.servers
549 .values()
550 .filter_map(|state| {
551 if let ContextServerState::Running { server, .. } = state {
552 Some(server.clone())
553 } else {
554 None
555 }
556 })
557 .collect()
558 }
559
560 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
561 cx.spawn(async move |this, cx| {
562 let this = this.upgrade().context("Context server store dropped")?;
563 let id = server.id();
564 let settings = this
565 .update(cx, |this, _| {
566 this.context_server_settings.get(&id.0).cloned()
567 })
568 .context("Failed to get context server settings")?;
569
570 if !settings.enabled() {
571 return anyhow::Ok(());
572 }
573
574 let (registry, worktree_store) = this.update(cx, |this, _| {
575 (this.registry.clone(), this.worktree_store.clone())
576 });
577 let configuration = ContextServerConfiguration::from_settings(
578 settings,
579 id.clone(),
580 registry,
581 worktree_store,
582 cx,
583 )
584 .await
585 .context("Failed to create context server configuration")?;
586
587 this.update(cx, |this, cx| {
588 this.run_server(server, Arc::new(configuration), cx)
589 });
590 Ok(())
591 })
592 .detach_and_log_err(cx);
593 }
594
595 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
596 if matches!(
597 self.servers.get(id),
598 Some(ContextServerState::Stopped { .. })
599 ) {
600 return Ok(());
601 }
602
603 let state = self
604 .servers
605 .remove(id)
606 .context("Context server not found")?;
607
608 let server = state.server();
609 let configuration = state.configuration();
610 let mut result = Ok(());
611 if let ContextServerState::Running { server, .. } = &state {
612 result = server.stop();
613 }
614 drop(state);
615
616 self.update_server_state(
617 id.clone(),
618 ContextServerState::Stopped {
619 configuration,
620 server,
621 },
622 cx,
623 );
624
625 result
626 }
627
628 fn run_server(
629 &mut self,
630 server: Arc<ContextServer>,
631 configuration: Arc<ContextServerConfiguration>,
632 cx: &mut Context<Self>,
633 ) {
634 let id = server.id();
635 if matches!(
636 self.servers.get(&id),
637 Some(
638 ContextServerState::Starting { .. }
639 | ContextServerState::Running { .. }
640 | ContextServerState::Authenticating { .. },
641 )
642 ) {
643 self.stop_server(&id, cx).log_err();
644 }
645 let task = cx.spawn({
646 let id = server.id();
647 let server = server.clone();
648 let configuration = configuration.clone();
649
650 async move |this, cx| {
651 let new_state = match server.clone().start(cx).await {
652 Ok(_) => {
653 debug_assert!(server.client().is_some());
654 ContextServerState::Running {
655 server,
656 configuration,
657 }
658 }
659 Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
660 };
661 this.update(cx, |this, cx| {
662 this.update_server_state(id.clone(), new_state, cx)
663 })
664 .log_err();
665 }
666 });
667
668 self.update_server_state(
669 id.clone(),
670 ContextServerState::Starting {
671 configuration,
672 _task: task,
673 server,
674 },
675 cx,
676 );
677 }
678
679 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
680 let state = self
681 .servers
682 .remove(id)
683 .context("Context server not found")?;
684
685 if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
686 let server_url = url.clone();
687 let id = id.clone();
688 cx.spawn(async move |_this, cx| {
689 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
690 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
691 {
692 log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
693 }
694 })
695 .detach();
696 }
697
698 drop(state);
699 cx.emit(ServerStatusChangedEvent {
700 server_id: id.clone(),
701 status: ContextServerStatus::Stopped,
702 });
703 Ok(())
704 }
705
706 pub async fn create_context_server(
707 this: WeakEntity<Self>,
708 id: ContextServerId,
709 configuration: Arc<ContextServerConfiguration>,
710 cx: &mut AsyncApp,
711 ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
712 let remote = configuration.remote();
713 let needs_remote_command = match configuration.as_ref() {
714 ContextServerConfiguration::Custom { .. }
715 | ContextServerConfiguration::Extension { .. } => remote,
716 ContextServerConfiguration::Http { .. } => false,
717 };
718
719 let (remote_state, is_remote_project) = this.update(cx, |this, _| {
720 let remote_state = match &this.state {
721 ContextServerStoreState::Remote {
722 project_id,
723 upstream_client,
724 } if needs_remote_command => Some((*project_id, upstream_client.clone())),
725 _ => None,
726 };
727 (remote_state, this.is_remote_project())
728 })?;
729
730 let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
731 this.project
732 .as_ref()
733 .and_then(|project| {
734 project
735 .read_with(cx, |project, cx| project.active_project_directory(cx))
736 .ok()
737 .flatten()
738 })
739 .or_else(|| {
740 this.worktree_store.read_with(cx, |store, cx| {
741 store.visible_worktrees(cx).fold(None, |acc, item| {
742 if acc.is_none() {
743 item.read(cx).root_dir()
744 } else {
745 acc
746 }
747 })
748 })
749 })
750 })?;
751
752 let configuration = if let Some((project_id, upstream_client)) = remote_state {
753 let root_dir = root_path.as_ref().map(|p| p.display().to_string());
754
755 let response = upstream_client
756 .update(cx, |client, _| {
757 client
758 .proto_client()
759 .request(proto::GetContextServerCommand {
760 project_id,
761 server_id: id.0.to_string(),
762 root_dir: root_dir.clone(),
763 })
764 })
765 .await?;
766
767 let remote_command = upstream_client.update(cx, |client, _| {
768 client.build_command(
769 Some(response.path),
770 &response.args,
771 &response.env.into_iter().collect(),
772 root_dir,
773 None,
774 )
775 })?;
776
777 let command = ContextServerCommand {
778 path: remote_command.program.into(),
779 args: remote_command.args,
780 env: Some(remote_command.env.into_iter().collect()),
781 timeout: None,
782 };
783
784 Arc::new(ContextServerConfiguration::Custom { command, remote })
785 } else {
786 configuration
787 };
788
789 if let Some(server) = this.update(cx, |this, _| {
790 this.context_server_factory
791 .as_ref()
792 .map(|factory| factory(id.clone(), configuration.clone()))
793 })? {
794 return Ok((server, configuration));
795 }
796
797 let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
798 if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
799 if configuration.has_static_auth_header() {
800 None
801 } else {
802 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
803 let http_client = cx.update(|cx| cx.http_client());
804
805 match Self::load_session(&credentials_provider, url, &cx).await {
806 Ok(Some(session)) => {
807 log::info!("{} loaded cached OAuth session from keychain", id);
808 Some(Self::create_oauth_token_provider(
809 &id,
810 url,
811 session,
812 http_client,
813 credentials_provider,
814 cx,
815 ))
816 }
817 Ok(None) => None,
818 Err(err) => {
819 log::warn!("{} failed to load cached OAuth session: {}", id, err);
820 None
821 }
822 }
823 }
824 } else {
825 None
826 };
827
828 let server: Arc<ContextServer> = this.update(cx, |this, cx| {
829 let global_timeout =
830 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
831
832 match configuration.as_ref() {
833 ContextServerConfiguration::Http {
834 url,
835 headers,
836 timeout,
837 } => {
838 let transport = HttpTransport::new_with_token_provider(
839 cx.http_client(),
840 url.to_string(),
841 headers.clone(),
842 cx.background_executor().clone(),
843 cached_token_provider.clone(),
844 );
845 anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
846 id,
847 Arc::new(transport),
848 Some(Duration::from_secs(
849 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
850 )),
851 )))
852 }
853 _ => {
854 let mut command = configuration
855 .command()
856 .context("Missing command configuration for stdio context server")?
857 .clone();
858 command.timeout = Some(
859 command
860 .timeout
861 .unwrap_or(global_timeout)
862 .min(MAX_TIMEOUT_SECS),
863 );
864
865 // Don't pass remote paths as working directory for locally-spawned processes
866 let working_directory = if is_remote_project { None } else { root_path };
867 anyhow::Ok(Arc::new(ContextServer::stdio(
868 id,
869 command,
870 working_directory,
871 )))
872 }
873 }
874 })??;
875
876 Ok((server, configuration))
877 }
878
879 async fn handle_get_context_server_command(
880 this: Entity<Self>,
881 envelope: TypedEnvelope<proto::GetContextServerCommand>,
882 mut cx: AsyncApp,
883 ) -> Result<proto::ContextServerCommand> {
884 let server_id = ContextServerId(envelope.payload.server_id.into());
885
886 let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
887 let ContextServerStoreState::Local {
888 is_headless: true, ..
889 } = &this.state
890 else {
891 anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
892 };
893
894 let settings = this
895 .context_server_settings
896 .get(&server_id.0)
897 .cloned()
898 .or_else(|| {
899 this.registry
900 .read(inner_cx)
901 .context_server_descriptor(&server_id.0)
902 .map(|_| ContextServerSettings::default_extension())
903 })
904 .with_context(|| format!("context server `{}` not found", server_id))?;
905
906 anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
907 })?;
908
909 let configuration = ContextServerConfiguration::from_settings(
910 settings,
911 server_id.clone(),
912 registry,
913 worktree_store,
914 &cx,
915 )
916 .await
917 .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
918
919 let command = configuration
920 .command()
921 .context("context server has no command (HTTP servers don't need RPC)")?;
922
923 Ok(proto::ContextServerCommand {
924 path: command.path.display().to_string(),
925 args: command.args.clone(),
926 env: command
927 .env
928 .clone()
929 .map(|env| env.into_iter().collect())
930 .unwrap_or_default(),
931 })
932 }
933
934 fn resolve_project_settings<'a>(
935 worktree_store: &'a Entity<WorktreeStore>,
936 cx: &'a App,
937 ) -> &'a ProjectSettings {
938 let location = worktree_store
939 .read(cx)
940 .visible_worktrees(cx)
941 .next()
942 .map(|worktree| settings::SettingsLocation {
943 worktree_id: worktree.read(cx).id(),
944 path: RelPath::empty(),
945 });
946 ProjectSettings::get(location, cx)
947 }
948
949 fn create_oauth_token_provider(
950 id: &ContextServerId,
951 server_url: &url::Url,
952 session: OAuthSession,
953 http_client: Arc<dyn HttpClient>,
954 credentials_provider: Arc<dyn CredentialsProvider>,
955 cx: &mut AsyncApp,
956 ) -> Arc<dyn oauth::OAuthTokenProvider> {
957 let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
958 let id = id.clone();
959 let server_url = server_url.clone();
960
961 cx.spawn(async move |cx| {
962 while let Some(refreshed_session) = token_refresh_rx.next().await {
963 if let Err(err) =
964 Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
965 .await
966 {
967 log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
968 }
969 }
970 log::debug!("{} OAuth session persistence task ended", id);
971 })
972 .detach();
973
974 Arc::new(McpOAuthTokenProvider::new(
975 session,
976 http_client,
977 Some(token_refresh_tx),
978 ))
979 }
980
981 /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
982 ///
983 /// This starts a loopback HTTP callback server on an ephemeral port, builds
984 /// the authorization URL, opens the user's browser, waits for the callback,
985 /// exchanges the code for tokens, persists them in the keychain, and restarts
986 /// the server with the new token provider.
987 pub fn authenticate_server(
988 &mut self,
989 id: &ContextServerId,
990 cx: &mut Context<Self>,
991 ) -> Result<()> {
992 let state = self.servers.get(id).context("Context server not found")?;
993
994 let (discovery, server, configuration) = match state {
995 ContextServerState::AuthRequired {
996 discovery,
997 server,
998 configuration,
999 } => (discovery.clone(), server.clone(), configuration.clone()),
1000 _ => anyhow::bail!("Server is not in AuthRequired state"),
1001 };
1002
1003 let id = id.clone();
1004
1005 let task = cx.spawn({
1006 let id = id.clone();
1007 let server = server.clone();
1008 let configuration = configuration.clone();
1009 async move |this, cx| {
1010 let result = Self::run_oauth_flow(
1011 this.clone(),
1012 id.clone(),
1013 discovery.clone(),
1014 configuration.clone(),
1015 cx,
1016 )
1017 .await;
1018
1019 if let Err(err) = &result {
1020 log::error!("{} OAuth authentication failed: {:?}", id, err);
1021 // Transition back to AuthRequired so the user can retry
1022 // rather than landing in a terminal Error state.
1023 this.update(cx, |this, cx| {
1024 this.update_server_state(
1025 id.clone(),
1026 ContextServerState::AuthRequired {
1027 server,
1028 configuration,
1029 discovery,
1030 },
1031 cx,
1032 )
1033 })
1034 .log_err();
1035 }
1036 }
1037 });
1038
1039 self.update_server_state(
1040 id,
1041 ContextServerState::Authenticating {
1042 server,
1043 configuration,
1044 _task: task,
1045 },
1046 cx,
1047 );
1048
1049 Ok(())
1050 }
1051
1052 async fn run_oauth_flow(
1053 this: WeakEntity<Self>,
1054 id: ContextServerId,
1055 discovery: Arc<OAuthDiscovery>,
1056 configuration: Arc<ContextServerConfiguration>,
1057 cx: &mut AsyncApp,
1058 ) -> Result<()> {
1059 let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1060 let pkce = oauth::generate_pkce_challenge();
1061
1062 let mut state_bytes = [0u8; 32];
1063 rand::rng().fill(&mut state_bytes);
1064 let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1065
1066 // Start a loopback HTTP server on an ephemeral port. The redirect URI
1067 // includes this port so the browser sends the callback directly to our
1068 // process.
1069 let (redirect_uri, callback_rx) = oauth::start_callback_server()
1070 .await
1071 .context("Failed to start OAuth callback server")?;
1072
1073 let http_client = cx.update(|cx| cx.http_client());
1074 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1075 let server_url = match configuration.as_ref() {
1076 ContextServerConfiguration::Http { url, .. } => url.clone(),
1077 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1078 };
1079
1080 let client_registration =
1081 oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1082 .await
1083 .context("Failed to resolve OAuth client registration")?;
1084
1085 let auth_url = oauth::build_authorization_url(
1086 &discovery.auth_server_metadata,
1087 &client_registration.client_id,
1088 &redirect_uri,
1089 &discovery.scopes,
1090 &resource,
1091 &pkce,
1092 &state_param,
1093 );
1094
1095 cx.update(|cx| cx.open_url(auth_url.as_str()));
1096
1097 let callback = callback_rx
1098 .await
1099 .map_err(|_| {
1100 anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1101 })?
1102 .context("OAuth callback server received an invalid request")?;
1103
1104 if callback.state != state_param {
1105 anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1106 }
1107
1108 let tokens = oauth::exchange_code(
1109 &http_client,
1110 &discovery.auth_server_metadata,
1111 &callback.code,
1112 &client_registration.client_id,
1113 &redirect_uri,
1114 &pkce.verifier,
1115 &resource,
1116 )
1117 .await
1118 .context("Failed to exchange authorization code for tokens")?;
1119
1120 let session = OAuthSession {
1121 token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1122 resource: discovery.resource_metadata.resource.clone(),
1123 client_registration,
1124 tokens,
1125 };
1126
1127 Self::store_session(&credentials_provider, &server_url, &session, cx)
1128 .await
1129 .context("Failed to persist OAuth session in keychain")?;
1130
1131 let token_provider = Self::create_oauth_token_provider(
1132 &id,
1133 &server_url,
1134 session,
1135 http_client.clone(),
1136 credentials_provider,
1137 cx,
1138 );
1139
1140 let new_server = this.update(cx, |this, cx| {
1141 let global_timeout =
1142 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1143
1144 match configuration.as_ref() {
1145 ContextServerConfiguration::Http {
1146 url,
1147 headers,
1148 timeout,
1149 } => {
1150 let transport = HttpTransport::new_with_token_provider(
1151 http_client.clone(),
1152 url.to_string(),
1153 headers.clone(),
1154 cx.background_executor().clone(),
1155 Some(token_provider.clone()),
1156 );
1157 Ok(Arc::new(ContextServer::new_with_timeout(
1158 id.clone(),
1159 Arc::new(transport),
1160 Some(Duration::from_secs(
1161 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1162 )),
1163 )))
1164 }
1165 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1166 }
1167 })??;
1168
1169 this.update(cx, |this, cx| {
1170 this.run_server(new_server, configuration, cx);
1171 })?;
1172
1173 Ok(())
1174 }
1175
1176 /// Store the full OAuth session in the system keychain, keyed by the
1177 /// server's canonical URI.
1178 async fn store_session(
1179 credentials_provider: &Arc<dyn CredentialsProvider>,
1180 server_url: &url::Url,
1181 session: &OAuthSession,
1182 cx: &AsyncApp,
1183 ) -> Result<()> {
1184 let key = Self::keychain_key(server_url);
1185 let json = serde_json::to_string(session)?;
1186 credentials_provider
1187 .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1188 .await
1189 }
1190
1191 /// Load the full OAuth session from the system keychain for the given
1192 /// server URL.
1193 async fn load_session(
1194 credentials_provider: &Arc<dyn CredentialsProvider>,
1195 server_url: &url::Url,
1196 cx: &AsyncApp,
1197 ) -> Result<Option<OAuthSession>> {
1198 let key = Self::keychain_key(server_url);
1199 match credentials_provider.read_credentials(&key, cx).await? {
1200 Some((_username, password_bytes)) => {
1201 let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1202 Ok(Some(session))
1203 }
1204 None => Ok(None),
1205 }
1206 }
1207
1208 /// Clear the stored OAuth session from the system keychain.
1209 async fn clear_session(
1210 credentials_provider: &Arc<dyn CredentialsProvider>,
1211 server_url: &url::Url,
1212 cx: &AsyncApp,
1213 ) -> Result<()> {
1214 let key = Self::keychain_key(server_url);
1215 credentials_provider.delete_credentials(&key, cx).await
1216 }
1217
1218 fn keychain_key(server_url: &url::Url) -> String {
1219 format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1220 }
1221
1222 /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1223 /// session from the keychain and stop the server.
1224 pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1225 let state = self.servers.get(id).context("Context server not found")?;
1226 let configuration = state.configuration();
1227
1228 let server_url = match configuration.as_ref() {
1229 ContextServerConfiguration::Http { url, .. } => url.clone(),
1230 _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1231 };
1232
1233 let id = id.clone();
1234 self.stop_server(&id, cx)?;
1235
1236 cx.spawn(async move |this, cx| {
1237 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1238 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1239 log::error!("{} failed to clear OAuth session: {}", id, err);
1240 }
1241 // Trigger server recreation so the next start uses a fresh
1242 // transport without the old (now-invalidated) token provider.
1243 this.update(cx, |this, cx| {
1244 this.available_context_servers_changed(cx);
1245 })
1246 .log_err();
1247 })
1248 .detach();
1249
1250 Ok(())
1251 }
1252
1253 fn update_server_state(
1254 &mut self,
1255 id: ContextServerId,
1256 state: ContextServerState,
1257 cx: &mut Context<Self>,
1258 ) {
1259 let status = ContextServerStatus::from_state(&state);
1260 self.servers.insert(id.clone(), state);
1261 cx.emit(ServerStatusChangedEvent {
1262 server_id: id,
1263 status,
1264 });
1265 }
1266
1267 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1268 if self.update_servers_task.is_some() {
1269 self.needs_server_update = true;
1270 } else {
1271 self.needs_server_update = false;
1272 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1273 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1274 log::error!("Error maintaining context servers: {}", err);
1275 }
1276
1277 this.update(cx, |this, cx| {
1278 this.populate_server_ids(cx);
1279 cx.notify();
1280 this.update_servers_task.take();
1281 if this.needs_server_update {
1282 this.available_context_servers_changed(cx);
1283 }
1284 })?;
1285
1286 Ok(())
1287 }));
1288 }
1289 }
1290
1291 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1292 // Don't start context servers if AI is disabled
1293 let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1294 if ai_disabled {
1295 // Stop all running servers when AI is disabled
1296 this.update(cx, |this, cx| {
1297 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1298 for id in server_ids {
1299 let _ = this.stop_server(&id, cx);
1300 }
1301 })?;
1302 return Ok(());
1303 }
1304
1305 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1306 (
1307 this.context_server_settings.clone(),
1308 this.registry.clone(),
1309 this.worktree_store.clone(),
1310 )
1311 })?;
1312
1313 for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1314 configured_servers
1315 .entry(id)
1316 .or_insert(ContextServerSettings::default_extension());
1317 }
1318
1319 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1320 configured_servers
1321 .into_iter()
1322 .partition(|(_, settings)| settings.enabled());
1323
1324 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1325 let id = ContextServerId(id);
1326 ContextServerConfiguration::from_settings(
1327 settings,
1328 id.clone(),
1329 registry.clone(),
1330 worktree_store.clone(),
1331 cx,
1332 )
1333 .map(move |config| (id, config))
1334 }))
1335 .await
1336 .into_iter()
1337 .filter_map(|(id, config)| config.map(|config| (id, config)))
1338 .collect::<HashMap<_, _>>();
1339
1340 let mut servers_to_start = Vec::new();
1341 let mut servers_to_remove = HashSet::default();
1342 let mut servers_to_stop = HashSet::default();
1343
1344 this.update(cx, |this, _cx| {
1345 for server_id in this.servers.keys() {
1346 // All servers that are not in desired_servers should be removed from the store.
1347 // This can happen if the user removed a server from the context server settings.
1348 if !configured_servers.contains_key(server_id) {
1349 if disabled_servers.contains_key(&server_id.0) {
1350 servers_to_stop.insert(server_id.clone());
1351 } else {
1352 servers_to_remove.insert(server_id.clone());
1353 }
1354 }
1355 }
1356
1357 for (id, config) in configured_servers {
1358 let state = this.servers.get(&id);
1359 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1360 let existing_config = state.as_ref().map(|state| state.configuration());
1361 if existing_config.as_deref() != Some(&config) || is_stopped {
1362 let config = Arc::new(config);
1363 servers_to_start.push((id.clone(), config));
1364 if this.servers.contains_key(&id) {
1365 servers_to_stop.insert(id);
1366 }
1367 }
1368 }
1369
1370 anyhow::Ok(())
1371 })??;
1372
1373 this.update(cx, |this, inner_cx| {
1374 for id in servers_to_stop {
1375 this.stop_server(&id, inner_cx)?;
1376 }
1377 for id in servers_to_remove {
1378 this.remove_server(&id, inner_cx)?;
1379 }
1380 anyhow::Ok(())
1381 })??;
1382
1383 for (id, config) in servers_to_start {
1384 match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1385 Ok((server, config)) => {
1386 this.update(cx, |this, cx| {
1387 this.run_server(server, config, cx);
1388 })?;
1389 }
1390 Err(err) => {
1391 log::error!("{id} context server failed to create: {err:#}");
1392 this.update(cx, |_this, cx| {
1393 cx.emit(ServerStatusChangedEvent {
1394 server_id: id,
1395 status: ContextServerStatus::Error(err.to_string().into()),
1396 });
1397 cx.notify();
1398 })?;
1399 }
1400 }
1401 }
1402
1403 Ok(())
1404 }
1405}
1406
1407/// Determines the appropriate server state after a start attempt fails.
1408///
1409/// When the error is an HTTP 401 with no static auth header configured,
1410/// attempts OAuth discovery so the UI can offer an authentication flow.
1411async fn resolve_start_failure(
1412 id: &ContextServerId,
1413 err: anyhow::Error,
1414 server: Arc<ContextServer>,
1415 configuration: Arc<ContextServerConfiguration>,
1416 cx: &AsyncApp,
1417) -> ContextServerState {
1418 let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1419 TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1420 });
1421
1422 if www_authenticate.is_some() && configuration.has_static_auth_header() {
1423 log::warn!("{id} received 401 with a static Authorization header configured");
1424 return ContextServerState::Error {
1425 configuration,
1426 server,
1427 error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1428 .into(),
1429 };
1430 }
1431
1432 let server_url = match configuration.as_ref() {
1433 ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1434 url.clone()
1435 }
1436 _ => {
1437 if www_authenticate.is_some() {
1438 log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1439 } else {
1440 log::error!("{id} context server failed to start: {err}");
1441 }
1442 return ContextServerState::Error {
1443 configuration,
1444 server,
1445 error: err.to_string().into(),
1446 };
1447 }
1448 };
1449
1450 // When the error is NOT a 401 but there is a cached OAuth session in the
1451 // keychain, the session is likely stale/expired and caused the failure
1452 // (e.g. timeout because the server rejected the token silently). Clear it
1453 // so the next start attempt can get a clean 401 and trigger the auth flow.
1454 if www_authenticate.is_none() {
1455 let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1456 match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1457 Ok(Some(_)) => {
1458 log::info!("{id} start failed with a cached OAuth session present; clearing it");
1459 ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1460 .await
1461 .log_err();
1462 }
1463 _ => {
1464 log::error!("{id} context server failed to start: {err}");
1465 return ContextServerState::Error {
1466 configuration,
1467 server,
1468 error: err.to_string().into(),
1469 };
1470 }
1471 }
1472 }
1473
1474 let default_www_authenticate = oauth::WwwAuthenticate {
1475 resource_metadata: None,
1476 scope: None,
1477 error: None,
1478 error_description: None,
1479 };
1480 let www_authenticate = www_authenticate
1481 .as_ref()
1482 .unwrap_or(&default_www_authenticate);
1483 let http_client = cx.update(|cx| cx.http_client());
1484
1485 match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1486 Ok(discovery) => {
1487 log::info!(
1488 "{id} requires OAuth authorization (auth server: {})",
1489 discovery.auth_server_metadata.issuer,
1490 );
1491 ContextServerState::AuthRequired {
1492 server,
1493 configuration,
1494 discovery: Arc::new(discovery),
1495 }
1496 }
1497 Err(discovery_err) => {
1498 log::error!("{id} OAuth discovery failed: {discovery_err}");
1499 ContextServerState::Error {
1500 configuration,
1501 server,
1502 error: format!("OAuth discovery failed: {discovery_err}").into(),
1503 }
1504 }
1505 }
1506}