agent2.rs

  1mod acp;
  2
  3use anyhow::{Result, anyhow};
  4use async_trait::async_trait;
  5use chrono::{DateTime, Utc};
  6use futures::{
  7    FutureExt, StreamExt,
  8    channel::{mpsc, oneshot},
  9    select_biased,
 10    stream::{BoxStream, FuturesUnordered},
 11};
 12use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
 13use project::Project;
 14use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
 15
 16#[async_trait(?Send)]
 17pub trait Agent: 'static {
 18    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
 19    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 20    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 21    async fn thread_entries(
 22        &self,
 23        id: ThreadId,
 24        cx: &mut AsyncApp,
 25    ) -> Result<Vec<AgentThreadEntryContent>>;
 26    async fn send_thread_message(
 27        &self,
 28        thread_id: ThreadId,
 29        message: Message,
 30        cx: &mut AsyncApp,
 31    ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>;
 32}
 33
 34pub enum ResponseEvent {
 35    MessageResponse(MessageResponse),
 36    ReadFileRequest(ReadFileRequest),
 37    // GlobSearchRequest(SearchRequest),
 38    // RegexSearchRequest(RegexSearchRequest),
 39    // RunCommandRequest(RunCommandRequest),
 40    // WebSearchResponse(WebSearchResponse),
 41}
 42
 43pub struct MessageResponse {
 44    role: Role,
 45    chunks: BoxStream<'static, Result<MessageChunk>>,
 46}
 47
 48#[derive(Debug)]
 49pub struct ReadFileRequest {
 50    path: PathBuf,
 51    range: Range<usize>,
 52    response_tx: oneshot::Sender<Result<FileContent>>,
 53}
 54
 55impl ReadFileRequest {
 56    pub fn respond(self, content: Result<FileContent>) {
 57        self.response_tx.send(content).ok();
 58    }
 59}
 60
 61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 62pub struct ThreadId(SharedString);
 63
 64#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 65pub struct FileVersion(u64);
 66
 67#[derive(Debug)]
 68pub struct AgentThreadSummary {
 69    pub id: ThreadId,
 70    pub title: String,
 71    pub created_at: DateTime<Utc>,
 72}
 73
 74#[derive(Debug, PartialEq, Eq)]
 75pub struct FileContent {
 76    pub path: PathBuf,
 77    pub version: FileVersion,
 78    pub content: String,
 79}
 80
 81#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 82pub enum Role {
 83    User,
 84    Assistant,
 85}
 86
 87#[derive(Debug, Eq, PartialEq)]
 88pub struct Message {
 89    pub role: Role,
 90    pub chunks: Vec<MessageChunk>,
 91}
 92
 93#[derive(Debug, Eq, PartialEq)]
 94pub enum MessageChunk {
 95    Text {
 96        chunk: String,
 97    },
 98    File {
 99        content: FileContent,
100    },
101    Directory {
102        path: PathBuf,
103        contents: Vec<FileContent>,
104    },
105    Symbol {
106        path: PathBuf,
107        range: Range<u64>,
108        version: FileVersion,
109        name: String,
110        content: String,
111    },
112    Thread {
113        title: String,
114        content: Vec<AgentThreadEntryContent>,
115    },
116    Fetch {
117        url: String,
118        content: String,
119    },
120}
121
122impl From<&str> for MessageChunk {
123    fn from(chunk: &str) -> Self {
124        MessageChunk::Text {
125            chunk: chunk.to_string(),
126        }
127    }
128}
129
130#[derive(Debug, Eq, PartialEq)]
131pub enum AgentThreadEntryContent {
132    Message(Message),
133    ReadFile { path: PathBuf, content: String },
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
137pub struct ThreadEntryId(usize);
138
139impl ThreadEntryId {
140    pub fn post_inc(&mut self) -> Self {
141        let id = *self;
142        self.0 += 1;
143        id
144    }
145}
146
147#[derive(Debug)]
148pub struct ThreadEntry {
149    pub id: ThreadEntryId,
150    pub content: AgentThreadEntryContent,
151}
152
153pub struct ThreadStore {
154    threads: Vec<AgentThreadSummary>,
155    agent: Arc<dyn Agent>,
156    project: Entity<Project>,
157}
158
159impl ThreadStore {
160    pub async fn load(
161        agent: Arc<dyn Agent>,
162        project: Entity<Project>,
163        cx: &mut AsyncApp,
164    ) -> Result<Entity<Self>> {
165        let threads = agent.threads(cx).await?;
166        cx.new(|_cx| Self {
167            threads,
168            agent,
169            project,
170        })
171    }
172
173    /// Returns the threads in reverse chronological order.
174    pub fn threads(&self) -> &[AgentThreadSummary] {
175        &self.threads
176    }
177
178    /// Opens a thread with the given ID.
179    pub fn open_thread(
180        &self,
181        id: ThreadId,
182        cx: &mut Context<Self>,
183    ) -> Task<Result<Entity<Thread>>> {
184        let agent = self.agent.clone();
185        cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
186    }
187
188    /// Creates a new thread.
189    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
190        let agent = self.agent.clone();
191        cx.spawn(async move |_, cx| agent.create_thread(cx).await)
192    }
193}
194
195pub struct Thread {
196    id: ThreadId,
197    next_entry_id: ThreadEntryId,
198    entries: Vec<ThreadEntry>,
199    agent: Arc<dyn Agent>,
200    project: Entity<Project>,
201}
202
203impl Thread {
204    pub async fn load(
205        agent: Arc<dyn Agent>,
206        thread_id: ThreadId,
207        project: Entity<Project>,
208        cx: &mut AsyncApp,
209    ) -> Result<Entity<Self>> {
210        let entries = agent.thread_entries(thread_id.clone(), cx).await?;
211        cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
212    }
213
214    pub fn new(
215        agent: Arc<dyn Agent>,
216        thread_id: ThreadId,
217        entries: Vec<AgentThreadEntryContent>,
218        project: Entity<Project>,
219        cx: &mut Context<Self>,
220    ) -> Self {
221        let mut next_entry_id = ThreadEntryId(0);
222        Self {
223            entries: entries
224                .into_iter()
225                .map(|entry| ThreadEntry {
226                    id: next_entry_id.post_inc(),
227                    content: entry,
228                })
229                .collect(),
230            agent,
231            id: thread_id,
232            next_entry_id,
233            project,
234        }
235    }
236
237    pub fn entries(&self) -> &[ThreadEntry] {
238        &self.entries
239    }
240
241    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
242        self.entries.push(ThreadEntry {
243            id: self.next_entry_id.post_inc(),
244            content: entry,
245        });
246        cx.notify();
247    }
248
249    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
250        let agent = self.agent.clone();
251        let id = self.id.clone();
252        cx.spawn(async move |this, cx| {
253            let mut events = agent.send_thread_message(id, message, cx).await?;
254            let mut pending_event_handlers = FuturesUnordered::new();
255
256            loop {
257                let mut next_event_handler_result = pin!(
258                    async {
259                        if pending_event_handlers.is_empty() {
260                            future::pending::<()>().await;
261                        }
262
263                        pending_event_handlers.next().await
264                    }
265                    .fuse()
266                );
267
268                select_biased! {
269                    event = events.next() => {
270                        let Some(event) = event else {
271                            while let Some(result) = pending_event_handlers.next().await {
272                                result?;
273                            }
274
275                            break;
276                        };
277
278                        let task = match event {
279                            Ok(ResponseEvent::MessageResponse(message)) => {
280                                this.update(cx, |this, cx| this.handle_message_response(message, cx))?
281                            }
282                            Ok(ResponseEvent::ReadFileRequest(request)) => {
283                                this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
284                            }
285                            Err(_) => todo!(),
286                        };
287                        pending_event_handlers.push(task);
288                    }
289                    result = next_event_handler_result => {
290                        // Event handlers should only return errors that are
291                        // unrecoverable and should therefore stop this turn of
292                        // the agentic loop.
293                        result.unwrap()?;
294                    }
295                }
296            }
297
298            Ok(())
299        })
300    }
301
302    fn handle_message_response(
303        &mut self,
304        mut message: MessageResponse,
305        cx: &mut Context<Self>,
306    ) -> Task<Result<()>> {
307        let entry_id = self.next_entry_id.post_inc();
308        self.entries.push(ThreadEntry {
309            id: entry_id,
310            content: AgentThreadEntryContent::Message(Message {
311                role: message.role,
312                chunks: Vec::new(),
313            }),
314        });
315        cx.notify();
316
317        cx.spawn(async move |this, cx| {
318            while let Some(chunk) = message.chunks.next().await {
319                match chunk {
320                    Ok(chunk) => {
321                        this.update(cx, |this, cx| {
322                            let ix = this
323                                .entries
324                                .binary_search_by_key(&entry_id, |entry| entry.id)
325                                .map_err(|_| anyhow!("message not found"))?;
326                            let AgentThreadEntryContent::Message(message) =
327                                &mut this.entries[ix].content
328                            else {
329                                unreachable!()
330                            };
331                            message.chunks.push(chunk);
332                            cx.notify();
333                            anyhow::Ok(())
334                        })??;
335                    }
336                    Err(err) => todo!("show error"),
337                }
338            }
339
340            Ok(())
341        })
342    }
343
344    fn handle_read_file_request(
345        &mut self,
346        request: ReadFileRequest,
347        cx: &mut Context<Self>,
348    ) -> Task<Result<()>> {
349        todo!()
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::acp::AcpAgent;
357    use gpui::TestAppContext;
358    use project::FakeFs;
359    use serde_json::json;
360    use settings::SettingsStore;
361    use std::{env, process::Stdio};
362    use util::path;
363
364    fn init_test(cx: &mut TestAppContext) {
365        env_logger::init();
366        cx.update(|cx| {
367            let settings_store = SettingsStore::test(cx);
368            cx.set_global(settings_store);
369            Project::init_settings(cx);
370        });
371    }
372
373    #[gpui::test]
374    async fn test_gemini(cx: &mut TestAppContext) {
375        init_test(cx);
376
377        cx.executor().allow_parking();
378
379        let fs = FakeFs::new(cx.executor());
380        fs.insert_tree(
381            path!("/test"),
382            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
383        )
384        .await;
385        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
386        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
387        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
388            .await
389            .unwrap();
390        let thread = thread_store
391            .update(cx, |thread_store, cx| {
392                assert_eq!(thread_store.threads().len(), 0);
393                thread_store.create_thread(cx)
394            })
395            .await
396            .unwrap();
397        thread
398            .update(cx, |thread, cx| {
399                thread.send(
400                    Message {
401                        role: Role::User,
402                        chunks: vec![
403                            "Read the 'test/foo' file and output all of its contents.".into(),
404                        ],
405                    },
406                    cx,
407                )
408            })
409            .await
410            .unwrap();
411        thread.read_with(cx, |thread, _| {
412            assert!(
413                thread.entries().iter().any(|entry| {
414                    entry.content
415                        == AgentThreadEntryContent::ReadFile {
416                            path: "test/foo".into(),
417                            content: "Lorem ipsum dolor".into(),
418                        }
419                }),
420                "Thread does not contain entry. Actual: {:?}",
421                thread.entries()
422            );
423        });
424    }
425
426    pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
427        let child = util::command::new_smol_command("node")
428            .arg("../../../gemini-cli/packages/cli")
429            .arg("--acp")
430            .args(["--model", "gemini-2.5-flash"])
431            .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
432            .stdin(Stdio::piped())
433            .stdout(Stdio::piped())
434            .stderr(Stdio::inherit())
435            .kill_on_drop(true)
436            .spawn()
437            .unwrap();
438
439        Ok(AcpAgent::stdio(child, project, cx))
440    }
441}