agent2.rs

  1mod acp;
  2
  3use anyhow::{Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use futures::{
  6    FutureExt, StreamExt,
  7    channel::{mpsc, oneshot},
  8    select_biased,
  9    stream::{BoxStream, FuturesUnordered},
 10};
 11use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
 12use project::Project;
 13use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
 14
 15pub trait Agent: 'static {
 16    type Thread: AgentThread;
 17
 18    fn threads(&self) -> impl Future<Output = Result<Vec<AgentThreadSummary>>>;
 19    fn create_thread(&self) -> impl Future<Output = Result<Self::Thread>>;
 20    fn open_thread(&self, id: ThreadId) -> impl Future<Output = Result<Self::Thread>>;
 21}
 22
 23pub trait AgentThread: 'static {
 24    fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
 25    fn send(
 26        &self,
 27        message: Message,
 28    ) -> impl Future<Output = Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>>;
 29}
 30
 31pub enum ResponseEvent {
 32    MessageResponse(MessageResponse),
 33    ReadFileRequest(ReadFileRequest),
 34    // GlobSearchRequest(SearchRequest),
 35    // RegexSearchRequest(RegexSearchRequest),
 36    // RunCommandRequest(RunCommandRequest),
 37    // WebSearchResponse(WebSearchResponse),
 38}
 39
 40pub struct MessageResponse {
 41    role: Role,
 42    chunks: BoxStream<'static, Result<MessageChunk>>,
 43}
 44
 45#[derive(Debug)]
 46pub struct ReadFileRequest {
 47    path: PathBuf,
 48    range: Range<usize>,
 49    response_tx: oneshot::Sender<Result<FileContent>>,
 50}
 51
 52impl ReadFileRequest {
 53    pub fn respond(self, content: Result<FileContent>) {
 54        self.response_tx.send(content).ok();
 55    }
 56}
 57
 58#[derive(Debug, Clone)]
 59pub struct ThreadId(String);
 60
 61#[derive(Debug, Clone, Copy)]
 62pub struct FileVersion(u64);
 63
 64#[derive(Debug, Clone)]
 65pub struct AgentThreadSummary {
 66    pub id: ThreadId,
 67    pub title: String,
 68    pub created_at: DateTime<Utc>,
 69}
 70
 71#[derive(Debug, Clone)]
 72pub struct FileContent {
 73    pub path: PathBuf,
 74    pub version: FileVersion,
 75    pub content: String,
 76}
 77
 78#[derive(Debug, Clone)]
 79pub enum Role {
 80    User,
 81    Assistant,
 82}
 83
 84#[derive(Debug, Clone)]
 85pub struct Message {
 86    pub role: Role,
 87    pub chunks: Vec<MessageChunk>,
 88}
 89
 90#[derive(Debug, Clone)]
 91pub enum MessageChunk {
 92    Text {
 93        chunk: String,
 94    },
 95    File {
 96        content: FileContent,
 97    },
 98    Directory {
 99        path: PathBuf,
100        contents: Vec<FileContent>,
101    },
102    Symbol {
103        path: PathBuf,
104        range: Range<u64>,
105        version: FileVersion,
106        name: String,
107        content: String,
108    },
109    Thread {
110        title: String,
111        content: Vec<AgentThreadEntry>,
112    },
113    Fetch {
114        url: String,
115        content: String,
116    },
117}
118
119#[derive(Debug, Clone)]
120pub enum AgentThreadEntry {
121    Message(Message),
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
125pub struct ThreadEntryId(usize);
126
127impl ThreadEntryId {
128    pub fn post_inc(&mut self) -> Self {
129        let id = *self;
130        self.0 += 1;
131        id
132    }
133}
134
135#[derive(Debug, Clone)]
136pub struct ThreadEntry {
137    pub id: ThreadEntryId,
138    pub entry: AgentThreadEntry,
139}
140
141pub struct ThreadStore<T: Agent> {
142    threads: Vec<AgentThreadSummary>,
143    agent: Arc<T>,
144    project: Entity<Project>,
145}
146
147impl<T: Agent> ThreadStore<T> {
148    pub async fn load(
149        agent: Arc<T>,
150        project: Entity<Project>,
151        cx: &mut AsyncApp,
152    ) -> Result<Entity<Self>> {
153        let threads = agent.threads().await?;
154        cx.new(|cx| Self {
155            threads,
156            agent,
157            project,
158        })
159    }
160
161    /// Returns the threads in reverse chronological order.
162    pub fn threads(&self) -> &[AgentThreadSummary] {
163        &self.threads
164    }
165
166    /// Opens a thread with the given ID.
167    pub fn open_thread(
168        &self,
169        id: ThreadId,
170        cx: &mut Context<Self>,
171    ) -> Task<Result<Entity<Thread<T::Thread>>>> {
172        let agent = self.agent.clone();
173        let project = self.project.clone();
174        cx.spawn(async move |_, cx| {
175            let agent_thread = agent.open_thread(id).await?;
176            Thread::load(Arc::new(agent_thread), project, cx).await
177        })
178    }
179
180    /// Creates a new thread.
181    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
182        let agent = self.agent.clone();
183        let project = self.project.clone();
184        cx.spawn(async move |_, cx| {
185            let agent_thread = agent.create_thread().await?;
186            Thread::load(Arc::new(agent_thread), project, cx).await
187        })
188    }
189}
190
191pub struct Thread<T: AgentThread> {
192    next_entry_id: ThreadEntryId,
193    entries: Vec<ThreadEntry>,
194    agent_thread: Arc<T>,
195    project: Entity<Project>,
196}
197
198impl<T: AgentThread> Thread<T> {
199    pub async fn load(
200        agent_thread: Arc<T>,
201        project: Entity<Project>,
202        cx: &mut AsyncApp,
203    ) -> Result<Entity<Self>> {
204        let entries = agent_thread.entries().await?;
205        cx.new(|cx| Self::new(agent_thread, entries, project, cx))
206    }
207
208    pub fn new(
209        agent_thread: Arc<T>,
210        entries: Vec<AgentThreadEntry>,
211        project: Entity<Project>,
212        cx: &mut Context<Self>,
213    ) -> Self {
214        let mut next_entry_id = ThreadEntryId(0);
215        Self {
216            entries: entries
217                .into_iter()
218                .map(|entry| ThreadEntry {
219                    id: next_entry_id.post_inc(),
220                    entry,
221                })
222                .collect(),
223            next_entry_id,
224            agent_thread,
225            project,
226        }
227    }
228
229    async fn handle_message(
230        this: WeakEntity<Self>,
231        role: Role,
232        mut chunks: BoxStream<'static, Result<MessageChunk>>,
233        cx: &mut AsyncApp,
234    ) -> Result<()> {
235        let entry_id = this.update(cx, |this, cx| {
236            let entry_id = this.next_entry_id.post_inc();
237            this.entries.push(ThreadEntry {
238                id: entry_id,
239                entry: AgentThreadEntry::Message(Message {
240                    role,
241                    chunks: Vec::new(),
242                }),
243            });
244            cx.notify();
245            entry_id
246        })?;
247
248        while let Some(chunk) = chunks.next().await {
249            match chunk {
250                Ok(chunk) => {
251                    this.update(cx, |this, cx| {
252                        let ix = this
253                            .entries
254                            .binary_search_by_key(&entry_id, |entry| entry.id)
255                            .map_err(|_| anyhow!("message not found"))?;
256                        let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry else {
257                            unreachable!()
258                        };
259                        message.chunks.push(chunk);
260                        cx.notify();
261                        anyhow::Ok(())
262                    })??;
263                }
264                Err(err) => todo!("show error"),
265            }
266        }
267
268        Ok(())
269    }
270
271    pub fn entries(&self) -> &[ThreadEntry] {
272        &self.entries
273    }
274
275    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
276        let agent_thread = self.agent_thread.clone();
277        cx.spawn(async move |this, cx| {
278            let mut events = agent_thread.send(message).await?;
279            let mut pending_event_handlers = FuturesUnordered::new();
280
281            loop {
282                let mut next_event_handler_result = pin!(async {
283                    if pending_event_handlers.is_empty() {
284                        future::pending::<()>().await;
285                    }
286
287                    pending_event_handlers.next().await
288                }.fuse());
289
290                select_biased! {
291                    event = events.next() => {
292                        let Some(event) = event else {
293                            while let Some(result) = pending_event_handlers.next().await {
294                                result?;
295                            }
296
297                            break;
298                        };
299
300                        let task = match event {
301                            Ok(ResponseEvent::MessageResponse(message)) => {
302                                this.update(cx, |this, cx| this.handle_message_response(message, cx))?
303                            }
304                            Ok(ResponseEvent::ReadFileRequest(request)) => {
305                                this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
306                            }
307                            Err(_) => todo!(),
308                        };
309                        pending_event_handlers.push(task);
310                    }
311                    result = next_event_handler_result => {
312                        // Event handlers should only return errors that are
313                        // unrecoverable and should therefore stop this turn of
314                        // the agentic loop.
315                        result.unwrap()?;
316                    }
317                }
318            }
319
320            Ok(())
321        })
322    }
323
324    fn handle_message_response(
325        &mut self,
326        mut message: MessageResponse,
327        cx: &mut Context<Self>,
328    ) -> Task<Result<()>> {
329        let entry_id = self.next_entry_id.post_inc();
330        self.entries.push(ThreadEntry {
331            id: entry_id,
332            entry: AgentThreadEntry::Message(Message {
333                role: message.role,
334                chunks: Vec::new(),
335            }),
336        });
337        cx.notify();
338
339        cx.spawn(async move |this, cx| {
340            while let Some(chunk) = message.chunks.next().await {
341                match chunk {
342                    Ok(chunk) => {
343                        this.update(cx, |this, cx| {
344                            let ix = this
345                                .entries
346                                .binary_search_by_key(&entry_id, |entry| entry.id)
347                                .map_err(|_| anyhow!("message not found"))?;
348                            let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry
349                            else {
350                                unreachable!()
351                            };
352                            message.chunks.push(chunk);
353                            cx.notify();
354                            anyhow::Ok(())
355                        })??;
356                    }
357                    Err(err) => todo!("show error"),
358                }
359            }
360
361            Ok(())
362        })
363    }
364
365    fn handle_read_file_request(
366        &mut self,
367        request: ReadFileRequest,
368        cx: &mut Context<Self>,
369    ) -> Task<Result<()>> {
370        todo!()
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::acp::AcpAgent;
378    use gpui::TestAppContext;
379    use project::FakeFs;
380    use serde_json::json;
381    use settings::SettingsStore;
382    use std::{env, process::Stdio};
383    use util::path;
384
385    fn init_test(cx: &mut TestAppContext) {
386        env_logger::init();
387        cx.update(|cx| {
388            let settings_store = SettingsStore::test(cx);
389            cx.set_global(settings_store);
390            Project::init_settings(cx);
391        });
392    }
393
394    #[gpui::test]
395    async fn test_basic(cx: &mut TestAppContext) {
396        init_test(cx);
397
398        cx.executor().allow_parking();
399
400        let fs = FakeFs::new(cx.executor());
401        fs.insert_tree(
402            path!("/test"),
403            json!({"foo": "foo", "bar": "bar", "baz": "baz"}),
404        )
405        .await;
406        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
407        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
408        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
409            .await
410            .unwrap();
411    }
412
413    pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
414        let child = util::command::new_smol_command("node")
415            .arg("../../../gemini-cli/packages/cli")
416            .arg("--acp")
417            .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
418            .stdin(Stdio::piped())
419            .stdout(Stdio::piped())
420            .stderr(Stdio::inherit())
421            .kill_on_drop(true)
422            .spawn()
423            .unwrap();
424
425        Ok(AcpAgent::stdio(child, project, cx))
426    }
427}