agent2.rs

  1mod acp;
  2
  3use anyhow::{Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use futures::{
  6    FutureExt, StreamExt,
  7    channel::{mpsc, oneshot},
  8    select_biased,
  9    stream::{BoxStream, FuturesUnordered},
 10};
 11use gpui::{AppContext, AsyncApp, Context, Entity, Task};
 12use project::Project;
 13use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
 14
 15pub trait Agent: 'static {
 16    type Thread: AgentThread;
 17
 18    fn threads(&self) -> impl Future<Output = Result<Vec<AgentThreadSummary>>>;
 19    fn create_thread(&self) -> impl Future<Output = Result<Self::Thread>>;
 20    fn open_thread(&self, id: ThreadId) -> impl Future<Output = Result<Self::Thread>>;
 21}
 22
 23pub trait AgentThread: 'static {
 24    fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntryContent>>>;
 25    fn send(
 26        &self,
 27        message: Message,
 28    ) -> impl Future<Output = Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>>;
 29}
 30
 31pub enum ResponseEvent {
 32    MessageResponse(MessageResponse),
 33    ReadFileRequest(ReadFileRequest),
 34    // GlobSearchRequest(SearchRequest),
 35    // RegexSearchRequest(RegexSearchRequest),
 36    // RunCommandRequest(RunCommandRequest),
 37    // WebSearchResponse(WebSearchResponse),
 38}
 39
 40pub struct MessageResponse {
 41    role: Role,
 42    chunks: BoxStream<'static, Result<MessageChunk>>,
 43}
 44
 45#[derive(Debug)]
 46pub struct ReadFileRequest {
 47    path: PathBuf,
 48    range: Range<usize>,
 49    response_tx: oneshot::Sender<Result<FileContent>>,
 50}
 51
 52impl ReadFileRequest {
 53    pub fn respond(self, content: Result<FileContent>) {
 54        self.response_tx.send(content).ok();
 55    }
 56}
 57
 58#[derive(Debug, Clone)]
 59pub struct ThreadId(String);
 60
 61#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 62pub struct FileVersion(u64);
 63
 64#[derive(Debug)]
 65pub struct AgentThreadSummary {
 66    pub id: ThreadId,
 67    pub title: String,
 68    pub created_at: DateTime<Utc>,
 69}
 70
 71#[derive(Debug, PartialEq, Eq)]
 72pub struct FileContent {
 73    pub path: PathBuf,
 74    pub version: FileVersion,
 75    pub content: String,
 76}
 77
 78#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 79pub enum Role {
 80    User,
 81    Assistant,
 82}
 83
 84#[derive(Debug, Eq, PartialEq)]
 85pub struct Message {
 86    pub role: Role,
 87    pub chunks: Vec<MessageChunk>,
 88}
 89
 90#[derive(Debug, Eq, PartialEq)]
 91pub enum MessageChunk {
 92    Text {
 93        chunk: String,
 94    },
 95    File {
 96        content: FileContent,
 97    },
 98    Directory {
 99        path: PathBuf,
100        contents: Vec<FileContent>,
101    },
102    Symbol {
103        path: PathBuf,
104        range: Range<u64>,
105        version: FileVersion,
106        name: String,
107        content: String,
108    },
109    Thread {
110        title: String,
111        content: Vec<AgentThreadEntryContent>,
112    },
113    Fetch {
114        url: String,
115        content: String,
116    },
117}
118
119impl From<&str> for MessageChunk {
120    fn from(chunk: &str) -> Self {
121        MessageChunk::Text {
122            chunk: chunk.to_string(),
123        }
124    }
125}
126
127#[derive(Debug, Eq, PartialEq)]
128pub enum AgentThreadEntryContent {
129    Message(Message),
130    ReadFile { path: PathBuf, content: String },
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
134pub struct ThreadEntryId(usize);
135
136impl ThreadEntryId {
137    pub fn post_inc(&mut self) -> Self {
138        let id = *self;
139        self.0 += 1;
140        id
141    }
142}
143
144#[derive(Debug)]
145pub struct ThreadEntry {
146    pub id: ThreadEntryId,
147    pub content: AgentThreadEntryContent,
148}
149
150pub struct ThreadStore<T: Agent> {
151    threads: Vec<AgentThreadSummary>,
152    agent: Arc<T>,
153    project: Entity<Project>,
154}
155
156impl<T: Agent> ThreadStore<T> {
157    pub async fn load(
158        agent: Arc<T>,
159        project: Entity<Project>,
160        cx: &mut AsyncApp,
161    ) -> Result<Entity<Self>> {
162        let threads = agent.threads().await?;
163        cx.new(|cx| Self {
164            threads,
165            agent,
166            project,
167        })
168    }
169
170    /// Returns the threads in reverse chronological order.
171    pub fn threads(&self) -> &[AgentThreadSummary] {
172        &self.threads
173    }
174
175    /// Opens a thread with the given ID.
176    pub fn open_thread(
177        &self,
178        id: ThreadId,
179        cx: &mut Context<Self>,
180    ) -> Task<Result<Entity<Thread<T::Thread>>>> {
181        let agent = self.agent.clone();
182        let project = self.project.clone();
183        cx.spawn(async move |_, cx| {
184            let agent_thread = agent.open_thread(id).await?;
185            Thread::load(Arc::new(agent_thread), project, cx).await
186        })
187    }
188
189    /// Creates a new thread.
190    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
191        let agent = self.agent.clone();
192        let project = self.project.clone();
193        cx.spawn(async move |_, cx| {
194            let agent_thread = agent.create_thread().await?;
195            Thread::load(Arc::new(agent_thread), project, cx).await
196        })
197    }
198}
199
200pub struct Thread<T: AgentThread> {
201    next_entry_id: ThreadEntryId,
202    entries: Vec<ThreadEntry>,
203    agent_thread: Arc<T>,
204    project: Entity<Project>,
205}
206
207impl<T: AgentThread> Thread<T> {
208    pub async fn load(
209        agent_thread: Arc<T>,
210        project: Entity<Project>,
211        cx: &mut AsyncApp,
212    ) -> Result<Entity<Self>> {
213        let entries = agent_thread.entries().await?;
214        cx.new(|cx| Self::new(agent_thread, entries, project, cx))
215    }
216
217    pub fn new(
218        agent_thread: Arc<T>,
219        entries: Vec<AgentThreadEntryContent>,
220        project: Entity<Project>,
221        cx: &mut Context<Self>,
222    ) -> Self {
223        let mut next_entry_id = ThreadEntryId(0);
224        Self {
225            entries: entries
226                .into_iter()
227                .map(|entry| ThreadEntry {
228                    id: next_entry_id.post_inc(),
229                    content: entry,
230                })
231                .collect(),
232            next_entry_id,
233            agent_thread,
234            project,
235        }
236    }
237
238    pub fn entries(&self) -> &[ThreadEntry] {
239        &self.entries
240    }
241
242    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
243        let agent_thread = self.agent_thread.clone();
244        cx.spawn(async move |this, cx| {
245            let mut events = agent_thread.send(message).await?;
246            let mut pending_event_handlers = FuturesUnordered::new();
247
248            loop {
249                let mut next_event_handler_result = pin!(
250                    async {
251                        if pending_event_handlers.is_empty() {
252                            future::pending::<()>().await;
253                        }
254
255                        pending_event_handlers.next().await
256                    }
257                    .fuse()
258                );
259
260                select_biased! {
261                    event = events.next() => {
262                        let Some(event) = event else {
263                            while let Some(result) = pending_event_handlers.next().await {
264                                result?;
265                            }
266
267                            break;
268                        };
269
270                        let task = match event {
271                            Ok(ResponseEvent::MessageResponse(message)) => {
272                                this.update(cx, |this, cx| this.handle_message_response(message, cx))?
273                            }
274                            Ok(ResponseEvent::ReadFileRequest(request)) => {
275                                this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
276                            }
277                            Err(_) => todo!(),
278                        };
279                        pending_event_handlers.push(task);
280                    }
281                    result = next_event_handler_result => {
282                        // Event handlers should only return errors that are
283                        // unrecoverable and should therefore stop this turn of
284                        // the agentic loop.
285                        result.unwrap()?;
286                    }
287                }
288            }
289
290            Ok(())
291        })
292    }
293
294    fn handle_message_response(
295        &mut self,
296        mut message: MessageResponse,
297        cx: &mut Context<Self>,
298    ) -> Task<Result<()>> {
299        let entry_id = self.next_entry_id.post_inc();
300        self.entries.push(ThreadEntry {
301            id: entry_id,
302            content: AgentThreadEntryContent::Message(Message {
303                role: message.role,
304                chunks: Vec::new(),
305            }),
306        });
307        cx.notify();
308
309        cx.spawn(async move |this, cx| {
310            while let Some(chunk) = message.chunks.next().await {
311                match chunk {
312                    Ok(chunk) => {
313                        this.update(cx, |this, cx| {
314                            let ix = this
315                                .entries
316                                .binary_search_by_key(&entry_id, |entry| entry.id)
317                                .map_err(|_| anyhow!("message not found"))?;
318                            let AgentThreadEntryContent::Message(message) =
319                                &mut this.entries[ix].content
320                            else {
321                                unreachable!()
322                            };
323                            message.chunks.push(chunk);
324                            cx.notify();
325                            anyhow::Ok(())
326                        })??;
327                    }
328                    Err(err) => todo!("show error"),
329                }
330            }
331
332            Ok(())
333        })
334    }
335
336    fn handle_read_file_request(
337        &mut self,
338        request: ReadFileRequest,
339        cx: &mut Context<Self>,
340    ) -> Task<Result<()>> {
341        todo!()
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::acp::AcpAgent;
349    use gpui::TestAppContext;
350    use project::FakeFs;
351    use serde_json::json;
352    use settings::SettingsStore;
353    use std::{env, process::Stdio};
354    use util::path;
355
356    fn init_test(cx: &mut TestAppContext) {
357        env_logger::init();
358        cx.update(|cx| {
359            let settings_store = SettingsStore::test(cx);
360            cx.set_global(settings_store);
361            Project::init_settings(cx);
362        });
363    }
364
365    #[gpui::test]
366    async fn test_gemini(cx: &mut TestAppContext) {
367        init_test(cx);
368
369        cx.executor().allow_parking();
370
371        let fs = FakeFs::new(cx.executor());
372        fs.insert_tree(
373            path!("/test"),
374            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
375        )
376        .await;
377        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
378        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
379        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
380            .await
381            .unwrap();
382        let thread = thread_store
383            .update(cx, |thread_store, cx| {
384                assert_eq!(thread_store.threads().len(), 0);
385                thread_store.create_thread(cx)
386            })
387            .await
388            .unwrap();
389        thread
390            .update(cx, |thread, cx| {
391                thread.send(
392                    Message {
393                        role: Role::User,
394                        chunks: vec![
395                            "Read the 'test/foo' file and output all of its contents.".into(),
396                        ],
397                    },
398                    cx,
399                )
400            })
401            .await
402            .unwrap();
403        thread.read_with(cx, |thread, cx| {
404            assert!(
405                thread.entries().iter().any(|entry| {
406                    entry.content
407                        == AgentThreadEntryContent::ReadFile {
408                            path: "test/foo".into(),
409                            content: "Lorem ipsum dolor".into(),
410                        }
411                }),
412                "Thread does not contain entry. Actual: {:?}",
413                thread.entries()
414            );
415        });
416    }
417
418    pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
419        let child = util::command::new_smol_command("node")
420            .arg("../../../gemini-cli/packages/cli")
421            .arg("--acp")
422            // .args(["--model", "gemini-2.5-flash"])
423            .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
424            .stdin(Stdio::piped())
425            .stdout(Stdio::piped())
426            .stderr(Stdio::inherit())
427            .kill_on_drop(true)
428            .spawn()
429            .unwrap();
430
431        Ok(AcpAgent::stdio(child, project, cx))
432    }
433}