Stop generation

Agus Zubiaga created

Change summary

crates/agent_servers/src/codex.rs     | 61 ++++++++++++++++++----------
crates/context_server/src/client.rs   | 48 +++++++++++++++++++++-
crates/context_server/src/protocol.rs | 11 +++++
crates/context_server/src/types.rs    | 16 +++++--
4 files changed, 106 insertions(+), 30 deletions(-)

Detailed changes

crates/agent_servers/src/codex.rs 🔗

@@ -2,7 +2,7 @@ use collections::HashMap;
 use context_server::types::CallToolParams;
 use context_server::types::requests::CallTool;
 use context_server::{ContextServer, ContextServerCommand, ContextServerId};
-use futures::channel::mpsc;
+use futures::channel::{mpsc, oneshot};
 use project::Project;
 use settings::SettingsStore;
 use smol::stream::StreamExt;
@@ -144,6 +144,7 @@ impl AgentServer for Codex {
                 let connection = CodexAgentConnection {
                     root_dir,
                     codex_mcp: codex_mcp_client,
+                    cancel_request_tx: Default::default(),
                     _handler_task: handler_task,
                     _zed_mcp: zed_mcp_server,
                 };
@@ -162,6 +163,7 @@ impl AgentConnection for CodexAgentConnection {
     ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
         let client = self.codex_mcp.client();
         let root_dir = self.root_dir.clone();
+        let cancel_request_tx = self.cancel_request_tx.clone();
         async move {
             let client = client.context("Codex MCP server is not initialized")?;
 
@@ -177,34 +179,48 @@ impl AgentConnection for CodexAgentConnection {
                     Err(anyhow!("Authentication not supported"))
                 }
                 AnyAgentRequest::SendUserMessageParams(message) => {
+                    let (new_cancel_tx, cancel_rx) = oneshot::channel();
+                    cancel_request_tx.borrow_mut().replace(new_cancel_tx);
+
                     client
-                        .request::<CallTool>(CallToolParams {
-                            name: "codex".into(),
-                            arguments: Some(serde_json::to_value(CodexToolCallParam {
-                                prompt: message
-                                    .chunks
-                                    .into_iter()
-                                    .filter_map(|chunk| match chunk {
-                                        acp::UserMessageChunk::Text { text } => Some(text),
-                                        acp::UserMessageChunk::Path { .. } => {
-                                            // todo!
-                                            None
-                                        }
-                                    })
-                                    .collect(),
-                                cwd: root_dir,
-                            })?),
-                            meta: None,
-                        })
+                        .cancellable_request::<CallTool>(
+                            CallToolParams {
+                                name: "codex".into(),
+                                arguments: Some(serde_json::to_value(CodexToolCallParam {
+                                    prompt: message
+                                        .chunks
+                                        .into_iter()
+                                        .filter_map(|chunk| match chunk {
+                                            acp::UserMessageChunk::Text { text } => Some(text),
+                                            acp::UserMessageChunk::Path { .. } => {
+                                                // todo!
+                                                None
+                                            }
+                                        })
+                                        .collect(),
+                                    cwd: root_dir,
+                                })?),
+                                meta: None,
+                            },
+                            cancel_rx,
+                        )
                         .await?;
 
                     Ok(AnyAgentResult::SendUserMessageResponse(
                         acp::SendUserMessageResponse,
                     ))
                 }
-                AnyAgentRequest::CancelSendMessageParams(_) => Ok(
-                    AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse),
-                ),
+                AnyAgentRequest::CancelSendMessageParams(_) => {
+                    if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
+                        if let Some(cancel_tx) = borrow.take() {
+                            cancel_tx.send(()).ok();
+                        }
+                    }
+
+                    Ok(AnyAgentResult::CancelSendMessageResponse(
+                        acp::CancelSendMessageResponse,
+                    ))
+                }
             }
         }
         .boxed_local()
