agent2.rs

  1mod acp;
  2
  3use anyhow::{Result, anyhow};
  4use async_trait::async_trait;
  5use chrono::{DateTime, Utc};
  6use futures::{
  7    FutureExt, StreamExt,
  8    channel::{mpsc, oneshot},
  9    select_biased,
 10    stream::{BoxStream, FuturesUnordered},
 11};
 12use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
 13use project::Project;
 14use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
 15
 16#[async_trait]
 17pub trait Agent: 'static {
 18    async fn threads(&self) -> Result<Vec<AgentThreadSummary>>;
 19    async fn create_thread(&self) -> Result<Entity<Thread>>;
 20    async fn open_thread(&self, id: ThreadId) -> Result<Entity<Thread>>;
 21    async fn thread_entries(&self, id: ThreadId) -> Result<Vec<AgentThreadEntryContent>>;
 22    async fn send_thread_message(
 23        &self,
 24        thread_id: ThreadId,
 25        message: Message,
 26    ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>;
 27}
 28
 29pub enum ResponseEvent {
 30    MessageResponse(MessageResponse),
 31    ReadFileRequest(ReadFileRequest),
 32    // GlobSearchRequest(SearchRequest),
 33    // RegexSearchRequest(RegexSearchRequest),
 34    // RunCommandRequest(RunCommandRequest),
 35    // WebSearchResponse(WebSearchResponse),
 36}
 37
 38pub struct MessageResponse {
 39    role: Role,
 40    chunks: BoxStream<'static, Result<MessageChunk>>,
 41}
 42
 43#[derive(Debug)]
 44pub struct ReadFileRequest {
 45    path: PathBuf,
 46    range: Range<usize>,
 47    response_tx: oneshot::Sender<Result<FileContent>>,
 48}
 49
 50impl ReadFileRequest {
 51    pub fn respond(self, content: Result<FileContent>) {
 52        self.response_tx.send(content).ok();
 53    }
 54}
 55
 56#[derive(Debug, Clone)]
 57pub struct ThreadId(SharedString);
 58
 59#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 60pub struct FileVersion(u64);
 61
 62#[derive(Debug)]
 63pub struct AgentThreadSummary {
 64    pub id: ThreadId,
 65    pub title: String,
 66    pub created_at: DateTime<Utc>,
 67}
 68
 69#[derive(Debug, PartialEq, Eq)]
 70pub struct FileContent {
 71    pub path: PathBuf,
 72    pub version: FileVersion,
 73    pub content: String,
 74}
 75
 76#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 77pub enum Role {
 78    User,
 79    Assistant,
 80}
 81
 82#[derive(Debug, Eq, PartialEq)]
 83pub struct Message {
 84    pub role: Role,
 85    pub chunks: Vec<MessageChunk>,
 86}
 87
 88#[derive(Debug, Eq, PartialEq)]
 89pub enum MessageChunk {
 90    Text {
 91        chunk: String,
 92    },
 93    File {
 94        content: FileContent,
 95    },
 96    Directory {
 97        path: PathBuf,
 98        contents: Vec<FileContent>,
 99    },
100    Symbol {
101        path: PathBuf,
102        range: Range<u64>,
103        version: FileVersion,
104        name: String,
105        content: String,
106    },
107    Thread {
108        title: String,
109        content: Vec<AgentThreadEntryContent>,
110    },
111    Fetch {
112        url: String,
113        content: String,
114    },
115}
116
117impl From<&str> for MessageChunk {
118    fn from(chunk: &str) -> Self {
119        MessageChunk::Text {
120            chunk: chunk.to_string(),
121        }
122    }
123}
124
125#[derive(Debug, Eq, PartialEq)]
126pub enum AgentThreadEntryContent {
127    Message(Message),
128    ReadFile { path: PathBuf, content: String },
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
132pub struct ThreadEntryId(usize);
133
134impl ThreadEntryId {
135    pub fn post_inc(&mut self) -> Self {
136        let id = *self;
137        self.0 += 1;
138        id
139    }
140}
141
142#[derive(Debug)]
143pub struct ThreadEntry {
144    pub id: ThreadEntryId,
145    pub content: AgentThreadEntryContent,
146}
147
148pub struct ThreadStore<T: Agent> {
149    threads: Vec<AgentThreadSummary>,
150    agent: Arc<T>,
151    project: Entity<Project>,
152}
153
154impl<T: Agent> ThreadStore<T> {
155    pub async fn load(
156        agent: Arc<T>,
157        project: Entity<Project>,
158        cx: &mut AsyncApp,
159    ) -> Result<Entity<Self>> {
160        let threads = agent.threads().await?;
161        cx.new(|cx| Self {
162            threads,
163            agent,
164            project,
165        })
166    }
167
168    /// Returns the threads in reverse chronological order.
169    pub fn threads(&self) -> &[AgentThreadSummary] {
170        &self.threads
171    }
172
173    /// Opens a thread with the given ID.
174    pub fn open_thread(
175        &self,
176        id: ThreadId,
177        cx: &mut Context<Self>,
178    ) -> Task<Result<Entity<Thread>>> {
179        let agent = self.agent.clone();
180        let project = self.project.clone();
181        cx.spawn(async move |_, cx| {
182            let agent_thread = agent.open_thread(id).await?;
183            Thread::load(agent_thread, project, cx).await
184        })
185    }
186
187    /// Creates a new thread.
188    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
189        let agent = self.agent.clone();
190        let project = self.project.clone();
191        cx.spawn(async move |_, cx| {
192            let agent_thread = agent.create_thread().await?;
193            Thread::load(agent_thread, project, cx).await
194        })
195    }
196}
197
198pub struct Thread {
199    id: ThreadId,
200    next_entry_id: ThreadEntryId,
201    entries: Vec<ThreadEntry>,
202    agent: Arc<dyn Agent>,
203    project: Entity<Project>,
204}
205
206impl Thread {
207    pub async fn load(
208        agent: Arc<dyn Agent>,
209        thread_id: ThreadId,
210        project: Entity<Project>,
211        cx: &mut AsyncApp,
212    ) -> Result<Entity<Self>> {
213        let entries = agent.thread_entries(thread_id.clone()).await?;
214        cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
215    }
216
217    pub fn new(
218        agent: Arc<dyn Agent>,
219        thread_id: ThreadId,
220        entries: Vec<AgentThreadEntryContent>,
221        project: Entity<Project>,
222        cx: &mut Context<Self>,
223    ) -> Self {
224        let mut next_entry_id = ThreadEntryId(0);
225        Self {
226            entries: entries
227                .into_iter()
228                .map(|entry| ThreadEntry {
229                    id: next_entry_id.post_inc(),
230                    content: entry,
231                })
232                .collect(),
233            agent,
234            id: thread_id,
235            next_entry_id,
236            project,
237        }
238    }
239
240    pub fn entries(&self) -> &[ThreadEntry] {
241        &self.entries
242    }
243
244    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
245        let agent = self.agent.clone();
246        let id = self.id;
247        cx.spawn(async move |this, cx| {
248            let mut events = agent.send_thread_message(id, message).await?;
249            let mut pending_event_handlers = FuturesUnordered::new();
250
251            loop {
252                let mut next_event_handler_result = pin!(
253                    async {
254                        if pending_event_handlers.is_empty() {
255                            future::pending::<()>().await;
256                        }
257
258                        pending_event_handlers.next().await
259                    }
260                    .fuse()
261                );
262
263                select_biased! {
264                    event = events.next() => {
265                        let Some(event) = event else {
266                            while let Some(result) = pending_event_handlers.next().await {
267                                result?;
268                            }
269
270                            break;
271                        };
272
273                        let task = match event {
274                            Ok(ResponseEvent::MessageResponse(message)) => {
275                                this.update(cx, |this, cx| this.handle_message_response(message, cx))?
276                            }
277                            Ok(ResponseEvent::ReadFileRequest(request)) => {
278                                this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
279                            }
280                            Err(_) => todo!(),
281                        };
282                        pending_event_handlers.push(task);
283                    }
284                    result = next_event_handler_result => {
285                        // Event handlers should only return errors that are
286                        // unrecoverable and should therefore stop this turn of
287                        // the agentic loop.
288                        result.unwrap()?;
289                    }
290                }
291            }
292
293            Ok(())
294        })
295    }
296
297    fn handle_message_response(
298        &mut self,
299        mut message: MessageResponse,
300        cx: &mut Context<Self>,
301    ) -> Task<Result<()>> {
302        let entry_id = self.next_entry_id.post_inc();
303        self.entries.push(ThreadEntry {
304            id: entry_id,
305            content: AgentThreadEntryContent::Message(Message {
306                role: message.role,
307                chunks: Vec::new(),
308            }),
309        });
310        cx.notify();
311
312        cx.spawn(async move |this, cx| {
313            while let Some(chunk) = message.chunks.next().await {
314                match chunk {
315                    Ok(chunk) => {
316                        this.update(cx, |this, cx| {
317                            let ix = this
318                                .entries
319                                .binary_search_by_key(&entry_id, |entry| entry.id)
320                                .map_err(|_| anyhow!("message not found"))?;
321                            let AgentThreadEntryContent::Message(message) =
322                                &mut this.entries[ix].content
323                            else {
324                                unreachable!()
325                            };
326                            message.chunks.push(chunk);
327                            cx.notify();
328                            anyhow::Ok(())
329                        })??;
330                    }
331                    Err(err) => todo!("show error"),
332                }
333            }
334
335            Ok(())
336        })
337    }
338
339    fn handle_read_file_request(
340        &mut self,
341        request: ReadFileRequest,
342        cx: &mut Context<Self>,
343    ) -> Task<Result<()>> {
344        todo!()
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::acp::AcpAgent;
352    use gpui::TestAppContext;
353    use project::FakeFs;
354    use serde_json::json;
355    use settings::SettingsStore;
356    use std::{env, process::Stdio};
357    use util::path;
358
359    fn init_test(cx: &mut TestAppContext) {
360        env_logger::init();
361        cx.update(|cx| {
362            let settings_store = SettingsStore::test(cx);
363            cx.set_global(settings_store);
364            Project::init_settings(cx);
365        });
366    }
367
368    #[gpui::test]
369    async fn test_gemini(cx: &mut TestAppContext) {
370        init_test(cx);
371
372        cx.executor().allow_parking();
373
374        let fs = FakeFs::new(cx.executor());
375        fs.insert_tree(
376            path!("/test"),
377            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
378        )
379        .await;
380        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
381        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
382        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
383            .await
384            .unwrap();
385        let thread = thread_store
386            .update(cx, |thread_store, cx| {
387                assert_eq!(thread_store.threads().len(), 0);
388                thread_store.create_thread(cx)
389            })
390            .await
391            .unwrap();
392        thread
393            .update(cx, |thread, cx| {
394                thread.send(
395                    Message {
396                        role: Role::User,
397                        chunks: vec![
398                            "Read the 'test/foo' file and output all of its contents.".into(),
399                        ],
400                    },
401                    cx,
402                )
403            })
404            .await
405            .unwrap();
406        thread.read_with(cx, |thread, _| {
407            assert!(
408                thread.entries().iter().any(|entry| {
409                    entry.content
410                        == AgentThreadEntryContent::ReadFile {
411                            path: "test/foo".into(),
412                            content: "Lorem ipsum dolor".into(),
413                        }
414                }),
415                "Thread does not contain entry. Actual: {:?}",
416                thread.entries()
417            );
418        });
419    }
420
421    pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
422        let child = util::command::new_smol_command("node")
423            .arg("../../../gemini-cli/packages/cli")
424            .arg("--acp")
425            .args(["--model", "gemini-2.5-flash"])
426            .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
427            .stdin(Stdio::piped())
428            .stdout(Stdio::piped())
429            .stderr(Stdio::inherit())
430            .kill_on_drop(true)
431            .spawn()
432            .unwrap();
433
434        Ok(AcpAgent::stdio(child, project, cx))
435    }
436}