Auth WIP

Agus Zubiaga created

Change summary

Cargo.lock                                 |   2 
Cargo.toml                                 |   2 
crates/acp_thread/src/acp_thread.rs        |   4 
crates/acp_thread/src/connection.rs        |   8 +
crates/acp_thread/src/old_acp_support.rs   |  15 +++
crates/agent_servers/src/acp_connection.rs | 100 ++++++++++++++++++-----
crates/agent_servers/src/claude.rs         |  10 +
crates/agent_ui/src/acp/thread_view.rs     |  55 ++++++++----
8 files changed, 140 insertions(+), 56 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -139,8 +139,6 @@ dependencies = [
 [[package]]
 name = "agent-client-protocol"
 version = "0.0.13"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4255a06cc2414033d1fe4baf1968bcc8f16d7e5814f272b97779b5806d129142"
 dependencies = [
  "schemars",
  "serde",

Cargo.toml 🔗

@@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" }
 #
 
 agentic-coding-protocol = "0.0.10"
-agent-client-protocol = "0.0.13"
+agent-client-protocol = {path="../agent-client-protocol"}
 aho-corasick = "1.1"
 alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
 any_vec = "0.14"

crates/acp_thread/src/acp_thread.rs 🔗

@@ -958,10 +958,6 @@ impl AcpThread {
         cx.notify();
     }
 
-    pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
-        self.connection.authenticate(cx)
-    }
-
     #[cfg(any(test, feature = "test-support"))]
     pub fn send_raw(
         &mut self,

crates/acp_thread/src/connection.rs 🔗

@@ -1,6 +1,6 @@
-use std::{path::Path, rc::Rc};
+use std::{cell::Ref, path::Path, rc::Rc};
 
-use agent_client_protocol as acp;
+use agent_client_protocol::{self as acp};
 use anyhow::Result;
 use gpui::{AsyncApp, Entity, Task};
 use project::Project;
@@ -16,7 +16,9 @@ pub trait AgentConnection {
         cx: &mut AsyncApp,
     ) -> Task<Result<Entity<AcpThread>>>;
 
-    fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
+    fn state(&self) -> Ref<'_, acp::AgentState>;
+
+    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 
     fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
 

crates/acp_thread/src/old_acp_support.rs 🔗

@@ -5,7 +5,13 @@ use anyhow::{Context as _, Result};
 use futures::channel::oneshot;
 use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 use project::Project;
-use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
+use std::{
+    cell::{Ref, RefCell},
+    error::Error,
+    fmt,
+    path::Path,
+    rc::Rc,
+};
 use ui::App;
 
 use crate::{AcpThread, AgentConnection};
@@ -364,6 +370,7 @@ pub struct OldAcpAgentConnection {
     pub name: &'static str,
     pub connection: acp_old::AgentConnection,
     pub child_status: Task<Result<()>>,
+    pub agent_state: Rc<RefCell<acp::AgentState>>,
 }
 
 impl AgentConnection for OldAcpAgentConnection {
@@ -397,7 +404,11 @@ impl AgentConnection for OldAcpAgentConnection {
         })
     }
 
-    fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
+    fn state(&self) -> Ref<'_, acp::AgentState> {
+        self.agent_state.borrow()
+    }
+
+    fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
         let task = self
             .connection
             .request_any(acp_old::AuthenticateParams.into_any());

crates/agent_servers/src/acp_connection.rs 🔗

@@ -7,10 +7,10 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use futures::channel::{mpsc, oneshot};
 use project::Project;
 use smol::stream::StreamExt as _;
-use std::cell::RefCell;
+use std::cell::{Ref, RefCell};
 use std::rc::Rc;
 use std::{path::Path, sync::Arc};
-use util::ResultExt;
+use util::{ResultExt, TryFutureExt};
 
 use anyhow::{Context, Result};
 use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
