Detailed changes
@@ -1595,7 +1595,6 @@ mod tests {
connection,
child_status: io_task,
current_thread: thread_rc,
- agent_state: Default::default(),
};
AcpThread::new(
@@ -1,4 +1,4 @@
-use std::{cell::Ref, path::Path, rc::Rc};
+use std::{error::Error, fmt, path::Path, rc::Rc};
use agent_client_protocol::{self as acp};
use anyhow::Result;
@@ -16,7 +16,7 @@ pub trait AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>;
- fn state(&self) -> Ref<'_, acp::AgentState>;
+ fn auth_methods(&self) -> Vec<acp::AuthMethod>;
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
@@ -24,3 +24,13 @@ pub trait AgentConnection {
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
}
+
+#[derive(Debug)]
+pub struct AuthRequired;
+
+impl Error for AuthRequired {}
+impl fmt::Display for AuthRequired {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "AuthRequired")
+ }
+}
@@ -5,17 +5,11 @@ use anyhow::{Context as _, Result};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
-use std::{
- cell::{Ref, RefCell},
- error::Error,
- fmt,
- path::Path,
- rc::Rc,
-};
+use std::{cell::RefCell, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
-use crate::{AcpThread, AgentConnection};
+use crate::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)]
pub struct OldAcpClientDelegate {
@@ -357,21 +351,10 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
}
}
-#[derive(Debug)]
-pub struct Unauthenticated;
-
-impl Error for Unauthenticated {}
-impl fmt::Display for Unauthenticated {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "Unauthenticated")
- }
-}
-
pub struct OldAcpAgentConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
pub child_status: Task<Result<()>>,
- pub agent_state: Rc<RefCell<acp::AgentState>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
}
@@ -394,7 +377,7 @@ impl AgentConnection for OldAcpAgentConnection {
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
- anyhow::bail!(Unauthenticated)
+ anyhow::bail!(AuthRequired)
}
cx.update(|cx| {
@@ -408,8 +391,12 @@ impl AgentConnection for OldAcpAgentConnection {
})
}
- fn state(&self) -> Ref<'_, acp::AgentState> {
- self.agent_state.borrow()
+ fn auth_methods(&self) -> Vec<acp::AuthMethod> {
+ vec![acp::AuthMethod {
+ id: acp::AuthMethodId("acp-old-no-id".into()),
+ label: "Log in".into(),
+ description: None,
+ }]
}
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
@@ -7,24 +7,23 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project;
use smol::stream::StreamExt as _;
-use std::cell::{Ref, RefCell};
+use std::cell::RefCell;
use std::rc::Rc;
use std::{path::Path, sync::Arc};
-use util::{ResultExt, TryFutureExt};
+use util::ResultExt;
use anyhow::{Context, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::mcp_server::ZedMcpServer;
use crate::{AgentServerCommand, mcp_server};
-use acp_thread::{AcpThread, AgentConnection};
+use acp_thread::{AcpThread, AgentConnection, AuthRequired};
pub struct AcpConnection {
- agent_state: Rc<RefCell<acp::AgentState>>,
+ auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
server_name: &'static str,
client: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
- _agent_state_task: Task<()>,
_session_update_task: Task<()>,
}
@@ -47,24 +46,8 @@ impl AcpConnection {
.into();
ContextServer::start(client.clone(), cx).await?;
- let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
let mcp_client = client.client().context("Failed to subscribe")?;
- mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
- move |notification, _cx| {
- log::trace!(
- "ACP Notification: {}",
- serde_json::to_string_pretty(¬ification).unwrap()
- );
-
- if let Some(state) =
- serde_json::from_value::<acp::AgentState>(notification).log_err()
- {
- state_tx.send(state).log_err();
- }
- }
- });
-
let (notification_tx, mut notification_rx) = mpsc::unbounded();
mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
move |notification, _cx| {
@@ -83,17 +66,6 @@ impl AcpConnection {
});
let sessions = Rc::new(RefCell::new(HashMap::default()));
- let initial_state = state_rx.recv().await?;
- let agent_state = Rc::new(RefCell::new(initial_state));
-
- let agent_state_task = cx.foreground_executor().spawn({
- let agent_state = agent_state.clone();
- async move {
- while let Some(state) = state_rx.recv().log_err().await {
- agent_state.replace(state);
- }
- }
- });
let session_update_handler_task = cx.spawn({
let sessions = sessions.clone();
@@ -105,11 +77,10 @@ impl AcpConnection {
});
Ok(Self {
+ auth_methods: Default::default(),
server_name,
client,
sessions,
- agent_state,
- _agent_state_task: agent_state_task,
_session_update_task: session_update_handler_task,
})
}
@@ -154,6 +125,7 @@ impl AgentConnection for AcpConnection {
) -> Task<Result<Entity<AcpThread>>> {
let client = self.client.client();
let sessions = self.sessions.clone();
+ let auth_methods = self.auth_methods.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let client = client.context("MCP server is not initialized yet")?;
@@ -194,12 +166,18 @@ impl AgentConnection for AcpConnection {
response.structured_content.context("Empty response")?,
)?;
+ auth_methods.replace(result.auth_methods);
+
+ let Some(session_id) = result.session_id else {
+ anyhow::bail!(AuthRequired);
+ };
+
let thread = cx.new(|cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
- result.session_id.clone(),
+ session_id.clone(),
cx,
)
})?;
@@ -211,14 +189,14 @@ impl AgentConnection for AcpConnection {
cancel_tx: None,
_mcp_server: mcp_server,
};
- sessions.borrow_mut().insert(result.session_id, session);
+ sessions.borrow_mut().insert(session_id, session);
Ok(thread)
})
}
- fn state(&self) -> Ref<'_, acp::AgentState> {
- self.agent_state.borrow()
+ fn auth_methods(&self) -> Vec<agent_client_protocol::AuthMethod> {
+ self.auth_methods.borrow().clone()
}
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
@@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
-use std::cell::{Ref, RefCell};
+use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
use std::rc::Rc;
@@ -58,7 +58,6 @@ impl AgentServer for ClaudeCode {
_cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let connection = ClaudeAgentConnection {
- agent_state: Default::default(),
sessions: Default::default(),
};
@@ -67,7 +66,6 @@ impl AgentServer for ClaudeCode {
}
struct ClaudeAgentConnection {
- agent_state: Rc<RefCell<acp::AgentState>>,
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
}
@@ -185,8 +183,8 @@ impl AgentConnection for ClaudeAgentConnection {
})
}
- fn state(&self) -> Ref<'_, acp::AgentState> {
- self.agent_state.borrow()
+ fn auth_methods(&self) -> Vec<acp::AuthMethod> {
+ vec![]
}
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
@@ -216,15 +216,6 @@ impl AcpThreadView {
}
};
- if connection.state().needs_authentication {
- this.update(cx, |this, cx| {
- this.thread_state = ThreadState::Unauthenticated { connection };
- cx.notify();
- })
- .ok();
- return;
- }
-
let result = match connection
.clone()
.new_thread(project.clone(), &root_dir, cx)
@@ -233,7 +224,7 @@ impl AcpThreadView {
Err(e) => {
let mut cx = cx.clone();
// todo! remove duplication
- if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
+ if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify();
@@ -2219,17 +2210,14 @@ impl Render for AcpThreadView {
.justify_center()
.child(self.render_pending_auth_state())
.child(h_flex().mt_1p5().justify_center().children(
- connection.state().auth_methods.iter().map(|method| {
- Button::new(
- SharedString::from(method.id.0.clone()),
- method.label.clone(),
- )
- .on_click({
- let method_id = method.id.clone();
- cx.listener(move |this, _, window, cx| {
- this.authenticate(method_id.clone(), window, cx)
+ connection.auth_methods().into_iter().map(|method| {
+ Button::new(SharedString::from(method.id.0.clone()), method.label)
+ .on_click({
+ let method_id = method.id.clone();
+ cx.listener(move |this, _, window, cx| {
+ this.authenticate(method_id.clone(), window, cx)
+ })
})
- })
}),
)),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),