acp: Support loading a session for external agents (#46992)

Ben Brandt created

Release Notes:

- N/A

Change summary

crates/acp_thread/src/connection.rs    |   2 
crates/agent/src/agent.rs              |   2 
crates/agent_servers/src/acp.rs        | 268 ++++++++++++++++++---------
crates/agent_ui/src/acp/thread_view.rs |  27 +-
crates/agent_ui/src/agent_panel.rs     |  38 +--
crates/agent_ui_v2/src/agents_panel.rs |   2 
6 files changed, 209 insertions(+), 130 deletions(-)

Detailed changes

crates/acp_thread/src/connection.rs 🔗

@@ -38,7 +38,7 @@ pub trait AgentConnection {
     ) -> Task<Result<Entity<AcpThread>>>;
 
     /// Whether this agent supports loading existing sessions.
-    fn supports_load_session(&self) -> bool {
+    fn supports_load_session(&self, _cx: &App) -> bool {
         false
     }
 

crates/agent/src/agent.rs 🔗

@@ -1220,7 +1220,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         })
     }
 
-    fn supports_load_session(&self) -> bool {
+    fn supports_load_session(&self, _cx: &App) -> bool {
         true
     }
 

crates/agent_servers/src/acp.rs 🔗

@@ -351,92 +351,20 @@ impl AgentConnection for AcpConnection {
         cx: &mut App,
     ) -> Task<Result<Entity<AcpThread>>> {
         let name = self.server_name.clone();
-        let conn = self.connection.clone();
-        let sessions = self.sessions.clone();
-        let default_mode = self.default_mode.clone();
-        let default_model = self.default_model.clone();
-        let default_config_options = self.default_config_options.clone();
         let cwd = cwd.to_path_buf();
-        let context_server_store = project.read(cx).context_server_store().read(cx);
-        let is_local = project.read(cx).is_local();
-        let mcp_servers = context_server_store
-            .configured_server_ids()
-            .iter()
-            .filter_map(|id| {
-                let configuration = context_server_store.configuration_for_server(id)?;
-                match &*configuration {
-                    project::context_server_store::ContextServerConfiguration::Custom {
-                        command,
-                        remote,
-                        ..
-                    }
-                    | project::context_server_store::ContextServerConfiguration::Extension {
-                        command,
-                        remote,
-                        ..
-                    } if is_local || *remote => Some(acp::McpServer::Stdio(
-                        acp::McpServerStdio::new(id.0.to_string(), &command.path)
-                            .args(command.args.clone())
-                            .env(if let Some(env) = command.env.as_ref() {
-                                env.iter()
-                                    .map(|(name, value)| acp::EnvVariable::new(name, value))
-                                    .collect()
-                            } else {
-                                vec![]
-                            }),
-                    )),
-                    project::context_server_store::ContextServerConfiguration::Http {
-                        url,
-                        headers,
-                        timeout: _,
-                    } => Some(acp::McpServer::Http(
-                        acp::McpServerHttp::new(id.0.to_string(), url.to_string()).headers(
-                            headers
-                                .iter()
-                                .map(|(name, value)| acp::HttpHeader::new(name, value))
-                                .collect(),
-                        ),
-                    )),
-                    _ => None,
-                }
-            })
-            .collect();
+        let mcp_servers = mcp_servers_for_project(&project, cx);
 
         cx.spawn(async move |cx| {
-            let response = conn
+            let response = self.connection
                 .new_session(acp::NewSessionRequest::new(cwd).mcp_servers(mcp_servers))
                 .await
-                .map_err(|err| {
-                    if err.code == acp::ErrorCode::AuthRequired {
-                        let mut error = AuthRequired::new();
-
-                        if err.message != acp::ErrorCode::AuthRequired.to_string() {
-                            error = error.with_description(err.message);
-                        }
+                .map_err(map_acp_error)?;
 
-                        anyhow!(error)
-                    } else {
-                        anyhow!(err)
-                    }
-                })?;
-
-            let use_config_options = cx.update(|cx| cx.has_flag::<AcpBetaFeatureFlag>());
-
-            // Config options take precedence over legacy modes/models
-            let (modes, models, config_options) = if use_config_options && let Some(opts) = response.config_options {
-                (
-                    None,
-                    None,
-                    Some(Rc::new(RefCell::new(opts))),
-                )
-            } else {
-                // Fall back to legacy modes/models
-                let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
-                let models = response.models.map(|models| Rc::new(RefCell::new(models)));
-                (modes, models, None)
-            };
+            let (modes, models, config_options) = cx.update(|cx| {
+                config_state(cx, response.modes, response.models, response.config_options)
+            });
 
-            if let Some(default_mode) = default_mode {
+            if let Some(default_mode) = self.default_mode.clone() {
                 if let Some(modes) = modes.as_ref() {
                     let mut modes_ref = modes.borrow_mut();
                     let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode);
@@ -448,7 +376,7 @@ impl AgentConnection for AcpConnection {
                             let default_mode = default_mode.clone();
                             let session_id = response.session_id.clone();
                             let modes = modes.clone();
-                            let conn = conn.clone();
+                            let conn = self.connection.clone();
                             async move |_| {
                                 let result = conn.set_session_mode(acp::SetSessionModeRequest::new(session_id, default_mode))
                                 .await.log_err();
@@ -479,7 +407,7 @@ impl AgentConnection for AcpConnection {
                 }
             }
 
-            if let Some(default_model) = default_model {
+            if let Some(default_model) = self.default_model.clone() {
                 if let Some(models) = models.as_ref() {
                     let mut models_ref = models.borrow_mut();
                     let has_model = models_ref.available_models.iter().any(|model| model.model_id == default_model);
@@ -491,7 +419,7 @@ impl AgentConnection for AcpConnection {
                             let default_model = default_model.clone();
                             let session_id = response.session_id.clone();
                             let models = models.clone();
-                            let conn = conn.clone();
+                            let conn = self.connection.clone();
                             async move |_| {
                                 let result = conn.set_session_model(acp::SetSessionModelRequest::new(session_id, default_model))
                                 .await.log_err();
@@ -528,7 +456,7 @@ impl AgentConnection for AcpConnection {
                     config_opts_ref
                         .iter()
                         .filter_map(|config_option| {
-                            let default_value = default_config_options.get(&*config_option.id.0)?;
+                            let default_value = self.default_config_options.get(&*config_option.id.0)?;
 
                             let is_valid = match &config_option.kind {
                                 acp::SessionConfigKind::Select(select) => match &select.options {
@@ -570,7 +498,7 @@ impl AgentConnection for AcpConnection {
                         let session_id = response.session_id.clone();
                         let config_id_clone = config_id.clone();
                         let config_opts = config_opts.clone();
-                        let conn = conn.clone();
+                        let conn = self.connection.clone();
                         async move |_| {
                             let result = conn
                                 .set_session_config_option(
@@ -608,7 +536,6 @@ impl AgentConnection for AcpConnection {
                 }
             }
 
-            let session_id = response.session_id;
             let action_log = cx.new(|_| ActionLog::new(project.clone()));
             let thread: Entity<AcpThread> = cx.new(|cx| {
                 AcpThread::new(
@@ -616,22 +543,99 @@ impl AgentConnection for AcpConnection {
                     self.clone(),
                     project,
                     action_log,
-                    session_id.clone(),
+                    response.session_id.clone(),
                     // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
                     watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
                     cx,
                 )
             });
 
+            self.sessions.borrow_mut().insert(
+                response.session_id,
+                AcpSession {
+                    thread: thread.downgrade(),
+                    suppress_abort_err: false,
+                    session_modes: modes,
+                    models,
+                    config_options: config_options.map(ConfigOptions::new),
+                },
+            );
 
-            let session = AcpSession {
+            if let Some(session_list) = &self.session_list {
+                session_list.notify_update();
+            }
+
+            Ok(thread)
+        })
+    }
+
+    fn supports_load_session(&self, cx: &App) -> bool {
+        cx.has_flag::<AcpBetaFeatureFlag>() && self.agent_capabilities.load_session
+    }
+
+    fn load_session(
+        self: Rc<Self>,
+        session: AgentSessionInfo,
+        project: Entity<Project>,
+        cwd: &Path,
+        cx: &mut App,
+    ) -> Task<Result<Entity<AcpThread>>> {
+        if !cx.has_flag::<AcpBetaFeatureFlag>() || !self.agent_capabilities.load_session {
+            return Task::ready(Err(anyhow!(LoadError::Other(
+                "Loading sessions is not supported by this agent.".into()
+            ))));
+        }
+
+        let cwd = cwd.to_path_buf();
+        let mcp_servers = mcp_servers_for_project(&project, cx);
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let thread: Entity<AcpThread> = cx.new(|cx| {
+            AcpThread::new(
+                self.server_name.clone(),
+                self.clone(),
+                project,
+                action_log,
+                session.session_id.clone(),
+                watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
+                cx,
+            )
+        });
+
+        self.sessions.borrow_mut().insert(
+            session.session_id.clone(),
+            AcpSession {
                 thread: thread.downgrade(),
                 suppress_abort_err: false,
-                session_modes: modes,
-                models,
-                config_options: config_options.map(|opts| ConfigOptions::new(opts))
+                session_modes: None,
+                models: None,
+                config_options: None,
+            },
+        );
+
+        cx.spawn(async move |cx| {
+            let response = match self
+                .connection
+                .load_session(
+                    acp::LoadSessionRequest::new(session.session_id.clone(), cwd)
+                        .mcp_servers(mcp_servers),
+                )
+                .await
+            {
+                Ok(response) => response,
+                Err(err) => {
+                    self.sessions.borrow_mut().remove(&session.session_id);
+                    return Err(map_acp_error(err));
+                }
             };
-            sessions.borrow_mut().insert(session_id, session);
+
+            let (modes, models, config_options) = cx.update(|cx| {
+                config_state(cx, response.modes, response.models, response.config_options)
+            });
+            if let Some(session) = self.sessions.borrow_mut().get_mut(&session.session_id) {
+                session.session_modes = modes;
+                session.models = models;
+                session.config_options = config_options.map(ConfigOptions::new);
+            }
 
             if let Some(session_list) = &self.session_list {
                 session_list.notify_update();
@@ -801,6 +805,88 @@ impl AgentConnection for AcpConnection {
     }
 }
 
+fn map_acp_error(err: acp::Error) -> anyhow::Error {
+    if err.code == acp::ErrorCode::AuthRequired {
+        let mut error = AuthRequired::new();
+
+        if err.message != acp::ErrorCode::AuthRequired.to_string() {
+            error = error.with_description(err.message);
+        }
+
+        anyhow!(error)
+    } else {
+        anyhow!(err)
+    }
+}
+
+fn mcp_servers_for_project(project: &Entity<Project>, cx: &App) -> Vec<acp::McpServer> {
+    let context_server_store = project.read(cx).context_server_store().read(cx);
+    let is_local = project.read(cx).is_local();
+    context_server_store
+        .configured_server_ids()
+        .iter()
+        .filter_map(|id| {
+            let configuration = context_server_store.configuration_for_server(id)?;
+            match &*configuration {
+                project::context_server_store::ContextServerConfiguration::Custom {
+                    command,
+                    remote,
+                    ..
+                }
+                | project::context_server_store::ContextServerConfiguration::Extension {
+                    command,
+                    remote,
+                    ..
+                } if is_local || *remote => Some(acp::McpServer::Stdio(
+                    acp::McpServerStdio::new(id.0.to_string(), &command.path)
+                        .args(command.args.clone())
+                        .env(if let Some(env) = command.env.as_ref() {
+                            env.iter()
+                                .map(|(name, value)| acp::EnvVariable::new(name, value))
+                                .collect()
+                        } else {
+                            vec![]
+                        }),
+                )),
+                project::context_server_store::ContextServerConfiguration::Http {
+                    url,
+                    headers,
+                    timeout: _,
+                } => Some(acp::McpServer::Http(
+                    acp::McpServerHttp::new(id.0.to_string(), url.to_string()).headers(
+                        headers
+                            .iter()
+                            .map(|(name, value)| acp::HttpHeader::new(name, value))
+                            .collect(),
+                    ),
+                )),
+                _ => None,
+            }
+        })
+        .collect()
+}
+
+fn config_state(
+    cx: &App,
+    modes: Option<acp::SessionModeState>,
+    models: Option<acp::SessionModelState>,
+    config_options: Option<Vec<acp::SessionConfigOption>>,
+) -> (
+    Option<Rc<RefCell<acp::SessionModeState>>>,
+    Option<Rc<RefCell<acp::SessionModelState>>>,
+    Option<Rc<RefCell<Vec<acp::SessionConfigOption>>>>,
+) {
+    if cx.has_flag::<AcpBetaFeatureFlag>()
+        && let Some(opts) = config_options
+    {
+        return (None, None, Some(Rc::new(RefCell::new(opts))));
+    }
+
+    let modes = modes.map(|modes| Rc::new(RefCell::new(modes)));
+    let models = models.map(|models| Rc::new(RefCell::new(models)));
+    (modes, models, None)
+}
+
 struct AcpSessionModes {
     session_id: acp::SessionId,
     connection: Rc<acp::ClientSideConnection>,

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -651,28 +651,25 @@ impl AcpThreadView {
             }
 
             let result = if let Some(resume) = resume_thread.clone() {
-                if connection.supports_load_session() {
-                    let session_cwd = resume
-                        .cwd
-                        .clone()
-                        .unwrap_or_else(|| fallback_cwd.as_ref().to_path_buf());
-                    cx.update(|_, cx| {
+                cx.update(|_, cx| {
+                    if connection.supports_load_session(cx) {
+                        let session_cwd = resume
+                            .cwd
+                            .clone()
+                            .unwrap_or_else(|| fallback_cwd.as_ref().to_path_buf());
                         connection.clone().load_session(
                             resume,
                             project.clone(),
                             session_cwd.as_path(),
                             cx,
                         )
-                    })
-                    .log_err()
-                } else {
-                    cx.update(|_, _| {
+                    } else {
                         Task::ready(Err(anyhow!(LoadError::Other(
                             "Loading sessions is not supported by this agent.".into()
                         ))))
-                    })
-                    .log_err()
-                }
+                    }
+                })
+                .log_err()
             } else {
                 cx.update(|_, cx| {
                     connection
@@ -723,7 +720,7 @@ impl AcpThreadView {
 
                         let connection = thread.read(cx).connection().clone();
                         let session_id = thread.read(cx).session_id().clone();
-                        let session_list = if connection.supports_load_session() {
+                        let session_list = if connection.supports_load_session(cx) {
                             connection.session_list(cx)
                         } else {
                             None
@@ -6673,7 +6670,7 @@ impl AcpThreadView {
                 MentionUri::Thread { id, name } => {
                     if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
                         panel.update(cx, |panel, cx| {
-                            panel.load_agent_thread(
+                            panel.open_thread(
                                 AgentSessionInfo {
                                     session_id: id,
                                     cwd: None,

crates/agent_ui/src/agent_panel.rs 🔗

@@ -553,13 +553,7 @@ impl AgentPanel {
             window,
             |this, _, event, window, cx| match event {
                 ThreadHistoryEvent::Open(thread) => {
-                    this.external_thread(
-                        Some(crate::ExternalAgent::NativeAgent),
-                        Some(thread.clone()),
-                        None,
-                        window,
-                        cx,
-                    );
+                    this.load_agent_thread(thread.clone(), window, cx);
                 }
             },
         )
@@ -1390,13 +1384,7 @@ impl AgentPanel {
                             let entry = entry.clone();
                             panel
                                 .update(cx, move |this, cx| {
-                                    this.external_thread(
-                                        Some(ExternalAgent::NativeAgent),
-                                        Some(entry.clone()),
-                                        None,
-                                        window,
-                                        cx,
-                                    );
+                                    this.load_agent_thread(entry.clone(), window, cx);
                                 })
                                 .ok();
                         }
@@ -1450,6 +1438,17 @@ impl AgentPanel {
         self.selected_agent.clone()
     }
 
+    fn selected_external_agent(&self) -> Option<ExternalAgent> {
+        match &self.selected_agent {
+            AgentType::NativeAgent => Some(ExternalAgent::NativeAgent),
+            AgentType::Gemini => Some(ExternalAgent::Gemini),
+            AgentType::ClaudeCode => Some(ExternalAgent::ClaudeCode),
+            AgentType::Codex => Some(ExternalAgent::Codex),
+            AgentType::Custom { name } => Some(ExternalAgent::Custom { name: name.clone() }),
+            AgentType::TextThread => None,
+        }
+    }
+
     fn sync_agent_servers_from_extensions(&mut self, cx: &mut Context<Self>) {
         if let Some(extension_store) = ExtensionStore::try_global(cx) {
             let (manifests, extensions_dir) = {
@@ -1527,13 +1526,10 @@ impl AgentPanel {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        self.external_thread(
-            Some(ExternalAgent::NativeAgent),
-            Some(thread),
-            None,
-            window,
-            cx,
-        );
+        let Some(agent) = self.selected_external_agent() else {
+            return;
+        };
+        self.external_thread(Some(agent), Some(thread), None, window, cx);
     }
 
     fn _external_thread(

crates/agent_ui_v2/src/agents_panel.rs 🔗

@@ -151,7 +151,7 @@ impl AgentsPanel {
             };
 
             cx.update(|cx| {
-                if connection.supports_load_session()
+                if connection.supports_load_session(cx)
                     && let Some(session_list) = connection.session_list(cx)
                 {
                     history_handle.update(cx, |history, cx| {