@@ -20,10 +20,12 @@ use crate::{AgentServerCommand, mcp_server};
 use acp_thread::{AcpThread, AgentConnection};
 
 pub struct AcpConnection {
+    agent_state: Rc<RefCell<acp::AgentState>>,
     server_name: &'static str,
     client: Arc<context_server::ContextServer>,
     sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
-    _notification_handler_task: Task<()>,
+    _agent_state_task: Task<()>,
+    _session_update_task: Task<()>,
 }
 
 impl AcpConnection {
@@ -43,29 +45,55 @@ impl AcpConnection {
         .into();
         ContextServer::start(client.clone(), cx).await?;
 
+        let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
+        let mcp_client = client.client().context("Failed to subscribe")?;
+
+        mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
+            move |notification, _cx| {
+                log::trace!(
+                    "ACP Notification: {}",
+                    serde_json::to_string_pretty(&notification).unwrap()
+                );
+
+                if let Some(state) =
+                    serde_json::from_value::<acp::AgentState>(notification).log_err()
+                {
+                    state_tx.send(state).log_err();
+                }
+            }
+        });
+
         let (notification_tx, mut notification_rx) = mpsc::unbounded();
-        client
-            .client()
-            .context("Failed to subscribe")?
-            .on_notification(acp::AGENT_METHODS.session_update, {
-                move |notification, _cx| {
-                    let notification_tx = notification_tx.clone();
-                    log::trace!(
-                        "ACP Notification: {}",
-                        serde_json::to_string_pretty(&notification).unwrap()
-                    );
-
-                    if let Some(notification) =
-                        serde_json::from_value::<acp::SessionNotification>(notification).log_err()
-                    {
-                        notification_tx.unbounded_send(notification).ok();
-                    }
+        mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
+            move |notification, _cx| {
+                let notification_tx = notification_tx.clone();
+                log::trace!(
+                    "ACP Notification: {}",
+                    serde_json::to_string_pretty(&notification).unwrap()
+                );
+
+                if let Some(notification) =
+                    serde_json::from_value::<acp::SessionNotification>(notification).log_err()
+                {
+                    notification_tx.unbounded_send(notification).ok();
                 }
-            });
+            }
+        });
 
         let sessions = Rc::new(RefCell::new(HashMap::default()));
+        let initial_state = state_rx.recv().await?;
+        let agent_state = Rc::new(RefCell::new(initial_state));
+
+        let agent_state_task = cx.foreground_executor().spawn({
+            let agent_state = agent_state.clone();
+            async move {
+                while let Some(state) = state_rx.recv().log_err().await {
+                    agent_state.replace(state);
+                }
+            }
+        });
 
-        let notification_handler_task = cx.spawn({
+        let session_update_handler_task = cx.spawn({
             let sessions = sessions.clone();
             async move |cx| {
                 while let Some(notification) = notification_rx.next().await {
@@ -78,7 +106,9 @@ impl AcpConnection {
             server_name,
             client,
             sessions,
-            _notification_handler_task: notification_handler_task,
+            agent_state,
+            _agent_state_task: agent_state_task,
+            _session_update_task: session_update_handler_task,
         })
     }
 
@@ -185,8 +215,30 @@ impl AgentConnection for AcpConnection {
         })
     }
 
-    fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
-        Task::ready(Err(anyhow!("Authentication not supported")))
+    fn state(&self) -> Ref<'_, acp::AgentState> {
+        self.agent_state.borrow()
+    }
+
+    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
+        let client = self.client.client();
+        cx.foreground_executor().spawn(async move {
+            let params = acp::AuthenticateArguments { method_id };
+
+            let response = client
+                .context("MCP server is not initialized yet")?
+                .request::<requests::CallTool>(context_server::types::CallToolParams {
+                    name: acp::AGENT_METHODS.authenticate.into(),
+                    arguments: Some(serde_json::to_value(params)?),
+                    meta: None,
+                })
+                .await?;
+
+            if response.is_error.unwrap_or_default() {
+                Err(anyhow!(response.text_contents()))
+            } else {
+                Ok(())
+            }
+        })
     }
 
     fn prompt(

crates/agent_servers/src/claude.rs 🔗

@@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
 use project::Project;
 use settings::SettingsStore;
 use smol::process::Child;
-use std::cell::RefCell;
+use std::cell::{Ref, RefCell};
 use std::fmt::Display;
 use std::path::Path;
 use std::rc::Rc;
@@ -58,6 +58,7 @@ impl AgentServer for ClaudeCode {
         _cx: &mut App,
     ) -> Task<Result<Rc<dyn AgentConnection>>> {
         let connection = ClaudeAgentConnection {
+            agent_state: Default::default(),
             sessions: Default::default(),
         };
 
@@ -66,6 +67,7 @@ impl AgentServer for ClaudeCode {
 }
 
 struct ClaudeAgentConnection {
+    agent_state: Rc<RefCell<acp::AgentState>>,
     sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
 }
 
@@ -183,7 +185,11 @@ impl AgentConnection for ClaudeAgentConnection {
         })
     }
 
