Fix auth

Agus Zubiaga created

Change summary

crates/acp_thread/src/acp_thread.rs        |  1 
crates/acp_thread/src/connection.rs        | 14 +++++
crates/acp_thread/src/old_acp_support.rs   | 31 ++++---------
crates/agent_servers/src/acp_connection.rs | 54 +++++++----------------
crates/agent_servers/src/claude.rs         |  8 +--
crates/agent_ui/src/acp/thread_view.rs     | 28 +++--------
6 files changed, 48 insertions(+), 88 deletions(-)

Detailed changes

crates/acp_thread/src/connection.rs 🔗

@@ -1,4 +1,4 @@
-use std::{cell::Ref, path::Path, rc::Rc};
+use std::{error::Error, fmt, path::Path, rc::Rc};
 
 use agent_client_protocol::{self as acp};
 use anyhow::Result;
@@ -16,7 +16,7 @@ pub trait AgentConnection {
         cx: &mut AsyncApp,
     ) -> Task<Result<Entity<AcpThread>>>;
 
-    fn state(&self) -> Ref<'_, acp::AgentState>;
+    fn auth_methods(&self) -> Vec<acp::AuthMethod>;
 
     fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 
@@ -24,3 +24,13 @@ pub trait AgentConnection {
 
     fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 }
+
+#[derive(Debug)]
+pub struct AuthRequired;
+
+impl Error for AuthRequired {}
+impl fmt::Display for AuthRequired {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "AuthRequired")
+    }
+}

crates/acp_thread/src/old_acp_support.rs 🔗

@@ -5,17 +5,11 @@ use anyhow::{Context as _, Result};
 use futures::channel::oneshot;
 use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 use project::Project;
-use std::{
-    cell::{Ref, RefCell},
-    error::Error,
-    fmt,
-    path::Path,
-    rc::Rc,
-};
+use std::{cell::RefCell, path::Path, rc::Rc};
 use ui::App;
 use util::ResultExt as _;
 
-use crate::{AcpThread, AgentConnection};
+use crate::{AcpThread, AgentConnection, AuthRequired};
 
 #[derive(Clone)]
 pub struct OldAcpClientDelegate {
@@ -357,21 +351,10 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
     }
 }
 
-#[derive(Debug)]
-pub struct Unauthenticated;
-
-impl Error for Unauthenticated {}
-impl fmt::Display for Unauthenticated {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(f, "Unauthenticated")
-    }
-}
-
 pub struct OldAcpAgentConnection {
     pub name: &'static str,
     pub connection: acp_old::AgentConnection,
     pub child_status: Task<Result<()>>,
-    pub agent_state: Rc<RefCell<acp::AgentState>>,
     pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
 }
 
