From a822711e99c8054c55834112991b4384db383eeb Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 21 Jul 2025 19:46:14 -0300 Subject: [PATCH] Stop generation --- crates/agent_servers/src/codex.rs | 61 +++++++++++++++++---------- crates/context_server/src/client.rs | 48 +++++++++++++++++++-- crates/context_server/src/protocol.rs | 11 +++++ crates/context_server/src/types.rs | 16 ++++--- 4 files changed, 106 insertions(+), 30 deletions(-) diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs index 9e301b33fcbb2bd2f725a2a56e565f69ef5f1313..3a35a4e25b87c61ab662e80ddc0df3380737726c 100644 --- a/crates/agent_servers/src/codex.rs +++ b/crates/agent_servers/src/codex.rs @@ -2,7 +2,7 @@ use collections::HashMap; use context_server::types::CallToolParams; use context_server::types::requests::CallTool; use context_server::{ContextServer, ContextServerCommand, ContextServerId}; -use futures::channel::mpsc; +use futures::channel::{mpsc, oneshot}; use project::Project; use settings::SettingsStore; use smol::stream::StreamExt; @@ -144,6 +144,7 @@ impl AgentServer for Codex { let connection = CodexAgentConnection { root_dir, codex_mcp: codex_mcp_client, + cancel_request_tx: Default::default(), _handler_task: handler_task, _zed_mcp: zed_mcp_server, }; @@ -162,6 +163,7 @@ impl AgentConnection for CodexAgentConnection { ) -> LocalBoxFuture<'static, Result> { let client = self.codex_mcp.client(); let root_dir = self.root_dir.clone(); + let cancel_request_tx = self.cancel_request_tx.clone(); async move { let client = client.context("Codex MCP server is not initialized")?; @@ -177,34 +179,48 @@ impl AgentConnection for CodexAgentConnection { Err(anyhow!("Authentication not supported")) } AnyAgentRequest::SendUserMessageParams(message) => { + let (new_cancel_tx, cancel_rx) = oneshot::channel(); + cancel_request_tx.borrow_mut().replace(new_cancel_tx); + client - .request::(CallToolParams { - name: "codex".into(), - arguments: Some(serde_json::to_value(CodexToolCallParam { - prompt: message - .chunks - .into_iter() - .filter_map(|chunk| match chunk { - acp::UserMessageChunk::Text { text } => Some(text), - acp::UserMessageChunk::Path { .. } => { - // todo! - None - } - }) - .collect(), - cwd: root_dir, - })?), - meta: None, - }) + .cancellable_request::( + CallToolParams { + name: "codex".into(), + arguments: Some(serde_json::to_value(CodexToolCallParam { + prompt: message + .chunks + .into_iter() + .filter_map(|chunk| match chunk { + acp::UserMessageChunk::Text { text } => Some(text), + acp::UserMessageChunk::Path { .. } => { + // todo! + None + } + }) + .collect(), + cwd: root_dir, + })?), + meta: None, + }, + cancel_rx, + ) .await?; Ok(AnyAgentResult::SendUserMessageResponse( acp::SendUserMessageResponse, )) } - AnyAgentRequest::CancelSendMessageParams(_) => Ok( - AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse), - ), + AnyAgentRequest::CancelSendMessageParams(_) => { + if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() { + if let Some(cancel_tx) = borrow.take() { + cancel_tx.send(()).ok(); + } + } + + Ok(AnyAgentResult::CancelSendMessageResponse( + acp::CancelSendMessageResponse, + )) + } } } .boxed_local() @@ -214,6 +230,7 @@ impl AgentConnection for CodexAgentConnection { struct CodexAgentConnection { codex_mcp: Arc, root_dir: PathBuf, + cancel_request_tx: Rc>>>, _handler_task: Task<()>, _zed_mcp: ZedMcpServer, } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 6b24d9b136efc2d9cc99843e54027058e1602861..960f9717e58f2fc1502d5030da0908cdd8b76e73 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; -use futures::{FutureExt, StreamExt, channel::oneshot, select}; +use futures::{FutureExt, StreamExt, channel::oneshot, future, select}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use parking_lot::Mutex; use postage::barrier; @@ -10,15 +10,19 @@ use smol::channel; use std::{ fmt, path::PathBuf, + pin::pin, sync::{ Arc, atomic::{AtomicI32, Ordering::SeqCst}, }, time::{Duration, Instant}, }; -use util::TryFutureExt; +use util::{ResultExt, TryFutureExt}; -use crate::transport::{StdioTransport, Transport}; +use crate::{ + transport::{StdioTransport, Transport}, + types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled}, +}; const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); @@ -294,6 +298,24 @@ impl Client { &self, method: &str, params: impl Serialize, + ) -> Result { + self.request_impl(method, params, None).await + } + + pub async fn cancellable_request( + &self, + method: &str, + params: impl Serialize, + cancel_rx: oneshot::Receiver<()>, + ) -> Result { + self.request_impl(method, params, Some(cancel_rx)).await + } + + pub async fn request_impl( + &self, + method: &str, + params: impl Serialize, + cancel_rx: Option>, ) -> Result { let id = self.next_id.fetch_add(1, SeqCst); let request = serde_json::to_string(&Request { @@ -330,6 +352,16 @@ impl Client { send?; let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); + let mut cancel_fut = pin!( + match cancel_rx { + Some(rx) => future::Either::Left(async { + rx.await.log_err(); + }), + None => future::Either::Right(future::pending()), + } + .fuse() + ); + select! { response = rx.fuse() => { let elapsed = started.elapsed(); @@ -348,6 +380,16 @@ impl Client { Err(_) => anyhow::bail!("cancelled") } } + _ = cancel_fut => { + self.notify( + Cancelled::METHOD, + ClientNotification::Cancelled(CancelledParams { + request_id: RequestId::Int(id), + reason: None + }) + ).log_err(); + anyhow::bail!("Request cancelled") + } _ = timeout => { log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); anyhow::bail!("Context server request timeout"); diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 0eb7e9dfa019a9434b840cec46679047e7f9317f..7263f502fa44b05b6bba72eda58c7ad84b52ebf7 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -6,6 +6,7 @@ //! of messages. use anyhow::Result; +use futures::channel::oneshot; use gpui::AsyncApp; use serde_json::Value; @@ -97,6 +98,16 @@ impl InitializedContextServerProtocol { self.inner.request(T::METHOD, params).await } + pub async fn cancellable_request( + &self, + params: T::Params, + cancel_rx: oneshot::Receiver<()>, + ) -> Result { + self.inner + .cancellable_request(T::METHOD, params, cancel_rx) + .await + } + pub fn notify(&self, params: T::Params) -> Result<()> { self.inner.notify(T::METHOD, params) } diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 4a6fdcabd3421e14cab3ff89ce6962023935059a..f92c86aa3cd722fdf5e6f68aeaba5df137fcde62 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -3,6 +3,8 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; +use crate::client::RequestId; + pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const VERSION_2024_11_05: &str = "2024-11-05"; @@ -100,6 +102,7 @@ pub mod notifications { notification!("notifications/initialized", Initialized, ()); notification!("notifications/progress", Progress, ProgressParams); notification!("notifications/message", Message, MessageParams); + notification!("notifications/cancelled", Cancelled, CancelledParams); notification!( "notifications/resources/updated", ResourcesUpdated, @@ -617,11 +620,14 @@ pub enum ClientNotification { Initialized, Progress(ProgressParams), RootsListChanged, - Cancelled { - request_id: String, - #[serde(skip_serializing_if = "Option::is_none")] - reason: Option, - }, + Cancelled(CancelledParams), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CancelledParams { + pub request_id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, } #[derive(Debug, Serialize, Deserialize)]