-    fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
+    fn state(&self) -> Ref<'_, acp::AgentState> {
+        self.agent_state.borrow()
+    }
+
+    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
         Task::ready(Err(anyhow!("Authentication not supported")))
     }
 

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -216,6 +216,15 @@ impl AcpThreadView {
                 }
             };
 
+            if connection.state().needs_authentication {
+                this.update(cx, |this, cx| {
+                    this.thread_state = ThreadState::Unauthenticated { connection };
+                    cx.notify();
+                })
+                .ok();
+                return;
+            }
+
             let result = match connection
                 .clone()
                 .new_thread(project.clone(), &root_dir, cx)
@@ -223,6 +232,7 @@ impl AcpThreadView {
             {
                 Err(e) => {
                     let mut cx = cx.clone();
+                    // todo! remove duplication
                     if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
                         this.update(&mut cx, |this, cx| {
                             this.thread_state = ThreadState::Unauthenticated { connection };
@@ -640,13 +650,18 @@ impl AcpThreadView {
         Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
     }
 
-    fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+    fn authenticate(
+        &mut self,
+        method: acp::AuthMethodId,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
         let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
             return;
         };
 
         self.last_error.take();
-        let authenticate = connection.authenticate(cx);
+        let authenticate = connection.authenticate(method, cx);
         self.auth_task = Some(cx.spawn_in(window, {
             let project = self.project.clone();
             let agent = self.agent.clone();
@@ -2197,22 +2212,26 @@ impl Render for AcpThreadView {
             .on_action(cx.listener(Self::next_history_message))
             .on_action(cx.listener(Self::open_agent_diff))
             .child(match &self.thread_state {
-                ThreadState::Unauthenticated { .. } => {
-                    v_flex()
-                        .p_2()
-                        .flex_1()
-                        .items_center()
-                        .justify_center()
-                        .child(self.render_pending_auth_state())
-                        .child(
-                            h_flex().mt_1p5().justify_center().child(
-                                Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
-                                    .on_click(cx.listener(|this, _, window, cx| {
-                                        this.authenticate(window, cx)
-                                    })),
-                            ),
-                        )
-                }
+                ThreadState::Unauthenticated { connection } => v_flex()
+                    .p_2()
+                    .flex_1()
+                    .items_center()
+                    .justify_center()
+                    .child(self.render_pending_auth_state())
+                    .child(h_flex().mt_1p5().justify_center().children(
+                        connection.state().auth_methods.iter().map(|method| {
+                            Button::new(
+                                SharedString::from(method.id.0.clone()),
+                                method.label.clone(),
+                            )
+                            .on_click({
+                                let method_id = method.id.clone();
+                                cx.listener(move |this, _, window, cx| {
+                                    this.authenticate(method_id.clone(), window, cx)
+                                })
+                            })
+                        }),
+                    )),
                 ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
                 ThreadState::LoadError(e) => v_flex()
                     .p_2()