Detailed changes
@@ -2,7 +2,7 @@ use collections::HashMap;
use context_server::types::CallToolParams;
use context_server::types::requests::CallTool;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
-use futures::channel::mpsc;
+use futures::channel::{mpsc, oneshot};
use project::Project;
use settings::SettingsStore;
use smol::stream::StreamExt;
@@ -144,6 +144,7 @@ impl AgentServer for Codex {
let connection = CodexAgentConnection {
root_dir,
codex_mcp: codex_mcp_client,
+ cancel_request_tx: Default::default(),
_handler_task: handler_task,
_zed_mcp: zed_mcp_server,
};
@@ -162,6 +163,7 @@ impl AgentConnection for CodexAgentConnection {
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
let client = self.codex_mcp.client();
let root_dir = self.root_dir.clone();
+ let cancel_request_tx = self.cancel_request_tx.clone();
async move {
let client = client.context("Codex MCP server is not initialized")?;
@@ -177,34 +179,48 @@ impl AgentConnection for CodexAgentConnection {
Err(anyhow!("Authentication not supported"))
}
AnyAgentRequest::SendUserMessageParams(message) => {
+ let (new_cancel_tx, cancel_rx) = oneshot::channel();
+ cancel_request_tx.borrow_mut().replace(new_cancel_tx);
+
client
- .request::<CallTool>(CallToolParams {
- name: "codex".into(),
- arguments: Some(serde_json::to_value(CodexToolCallParam {
- prompt: message
- .chunks
- .into_iter()
- .filter_map(|chunk| match chunk {
- acp::UserMessageChunk::Text { text } => Some(text),
- acp::UserMessageChunk::Path { .. } => {
- // todo!
- None
- }
- })
- .collect(),
- cwd: root_dir,
- })?),
- meta: None,
- })
+ .cancellable_request::<CallTool>(
+ CallToolParams {
+ name: "codex".into(),
+ arguments: Some(serde_json::to_value(CodexToolCallParam {
+ prompt: message
+ .chunks
+ .into_iter()
+ .filter_map(|chunk| match chunk {
+ acp::UserMessageChunk::Text { text } => Some(text),
+ acp::UserMessageChunk::Path { .. } => {
+ // todo!
+ None
+ }
+ })
+ .collect(),
+ cwd: root_dir,
+ })?),
+ meta: None,
+ },
+ cancel_rx,
+ )
.await?;
Ok(AnyAgentResult::SendUserMessageResponse(
acp::SendUserMessageResponse,
))
}
- AnyAgentRequest::CancelSendMessageParams(_) => Ok(
- AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse),
- ),
+ AnyAgentRequest::CancelSendMessageParams(_) => {
+ if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
+ if let Some(cancel_tx) = borrow.take() {
+ cancel_tx.send(()).ok();
+ }
+ }
+
+ Ok(AnyAgentResult::CancelSendMessageResponse(
+ acp::CancelSendMessageResponse,
+ ))
+ }
}
}
.boxed_local()
@@ -214,6 +230,7 @@ impl AgentConnection for CodexAgentConnection {
struct CodexAgentConnection {
codex_mcp: Arc<context_server::ContextServer>,
root_dir: PathBuf,
+ cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
_handler_task: Task<()>,
_zed_mcp: ZedMcpServer,
}
@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use collections::HashMap;
-use futures::{FutureExt, StreamExt, channel::oneshot, select};
+use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
use parking_lot::Mutex;
use postage::barrier;
@@ -10,15 +10,19 @@ use smol::channel;
use std::{
fmt,
path::PathBuf,
+ pin::pin,
sync::{
Arc,
atomic::{AtomicI32, Ordering::SeqCst},
},
time::{Duration, Instant},
};
-use util::TryFutureExt;
+use util::{ResultExt, TryFutureExt};
-use crate::transport::{StdioTransport, Transport};
+use crate::{
+ transport::{StdioTransport, Transport},
+ types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
+};
const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
@@ -294,6 +298,24 @@ impl Client {
&self,
method: &str,
params: impl Serialize,
+ ) -> Result<T> {
+ self.request_impl(method, params, None).await
+ }
+
+ pub async fn cancellable_request<T: DeserializeOwned>(
+ &self,
+ method: &str,
+ params: impl Serialize,
+ cancel_rx: oneshot::Receiver<()>,
+ ) -> Result<T> {
+ self.request_impl(method, params, Some(cancel_rx)).await
+ }
+
+ pub async fn request_impl<T: DeserializeOwned>(
+ &self,
+ method: &str,
+ params: impl Serialize,
+ cancel_rx: Option<oneshot::Receiver<()>>,
) -> Result<T> {
let id = self.next_id.fetch_add(1, SeqCst);
let request = serde_json::to_string(&Request {
@@ -330,6 +352,16 @@ impl Client {
send?;
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
+ let mut cancel_fut = pin!(
+ match cancel_rx {
+ Some(rx) => future::Either::Left(async {
+ rx.await.log_err();
+ }),
+ None => future::Either::Right(future::pending()),
+ }
+ .fuse()
+ );
+
select! {
response = rx.fuse() => {
let elapsed = started.elapsed();
@@ -348,6 +380,16 @@ impl Client {
Err(_) => anyhow::bail!("cancelled")
}
}
+ _ = cancel_fut => {
+ self.notify(
+ Cancelled::METHOD,
+ ClientNotification::Cancelled(CancelledParams {
+ request_id: RequestId::Int(id),
+ reason: None
+ })
+ ).log_err();
+ anyhow::bail!("Request cancelled")
+ }
_ = timeout => {
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
anyhow::bail!("Context server request timeout");
@@ -6,6 +6,7 @@
//! of messages.
use anyhow::Result;
+use futures::channel::oneshot;
use gpui::AsyncApp;
use serde_json::Value;
@@ -97,6 +98,16 @@ impl InitializedContextServerProtocol {
self.inner.request(T::METHOD, params).await
}
+ pub async fn cancellable_request<T: Request>(
+ &self,
+ params: T::Params,
+ cancel_rx: oneshot::Receiver<()>,
+ ) -> Result<T::Response> {
+ self.inner
+ .cancellable_request(T::METHOD, params, cancel_rx)
+ .await
+ }
+
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
self.inner.notify(T::METHOD, params)
}
@@ -3,6 +3,8 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use url::Url;
+use crate::client::RequestId;
+
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
pub const VERSION_2024_11_05: &str = "2024-11-05";
@@ -100,6 +102,7 @@ pub mod notifications {
notification!("notifications/initialized", Initialized, ());
notification!("notifications/progress", Progress, ProgressParams);
notification!("notifications/message", Message, MessageParams);
+ notification!("notifications/cancelled", Cancelled, CancelledParams);
notification!(
"notifications/resources/updated",
ResourcesUpdated,
@@ -617,11 +620,14 @@ pub enum ClientNotification {
Initialized,
Progress(ProgressParams),
RootsListChanged,
- Cancelled {
- request_id: String,
- #[serde(skip_serializing_if = "Option::is_none")]
- reason: Option<String>,
- },
+ Cancelled(CancelledParams),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CancelledParams {
+ pub request_id: RequestId,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]