context_server_store.rs

  1pub mod extension;
  2pub mod registry;
  3
  4use std::path::Path;
  5use std::sync::Arc;
  6use std::time::Duration;
  7
  8use anyhow::{Context as _, Result};
  9use collections::{HashMap, HashSet};
 10use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 11use futures::{FutureExt as _, future::Either, future::join_all};
 12use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
 13use itertools::Itertools;
 14use registry::ContextServerDescriptorRegistry;
 15use remote::RemoteClient;
 16use rpc::{AnyProtoClient, TypedEnvelope, proto};
 17use settings::{Settings as _, SettingsStore};
 18use util::{ResultExt as _, rel_path::RelPath};
 19
 20use crate::{
 21    DisableAiSettings, Project,
 22    project_settings::{ContextServerSettings, ProjectSettings},
 23    worktree_store::WorktreeStore,
 24};
 25
 26/// Maximum timeout for context server requests
 27/// Prevents extremely large timeout values from tying up resources indefinitely.
 28const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
 29
 30pub fn init(cx: &mut App) {
 31    extension::init(cx);
 32}
 33
 34actions!(
 35    context_server,
 36    [
 37        /// Restarts the context server.
 38        Restart
 39    ]
 40);
 41
 42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 43pub enum ContextServerStatus {
 44    Starting,
 45    Running,
 46    Stopped,
 47    Error(Arc<str>),
 48}
 49
 50impl ContextServerStatus {
 51    fn from_state(state: &ContextServerState) -> Self {
 52        match state {
 53            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
 54            ContextServerState::Running { .. } => ContextServerStatus::Running,
 55            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
 56            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
 57        }
 58    }
 59}
 60
 61enum ContextServerState {
 62    Starting {
 63        server: Arc<ContextServer>,
 64        configuration: Arc<ContextServerConfiguration>,
 65        _task: Task<()>,
 66    },
 67    Running {
 68        server: Arc<ContextServer>,
 69        configuration: Arc<ContextServerConfiguration>,
 70    },
 71    Stopped {
 72        server: Arc<ContextServer>,
 73        configuration: Arc<ContextServerConfiguration>,
 74    },
 75    Error {
 76        server: Arc<ContextServer>,
 77        configuration: Arc<ContextServerConfiguration>,
 78        error: Arc<str>,
 79    },
 80}
 81
 82impl ContextServerState {
 83    pub fn server(&self) -> Arc<ContextServer> {
 84        match self {
 85            ContextServerState::Starting { server, .. } => server.clone(),
 86            ContextServerState::Running { server, .. } => server.clone(),
 87            ContextServerState::Stopped { server, .. } => server.clone(),
 88            ContextServerState::Error { server, .. } => server.clone(),
 89        }
 90    }
 91
 92    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
 93        match self {
 94            ContextServerState::Starting { configuration, .. } => configuration.clone(),
 95            ContextServerState::Running { configuration, .. } => configuration.clone(),
 96            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
 97            ContextServerState::Error { configuration, .. } => configuration.clone(),
 98        }
 99    }
