1pub mod extension;
2pub mod registry;
3
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context as _, Result};
9use collections::{HashMap, HashSet};
10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
11use context_server::transport::{HttpTransport, TransportError};
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use credentials_provider::CredentialsProvider;
14use futures::future::Either;
15use futures::{FutureExt as _, StreamExt as _, future::join_all};
16use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
17use http_client::HttpClient;
18use itertools::Itertools;
19use rand::Rng as _;
20use registry::ContextServerDescriptorRegistry;
21use remote::RemoteClient;
22use rpc::{AnyProtoClient, TypedEnvelope, proto};
23use settings::{Settings as _, SettingsStore};
24use util::{ResultExt as _, rel_path::RelPath};
25
26use crate::{
27 DisableAiSettings, Project,
28 project_settings::{ContextServerSettings, ProjectSettings},
29 worktree_store::WorktreeStore,
30};
31
32/// Maximum timeout for context server requests
33/// Prevents extremely large timeout values from tying up resources indefinitely.
34const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
35
36pub fn init(cx: &mut App) {
37 extension::init(cx);
38}
39
40actions!(
41 context_server,
42 [
43 /// Restarts the context server.
44 Restart
45 ]
46);
47
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub enum ContextServerStatus {
50 Starting,
51 Running,
52 Stopped,
53 Error(Arc<str>),
54 /// The server returned 401 and OAuth authorization is needed. The UI
55 /// should show an "Authenticate" button.
56 AuthRequired,
57 /// The OAuth browser flow is in progress — the user has been redirected
58 /// to the authorization server and we're waiting for the callback.
59 Authenticating,
60}
61
62impl ContextServerStatus {
63 fn from_state(state: &ContextServerState) -> Self {
64 match state {
65 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
66 ContextServerState::Running { .. } => ContextServerStatus::Running,
67 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
68 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
69 ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
70 ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
71 }
72 }
73}
74
75enum ContextServerState {
76 Starting {
77 server: Arc<ContextServer>,
78 configuration: Arc<ContextServerConfiguration>,
79 _task: Task<()>,
80 },
81 Running {
82 server: Arc<ContextServer>,
83 configuration: Arc<ContextServerConfiguration>,
84 },
85 Stopped {
86 server: Arc<ContextServer>,
87 configuration: Arc<ContextServerConfiguration>,
88 },
89 Error {
90 server: Arc<ContextServer>,
91 configuration: Arc<ContextServerConfiguration>,
92 error: Arc<str>,
93 },
94 /// The server requires OAuth authorization before it can be used. The
95 /// `OAuthDiscovery` holds everything needed to start the browser flow.
96 AuthRequired {
97 server: Arc<ContextServer>,
98 configuration: Arc<ContextServerConfiguration>,
99 discovery: Arc<OAuthDiscovery>,
100 },
101 /// The OAuth browser flow is in progress. The user has been redirected
102 /// to the authorization server and we're waiting for the callback.
103 Authenticating {
104 server: Arc<ContextServer>,
105 configuration: Arc<ContextServerConfiguration>,
106 _task: Task<()>,
107 },
108}
109
110impl ContextServerState {
111 pub fn server(&self) -> Arc<ContextServer> {
112 match self {
113 ContextServerState::Starting { server, .. }
114 | ContextServerState::Running { server, .. }
115 | ContextServerState::Stopped { server, .. }
116 | ContextServerState::Error { server, .. }
117 | ContextServerState::AuthRequired { server, .. }
118 | ContextServerState::Authenticating { server, .. } => server.clone(),
119 }
120 }
121
122 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
123 match self {
124 ContextServerState::Starting { configuration, .. }
125 | ContextServerState::Running { configuration, .. }
126 | ContextServerState::Stopped { configuration, .. }
127 | ContextServerState::Error { configuration, .. }
128 | ContextServerState::AuthRequired { configuration, .. }
129 | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
130 }
131 }
132}
133
134#[derive(Debug, PartialEq, Eq)]
135pub enum ContextServerConfiguration {
136 Custom {
137 command: ContextServerCommand,
138 remote: bool,
139 },
140 Extension {
141 command: ContextServerCommand,
142 settings: serde_json::Value,
143 remote: bool,
144 },
145 Http {
146 url: url::Url,
147 headers: HashMap<String, String>,
148 timeout: Option<u64>,
149 },
150}
151
152impl ContextServerConfiguration {
153 pub fn command(&self) -> Option<&ContextServerCommand> {
154 match self {
155 ContextServerConfiguration::Custom { command, .. } => Some(command),
156 ContextServerConfiguration::Extension { command, .. } => Some(command),
157 ContextServerConfiguration::Http { .. } => None,
158 }
159 }
160
161 pub fn has_static_auth_header(&self) -> bool {
162 match self {
163 ContextServerConfiguration::Http { headers, .. } => headers
164 .keys()
165 .any(|k| k.eq_ignore_ascii_case("authorization")),
166 _ => false,
167 }
168 }
169
170 pub fn remote(&self) -> bool {
171 match self {
172 ContextServerConfiguration::Custom { remote, .. } => *remote,
173 ContextServerConfiguration::Extension { remote, .. } => *remote,
174 ContextServerConfiguration::Http { .. } => false,
175 }
176 }
177
178 pub async fn from_settings(
179 settings: ContextServerSettings,
180 id: ContextServerId,
181 registry: Entity<ContextServerDescriptorRegistry>,
182 worktree_store: Entity<WorktreeStore>,
183 cx: &AsyncApp,
184 ) -> Option<Self> {
185 const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
186
187 match settings {
188 ContextServerSettings::Stdio {
189 enabled: _,
190 command,
191 remote,
192 } => Some(ContextServerConfiguration::Custom { command, remote }),
193 ContextServerSettings::Extension {
194 enabled: _,
195 settings,
196 remote,
197 } => {
198 let descriptor =
199 cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
200
201 let command_future = descriptor.command(worktree_store, cx);
202 let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
203
204 match futures::future::select(command_future, timeout_future).await {
205 Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
206 command,
207 settings,
208 remote,
209 }),
210 Either::Left((Err(e), _)) => {
211 log::error!(
212 "Failed to create context server configuration from settings: {e:#}"
213 );
214 None
215 }
216 Either::Right(_) => {
217 log::error!(
218 "Timed out resolving command for extension context server {id}"
219 );
220 None
221 }
222 }
223 }
224 ContextServerSettings::Http {
225 enabled: _,
226 url,
227 headers: auth,
228 timeout,
229 } => {
230 let url = url::Url::parse(&url).log_err()?;
231 Some(ContextServerConfiguration::Http {
232 url,
233 headers: auth,
234 timeout,
235 })
236 }
237 }
238 }
239}
240
241pub type ContextServerFactory =
242 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
243
244enum ContextServerStoreState {
245 Local {
246 downstream_client: Option<(u64, AnyProtoClient)>,
247 is_headless: bool,
248 },
249 Remote {
250 project_id: u64,
251 upstream_client: Entity<RemoteClient>,
252 },
253}
254
255pub struct ContextServerStore {
256 state: ContextServerStoreState,
257 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
258 servers: HashMap<ContextServerId, ContextServerState>,
259 server_ids: Vec<ContextServerId>,
260 worktree_store: Entity<WorktreeStore>,
261 project: Option<WeakEntity<Project>>,
262 registry: Entity<ContextServerDescriptorRegistry>,
263 update_servers_task: Option<Task<Result<()>>>,
264 context_server_factory: Option<ContextServerFactory>,
265 needs_server_update: bool,
266 ai_disabled: bool,
267 _subscriptions: Vec<Subscription>,
268}
269
270pub struct ServerStatusChangedEvent {
271 pub server_id: ContextServerId,
272 pub status: ContextServerStatus,
273}
274
275impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
276
277impl ContextServerStore {
278 pub fn local(
279 worktree_store: Entity<WorktreeStore>,
280 weak_project: Option<WeakEntity<Project>>,
281 headless: bool,
282 cx: &mut Context<Self>,
283 ) -> Self {
284 Self::new_internal(
285 !headless,
286 None,
287 ContextServerDescriptorRegistry::default_global(cx),
288 worktree_store,
289 weak_project,
290 ContextServerStoreState::Local {
291 downstream_client: None,
292 is_headless: headless,
293 },
294 cx,
295 )
296 }
297
298 pub fn remote(
299 project_id: u64,
300 upstream_client: Entity<RemoteClient>,
301 worktree_store: Entity<WorktreeStore>,
302 weak_project: Option<WeakEntity<Project>>,
303 cx: &mut Context<Self>,
304 ) -> Self {
305 Self::new_internal(
306 true,
307 None,
308 ContextServerDescriptorRegistry::default_global(cx),
309 worktree_store,
310 weak_project,
311 ContextServerStoreState::Remote {
312 project_id,
313 upstream_client,
314 },
315 cx,
316 )
317 }
318
319 pub fn init_headless(session: &AnyProtoClient) {
320 session.add_entity_request_handler(Self::handle_get_context_server_command);
321 }
322
323 pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
324 if let ContextServerStoreState::Local {
325 downstream_client, ..
326 } = &mut self.state
327 {
328 *downstream_client = Some((project_id, client));
329 }
330 }
331
332 pub fn is_remote_project(&self) -> bool {
333 matches!(self.state, ContextServerStoreState::Remote { .. })
334 }
335
336 /// Returns all configured context server ids, excluding the ones that are disabled
337 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
338 self.context_server_settings
339 .iter()
340 .filter(|(_, settings)| settings.enabled())
341 .map(|(id, _)| ContextServerId(id.clone()))
342 .collect()
343 }
344
345 #[cfg(feature = "test-support")]
346 pub fn test(
347 registry: Entity<ContextServerDescriptorRegistry>,
348 worktree_store: Entity<WorktreeStore>,
349 weak_project: Option<WeakEntity<Project>>,
350 cx: &mut Context<Self>,
351 ) -> Self {
352 Self::new_internal(
353 false,
354 None,
355 registry,
356 worktree_store,
357 weak_project,
358 ContextServerStoreState::Local {
359 downstream_client: None,
360 is_headless: false,
361 },
362 cx,
363 )
364 }
365
366 #[cfg(feature = "test-support")]
367 pub fn test_maintain_server_loop(
368 context_server_factory: Option<ContextServerFactory>,
369 registry: Entity<ContextServerDescriptorRegistry>,
370 worktree_store: Entity<WorktreeStore>,
371 weak_project: Option<WeakEntity<Project>>,
372 cx: &mut Context<Self>,
373 ) -> Self {
374 Self::new_internal(
375 true,
376 context_server_factory,
377 registry,
378 worktree_store,
379 weak_project,
380 ContextServerStoreState::Local {
381 downstream_client: None,
382 is_headless: false,
383 },
384 cx,
385 )
386 }
387
388 #[cfg(feature = "test-support")]
389 pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
390 self.context_server_factory = Some(factory);
391 }
392
393 #[cfg(feature = "test-support")]
394 pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
395 &self.registry
396 }
397
398 #[cfg(feature = "test-support")]
399 pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
400 let configuration = Arc::new(ContextServerConfiguration::Custom {
401 command: ContextServerCommand {
402 path: "test".into(),
403 args: vec![],
404 env: None,
405 timeout: None,
406 },
407 remote: false,
408 });
409 self.run_server(server, configuration, cx);
410 }
411
412 fn new_internal(
413 maintain_server_loop: bool,
414 context_server_factory: Option<ContextServerFactory>,
415 registry: Entity<ContextServerDescriptorRegistry>,
416 worktree_store: Entity<WorktreeStore>,
417 weak_project: Option<WeakEntity<Project>>,
418 state: ContextServerStoreState,
419 cx: &mut Context<Self>,
420 ) -> Self {
421 let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
422 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
423 let ai_was_disabled = this.ai_disabled;
424 this.ai_disabled = ai_disabled;
425
426 let settings =
427 &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
428 let settings_changed = &this.context_server_settings != settings;
429
430 if settings_changed {
431 this.context_server_settings = settings.clone();
432 }
433
434 // When AI is disabled, stop all running servers
435 if ai_disabled {
436 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
437 for id in server_ids {
438 this.stop_server(&id, cx).log_err();
439 }
440 return;
441 }
442
443 // Trigger updates if AI was re-enabled or settings changed
444 if maintain_server_loop && (ai_was_disabled || settings_changed) {
445 this.available_context_servers_changed(cx);
446 }
447 })];
448
449 if maintain_server_loop {
450 subscriptions.push(cx.observe(®istry, |this, _registry, cx| {
451 if !DisableAiSettings::get_global(cx).disable_ai {
452 this.available_context_servers_changed(cx);
453 }
454 }));
455 }
456
457 let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
458 let mut this = Self {
459 state,
460 _subscriptions: subscriptions,
461 context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
462 .context_servers
463 .clone(),
464 worktree_store,
465 project: weak_project,
466 registry,
467 needs_server_update: false,
468 ai_disabled,
469 servers: HashMap::default(),
470 server_ids: Default::default(),
471 update_servers_task: None,
472 context_server_factory,
473 };
474 if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
475 this.available_context_servers_changed(cx);
476 }
477 this
478 }
479
480 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
481 self.servers.get(id).map(|state| state.server())
482 }
483
484 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
485 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
486 Some(server.clone())
487 } else {
488 None
489 }
490 }
491
492 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
493 self.servers.get(id).map(ContextServerStatus::from_state)
494 }
495
496 pub fn configuration_for_server(
497 &self,
498 id: &ContextServerId,
499 ) -> Option<Arc<ContextServerConfiguration>> {
500 self.servers.get(id).map(|state| state.configuration())
501 }
502
503 /// Returns a sorted slice of available unique context server IDs. Within the
504 /// slice, context servers which have `mcp-server-` as a prefix in their ID will
505 /// appear after servers that do not have this prefix in their ID.
506 pub fn server_ids(&self) -> &[ContextServerId] {
507 self.server_ids.as_slice()
508 }
509
510 fn populate_server_ids(&mut self, cx: &App) {
511 self.server_ids = self
512 .servers
513 .keys()
514 .cloned()
515 .chain(
516 self.registry
517 .read(cx)
518 .context_server_descriptors()
519 .into_iter()
520 .map(|(id, _)| ContextServerId(id)),
521 )
522 .chain(
523 self.context_server_settings
524 .keys()
525 .map(|id| ContextServerId(id.clone())),
526 )
527 .unique()
528 .sorted_unstable_by(
529 // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
530 |a, b| {
531 const MCP_PREFIX: &str = "mcp-server-";
532 match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
533 // If one has mcp-server- prefix and other doesn't, non-mcp comes first
534 (Some(_), None) => std::cmp::Ordering::Greater,
535 (None, Some(_)) => std::cmp::Ordering::Less,
536 // If both have same prefix status, sort by appropriate key
537 (Some(a), Some(b)) => a.cmp(b),
538 (None, None) => a.0.cmp(&b.0),
539 }
540 },
541 )
542 .collect();
543 }
544
545 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
546 self.servers
547 .values()
548 .filter_map(|state| {
549 if let ContextServerState::Running { server, .. } = state {
550 Some(server.clone())
551 } else {
552 None
553 }
554 })
555 .collect()
556 }
557
558 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
559 cx.spawn(async move |this, cx| {
560 let this = this.upgrade().context("Context server store dropped")?;
561 let id = server.id();
562 let settings = this
563 .update(cx, |this, _| {
564 this.context_server_settings.get(&id.0).cloned()
565 })
566 .context("Failed to get context server settings")?;
567
568 if !settings.enabled() {
569 return anyhow::Ok(());
570 }
571
572 let (registry, worktree_store) = this.update(cx, |this, _| {
573 (this.registry.clone(), this.worktree_store.clone())
574 });
575 let configuration = ContextServerConfiguration::from_settings(
576 settings,
577 id.clone(),
578 registry,
579 worktree_store,
580 cx,
581 )
582 .await
583 .context("Failed to create context server configuration")?;
584
585 this.update(cx, |this, cx| {
586 this.run_server(server, Arc::new(configuration), cx)
587 });
588 Ok(())
589 })
590 .detach_and_log_err(cx);
591 }
592
593 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
594 if matches!(
595 self.servers.get(id),
596 Some(ContextServerState::Stopped { .. })
597 ) {
598 return Ok(());
599 }
600
601 let state = self
602 .servers
603 .remove(id)
604 .context("Context server not found")?;
605
606 let server = state.server();
607 let configuration = state.configuration();
608 let mut result = Ok(());
609 if let ContextServerState::Running { server, .. } = &state {
610 result = server.stop();
611 }
612 drop(state);
613
614 self.update_server_state(
615 id.clone(),
616 ContextServerState::Stopped {
617 configuration,
618 server,
619 },
620 cx,
621 );
622
623 result
624 }
625
626 fn run_server(
627 &mut self,
628 server: Arc<ContextServer>,
629 configuration: Arc<ContextServerConfiguration>,
630 cx: &mut Context<Self>,
631 ) {
632 let id = server.id();
633 if matches!(
634 self.servers.get(&id),
635 Some(
636 ContextServerState::Starting { .. }
637 | ContextServerState::Running { .. }
638 | ContextServerState::Authenticating { .. },
639 )
640 ) {
641 self.stop_server(&id, cx).log_err();
642 }
643 let task = cx.spawn({
644 let id = server.id();
645 let server = server.clone();
646 let configuration = configuration.clone();
647
648 async move |this, cx| {
649 let new_state = match server.clone().start(cx).await {
650 Ok(_) => {
651 debug_assert!(server.client().is_some());
652 ContextServerState::Running {
653 server,
654 configuration,
655 }
656 }
657 Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
658 };
659 this.update(cx, |this, cx| {
660 this.update_server_state(id.clone(), new_state, cx)
661 })
662 .log_err();
663 }
664 });
665
666 self.update_server_state(
667 id.clone(),
668 ContextServerState::Starting {
669 configuration,
670 _task: task,
671 server,
672 },
673 cx,
674 );
675 }
676
677 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
678 let state = self
679 .servers
680 .remove(id)
681 .context("Context server not found")?;
682
683 if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
684 let server_url = url.clone();
685 let id = id.clone();
686 cx.spawn(async move |_this, cx| {
687 let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
688 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
689 {
690 log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
691 }
692 })
693 .detach();
694 }
695
696 drop(state);
697 cx.emit(ServerStatusChangedEvent {
698 server_id: id.clone(),
699 status: ContextServerStatus::Stopped,
700 });
701 Ok(())
702 }
703
704 pub async fn create_context_server(
705 this: WeakEntity<Self>,
706 id: ContextServerId,
707 configuration: Arc<ContextServerConfiguration>,
708 cx: &mut AsyncApp,
709 ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
710 let remote = configuration.remote();
711 let needs_remote_command = match configuration.as_ref() {
712 ContextServerConfiguration::Custom { .. }
713 | ContextServerConfiguration::Extension { .. } => remote,
714 ContextServerConfiguration::Http { .. } => false,
715 };
716
717 let (remote_state, is_remote_project) = this.update(cx, |this, _| {
718 let remote_state = match &this.state {
719 ContextServerStoreState::Remote {
720 project_id,
721 upstream_client,
722 } if needs_remote_command => Some((*project_id, upstream_client.clone())),
723 _ => None,
724 };
725 (remote_state, this.is_remote_project())
726 })?;
727
728 let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
729 this.project
730 .as_ref()
731 .and_then(|project| {
732 project
733 .read_with(cx, |project, cx| project.active_project_directory(cx))
734 .ok()
735 .flatten()
736 })
737 .or_else(|| {
738 this.worktree_store.read_with(cx, |store, cx| {
739 store.visible_worktrees(cx).fold(None, |acc, item| {
740 if acc.is_none() {
741 item.read(cx).root_dir()
742 } else {
743 acc
744 }
745 })
746 })
747 })
748 })?;
749
750 let configuration = if let Some((project_id, upstream_client)) = remote_state {
751 let root_dir = root_path.as_ref().map(|p| p.display().to_string());
752
753 let response = upstream_client
754 .update(cx, |client, _| {
755 client
756 .proto_client()
757 .request(proto::GetContextServerCommand {
758 project_id,
759 server_id: id.0.to_string(),
760 root_dir: root_dir.clone(),
761 })
762 })
763 .await?;
764
765 let remote_command = upstream_client.update(cx, |client, _| {
766 client.build_command(
767 Some(response.path),
768 &response.args,
769 &response.env.into_iter().collect(),
770 root_dir,
771 None,
772 )
773 })?;
774
775 let command = ContextServerCommand {
776 path: remote_command.program.into(),
777 args: remote_command.args,
778 env: Some(remote_command.env.into_iter().collect()),
779 timeout: None,
780 };
781
782 Arc::new(ContextServerConfiguration::Custom { command, remote })
783 } else {
784 configuration
785 };
786
787 if let Some(server) = this.update(cx, |this, _| {
788 this.context_server_factory
789 .as_ref()
790 .map(|factory| factory(id.clone(), configuration.clone()))
791 })? {
792 return Ok((server, configuration));
793 }
794
795 let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
796 if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
797 if configuration.has_static_auth_header() {
798 None
799 } else {
800 let credentials_provider =
801 cx.update(|cx| <dyn CredentialsProvider>::global(cx));
802 let http_client = cx.update(|cx| cx.http_client());
803
804 match Self::load_session(&credentials_provider, url, &cx).await {
805 Ok(Some(session)) => {
806 log::info!("{} loaded cached OAuth session from keychain", id);
807 Some(Self::create_oauth_token_provider(
808 &id,
809 url,
810 session,
811 http_client,
812 credentials_provider,
813 cx,
814 ))
815 }
816 Ok(None) => None,
817 Err(err) => {
818 log::warn!("{} failed to load cached OAuth session: {}", id, err);
819 None
820 }
821 }
822 }
823 } else {
824 None
825 };
826
827 let server: Arc<ContextServer> = this.update(cx, |this, cx| {
828 let global_timeout =
829 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
830
831 match configuration.as_ref() {
832 ContextServerConfiguration::Http {
833 url,
834 headers,
835 timeout,
836 } => {
837 let transport = HttpTransport::new_with_token_provider(
838 cx.http_client(),
839 url.to_string(),
840 headers.clone(),
841 cx.background_executor().clone(),
842 cached_token_provider.clone(),
843 );
844 anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
845 id,
846 Arc::new(transport),
847 Some(Duration::from_secs(
848 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
849 )),
850 )))
851 }
852 _ => {
853 let mut command = configuration
854 .command()
855 .context("Missing command configuration for stdio context server")?
856 .clone();
857 command.timeout = Some(
858 command
859 .timeout
860 .unwrap_or(global_timeout)
861 .min(MAX_TIMEOUT_SECS),
862 );
863
864 // Don't pass remote paths as working directory for locally-spawned processes
865 let working_directory = if is_remote_project { None } else { root_path };
866 anyhow::Ok(Arc::new(ContextServer::stdio(
867 id,
868 command,
869 working_directory,
870 )))
871 }
872 }
873 })??;
874
875 Ok((server, configuration))
876 }
877
878 async fn handle_get_context_server_command(
879 this: Entity<Self>,
880 envelope: TypedEnvelope<proto::GetContextServerCommand>,
881 mut cx: AsyncApp,
882 ) -> Result<proto::ContextServerCommand> {
883 let server_id = ContextServerId(envelope.payload.server_id.into());
884
885 let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
886 let ContextServerStoreState::Local {
887 is_headless: true, ..
888 } = &this.state
889 else {
890 anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
891 };
892
893 let settings = this
894 .context_server_settings
895 .get(&server_id.0)
896 .cloned()
897 .or_else(|| {
898 this.registry
899 .read(inner_cx)
900 .context_server_descriptor(&server_id.0)
901 .map(|_| ContextServerSettings::default_extension())
902 })
903 .with_context(|| format!("context server `{}` not found", server_id))?;
904
905 anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
906 })?;
907
908 let configuration = ContextServerConfiguration::from_settings(
909 settings,
910 server_id.clone(),
911 registry,
912 worktree_store,
913 &cx,
914 )
915 .await
916 .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
917
918 let command = configuration
919 .command()
920 .context("context server has no command (HTTP servers don't need RPC)")?;
921
922 Ok(proto::ContextServerCommand {
923 path: command.path.display().to_string(),
924 args: command.args.clone(),
925 env: command
926 .env
927 .clone()
928 .map(|env| env.into_iter().collect())
929 .unwrap_or_default(),
930 })
931 }
932
933 fn resolve_project_settings<'a>(
934 worktree_store: &'a Entity<WorktreeStore>,
935 cx: &'a App,
936 ) -> &'a ProjectSettings {
937 let location = worktree_store
938 .read(cx)
939 .visible_worktrees(cx)
940 .next()
941 .map(|worktree| settings::SettingsLocation {
942 worktree_id: worktree.read(cx).id(),
943 path: RelPath::empty(),
944 });
945 ProjectSettings::get(location, cx)
946 }
947
948 fn create_oauth_token_provider(
949 id: &ContextServerId,
950 server_url: &url::Url,
951 session: OAuthSession,
952 http_client: Arc<dyn HttpClient>,
953 credentials_provider: Arc<dyn CredentialsProvider>,
954 cx: &mut AsyncApp,
955 ) -> Arc<dyn oauth::OAuthTokenProvider> {
956 let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
957 let id = id.clone();
958 let server_url = server_url.clone();
959
960 cx.spawn(async move |cx| {
961 while let Some(refreshed_session) = token_refresh_rx.next().await {
962 if let Err(err) =
963 Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
964 .await
965 {
966 log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
967 }
968 }
969 log::debug!("{} OAuth session persistence task ended", id);
970 })
971 .detach();
972
973 Arc::new(McpOAuthTokenProvider::new(
974 session,
975 http_client,
976 Some(token_refresh_tx),
977 ))
978 }
979
980 /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
981 ///
982 /// This starts a loopback HTTP callback server on an ephemeral port, builds
983 /// the authorization URL, opens the user's browser, waits for the callback,
984 /// exchanges the code for tokens, persists them in the keychain, and restarts
985 /// the server with the new token provider.
986 pub fn authenticate_server(
987 &mut self,
988 id: &ContextServerId,
989 cx: &mut Context<Self>,
990 ) -> Result<()> {
991 let state = self.servers.get(id).context("Context server not found")?;
992
993 let (discovery, server, configuration) = match state {
994 ContextServerState::AuthRequired {
995 discovery,
996 server,
997 configuration,
998 } => (discovery.clone(), server.clone(), configuration.clone()),
999 _ => anyhow::bail!("Server is not in AuthRequired state"),
1000 };
1001
1002 let id = id.clone();
1003
1004 let task = cx.spawn({
1005 let id = id.clone();
1006 let server = server.clone();
1007 let configuration = configuration.clone();
1008 async move |this, cx| {
1009 let result = Self::run_oauth_flow(
1010 this.clone(),
1011 id.clone(),
1012 discovery.clone(),
1013 configuration.clone(),
1014 cx,
1015 )
1016 .await;
1017
1018 if let Err(err) = &result {
1019 log::error!("{} OAuth authentication failed: {:?}", id, err);
1020 // Transition back to AuthRequired so the user can retry
1021 // rather than landing in a terminal Error state.
1022 this.update(cx, |this, cx| {
1023 this.update_server_state(
1024 id.clone(),
1025 ContextServerState::AuthRequired {
1026 server,
1027 configuration,
1028 discovery,
1029 },
1030 cx,
1031 )
1032 })
1033 .log_err();
1034 }
1035 }
1036 });
1037
1038 self.update_server_state(
1039 id,
1040 ContextServerState::Authenticating {
1041 server,
1042 configuration,
1043 _task: task,
1044 },
1045 cx,
1046 );
1047
1048 Ok(())
1049 }
1050
1051 async fn run_oauth_flow(
1052 this: WeakEntity<Self>,
1053 id: ContextServerId,
1054 discovery: Arc<OAuthDiscovery>,
1055 configuration: Arc<ContextServerConfiguration>,
1056 cx: &mut AsyncApp,
1057 ) -> Result<()> {
1058 let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1059 let pkce = oauth::generate_pkce_challenge();
1060
1061 let mut state_bytes = [0u8; 32];
1062 rand::rng().fill(&mut state_bytes);
1063 let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1064
1065 // Start a loopback HTTP server on an ephemeral port. The redirect URI
1066 // includes this port so the browser sends the callback directly to our
1067 // process.
1068 let (redirect_uri, callback_rx) = oauth::start_callback_server()
1069 .await
1070 .context("Failed to start OAuth callback server")?;
1071
1072 let http_client = cx.update(|cx| cx.http_client());
1073 let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
1074 let server_url = match configuration.as_ref() {
1075 ContextServerConfiguration::Http { url, .. } => url.clone(),
1076 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1077 };
1078
1079 let client_registration =
1080 oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1081 .await
1082 .context("Failed to resolve OAuth client registration")?;
1083
1084 let auth_url = oauth::build_authorization_url(
1085 &discovery.auth_server_metadata,
1086 &client_registration.client_id,
1087 &redirect_uri,
1088 &discovery.scopes,
1089 &resource,
1090 &pkce,
1091 &state_param,
1092 );
1093
1094 cx.update(|cx| cx.open_url(auth_url.as_str()));
1095
1096 let callback = callback_rx
1097 .await
1098 .map_err(|_| {
1099 anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1100 })?
1101 .context("OAuth callback server received an invalid request")?;
1102
1103 if callback.state != state_param {
1104 anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1105 }
1106
1107 let tokens = oauth::exchange_code(
1108 &http_client,
1109 &discovery.auth_server_metadata,
1110 &callback.code,
1111 &client_registration.client_id,
1112 &redirect_uri,
1113 &pkce.verifier,
1114 &resource,
1115 )
1116 .await
1117 .context("Failed to exchange authorization code for tokens")?;
1118
1119 let session = OAuthSession {
1120 token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1121 resource: discovery.resource_metadata.resource.clone(),
1122 client_registration,
1123 tokens,
1124 };
1125
1126 Self::store_session(&credentials_provider, &server_url, &session, cx)
1127 .await
1128 .context("Failed to persist OAuth session in keychain")?;
1129
1130 let token_provider = Self::create_oauth_token_provider(
1131 &id,
1132 &server_url,
1133 session,
1134 http_client.clone(),
1135 credentials_provider,
1136 cx,
1137 );
1138
1139 let new_server = this.update(cx, |this, cx| {
1140 let global_timeout =
1141 Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1142
1143 match configuration.as_ref() {
1144 ContextServerConfiguration::Http {
1145 url,
1146 headers,
1147 timeout,
1148 } => {
1149 let transport = HttpTransport::new_with_token_provider(
1150 http_client.clone(),
1151 url.to_string(),
1152 headers.clone(),
1153 cx.background_executor().clone(),
1154 Some(token_provider.clone()),
1155 );
1156 Ok(Arc::new(ContextServer::new_with_timeout(
1157 id.clone(),
1158 Arc::new(transport),
1159 Some(Duration::from_secs(
1160 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1161 )),
1162 )))
1163 }
1164 _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1165 }
1166 })??;
1167
1168 this.update(cx, |this, cx| {
1169 this.run_server(new_server, configuration, cx);
1170 })?;
1171
1172 Ok(())
1173 }
1174
1175 /// Store the full OAuth session in the system keychain, keyed by the
1176 /// server's canonical URI.
1177 async fn store_session(
1178 credentials_provider: &Arc<dyn CredentialsProvider>,
1179 server_url: &url::Url,
1180 session: &OAuthSession,
1181 cx: &AsyncApp,
1182 ) -> Result<()> {
1183 let key = Self::keychain_key(server_url);
1184 let json = serde_json::to_string(session)?;
1185 credentials_provider
1186 .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1187 .await
1188 }
1189
1190 /// Load the full OAuth session from the system keychain for the given
1191 /// server URL.
1192 async fn load_session(
1193 credentials_provider: &Arc<dyn CredentialsProvider>,
1194 server_url: &url::Url,
1195 cx: &AsyncApp,
1196 ) -> Result<Option<OAuthSession>> {
1197 let key = Self::keychain_key(server_url);
1198 match credentials_provider.read_credentials(&key, cx).await? {
1199 Some((_username, password_bytes)) => {
1200 let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1201 Ok(Some(session))
1202 }
1203 None => Ok(None),
1204 }
1205 }
1206
1207 /// Clear the stored OAuth session from the system keychain.
1208 async fn clear_session(
1209 credentials_provider: &Arc<dyn CredentialsProvider>,
1210 server_url: &url::Url,
1211 cx: &AsyncApp,
1212 ) -> Result<()> {
1213 let key = Self::keychain_key(server_url);
1214 credentials_provider.delete_credentials(&key, cx).await
1215 }
1216
1217 fn keychain_key(server_url: &url::Url) -> String {
1218 format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1219 }
1220
1221 /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1222 /// session from the keychain and stop the server.
1223 pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1224 let state = self.servers.get(id).context("Context server not found")?;
1225 let configuration = state.configuration();
1226
1227 let server_url = match configuration.as_ref() {
1228 ContextServerConfiguration::Http { url, .. } => url.clone(),
1229 _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1230 };
1231
1232 let id = id.clone();
1233 self.stop_server(&id, cx)?;
1234
1235 cx.spawn(async move |this, cx| {
1236 let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
1237 if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1238 log::error!("{} failed to clear OAuth session: {}", id, err);
1239 }
1240 // Trigger server recreation so the next start uses a fresh
1241 // transport without the old (now-invalidated) token provider.
1242 this.update(cx, |this, cx| {
1243 this.available_context_servers_changed(cx);
1244 })
1245 .log_err();
1246 })
1247 .detach();
1248
1249 Ok(())
1250 }
1251
1252 fn update_server_state(
1253 &mut self,
1254 id: ContextServerId,
1255 state: ContextServerState,
1256 cx: &mut Context<Self>,
1257 ) {
1258 let status = ContextServerStatus::from_state(&state);
1259 self.servers.insert(id.clone(), state);
1260 cx.emit(ServerStatusChangedEvent {
1261 server_id: id,
1262 status,
1263 });
1264 }
1265
1266 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1267 if self.update_servers_task.is_some() {
1268 self.needs_server_update = true;
1269 } else {
1270 self.needs_server_update = false;
1271 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1272 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1273 log::error!("Error maintaining context servers: {}", err);
1274 }
1275
1276 this.update(cx, |this, cx| {
1277 this.populate_server_ids(cx);
1278 cx.notify();
1279 this.update_servers_task.take();
1280 if this.needs_server_update {
1281 this.available_context_servers_changed(cx);
1282 }
1283 })?;
1284
1285 Ok(())
1286 }));
1287 }
1288 }
1289
1290 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1291 // Don't start context servers if AI is disabled
1292 let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1293 if ai_disabled {
1294 // Stop all running servers when AI is disabled
1295 this.update(cx, |this, cx| {
1296 let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1297 for id in server_ids {
1298 let _ = this.stop_server(&id, cx);
1299 }
1300 })?;
1301 return Ok(());
1302 }
1303
1304 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1305 (
1306 this.context_server_settings.clone(),
1307 this.registry.clone(),
1308 this.worktree_store.clone(),
1309 )
1310 })?;
1311
1312 for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1313 configured_servers
1314 .entry(id)
1315 .or_insert(ContextServerSettings::default_extension());
1316 }
1317
1318 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1319 configured_servers
1320 .into_iter()
1321 .partition(|(_, settings)| settings.enabled());
1322
1323 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1324 let id = ContextServerId(id);
1325 ContextServerConfiguration::from_settings(
1326 settings,
1327 id.clone(),
1328 registry.clone(),
1329 worktree_store.clone(),
1330 cx,
1331 )
1332 .map(move |config| (id, config))
1333 }))
1334 .await
1335 .into_iter()
1336 .filter_map(|(id, config)| config.map(|config| (id, config)))
1337 .collect::<HashMap<_, _>>();
1338
1339 let mut servers_to_start = Vec::new();
1340 let mut servers_to_remove = HashSet::default();
1341 let mut servers_to_stop = HashSet::default();
1342
1343 this.update(cx, |this, _cx| {
1344 for server_id in this.servers.keys() {
1345 // All servers that are not in desired_servers should be removed from the store.
1346 // This can happen if the user removed a server from the context server settings.
1347 if !configured_servers.contains_key(server_id) {
1348 if disabled_servers.contains_key(&server_id.0) {
1349 servers_to_stop.insert(server_id.clone());
1350 } else {
1351 servers_to_remove.insert(server_id.clone());
1352 }
1353 }
1354 }
1355
1356 for (id, config) in configured_servers {
1357 let state = this.servers.get(&id);
1358 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1359 let existing_config = state.as_ref().map(|state| state.configuration());
1360 if existing_config.as_deref() != Some(&config) || is_stopped {
1361 let config = Arc::new(config);
1362 servers_to_start.push((id.clone(), config));
1363 if this.servers.contains_key(&id) {
1364 servers_to_stop.insert(id);
1365 }
1366 }
1367 }
1368
1369 anyhow::Ok(())
1370 })??;
1371
1372 this.update(cx, |this, inner_cx| {
1373 for id in servers_to_stop {
1374 this.stop_server(&id, inner_cx)?;
1375 }
1376 for id in servers_to_remove {
1377 this.remove_server(&id, inner_cx)?;
1378 }
1379 anyhow::Ok(())
1380 })??;
1381
1382 for (id, config) in servers_to_start {
1383 match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1384 Ok((server, config)) => {
1385 this.update(cx, |this, cx| {
1386 this.run_server(server, config, cx);
1387 })?;
1388 }
1389 Err(err) => {
1390 log::error!("{id} context server failed to create: {err:#}");
1391 this.update(cx, |_this, cx| {
1392 cx.emit(ServerStatusChangedEvent {
1393 server_id: id,
1394 status: ContextServerStatus::Error(err.to_string().into()),
1395 });
1396 cx.notify();
1397 })?;
1398 }
1399 }
1400 }
1401
1402 Ok(())
1403 }
1404}
1405
1406/// Determines the appropriate server state after a start attempt fails.
1407///
1408/// When the error is an HTTP 401 with no static auth header configured,
1409/// attempts OAuth discovery so the UI can offer an authentication flow.
1410async fn resolve_start_failure(
1411 id: &ContextServerId,
1412 err: anyhow::Error,
1413 server: Arc<ContextServer>,
1414 configuration: Arc<ContextServerConfiguration>,
1415 cx: &AsyncApp,
1416) -> ContextServerState {
1417 let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1418 TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1419 });
1420
1421 if www_authenticate.is_some() && configuration.has_static_auth_header() {
1422 log::warn!("{id} received 401 with a static Authorization header configured");
1423 return ContextServerState::Error {
1424 configuration,
1425 server,
1426 error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1427 .into(),
1428 };
1429 }
1430
1431 let server_url = match configuration.as_ref() {
1432 ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1433 url.clone()
1434 }
1435 _ => {
1436 if www_authenticate.is_some() {
1437 log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1438 } else {
1439 log::error!("{id} context server failed to start: {err}");
1440 }
1441 return ContextServerState::Error {
1442 configuration,
1443 server,
1444 error: err.to_string().into(),
1445 };
1446 }
1447 };
1448
1449 // When the error is NOT a 401 but there is a cached OAuth session in the
1450 // keychain, the session is likely stale/expired and caused the failure
1451 // (e.g. timeout because the server rejected the token silently). Clear it
1452 // so the next start attempt can get a clean 401 and trigger the auth flow.
1453 if www_authenticate.is_none() {
1454 let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
1455 match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1456 Ok(Some(_)) => {
1457 log::info!("{id} start failed with a cached OAuth session present; clearing it");
1458 ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1459 .await
1460 .log_err();
1461 }
1462 _ => {
1463 log::error!("{id} context server failed to start: {err}");
1464 return ContextServerState::Error {
1465 configuration,
1466 server,
1467 error: err.to_string().into(),
1468 };
1469 }
1470 }
1471 }
1472
1473 let default_www_authenticate = oauth::WwwAuthenticate {
1474 resource_metadata: None,
1475 scope: None,
1476 error: None,
1477 error_description: None,
1478 };
1479 let www_authenticate = www_authenticate
1480 .as_ref()
1481 .unwrap_or(&default_www_authenticate);
1482 let http_client = cx.update(|cx| cx.http_client());
1483
1484 match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1485 Ok(discovery) => {
1486 log::info!(
1487 "{id} requires OAuth authorization (auth server: {})",
1488 discovery.auth_server_metadata.issuer,
1489 );
1490 ContextServerState::AuthRequired {
1491 server,
1492 configuration,
1493 discovery: Arc::new(discovery),
1494 }
1495 }
1496 Err(discovery_err) => {
1497 log::error!("{id} OAuth discovery failed: {discovery_err}");
1498 ContextServerState::Error {
1499 configuration,
1500 server,
1501 error: format!("OAuth discovery failed: {discovery_err}").into(),
1502 }
1503 }
1504 }
1505}