Detailed changes
@@ -139,8 +139,6 @@ dependencies = [
[[package]]
name = "agent-client-protocol"
version = "0.0.13"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4255a06cc2414033d1fe4baf1968bcc8f16d7e5814f272b97779b5806d129142"
dependencies = [
"schemars",
"serde",
@@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
-agent-client-protocol = "0.0.13"
+agent-client-protocol = {path="../agent-client-protocol"}
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"
@@ -958,10 +958,6 @@ impl AcpThread {
cx.notify();
}
- pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
- self.connection.authenticate(cx)
- }
-
#[cfg(any(test, feature = "test-support"))]
pub fn send_raw(
&mut self,
@@ -1,6 +1,6 @@
-use std::{path::Path, rc::Rc};
+use std::{cell::Ref, path::Path, rc::Rc};
-use agent_client_protocol as acp;
+use agent_client_protocol::{self as acp};
use anyhow::Result;
use gpui::{AsyncApp, Entity, Task};
use project::Project;
@@ -16,7 +16,9 @@ pub trait AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>;
- fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
+ fn state(&self) -> Ref<'_, acp::AgentState>;
+
+ fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
@@ -5,7 +5,13 @@ use anyhow::{Context as _, Result};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
-use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
+use std::{
+ cell::{Ref, RefCell},
+ error::Error,
+ fmt,
+ path::Path,
+ rc::Rc,
+};
use ui::App;
use crate::{AcpThread, AgentConnection};
@@ -364,6 +370,7 @@ pub struct OldAcpAgentConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
pub child_status: Task<Result<()>>,
+ pub agent_state: Rc<RefCell<acp::AgentState>>,
}
impl AgentConnection for OldAcpAgentConnection {
@@ -397,7 +404,11 @@ impl AgentConnection for OldAcpAgentConnection {
})
}
- fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
+ fn state(&self) -> Ref<'_, acp::AgentState> {
+ self.agent_state.borrow()
+ }
+
+ fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let task = self
.connection
.request_any(acp_old::AuthenticateParams.into_any());
@@ -7,10 +7,10 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project;
use smol::stream::StreamExt as _;
-use std::cell::RefCell;
+use std::cell::{Ref, RefCell};
use std::rc::Rc;
use std::{path::Path, sync::Arc};
-use util::ResultExt;
+use util::{ResultExt, TryFutureExt};
use anyhow::{Context, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
@@ -20,10 +20,12 @@ use crate::{AgentServerCommand, mcp_server};
use acp_thread::{AcpThread, AgentConnection};
pub struct AcpConnection {
+ agent_state: Rc<RefCell<acp::AgentState>>,
server_name: &'static str,
client: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
- _notification_handler_task: Task<()>,
+ _agent_state_task: Task<()>,
+ _session_update_task: Task<()>,
}
impl AcpConnection {
@@ -43,29 +45,55 @@ impl AcpConnection {
.into();
ContextServer::start(client.clone(), cx).await?;
+ let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
+ let mcp_client = client.client().context("Failed to subscribe")?;
+
+ mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
+ move |notification, _cx| {
+ log::trace!(
+ "ACP Notification: {}",
+ serde_json::to_string_pretty(¬ification).unwrap()
+ );
+
+ if let Some(state) =
+ serde_json::from_value::<acp::AgentState>(notification).log_err()
+ {
+ state_tx.send(state).log_err();
+ }
+ }
+ });
+
let (notification_tx, mut notification_rx) = mpsc::unbounded();
- client
- .client()
- .context("Failed to subscribe")?
- .on_notification(acp::AGENT_METHODS.session_update, {
- move |notification, _cx| {
- let notification_tx = notification_tx.clone();
- log::trace!(
- "ACP Notification: {}",
- serde_json::to_string_pretty(¬ification).unwrap()
- );
-
- if let Some(notification) =
- serde_json::from_value::<acp::SessionNotification>(notification).log_err()
- {
- notification_tx.unbounded_send(notification).ok();
- }
+ mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
+ move |notification, _cx| {
+ let notification_tx = notification_tx.clone();
+ log::trace!(
+ "ACP Notification: {}",
+ serde_json::to_string_pretty(¬ification).unwrap()
+ );
+
+ if let Some(notification) =
+ serde_json::from_value::<acp::SessionNotification>(notification).log_err()
+ {
+ notification_tx.unbounded_send(notification).ok();
}
- });
+ }
+ });
let sessions = Rc::new(RefCell::new(HashMap::default()));
+ let initial_state = state_rx.recv().await?;
+ let agent_state = Rc::new(RefCell::new(initial_state));
+
+ let agent_state_task = cx.foreground_executor().spawn({
+ let agent_state = agent_state.clone();
+ async move {
+ while let Some(state) = state_rx.recv().log_err().await {
+ agent_state.replace(state);
+ }
+ }
+ });
- let notification_handler_task = cx.spawn({
+ let session_update_handler_task = cx.spawn({
let sessions = sessions.clone();
async move |cx| {
while let Some(notification) = notification_rx.next().await {
@@ -78,7 +106,9 @@ impl AcpConnection {
server_name,
client,
sessions,
- _notification_handler_task: notification_handler_task,
+ agent_state,
+ _agent_state_task: agent_state_task,
+ _session_update_task: session_update_handler_task,
})
}
@@ -185,8 +215,30 @@ impl AgentConnection for AcpConnection {
})
}
- fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
- Task::ready(Err(anyhow!("Authentication not supported")))
+ fn state(&self) -> Ref<'_, acp::AgentState> {
+ self.agent_state.borrow()
+ }
+
+ fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
+ let client = self.client.client();
+ cx.foreground_executor().spawn(async move {
+ let params = acp::AuthenticateArguments { method_id };
+
+ let response = client
+ .context("MCP server is not initialized yet")?
+ .request::<requests::CallTool>(context_server::types::CallToolParams {
+ name: acp::AGENT_METHODS.authenticate.into(),
+ arguments: Some(serde_json::to_value(params)?),
+ meta: None,
+ })
+ .await?;
+
+ if response.is_error.unwrap_or_default() {
+ Err(anyhow!(response.text_contents()))
+ } else {
+ Ok(())
+ }
+ })
}
fn prompt(
@@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
-use std::cell::RefCell;
+use std::cell::{Ref, RefCell};
use std::fmt::Display;
use std::path::Path;
use std::rc::Rc;
@@ -58,6 +58,7 @@ impl AgentServer for ClaudeCode {
_cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let connection = ClaudeAgentConnection {
+ agent_state: Default::default(),
sessions: Default::default(),
};
@@ -66,6 +67,7 @@ impl AgentServer for ClaudeCode {
}
struct ClaudeAgentConnection {
+ agent_state: Rc<RefCell<acp::AgentState>>,
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
}
@@ -183,7 +185,11 @@ impl AgentConnection for ClaudeAgentConnection {
})
}
- fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
+ fn state(&self) -> Ref<'_, acp::AgentState> {
+ self.agent_state.borrow()
+ }
+
+ fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported")))
}
@@ -216,6 +216,15 @@ impl AcpThreadView {
}
};
+ if connection.state().needs_authentication {
+ this.update(cx, |this, cx| {
+ this.thread_state = ThreadState::Unauthenticated { connection };
+ cx.notify();
+ })
+ .ok();
+ return;
+ }
+
let result = match connection
.clone()
.new_thread(project.clone(), &root_dir, cx)
@@ -223,6 +232,7 @@ impl AcpThreadView {
{
Err(e) => {
let mut cx = cx.clone();
+ // todo! remove duplication
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection };
@@ -640,13 +650,18 @@ impl AcpThreadView {
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
}
- fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ fn authenticate(
+ &mut self,
+ method: acp::AuthMethodId,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
return;
};
self.last_error.take();
- let authenticate = connection.authenticate(cx);
+ let authenticate = connection.authenticate(method, cx);
self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone();
let agent = self.agent.clone();
@@ -2197,22 +2212,26 @@ impl Render for AcpThreadView {
.on_action(cx.listener(Self::next_history_message))
.on_action(cx.listener(Self::open_agent_diff))
.child(match &self.thread_state {
- ThreadState::Unauthenticated { .. } => {
- v_flex()
- .p_2()
- .flex_1()
- .items_center()
- .justify_center()
- .child(self.render_pending_auth_state())
- .child(
- h_flex().mt_1p5().justify_center().child(
- Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
- .on_click(cx.listener(|this, _, window, cx| {
- this.authenticate(window, cx)
- })),
- ),
- )
- }
+ ThreadState::Unauthenticated { connection } => v_flex()
+ .p_2()
+ .flex_1()
+ .items_center()
+ .justify_center()
+ .child(self.render_pending_auth_state())
+ .child(h_flex().mt_1p5().justify_center().children(
+ connection.state().auth_methods.iter().map(|method| {
+ Button::new(
+ SharedString::from(method.id.0.clone()),
+ method.label.clone(),
+ )
+ .on_click({
+ let method_id = method.id.clone();
+ cx.listener(move |this, _, window, cx| {
+ this.authenticate(method_id.clone(), window, cx)
+ })
+ })
+ }),
+ )),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
ThreadState::LoadError(e) => v_flex()
.p_2()