100}
101
102#[derive(Debug, PartialEq, Eq)]
103pub enum ContextServerConfiguration {
104    Custom {
105        command: ContextServerCommand,
106        remote: bool,
107    },
108    Extension {
109        command: ContextServerCommand,
110        settings: serde_json::Value,
111        remote: bool,
112    },
113    Http {
114        url: url::Url,
115        headers: HashMap<String, String>,
116        timeout: Option<u64>,
117    },
118}
119
120impl ContextServerConfiguration {
121    pub fn command(&self) -> Option<&ContextServerCommand> {
122        match self {
123            ContextServerConfiguration::Custom { command, .. } => Some(command),
124            ContextServerConfiguration::Extension { command, .. } => Some(command),
125            ContextServerConfiguration::Http { .. } => None,
126        }
127    }
128
129    pub fn remote(&self) -> bool {
130        match self {
131            ContextServerConfiguration::Custom { remote, .. } => *remote,
132            ContextServerConfiguration::Extension { remote, .. } => *remote,
133            ContextServerConfiguration::Http { .. } => false,
134        }
135    }
136
137    pub async fn from_settings(
138        settings: ContextServerSettings,
139        id: ContextServerId,
140        registry: Entity<ContextServerDescriptorRegistry>,
141        worktree_store: Entity<WorktreeStore>,
142        cx: &AsyncApp,
143    ) -> Option<Self> {
144        const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
145
146        match settings {
147            ContextServerSettings::Stdio {
148                enabled: _,
149                command,
150                remote,
151            } => Some(ContextServerConfiguration::Custom { command, remote }),
152            ContextServerSettings::Extension {
153                enabled: _,
154                settings,
155                remote,
156            } => {
157                let descriptor =
158                    cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
159
160                let command_future = descriptor.command(worktree_store, cx);
161                let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
162
163                match futures::future::select(command_future, timeout_future).await {
164                    Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
165                        command,
166                        settings,
167                        remote,
168                    }),
169                    Either::Left((Err(e), _)) => {
170                        log::error!(
171                            "Failed to create context server configuration from settings: {e:#}"
172                        );
173                        None
174                    }
175                    Either::Right(_) => {
176                        log::error!(
177                            "Timed out resolving command for extension context server {id}"
178                        );
179                        None
180                    }
181                }
182            }
183            ContextServerSettings::Http {
184                enabled: _,
185                url,
186                headers: auth,
187                timeout,
188            } => {
189                let url = url::Url::parse(&url).log_err()?;
190                Some(ContextServerConfiguration::Http {
191                    url,
192                    headers: auth,
193                    timeout,
194                })
195            }
196        }
197    }
198}
199
200pub type ContextServerFactory =
201    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
202
203enum ContextServerStoreState {
204    Local {
205        downstream_client: Option<(u64, AnyProtoClient)>,
206        is_headless: bool,
207    },
208    Remote {
209        project_id: u64,
210        upstream_client: Entity<RemoteClient>,
211    },
212}
213
214pub struct ContextServerStore {
215    state: ContextServerStoreState,
216    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
217    servers: HashMap<ContextServerId, ContextServerState>,
218    server_ids: Vec<ContextServerId>,
219    worktree_store: Entity<WorktreeStore>,
220    project: Option<WeakEntity<Project>>,
221    registry: Entity<ContextServerDescriptorRegistry>,
222    update_servers_task: Option<Task<Result<()>>>,
223    context_server_factory: Option<ContextServerFactory>,
224    needs_server_update: bool,
225    _subscriptions: Vec<Subscription>,
226}
227
228pub struct ServerStatusChangedEvent {
229    pub server_id: ContextServerId,
230    pub status: ContextServerStatus,
231}
232
233impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
234
235impl ContextServerStore {
236    pub fn local(
237        worktree_store: Entity<WorktreeStore>,
238        weak_project: Option<WeakEntity<Project>>,
239        headless: bool,
240        cx: &mut Context<Self>,
241    ) -> Self {
242        Self::new_internal(
243            !headless,
244            None,
245            ContextServerDescriptorRegistry::default_global(cx),
246            worktree_store,
247            weak_project,
248            ContextServerStoreState::Local {
249                downstream_client: None,
250                is_headless: headless,
251            },
252            cx,
253        )
254    }
255
256    pub fn remote(
257        project_id: u64,
258        upstream_client: Entity<RemoteClient>,
259        worktree_store: Entity<WorktreeStore>,
260        weak_project: Option<WeakEntity<Project>>,
261        cx: &mut Context<Self>,
262    ) -> Self {
263        Self::new_internal(
264            true,
265            None,
266            ContextServerDescriptorRegistry::default_global(cx),
267            worktree_store,
268            weak_project,
269            ContextServerStoreState::Remote {
270                project_id,
271                upstream_client,
272            },
273            cx,
274        )
275    }
276
277    pub fn init_headless(session: &AnyProtoClient) {
278        session.add_entity_request_handler(Self::handle_get_context_server_command);
279    }
280
281    pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
282        if let ContextServerStoreState::Local {
283            downstream_client, ..
284        } = &mut self.state
285        {
286            *downstream_client = Some((project_id, client));
287        }
288    }
289
290    pub fn is_remote_project(&self) -> bool {
291        matches!(self.state, ContextServerStoreState::Remote { .. })
292    }
293
294    /// Returns all configured context server ids, excluding the ones that are disabled
295    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
296        self.context_server_settings
297            .iter()
298            .filter(|(_, settings)| settings.enabled())
299            .map(|(id, _)| ContextServerId(id.clone()))
300            .collect()
301    }
302
303    #[cfg(feature = "test-support")]
304    pub fn test(
305        registry: Entity<ContextServerDescriptorRegistry>,
306        worktree_store: Entity<WorktreeStore>,
307        weak_project: Option<WeakEntity<Project>>,
308        cx: &mut Context<Self>,
309    ) -> Self {
310        Self::new_internal(
311            false,
312            None,
313            registry,
314            worktree_store,
315            weak_project,
316            ContextServerStoreState::Local {
317                downstream_client: None,
318                is_headless: false,
319            },
320            cx,
321        )
322    }
323
324    #[cfg(feature = "test-support")]
325    pub fn test_maintain_server_loop(
326        context_server_factory: Option<ContextServerFactory>,
327        registry: Entity<ContextServerDescriptorRegistry>,
328        worktree_store: Entity<WorktreeStore>,
329        weak_project: Option<WeakEntity<Project>>,
330        cx: &mut Context<Self>,
331    ) -> Self {
332        Self::new_internal(
333            true,
334            context_server_factory,
335            registry,
336            worktree_store,
337            weak_project,
338            ContextServerStoreState::Local {
339                downstream_client: None,
340                is_headless: false,
341            },
342            cx,
343        )
344    }
345
346    #[cfg(feature = "test-support")]
347    pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
348        self.context_server_factory = Some(factory);
349    }
350
351    #[cfg(feature = "test-support")]
352    pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
353        &self.registry
354    }
355
356    #[cfg(feature = "test-support")]
357    pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
358        let configuration = Arc::new(ContextServerConfiguration::Custom {
359            command: ContextServerCommand {
360                path: "test".into(),
361                args: vec![],
362                env: None,
363                timeout: None,
364            },
365            remote: false,
366        });
367        self.run_server(server, configuration, cx);
368    }
369
370    fn new_internal(
371        maintain_server_loop: bool,
372        context_server_factory: Option<ContextServerFactory>,
373        registry: Entity<ContextServerDescriptorRegistry>,
374        worktree_store: Entity<WorktreeStore>,
375        weak_project: Option<WeakEntity<Project>>,
376        state: ContextServerStoreState,
377        cx: &mut Context<Self>,
378    ) -> Self {
379        let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
380            let settings =
381                &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
382            if &this.context_server_settings == settings {
383                return;
384            }
385            this.context_server_settings = settings.clone();
386            if maintain_server_loop {
387                this.available_context_servers_changed(cx);
388            }
389        })];
390
391        if maintain_server_loop {
392            subscriptions.push(cx.observe(&registry, |this, _registry, cx| {
393                this.available_context_servers_changed(cx);
394            }));
395        }
396
397        let mut this = Self {
398            state,
399            _subscriptions: subscriptions,
400            context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
401                .context_servers
402                .clone(),
403            worktree_store,
404            project: weak_project,
405            registry,
406            needs_server_update: false,
407            servers: HashMap::default(),
408            server_ids: Default::default(),
409            update_servers_task: None,
410            context_server_factory,
411        };
412        if maintain_server_loop {
413            this.available_context_servers_changed(cx);
414        }
415        this
416    }
417
418    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
419        self.servers.get(id).map(|state| state.server())
420    }
421
422    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
423        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
424            Some(server.clone())
425        } else {
426            None
427        }
428    }
429
430    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
431        self.servers.get(id).map(ContextServerStatus::from_state)
432    }
433
434    pub fn configuration_for_server(
435        &self,
436        id: &ContextServerId,
437    ) -> Option<Arc<ContextServerConfiguration>> {
438        self.servers.get(id).map(|state| state.configuration())
439    }
440
441    /// Returns a sorted slice of available unique context server IDs. Within the
442    /// slice, context servers which have `mcp-server-` as a prefix in their ID will
443    /// appear after servers that do not have this prefix in their ID.
444    pub fn server_ids(&self) -> &[ContextServerId] {
445        self.server_ids.as_slice()
446    }
447
448    fn populate_server_ids(&mut self, cx: &App) {
449        self.server_ids = self
450            .servers
451            .keys()
452            .cloned()
453            .chain(
454                self.registry
455                    .read(cx)
456                    .context_server_descriptors()
457                    .into_iter()
458                    .map(|(id, _)| ContextServerId(id)),
459            )
460            .chain(
461                self.context_server_settings
462                    .keys()
463                    .map(|id| ContextServerId(id.clone())),
464            )
465            .unique()
466            .sorted_unstable_by(
467                // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
468                |a, b| {
469                    const MCP_PREFIX: &str = "mcp-server-";
470                    match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
471                        // If one has mcp-server- prefix and other doesn't, non-mcp comes first
472                        (Some(_), None) => std::cmp::Ordering::Greater,
473                        (None, Some(_)) => std::cmp::Ordering::Less,
474                        // If both have same prefix status, sort by appropriate key
475                        (Some(a), Some(b)) => a.cmp(b),
476                        (None, None) => a.0.cmp(&b.0),
477                    }
478                },
479            )
480            .collect();
481    }
482
483    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
484        self.servers
485            .values()
486            .filter_map(|state| {
487                if let ContextServerState::Running { server, .. } = state {
488                    Some(server.clone())
489                } else {
490                    None
491                }
492            })
493            .collect()
494    }
495
496    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
497        cx.spawn(async move |this, cx| {
498            let this = this.upgrade().context("Context server store dropped")?;
499            let settings = this
500                .update(cx, |this, _| {
501                    this.context_server_settings.get(&server.id().0).cloned()
502                })
503                .context("Failed to get context server settings")?;
504
505            if !settings.enabled() {
506                return anyhow::Ok(());
507            }
508
509            let (registry, worktree_store) = this.update(cx, |this, _| {
510                (this.registry.clone(), this.worktree_store.clone())
511            });
512            let configuration = ContextServerConfiguration::from_settings(
513                settings,
514                server.id(),
515                registry,
516                worktree_store,
517                cx,
518            )
519            .await
520            .context("Failed to create context server configuration")?;
521
522            this.update(cx, |this, cx| {
523                this.run_server(server, Arc::new(configuration), cx)
524            });
525            Ok(())
526        })
527        .detach_and_log_err(cx);
528    }
529
530    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
531        if matches!(
532            self.servers.get(id),
533            Some(ContextServerState::Stopped { .. })
534        ) {
535            return Ok(());
536        }
537
538        let state = self
539            .servers
540            .remove(id)
541            .context("Context server not found")?;
542
543        let server = state.server();
544        let configuration = state.configuration();
545        let mut result = Ok(());
546        if let ContextServerState::Running { server, .. } = &state {
547            result = server.stop();
548        }
549        drop(state);
550
551        self.update_server_state(
552            id.clone(),
553            ContextServerState::Stopped {
554                configuration,
555                server,
556            },
557            cx,
558        );
559
560        result
561    }
562
563    fn run_server(
564        &mut self,
565        server: Arc<ContextServer>,
566        configuration: Arc<ContextServerConfiguration>,
567        cx: &mut Context<Self>,
568    ) {
569        let id = server.id();
570        if matches!(
571            self.servers.get(&id),
572            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
573        ) {
574            self.stop_server(&id, cx).log_err();
575        }
576        let task = cx.spawn({
577            let id = server.id();
578            let server = server.clone();
579            let configuration = configuration.clone();
580
581            async move |this, cx| {
582                match server.clone().start(cx).await {
583                    Ok(_) => {
584                        debug_assert!(server.client().is_some());
585
586                        this.update(cx, |this, cx| {
587                            this.update_server_state(
588                                id.clone(),
589                                ContextServerState::Running {
590                                    server,
591                                    configuration,
592                                },
593                                cx,
594                            )
595                        })
596                        .log_err()
597                    }
598                    Err(err) => {
599                        log::error!("{} context server failed to start: {}", id, err);
600                        this.update(cx, |this, cx| {
601                            this.update_server_state(
602                                id.clone(),
603                                ContextServerState::Error {
604                                    configuration,
605                                    server,
606                                    error: err.to_string().into(),
607                                },
608                                cx,
609                            )
610                        })
611                        .log_err()
612                    }
613                };
614            }
615        });
616
617        self.update_server_state(
618            id.clone(),
619            ContextServerState::Starting {
620                configuration,
621                _task: task,
622                server,
623            },
624            cx,
625        );
626    }
627
628    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
629        let state = self
630            .servers
631            .remove(id)
632            .context("Context server not found")?;
633        drop(state);
634        cx.emit(ServerStatusChangedEvent {
635            server_id: id.clone(),
636            status: ContextServerStatus::Stopped,
637        });
638        Ok(())
639    }
640
641    pub async fn create_context_server(
642        this: WeakEntity<Self>,
643        id: ContextServerId,
644        configuration: Arc<ContextServerConfiguration>,
645        cx: &mut AsyncApp,
646    ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
647        let remote = configuration.remote();
648        let needs_remote_command = match configuration.as_ref() {
649            ContextServerConfiguration::Custom { .. }
650            | ContextServerConfiguration::Extension { .. } => remote,
651            ContextServerConfiguration::Http { .. } => false,
652        };
653
654        let (remote_state, is_remote_project) = this.update(cx, |this, _| {
655            let remote_state = match &this.state {
656                ContextServerStoreState::Remote {
657                    project_id,
658                    upstream_client,
659                } if needs_remote_command => Some((*project_id, upstream_client.clone())),
660                _ => None,
661            };
662            (remote_state, this.is_remote_project())
663        })?;
664
665        let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
666            this.project
667                .as_ref()
668                .and_then(|project| {
669                    project
670                        .read_with(cx, |project, cx| project.active_project_directory(cx))
671                        .ok()
672                        .flatten()
673                })
674                .or_else(|| {
675                    this.worktree_store.read_with(cx, |store, cx| {
676                        store.visible_worktrees(cx).fold(None, |acc, item| {
677                            if acc.is_none() {
678                                item.read(cx).root_dir()
679                            } else {
680                                acc
681                            }
682                        })
683                    })
684                })
685        })?;
686
687        let configuration = if let Some((project_id, upstream_client)) = remote_state {
688            let root_dir = root_path.as_ref().map(|p| p.display().to_string());
689
690            let response = upstream_client
691                .update(cx, |client, _| {
692                    client
693                        .proto_client()
694                        .request(proto::GetContextServerCommand {
695                            project_id,
696                            server_id: id.0.to_string(),
697                            root_dir: root_dir.clone(),
698                        })
699                })
700                .await?;
701
702            let remote_command = upstream_client.update(cx, |client, _| {
703                client.build_command(
704                    Some(response.path),
705                    &response.args,
706                    &response.env.into_iter().collect(),
707                    root_dir,
708                    None,
709                )
710            })?;
711
712            let command = ContextServerCommand {
713                path: remote_command.program.into(),
714                args: remote_command.args,
715                env: Some(remote_command.env.into_iter().collect()),
716                timeout: None,
717            };
718
719            Arc::new(ContextServerConfiguration::Custom { command, remote })
720        } else {
721            configuration
722        };
723
724        let server: Arc<ContextServer> = this.update(cx, |this, cx| {
725            let global_timeout =
726                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
727
728            if let Some(factory) = this.context_server_factory.as_ref() {
729                return anyhow::Ok(factory(id.clone(), configuration.clone()));
730            }
731
732            match configuration.as_ref() {
733                ContextServerConfiguration::Http {
734                    url,
735                    headers,
736                    timeout,
737                } => anyhow::Ok(Arc::new(ContextServer::http(
738                    id,
739                    url,
740                    headers.clone(),
741                    cx.http_client(),
742                    cx.background_executor().clone(),
743                    Some(Duration::from_secs(
744                        timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
745                    )),
746                )?)),
747                _ => {
748                    let mut command = configuration
749                        .command()
750                        .context("Missing command configuration for stdio context server")?
751                        .clone();
752                    command.timeout = Some(
753                        command
754                            .timeout
755                            .unwrap_or(global_timeout)
756                            .min(MAX_TIMEOUT_SECS),
757                    );
758
759                    // Don't pass remote paths as working directory for locally-spawned processes
760                    let working_directory = if is_remote_project { None } else { root_path };
761                    anyhow::Ok(Arc::new(ContextServer::stdio(
762                        id,
763                        command,
764                        working_directory,
765                    )))
766                }
767            }
768        })??;
769
770        Ok((server, configuration))
771    }
772
773    async fn handle_get_context_server_command(
774        this: Entity<Self>,
775        envelope: TypedEnvelope<proto::GetContextServerCommand>,
776        mut cx: AsyncApp,
777    ) -> Result<proto::ContextServerCommand> {
778        let server_id = ContextServerId(envelope.payload.server_id.into());
779
780        let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
781            let ContextServerStoreState::Local {
782                is_headless: true, ..
783            } = &this.state
784            else {
785                anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
786            };
787
788            let settings = this
789                .context_server_settings
790                .get(&server_id.0)
791                .cloned()
792                .or_else(|| {
793                    this.registry
794                        .read(inner_cx)
795                        .context_server_descriptor(&server_id.0)
796                        .map(|_| ContextServerSettings::default_extension())
797                })
798                .with_context(|| format!("context server `{}` not found", server_id))?;
799
800            anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
801        })?;
802
803        let configuration = ContextServerConfiguration::from_settings(
804            settings,
805            server_id.clone(),
806            registry,
807            worktree_store,
808            &cx,
809        )
810        .await
811        .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
812
813        let command = configuration
814            .command()
815            .context("context server has no command (HTTP servers don't need RPC)")?;
816
817        Ok(proto::ContextServerCommand {
818            path: command.path.display().to_string(),
819            args: command.args.clone(),
820            env: command
821                .env
822                .clone()
823                .map(|env| env.into_iter().collect())
824                .unwrap_or_default(),
825        })
826    }
827
828    fn resolve_project_settings<'a>(
829        worktree_store: &'a Entity<WorktreeStore>,
830        cx: &'a App,
831    ) -> &'a ProjectSettings {
832        let location = worktree_store
833            .read(cx)
834            .visible_worktrees(cx)
835            .next()
836            .map(|worktree| settings::SettingsLocation {
837                worktree_id: worktree.read(cx).id(),
838                path: RelPath::empty(),
839            });
840        ProjectSettings::get(location, cx)
841    }
842
843    fn update_server_state(
844        &mut self,
845        id: ContextServerId,
846        state: ContextServerState,
847        cx: &mut Context<Self>,
848    ) {
849        let status = ContextServerStatus::from_state(&state);
850        self.servers.insert(id.clone(), state);
851        cx.emit(ServerStatusChangedEvent {
852            server_id: id,
853            status,
854        });
855    }
856
857    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
858        if self.update_servers_task.is_some() {
859            self.needs_server_update = true;
860        } else {
861            self.needs_server_update = false;
862            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
863                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
864                    log::error!("Error maintaining context servers: {}", err);
865                }
866
867                this.update(cx, |this, cx| {
868                    this.populate_server_ids(cx);
869                    cx.notify();
870                    this.update_servers_task.take();
871                    if this.needs_server_update {
872                        this.available_context_servers_changed(cx);
873                    }
874                })?;
875
876                Ok(())
877            }));
878        }
879    }
880
881    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
882        // Don't start context servers if AI is disabled
883        let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
884        if ai_disabled {
885            // Stop all running servers when AI is disabled
886            this.update(cx, |this, cx| {
887                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
888                for id in server_ids {
889                    let _ = this.stop_server(&id, cx);
890                }
891            })?;
892            return Ok(());
893        }
894
895        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
896            (
897                this.context_server_settings.clone(),
898                this.registry.clone(),
899                this.worktree_store.clone(),
900            )
901        })?;
902
903        for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
904            configured_servers
905                .entry(id)
906                .or_insert(ContextServerSettings::default_extension());
907        }
908
909        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
910            configured_servers
911                .into_iter()
912                .partition(|(_, settings)| settings.enabled());
913
914        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
915            let id = ContextServerId(id);
916            ContextServerConfiguration::from_settings(
917                settings,
918                id.clone(),
919                registry.clone(),
920                worktree_store.clone(),
921                cx,
922            )
923            .map(move |config| (id, config))
924        }))
925        .await
926        .into_iter()
927        .filter_map(|(id, config)| config.map(|config| (id, config)))
928        .collect::<HashMap<_, _>>();
929
930        let mut servers_to_start = Vec::new();
931        let mut servers_to_remove = HashSet::default();
932        let mut servers_to_stop = HashSet::default();
933
934        this.update(cx, |this, _cx| {
935            for server_id in this.servers.keys() {
936                // All servers that are not in desired_servers should be removed from the store.
937                // This can happen if the user removed a server from the context server settings.
938                if !configured_servers.contains_key(server_id) {
939                    if disabled_servers.contains_key(&server_id.0) {
940                        servers_to_stop.insert(server_id.clone());
941                    } else {
942                        servers_to_remove.insert(server_id.clone());
943                    }
944                }
945            }
946
947            for (id, config) in configured_servers {
948                let state = this.servers.get(&id);
949                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
950                let existing_config = state.as_ref().map(|state| state.configuration());
951                if existing_config.as_deref() != Some(&config) || is_stopped {
952                    let config = Arc::new(config);
953                    servers_to_start.push((id.clone(), config));
954                    if this.servers.contains_key(&id) {
955                        servers_to_stop.insert(id);
956                    }
957                }
958            }
959
960            anyhow::Ok(())
961        })??;
962
963        this.update(cx, |this, inner_cx| {
964            for id in servers_to_stop {
965                this.stop_server(&id, inner_cx)?;
966            }
967            for id in servers_to_remove {
968                this.remove_server(&id, inner_cx)?;
969            }
970            anyhow::Ok(())
971        })??;
972
973        for (id, config) in servers_to_start {
974            match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
975                Ok((server, config)) => {
976                    this.update(cx, |this, cx| {
977                        this.run_server(server, config, cx);
978                    })?;
979                }
980                Err(err) => {
981                    log::error!("{id} context server failed to create: {err:#}");
982                    this.update(cx, |_this, cx| {
983                        cx.emit(ServerStatusChangedEvent {
984                            server_id: id,
985                            status: ContextServerStatus::Error(err.to_string().into()),
986                        });
987                        cx.notify();
988                    })?;
989                }
990            }
991        }
992
993        Ok(())
994    }
995}