@@ -394,7 +377,7 @@ impl AgentConnection for OldAcpAgentConnection {
             let result = acp_old::InitializeParams::response_from_any(result)?;
 
             if !result.is_authenticated {
-                anyhow::bail!(Unauthenticated)
+                anyhow::bail!(AuthRequired)
             }
 
             cx.update(|cx| {
@@ -408,8 +391,12 @@ impl AgentConnection for OldAcpAgentConnection {
         })
     }
 
-    fn state(&self) -> Ref<'_, acp::AgentState> {
-        self.agent_state.borrow()
+    fn auth_methods(&self) -> Vec<acp::AuthMethod> {
+        vec![acp::AuthMethod {
+            id: acp::AuthMethodId("acp-old-no-id".into()),
+            label: "Log in".into(),
+            description: None,
+        }]
     }
 
     fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {

crates/agent_servers/src/acp_connection.rs 🔗

@@ -7,24 +7,23 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use futures::channel::{mpsc, oneshot};
 use project::Project;
 use smol::stream::StreamExt as _;
-use std::cell::{Ref, RefCell};
+use std::cell::RefCell;
 use std::rc::Rc;
 use std::{path::Path, sync::Arc};
-use util::{ResultExt, TryFutureExt};
+use util::ResultExt;
 
 use anyhow::{Context, Result};
 use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 
 use crate::mcp_server::ZedMcpServer;
 use crate::{AgentServerCommand, mcp_server};
-use acp_thread::{AcpThread, AgentConnection};
+use acp_thread::{AcpThread, AgentConnection, AuthRequired};
 
 pub struct AcpConnection {
-    agent_state: Rc<RefCell<acp::AgentState>>,
+    auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
     server_name: &'static str,
     client: Arc<context_server::ContextServer>,
     sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
-    _agent_state_task: Task<()>,
     _session_update_task: Task<()>,
 }
 
@@ -47,24 +46,8 @@ impl AcpConnection {
         .into();
         ContextServer::start(client.clone(), cx).await?;
 
-        let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
         let mcp_client = client.client().context("Failed to subscribe")?;
 
-        mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
-            move |notification, _cx| {
-                log::trace!(
-                    "ACP Notification: {}",
-                    serde_json::to_string_pretty(&notification).unwrap()
-                );
-
-                if let Some(state) =
-                    serde_json::from_value::<acp::AgentState>(notification).log_err()
-                {
-                    state_tx.send(state).log_err();
-                }
-            }
-        });
-
         let (notification_tx, mut notification_rx) = mpsc::unbounded();
         mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
             move |notification, _cx| {
@@ -83,17 +66,6 @@ impl AcpConnection {
         });
 
         let sessions = Rc::new(RefCell::new(HashMap::default()));
-        let initial_state = state_rx.recv().await?;
-        let agent_state = Rc::new(RefCell::new(initial_state));
-
-        let agent_state_task = cx.foreground_executor().spawn({
-            let agent_state = agent_state.clone();
-            async move {
-                while let Some(state) = state_rx.recv().log_err().await {
-                    agent_state.replace(state);
-                }
-            }
-        });
 
         let session_update_handler_task = cx.spawn({
             let sessions = sessions.clone();
@@ -105,11 +77,10 @@ impl AcpConnection {
         });
 
         Ok(Self {
+            auth_methods: Default::default(),
             server_name,
             client,
             sessions,
-            agent_state,
-            _agent_state_task: agent_state_task,
             _session_update_task: session_update_handler_task,
         })
     }
@@ -154,6 +125,7 @@ impl AgentConnection for AcpConnection {
     ) -> Task<Result<Entity<AcpThread>>> {
         let client = self.client.client();
         let sessions = self.sessions.clone();
+        let auth_methods = self.auth_methods.clone();
         let cwd = cwd.to_path_buf();
         cx.spawn(async move |cx| {
             let client = client.context("MCP server is not initialized yet")?;
@@ -194,12 +166,18 @@ impl AgentConnection for AcpConnection {
                 response.structured_content.context("Empty response")?,
             )?;
 
+            auth_methods.replace(result.auth_methods);
+
+            let Some(session_id) = result.session_id else {
+                anyhow::bail!(AuthRequired);
+            };
+
             let thread = cx.new(|cx| {
                 AcpThread::new(
                     self.server_name,
                     self.clone(),
                     project,
-                    result.session_id.clone(),
+                    session_id.clone(),
                     cx,
                 )
             })?;
@@ -211,14 +189,14 @@ impl AgentConnection for AcpConnection {
                 cancel_tx: None,
                 _mcp_server: mcp_server,
             };
-            sessions.borrow_mut().insert(result.session_id, session);
+            sessions.borrow_mut().insert(session_id, session);
 
             Ok(thread)
         })
     }
 
-    fn state(&self) -> Ref<'_, acp::AgentState> {
-        self.agent_state.borrow()
+    fn auth_methods(&self) -> Vec<agent_client_protocol::AuthMethod> {
+        self.auth_methods.borrow().clone()
     }
 
     fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {

crates/agent_servers/src/claude.rs 🔗

@@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
 use project::Project;
 use settings::SettingsStore;
 use smol::process::Child;
-use std::cell::{Ref, RefCell};
+use std::cell::RefCell;
 use std::fmt::Display;
 use std::path::Path;
 use std::rc::Rc;
@@ -58,7 +58,6 @@ impl AgentServer for ClaudeCode {
         _cx: &mut App,
     ) -> Task<Result<Rc<dyn AgentConnection>>> {
         let connection = ClaudeAgentConnection {
-            agent_state: Default::default(),
             sessions: Default::default(),
         };
 
@@ -67,7 +66,6 @@ impl AgentServer for ClaudeCode {
 }
 
 struct ClaudeAgentConnection {
-    agent_state: Rc<RefCell<acp::AgentState>>,
     sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
 }
 
@@ -185,8 +183,8 @@ impl AgentConnection for ClaudeAgentConnection {
         })
     }
 
-    fn state(&self) -> Ref<'_, acp::AgentState> {
-        self.agent_state.borrow()
+    fn auth_methods(&self) -> Vec<acp::AuthMethod> {
+        vec![]
     }
 
     fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -216,15 +216,6 @@ impl AcpThreadView {
                 }
             };
 
-            if connection.state().needs_authentication {
-                this.update(cx, |this, cx| {
-                    this.thread_state = ThreadState::Unauthenticated { connection };
-                    cx.notify();
-                })
-                .ok();
-                return;
-            }
-
             let result = match connection
                 .clone()
                 .new_thread(project.clone(), &root_dir, cx)
@@ -233,7 +224,7 @@ impl AcpThreadView {
                 Err(e) => {
                     let mut cx = cx.clone();
                     // todo! remove duplication
-                    if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
+                    if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
                         this.update(&mut cx, |this, cx| {
                             this.thread_state = ThreadState::Unauthenticated { connection };
                             cx.notify();
@@ -2219,17 +2210,14 @@ impl Render for AcpThreadView {
                     .justify_center()
                     .child(self.render_pending_auth_state())
                     .child(h_flex().mt_1p5().justify_center().children(
-                        connection.state().auth_methods.iter().map(|method| {
-                            Button::new(
-                                SharedString::from(method.id.0.clone()),
-                                method.label.clone(),
-                            )
-                            .on_click({
-                                let method_id = method.id.clone();
-                                cx.listener(move |this, _, window, cx| {
-                                    this.authenticate(method_id.clone(), window, cx)
+                        connection.auth_methods().into_iter().map(|method| {
+                            Button::new(SharedString::from(method.id.0.clone()), method.label)
+                                .on_click({
+                                    let method_id = method.id.clone();
+                                    cx.listener(move |this, _, window, cx| {
+                                        this.authenticate(method_id.clone(), window, cx)
+                                    })
                                 })
-                            })
                         }),
                     )),
                 ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),