@@ -214,6 +230,7 @@ impl AgentConnection for CodexAgentConnection {
 struct CodexAgentConnection {
     codex_mcp: Arc<context_server::ContextServer>,
     root_dir: PathBuf,
+    cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
     _handler_task: Task<()>,
     _zed_mcp: ZedMcpServer,
 }

crates/context_server/src/client.rs 🔗

@@ -1,6 +1,6 @@
 use anyhow::{Context as _, Result, anyhow};
 use collections::HashMap;
-use futures::{FutureExt, StreamExt, channel::oneshot, select};
+use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
 use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
 use parking_lot::Mutex;
 use postage::barrier;
@@ -10,15 +10,19 @@ use smol::channel;
 use std::{
     fmt,
     path::PathBuf,
+    pin::pin,
     sync::{
         Arc,
         atomic::{AtomicI32, Ordering::SeqCst},
     },
     time::{Duration, Instant},
 };
-use util::TryFutureExt;
+use util::{ResultExt, TryFutureExt};
 
-use crate::transport::{StdioTransport, Transport};
+use crate::{
+    transport::{StdioTransport, Transport},
+    types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
+};
 
 const JSON_RPC_VERSION: &str = "2.0";
 const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
@@ -294,6 +298,24 @@ impl Client {
         &self,
         method: &str,
         params: impl Serialize,
+    ) -> Result<T> {
+        self.request_impl(method, params, None).await
+    }
+
+    pub async fn cancellable_request<T: DeserializeOwned>(
+        &self,
+        method: &str,
+        params: impl Serialize,
+        cancel_rx: oneshot::Receiver<()>,
+    ) -> Result<T> {
+        self.request_impl(method, params, Some(cancel_rx)).await
+    }
+
+    pub async fn request_impl<T: DeserializeOwned>(
+        &self,
+        method: &str,
+        params: impl Serialize,
+        cancel_rx: Option<oneshot::Receiver<()>>,
     ) -> Result<T> {
         let id = self.next_id.fetch_add(1, SeqCst);
         let request = serde_json::to_string(&Request {
@@ -330,6 +352,16 @@ impl Client {
         send?;
 
         let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
+        let mut cancel_fut = pin!(
+            match cancel_rx {
+                Some(rx) => future::Either::Left(async {
+                    rx.await.log_err();
+                }),
+                None => future::Either::Right(future::pending()),
+            }
+            .fuse()
+        );
+
         select! {
             response = rx.fuse() => {
                 let elapsed = started.elapsed();
@@ -348,6 +380,16 @@ impl Client {
                     Err(_) => anyhow::bail!("cancelled")
                 }
             }
+            _ = cancel_fut => {
+                self.notify(
+                    Cancelled::METHOD,
+                    ClientNotification::Cancelled(CancelledParams {
+                        request_id: RequestId::Int(id),
+                        reason: None
+                    })
+                ).log_err();
+                anyhow::bail!("Request cancelled")
+            }
             _ = timeout => {
                 log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
                 anyhow::bail!("Context server request timeout");

crates/context_server/src/protocol.rs 🔗

@@ -6,6 +6,7 @@
 //! of messages.
 
 use anyhow::Result;
+use futures::channel::oneshot;
 use gpui::AsyncApp;
 use serde_json::Value;
 
@@ -97,6 +98,16 @@ impl InitializedContextServerProtocol {
         self.inner.request(T::METHOD, params).await
     }
 
+    pub async fn cancellable_request<T: Request>(
+        &self,
+        params: T::Params,
+        cancel_rx: oneshot::Receiver<()>,
+    ) -> Result<T::Response> {
+        self.inner
+            .cancellable_request(T::METHOD, params, cancel_rx)
+            .await
+    }
+
     pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
         self.inner.notify(T::METHOD, params)
     }

crates/context_server/src/types.rs 🔗

@@ -3,6 +3,8 @@ use serde::de::DeserializeOwned;
 use serde::{Deserialize, Serialize};
 use url::Url;
 
+use crate::client::RequestId;
+
 pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
 pub const VERSION_2024_11_05: &str = "2024-11-05";
 
@@ -100,6 +102,7 @@ pub mod notifications {
     notification!("notifications/initialized", Initialized, ());
     notification!("notifications/progress", Progress, ProgressParams);
     notification!("notifications/message", Message, MessageParams);
+    notification!("notifications/cancelled", Cancelled, CancelledParams);
     notification!(
         "notifications/resources/updated",
         ResourcesUpdated,
@@ -617,11 +620,14 @@ pub enum ClientNotification {
     Initialized,
     Progress(ProgressParams),
     RootsListChanged,
-    Cancelled {
-        request_id: String,
-        #[serde(skip_serializing_if = "Option::is_none")]
-        reason: Option<String>,
-    },
+    Cancelled(CancelledParams),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CancelledParams {
+    pub request_id: RequestId,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub reason: Option<String>,
 }
 
 #[derive(Debug, Serialize, Deserialize)]