From f4e2d38c29ac9192011a8cb6172d2537cc7c3b31 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 25 Jun 2025 13:54:31 -0300 Subject: [PATCH] --wip-- --- Cargo.lock | 1 + crates/agent2/Cargo.toml | 9 +- crates/agent2/src/acp.rs | 176 ++++++++++++++++++++++++++++-------- crates/agent2/src/agent2.rs | 57 ++++++------ 4 files changed, 174 insertions(+), 69 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 085b679e0da74a88824c7277cb2553b42f51326e..5caf3a3c2c0450ad30f6a6b0b62d55a332d4fc60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,6 +112,7 @@ dependencies = [ "agentic-coding-protocol", "anyhow", "async-trait", + "base64 0.22.1", "chrono", "collections", "env_logger 0.11.8", diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index a4009759afaf6c60ced95a4e59be6787833b5fc9..bcb4379b67482a20c14038fb29c69dc527a5a98e 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -19,20 +19,21 @@ test-support = [ ] [dependencies] +agentic-coding-protocol = { path = "../../../agentic-coding-protocol" } anyhow.workspace = true async-trait.workspace = true -collections.workspace = true +base64.workspace = true chrono.workspace = true +collections.workspace = true futures.workspace = true -language.workspace = true gpui.workspace = true +language.workspace = true parking_lot.workspace = true project.workspace = true smol.workspace = true +util.workspace = true uuid.workspace = true workspace-hack.workspace = true -util.workspace = true -agentic-coding-protocol = { path = "../../../agentic-coding-protocol" } [dev-dependencies] env_logger.workspace = true diff --git a/crates/agent2/src/acp.rs b/crates/agent2/src/acp.rs index 192e81a9c000baeaa0a0c7bf65c92cded541365c..e2b2e1bd1ab9b08cf0928350cfaba336a29493fe 100644 --- a/crates/agent2/src/acp.rs +++ b/crates/agent2/src/acp.rs @@ -1,18 +1,19 @@ use std::{ + io::{Cursor, Write as _}, path::Path, sync::{Arc, Weak}, }; use crate::{ Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, - ResponseEvent, Role, ThreadId, + ResponseEvent, Role, Thread, ThreadEntry, ThreadId, }; -use agentic_coding_protocol::{self as acp}; +use agentic_coding_protocol::{self as acp, TurnId}; use anyhow::{Context as _, Result}; use async_trait::async_trait; use collections::HashMap; use futures::channel::mpsc::UnboundedReceiver; -use gpui::{AppContext, AsyncApp, Entity, Task}; +use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity}; use parking_lot::Mutex; use project::Project; use smol::process::Child; @@ -20,19 +21,41 @@ use util::ResultExt; pub struct AcpAgent { connection: Arc, - threads: Mutex>>, + threads: Arc>>>, _handler_task: Task<()>, _io_task: Task<()>, } struct AcpClientDelegate { project: Entity, + threads: Arc>>>, cx: AsyncApp, // sent_buffer_versions: HashMap, HashMap>, } #[async_trait(?Send)] impl acp::Client for AcpClientDelegate { + async fn stat(&self, params: acp::StatParams) -> Result { + let cx = &mut self.cx.clone(); + self.project.update(cx, |project, cx| { + let path = project + .project_path_for_absolute_path(Path::new(¶ms.path), cx) + .context("Failed to get project path")?; + + match project.entry_for_path(&path, cx) { + // todo! refresh entry? + None => Ok(acp::StatResponse { + exists: false, + is_directory: false, + }), + Some(entry) => Ok(acp::StatResponse { + exists: entry.is_created(), + is_directory: entry.is_dir(), + }), + } + })? + } + async fn stream_message_chunk( &self, request: acp::StreamMessageChunkParams, @@ -40,7 +63,10 @@ impl acp::Client for AcpClientDelegate { Ok(acp::StreamMessageChunkResponse) } - async fn read_file(&self, request: acp::ReadFileParams) -> Result { + async fn read_text_file( + &self, + request: acp::ReadTextFileParams, + ) -> Result { let cx = &mut self.cx.clone(); let buffer = self .project @@ -52,8 +78,77 @@ impl acp::Client for AcpClientDelegate { })?? .await?; - buffer.update(cx, |buffer, _| acp::ReadFileResponse { - content: buffer.text(), + buffer.update(cx, |buffer, _| { + let start = language::Point::new(request.line_offset.unwrap_or(0), 0); + let end = match request.line_limit { + None => buffer.max_point(), + Some(limit) => start + language::Point::new(limit + 1, 0), + }; + + let content = buffer.text_for_range(start..end).collect(); + + if let Some(thread) = self.threads.lock().get(&request.thread_id) { + thread.update(cx, |thread, cx| { + thread.push_entry(ThreadEntry { + content: AgentThreadEntryContent::ReadFile { + path: request.path.clone(), + content: content.clone(), + }, + }); + }) + } + + acp::ReadTextFileResponse { + content, + version: acp::FileVersion(0), + } + }) + } + + async fn read_binary_file( + &self, + request: acp::ReadBinaryFileParams, + ) -> Result { + let cx = &mut self.cx.clone(); + let file = self + .project + .update(cx, |project, cx| { + let (worktree, path) = project + .find_worktree(Path::new(&request.path), cx) + .context("Failed to get project path")?; + + let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx)); + anyhow::Ok(task) + })?? + .await?; + + // todo! test + let content = cx + .background_spawn(async move { + let start = request.byte_offset.unwrap_or(0) as usize; + let end = request + .byte_limit + .map(|limit| (start + limit as usize).min(file.content.len())) + .unwrap_or(file.content.len()); + + let range_content = &file.content[start..end]; + + let mut base64_content = Vec::new(); + let mut base64_encoder = base64::write::EncoderWriter::new( + Cursor::new(&mut base64_content), + &base64::engine::general_purpose::STANDARD, + ); + base64_encoder.write_all(range_content)?; + drop(base64_encoder); + + // SAFETY: The base64 encoder should not produce non-UTF8. + unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) } + }) + .await?; + + Ok(acp::ReadBinaryFileResponse { + content, + // todo! version: acp::FileVersion(0), }) } @@ -95,9 +190,8 @@ impl AcpAgent { } } +#[async_trait] impl Agent for AcpAgent { - type Thread = AcpAgentThread; - async fn threads(&self) -> Result> { let response = self.connection.request(acp::GetThreadsParams).await?; response @@ -118,7 +212,10 @@ impl Agent for AcpAgent { let thread = Arc::new(AcpAgentThread { id: response.thread_id.clone(), connection: self.connection.clone(), - state: Mutex::new(AcpAgentThreadState { turn: None }), + state: Mutex::new(AcpAgentThreadState { + turn: None, + next_turn_id: TurnId::default(), + }), }); self.threads .lock() @@ -126,25 +223,11 @@ impl Agent for AcpAgent { Ok(thread) } - async fn open_thread(&self, id: ThreadId) -> Result> { + async fn open_thread(&self, id: ThreadId) -> Result { todo!() } -} - -pub struct AcpAgentThread { - id: acp::ThreadId, - connection: Arc, - state: Mutex, -} - -struct AcpAgentThreadState { - turn: Option, -} -struct AcpAgentThreadTurn {} - -impl AgentThread for AcpAgentThread { - async fn entries(&self) -> Result> { + async fn thread_entries(&self, thread_id: ThreadId) -> Result> { let response = self .connection .request(acp::GetThreadEntriesParams { @@ -178,14 +261,22 @@ impl AgentThread for AcpAgentThread { .collect()) } - async fn send( + async fn send_thread_message( &self, + thread_id: ThreadId, message: crate::Message, ) -> Result>> { + let turn_id = { + let mut state = self.state.lock(); + let turn_id = state.next_turn_id.post_inc(); + state.turn = Some(AcpAgentThreadTurn { id: turn_id }); + turn_id + }; let response = self .connection .request(acp::SendMessageParams { thread_id: self.id.clone(), + turn_id, message: acp::Message { role: match message.role { Role::User => acp::Role::User, @@ -196,17 +287,11 @@ impl AgentThread for AcpAgentThread { .into_iter() .map(|chunk| match chunk { MessageChunk::Text { chunk } => acp::MessageChunk::Text { chunk }, - MessageChunk::File { content } => todo!(), - MessageChunk::Directory { path, contents } => todo!(), - MessageChunk::Symbol { - path, - range, - version, - name, - content, - } => todo!(), - MessageChunk::Thread { title, content } => todo!(), - MessageChunk::Fetch { url, content } => todo!(), + MessageChunk::File { .. } => todo!(), + MessageChunk::Directory { .. } => todo!(), + MessageChunk::Symbol { .. } => todo!(), + MessageChunk::Thread { .. } => todo!(), + MessageChunk::Fetch { .. } => todo!(), }) .collect(), }, @@ -216,6 +301,21 @@ impl AgentThread for AcpAgentThread { } } +pub struct AcpAgentThread { + id: acp::ThreadId, + connection: Arc, + state: Mutex, +} + +struct AcpAgentThreadState { + next_turn_id: acp::TurnId, + turn: Option, +} + +struct AcpAgentThreadTurn { + id: acp::TurnId, +} + impl From for ThreadId { fn from(thread_id: acp::ThreadId) -> Self { Self(thread_id.0) diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index f23b38164088eb8739e3f6a8818d88a25b695458..9c77a441cdd4f9ec60907487e686a293a6d4ba73 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,6 +1,7 @@ mod acp; use anyhow::{Result, anyhow}; +use async_trait::async_trait; use chrono::{DateTime, Utc}; use futures::{ FutureExt, StreamExt, @@ -8,24 +9,21 @@ use futures::{ select_biased, stream::{BoxStream, FuturesUnordered}, }; -use gpui::{AppContext, AsyncApp, Context, Entity, Task}; +use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task}; use project::Project; use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc}; +#[async_trait] pub trait Agent: 'static { - type Thread: AgentThread; - - fn threads(&self) -> impl Future>>; - fn create_thread(&self) -> impl Future>>; - fn open_thread(&self, id: ThreadId) -> impl Future>>; -} - -pub trait AgentThread: 'static { - fn entries(&self) -> impl Future>>; - fn send( + async fn threads(&self) -> Result>; + async fn create_thread(&self) -> Result>; + async fn open_thread(&self, id: ThreadId) -> Result>; + async fn thread_entries(&self, id: ThreadId) -> Result>; + async fn send_thread_message( &self, + thread_id: ThreadId, message: Message, - ) -> impl Future>>>; + ) -> Result>>; } pub enum ResponseEvent { @@ -56,7 +54,7 @@ impl ReadFileRequest { } #[derive(Debug, Clone)] -pub struct ThreadId(String); +pub struct ThreadId(SharedString); #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct FileVersion(u64); @@ -177,7 +175,7 @@ impl ThreadStore { &self, id: ThreadId, cx: &mut Context, - ) -> Task>>> { + ) -> Task>> { let agent = self.agent.clone(); let project = self.project.clone(); cx.spawn(async move |_, cx| { @@ -187,7 +185,7 @@ impl ThreadStore { } /// Creates a new thread. - pub fn create_thread(&self, cx: &mut Context) -> Task>>> { + pub fn create_thread(&self, cx: &mut Context) -> Task>> { let agent = self.agent.clone(); let project = self.project.clone(); cx.spawn(async move |_, cx| { @@ -197,25 +195,28 @@ impl ThreadStore { } } -pub struct Thread { +pub struct Thread { + id: ThreadId, next_entry_id: ThreadEntryId, entries: Vec, - agent_thread: Arc, + agent: Arc, project: Entity, } -impl Thread { +impl Thread { pub async fn load( - agent_thread: Arc, + agent: Arc, + thread_id: ThreadId, project: Entity, cx: &mut AsyncApp, ) -> Result> { - let entries = agent_thread.entries().await?; - cx.new(|cx| Self::new(agent_thread, entries, project, cx)) + let entries = agent.thread_entries(thread_id.clone()).await?; + cx.new(|cx| Self::new(agent, thread_id, entries, project, cx)) } pub fn new( - agent_thread: Arc, + agent: Arc, + thread_id: ThreadId, entries: Vec, project: Entity, cx: &mut Context, @@ -229,8 +230,9 @@ impl Thread { content: entry, }) .collect(), + agent, + id: thread_id, next_entry_id, - agent_thread, project, } } @@ -240,9 +242,10 @@ impl Thread { } pub fn send(&mut self, message: Message, cx: &mut Context) -> Task> { - let agent_thread = self.agent_thread.clone(); + let agent = self.agent.clone(); + let id = self.id; cx.spawn(async move |this, cx| { - let mut events = agent_thread.send(message).await?; + let mut events = agent.send_thread_message(id, message).await?; let mut pending_event_handlers = FuturesUnordered::new(); loop { @@ -400,7 +403,7 @@ mod tests { }) .await .unwrap(); - thread.read_with(cx, |thread, cx| { + thread.read_with(cx, |thread, _| { assert!( thread.entries().iter().any(|entry| { entry.content @@ -419,7 +422,7 @@ mod tests { let child = util::command::new_smol_command("node") .arg("../../../gemini-cli/packages/cli") .arg("--acp") - // .args(["--model", "gemini-2.5-flash"]) + .args(["--model", "gemini-2.5-flash"]) .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap()) .stdin(Stdio::piped()) .stdout(Stdio::piped())