Ensure context servers are spawned in the workspace directory (#35271)

Finn Evers and Danilo Leal created

This fixes an issue where we were not setting the context server working
directory at all.

Release Notes:

- Context servers will now be spawned in the currently active project
root.

---------

Co-authored-by: Danilo Leal <danilo@zed.dev>

Change summary

crates/agent_servers/src/codex.rs                      |  2 
crates/context_server/src/client.rs                    |  3 
crates/context_server/src/context_server.rs            | 16 +
crates/context_server/src/transport/stdio_transport.rs | 11 +
crates/project/src/context_server_store.rs             | 82 +++++++++--
crates/project/src/project.rs                          | 12 +
6 files changed, 97 insertions(+), 29 deletions(-)

Detailed changes

crates/agent_servers/src/codex.rs 🔗

@@ -47,6 +47,7 @@ impl AgentServer for Codex {
         cx: &mut App,
     ) -> Task<Result<Rc<dyn AgentConnection>>> {
         let project = project.clone();
+        let working_directory = project.read(cx).active_project_directory(cx);
         cx.spawn(async move |cx| {
             let settings = cx.read_global(|settings: &SettingsStore, _| {
                 settings.get::<AllAgentServersSettings>(None).codex.clone()
@@ -65,6 +66,7 @@ impl AgentServer for Codex {
                     args: command.args,
                     env: command.env,
                 },
+                working_directory,
             )
             .into();
             ContextServer::start(client.clone(), cx).await?;

crates/context_server/src/client.rs 🔗

@@ -158,6 +158,7 @@ impl Client {
     pub fn stdio(
         server_id: ContextServerId,
         binary: ModelContextServerBinary,
+        working_directory: &Option<PathBuf>,
         cx: AsyncApp,
     ) -> Result<Self> {
         log::info!(
@@ -172,7 +173,7 @@ impl Client {
             .map(|name| name.to_string_lossy().to_string())
             .unwrap_or_else(String::new);
 
-        let transport = Arc::new(StdioTransport::new(binary, &cx)?);
+        let transport = Arc::new(StdioTransport::new(binary, working_directory, &cx)?);
         Self::new(server_id, server_name.into(), transport, cx)
     }
 

crates/context_server/src/context_server.rs 🔗

@@ -53,7 +53,7 @@ impl std::fmt::Debug for ContextServerCommand {
 }
 
 enum ContextServerTransport {
-    Stdio(ContextServerCommand),
+    Stdio(ContextServerCommand, Option<PathBuf>),
     Custom(Arc<dyn crate::transport::Transport>),
 }
 
@@ -64,11 +64,18 @@ pub struct ContextServer {
 }
 
 impl ContextServer {
-    pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
+    pub fn stdio(
+        id: ContextServerId,
+        command: ContextServerCommand,
+        working_directory: Option<Arc<Path>>,
+    ) -> Self {
         Self {
             id,
             client: RwLock::new(None),
-            configuration: ContextServerTransport::Stdio(command),
+            configuration: ContextServerTransport::Stdio(
+                command,
+                working_directory.map(|directory| directory.to_path_buf()),
+            ),
         }
     }
 
@@ -90,13 +97,14 @@ impl ContextServer {
 
     pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
         let client = match &self.configuration {
-            ContextServerTransport::Stdio(command) => Client::stdio(
+            ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
                 client::ContextServerId(self.id.0.clone()),
                 client::ModelContextServerBinary {
                     executable: Path::new(&command.path).to_path_buf(),
                     args: command.args.clone(),
                     env: command.env.clone(),
                 },
+                working_directory,
                 cx.clone(),
             )?,
             ContextServerTransport::Custom(transport) => Client::new(

crates/context_server/src/transport/stdio_transport.rs 🔗

@@ -1,3 +1,4 @@
+use std::path::PathBuf;
 use std::pin::Pin;
 
 use anyhow::{Context as _, Result};
@@ -22,7 +23,11 @@ pub struct StdioTransport {
 }
 
 impl StdioTransport {
-    pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result<Self> {
+    pub fn new(
+        binary: ModelContextServerBinary,
+        working_directory: &Option<PathBuf>,
+        cx: &AsyncApp,
+    ) -> Result<Self> {
         let mut command = util::command::new_smol_command(&binary.executable);
         command
             .args(&binary.args)
@@ -32,6 +37,10 @@ impl StdioTransport {
             .stderr(std::process::Stdio::piped())
             .kill_on_drop(true);
 
+        if let Some(working_directory) = working_directory {
+            command.current_dir(working_directory);
+        }
+
         let mut server = command.spawn().with_context(|| {
             format!(
                 "failed to spawn command. (path={:?}, args={:?})",

crates/project/src/context_server_store.rs 🔗

@@ -13,6 +13,7 @@ use settings::{Settings as _, SettingsStore};
 use util::ResultExt as _;
 
 use crate::{
+    Project,
     project_settings::{ContextServerSettings, ProjectSettings},
     worktree_store::WorktreeStore,
 };
@@ -144,6 +145,7 @@ pub struct ContextServerStore {
     context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
     servers: HashMap<ContextServerId, ContextServerState>,
     worktree_store: Entity<WorktreeStore>,
+    project: WeakEntity<Project>,
     registry: Entity<ContextServerDescriptorRegistry>,
     update_servers_task: Option<Task<Result<()>>>,
     context_server_factory: Option<ContextServerFactory>,
@@ -161,12 +163,17 @@ pub enum Event {
 impl EventEmitter<Event> for ContextServerStore {}
 
 impl ContextServerStore {
-    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
+    pub fn new(
+        worktree_store: Entity<WorktreeStore>,
+        weak_project: WeakEntity<Project>,
+        cx: &mut Context<Self>,
+    ) -> Self {
         Self::new_internal(
             true,
             None,
             ContextServerDescriptorRegistry::default_global(cx),
             worktree_store,
+            weak_project,
             cx,
         )
     }
@@ -184,9 +191,10 @@ impl ContextServerStore {
     pub fn test(
         registry: Entity<ContextServerDescriptorRegistry>,
         worktree_store: Entity<WorktreeStore>,
+        weak_project: WeakEntity<Project>,
         cx: &mut Context<Self>,
     ) -> Self {
-        Self::new_internal(false, None, registry, worktree_store, cx)
+        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
     }
 
     #[cfg(any(test, feature = "test-support"))]
@@ -194,6 +202,7 @@ impl ContextServerStore {
         context_server_factory: ContextServerFactory,
         registry: Entity<ContextServerDescriptorRegistry>,
         worktree_store: Entity<WorktreeStore>,
+        weak_project: WeakEntity<Project>,
         cx: &mut Context<Self>,
     ) -> Self {
         Self::new_internal(
@@ -201,6 +210,7 @@ impl ContextServerStore {
             Some(context_server_factory),
             registry,
             worktree_store,
+            weak_project,
             cx,
         )
     }
@@ -210,6 +220,7 @@ impl ContextServerStore {
         context_server_factory: Option<ContextServerFactory>,
         registry: Entity<ContextServerDescriptorRegistry>,
         worktree_store: Entity<WorktreeStore>,
+        weak_project: WeakEntity<Project>,
         cx: &mut Context<Self>,
     ) -> Self {
         let subscriptions = if maintain_server_loop {
@@ -235,6 +246,7 @@ impl ContextServerStore {
             context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
                 .clone(),
             worktree_store,
+            project: weak_project,
             registry,
             needs_server_update: false,
             servers: HashMap::default(),
@@ -360,7 +372,7 @@ impl ContextServerStore {
             let configuration = state.configuration();
 
             self.stop_server(&state.server().id(), cx)?;
-            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
+            let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
             self.run_server(new_server, configuration, cx);
         }
         Ok(())
@@ -449,14 +461,33 @@ impl ContextServerStore {
         &self,
         id: ContextServerId,
         configuration: Arc<ContextServerConfiguration>,
-    ) -> Result<Arc<ContextServer>> {
+        cx: &mut Context<Self>,
+    ) -> Arc<ContextServer> {
+        let root_path = self
+            .project
+            .read_with(cx, |project, cx| project.active_project_directory(cx))
+            .ok()
+            .flatten()
+            .or_else(|| {
+                self.worktree_store.read_with(cx, |store, cx| {
+                    store.visible_worktrees(cx).fold(None, |acc, item| {
+                        if acc.is_none() {
+                            item.read(cx).root_dir()
+                        } else {
+                            acc
+                        }
+                    })
+                })
+            });
+
         if let Some(factory) = self.context_server_factory.as_ref() {
-            Ok(factory(id, configuration))
+            factory(id, configuration)
         } else {
-            Ok(Arc::new(ContextServer::stdio(
+            Arc::new(ContextServer::stdio(
                 id,
                 configuration.command().clone(),
-            )))
+                root_path,
+            ))
         }
     }
 
@@ -553,7 +584,7 @@ impl ContextServerStore {
         let mut servers_to_remove = HashSet::default();
         let mut servers_to_stop = HashSet::default();
 
-        this.update(cx, |this, _cx| {
+        this.update(cx, |this, cx| {
             for server_id in this.servers.keys() {
                 // All servers that are not in desired_servers should be removed from the store.
                 // This can happen if the user removed a server from the context server settings.
@@ -572,14 +603,10 @@ impl ContextServerStore {
                 let existing_config = state.as_ref().map(|state| state.configuration());
                 if existing_config.as_deref() != Some(&config) || is_stopped {
                     let config = Arc::new(config);
-                    if let Some(server) = this
-                        .create_context_server(id.clone(), config.clone())
-                        .log_err()
-                    {
-                        servers_to_start.push((server, config));
-                        if this.servers.contains_key(&id) {
-                            servers_to_stop.insert(id);
-                        }
+                    let server = this.create_context_server(id.clone(), config.clone(), cx);
+                    servers_to_start.push((server, config));
+                    if this.servers.contains_key(&id) {
+                        servers_to_stop.insert(id);
                     }
                 }
             }
@@ -630,7 +657,12 @@ mod tests {
 
         let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
         let store = cx.new(|cx| {
-            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
+            ContextServerStore::test(
+                registry.clone(),
+                project.read(cx).worktree_store(),
+                project.downgrade(),
+                cx,
+            )
         });
 
         let server_1_id = ContextServerId(SERVER_1_ID.into());
@@ -705,7 +737,12 @@ mod tests {
 
         let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
         let store = cx.new(|cx| {
-            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
+            ContextServerStore::test(
+                registry.clone(),
+                project.read(cx).worktree_store(),
+                project.downgrade(),
+                cx,
+            )
         });
 
         let server_1_id = ContextServerId(SERVER_1_ID.into());
@@ -758,7 +795,12 @@ mod tests {
 
         let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
         let store = cx.new(|cx| {
-            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
+            ContextServerStore::test(
+                registry.clone(),
+                project.read(cx).worktree_store(),
+                project.downgrade(),
+                cx,
+            )
         });
 
         let server_id = ContextServerId(SERVER_1_ID.into());
@@ -842,6 +884,7 @@ mod tests {
                 }),
                 registry.clone(),
                 project.read(cx).worktree_store(),
+                project.downgrade(),
                 cx,
             )
         });
@@ -1074,6 +1117,7 @@ mod tests {
                 }),
                 registry.clone(),
                 project.read(cx).worktree_store(),
+                project.downgrade(),
                 cx,
             )
         });

crates/project/src/project.rs 🔗

@@ -998,8 +998,9 @@ impl Project {
             cx.subscribe(&worktree_store, Self::on_worktree_store_event)
                 .detach();
 
+            let weak_self = cx.weak_entity();
             let context_server_store =
-                cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx));
+                cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx));
 
             let environment = cx.new(|_| ProjectEnvironment::new(env));
             let manifest_tree = ManifestTree::new(worktree_store.clone(), cx);
@@ -1167,8 +1168,9 @@ impl Project {
             cx.subscribe(&worktree_store, Self::on_worktree_store_event)
                 .detach();
 
+            let weak_self = cx.weak_entity();
             let context_server_store =
-                cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx));
+                cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx));
 
             let buffer_store = cx.new(|cx| {
                 BufferStore::remote(
@@ -1428,8 +1430,6 @@ impl Project {
         let image_store = cx.new(|cx| {
             ImageStore::remote(worktree_store.clone(), client.clone().into(), remote_id, cx)
         })?;
-        let context_server_store =
-            cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx))?;
 
         let environment = cx.new(|_| ProjectEnvironment::new(None))?;
 
@@ -1496,6 +1496,10 @@ impl Project {
 
             let snippets = SnippetProvider::new(fs.clone(), BTreeSet::from_iter([]), cx);
 
+            let weak_self = cx.weak_entity();
+            let context_server_store =
+                cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx));
+
             let mut worktrees = Vec::new();
             for worktree in response.payload.worktrees {
                 let worktree =