agent2.rs

  1mod acp;
  2
  3use anyhow::{Result, anyhow};
  4use chrono::{DateTime, Utc};
  5use futures::{
  6    FutureExt, StreamExt,
  7    channel::{mpsc, oneshot},
  8    select_biased,
  9    stream::{BoxStream, FuturesUnordered},
 10};
 11use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
 12use project::Project;
 13use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
 14
 15pub trait Agent: 'static {
 16    type Thread: AgentThread;
 17
 18    fn threads(&self) -> impl Future<Output = Result<Vec<AgentThreadSummary>>>;
 19    fn create_thread(&self) -> impl Future<Output = Result<Self::Thread>>;
 20    fn open_thread(&self, id: ThreadId) -> impl Future<Output = Result<Self::Thread>>;
 21}
 22
 23pub trait AgentThread: 'static {
 24    fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
 25    fn send(
 26        &self,
 27        message: Message,
 28    ) -> impl Future<Output = Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>>;
 29}
 30
 31pub enum ResponseEvent {
 32    MessageResponse(MessageResponse),
 33    ReadFileRequest(ReadFileRequest),
 34    // GlobSearchRequest(SearchRequest),
 35    // RegexSearchRequest(RegexSearchRequest),
 36    // RunCommandRequest(RunCommandRequest),
 37    // WebSearchResponse(WebSearchResponse),
 38}
 39
 40pub struct MessageResponse {
 41    role: Role,
 42    chunks: BoxStream<'static, Result<MessageChunk>>,
 43}
 44
 45pub struct ReadFileRequest {
 46    path: PathBuf,
 47    range: Range<usize>,
 48    response_tx: oneshot::Sender<Result<FileContent>>,
 49}
 50
 51impl ReadFileRequest {
 52    pub fn respond(self, content: Result<FileContent>) {
 53        self.response_tx.send(content).ok();
 54    }
 55}
 56
 57pub struct ThreadId(String);
 58
 59pub struct FileVersion(u64);
 60
 61pub struct AgentThreadSummary {
 62    pub id: ThreadId,
 63    pub title: String,
 64    pub created_at: DateTime<Utc>,
 65}
 66
 67pub struct FileContent {
 68    pub path: PathBuf,
 69    pub version: FileVersion,
 70    pub content: String,
 71}
 72
 73pub enum Role {
 74    User,
 75    Assistant,
 76}
 77
 78pub struct Message {
 79    pub role: Role,
 80    pub chunks: Vec<MessageChunk>,
 81}
 82
 83pub enum MessageChunk {
 84    Text {
 85        chunk: String,
 86    },
 87    File {
 88        content: FileContent,
 89    },
 90    Directory {
 91        path: PathBuf,
 92        contents: Vec<FileContent>,
 93    },
 94    Symbol {
 95        path: PathBuf,
 96        range: Range<u64>,
 97        version: FileVersion,
 98        name: String,
 99        content: String,
100    },
101    Thread {
102        title: String,
103        content: Vec<AgentThreadEntry>,
104    },
105    Fetch {
106        url: String,
107        content: String,
108    },
109}
110
111pub enum AgentThreadEntry {
112    Message(Message),
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
116pub struct ThreadEntryId(usize);
117
118impl ThreadEntryId {
119    pub fn post_inc(&mut self) -> Self {
120        let id = *self;
121        self.0 += 1;
122        id
123    }
124}
125
126pub struct ThreadEntry {
127    pub id: ThreadEntryId,
128    pub entry: AgentThreadEntry,
129}
130
131pub struct ThreadStore<T: Agent> {
132    threads: Vec<AgentThreadSummary>,
133    agent: Arc<T>,
134    project: Entity<Project>,
135}
136
137impl<T: Agent> ThreadStore<T> {
138    pub async fn load(
139        agent: Arc<T>,
140        project: Entity<Project>,
141        cx: &mut AsyncApp,
142    ) -> Result<Entity<Self>> {
143        let threads = agent.threads().await?;
144        cx.new(|cx| Self {
145            threads,
146            agent,
147            project,
148        })
149    }
150
151    /// Returns the threads in reverse chronological order.
152    pub fn threads(&self) -> &[AgentThreadSummary] {
153        &self.threads
154    }
155
156    /// Opens a thread with the given ID.
157    pub fn open_thread(
158        &self,
159        id: ThreadId,
160        cx: &mut Context<Self>,
161    ) -> Task<Result<Entity<Thread<T::Thread>>>> {
162        let agent = self.agent.clone();
163        let project = self.project.clone();
164        cx.spawn(async move |_, cx| {
165            let agent_thread = agent.open_thread(id).await?;
166            Thread::load(Arc::new(agent_thread), project, cx).await
167        })
168    }
169
170    /// Creates a new thread.
171    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
172        let agent = self.agent.clone();
173        let project = self.project.clone();
174        cx.spawn(async move |_, cx| {
175            let agent_thread = agent.create_thread().await?;
176            Thread::load(Arc::new(agent_thread), project, cx).await
177        })
178    }
179}
180
181pub struct Thread<T: AgentThread> {
182    next_entry_id: ThreadEntryId,
183    entries: Vec<ThreadEntry>,
184    agent_thread: Arc<T>,
185    project: Entity<Project>,
186}
187
188impl<T: AgentThread> Thread<T> {
189    pub async fn load(
190        agent_thread: Arc<T>,
191        project: Entity<Project>,
192        cx: &mut AsyncApp,
193    ) -> Result<Entity<Self>> {
194        let entries = agent_thread.entries().await?;
195        cx.new(|cx| Self::new(agent_thread, entries, project, cx))
196    }
197
198    pub fn new(
199        agent_thread: Arc<T>,
200        entries: Vec<AgentThreadEntry>,
201        project: Entity<Project>,
202        cx: &mut Context<Self>,
203    ) -> Self {
204        let mut next_entry_id = ThreadEntryId(0);
205        Self {
206            entries: entries
207                .into_iter()
208                .map(|entry| ThreadEntry {
209                    id: next_entry_id.post_inc(),
210                    entry,
211                })
212                .collect(),
213            next_entry_id,
214            agent_thread,
215            project,
216        }
217    }
218
219    async fn handle_message(
220        this: WeakEntity<Self>,
221        role: Role,
222        mut chunks: BoxStream<'static, Result<MessageChunk>>,
223        cx: &mut AsyncApp,
224    ) -> Result<()> {
225        let entry_id = this.update(cx, |this, cx| {
226            let entry_id = this.next_entry_id.post_inc();
227            this.entries.push(ThreadEntry {
228                id: entry_id,
229                entry: AgentThreadEntry::Message(Message {
230                    role,
231                    chunks: Vec::new(),
232                }),
233            });
234            cx.notify();
235            entry_id
236        })?;
237
238        while let Some(chunk) = chunks.next().await {
239            match chunk {
240                Ok(chunk) => {
241                    this.update(cx, |this, cx| {
242                        let ix = this
243                            .entries
244                            .binary_search_by_key(&entry_id, |entry| entry.id)
245                            .map_err(|_| anyhow!("message not found"))?;
246                        let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry else {
247                            unreachable!()
248                        };
249                        message.chunks.push(chunk);
250                        cx.notify();
251                        anyhow::Ok(())
252                    })??;
253                }
254                Err(err) => todo!("show error"),
255            }
256        }
257
258        Ok(())
259    }
260
261    pub fn entries(&self) -> &[ThreadEntry] {
262        &self.entries
263    }
264
265    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
266        let agent_thread = self.agent_thread.clone();
267        cx.spawn(async move |this, cx| {
268            let mut events = agent_thread.send(message).await?;
269            let mut pending_event_handlers = FuturesUnordered::new();
270
271            loop {
272                let mut next_event_handler_result = pin!(async {
273                    if pending_event_handlers.is_empty() {
274                        future::pending::<()>().await;
275                    }
276
277                    pending_event_handlers.next().await
278                }.fuse());
279
280                select_biased! {
281                    event = events.next() => {
282                        let Some(event) = event else {
283                            while let Some(result) = pending_event_handlers.next().await {
284                                result?;
285                            }
286
287                            break;
288                        };
289
290                        let task = match event {
291                            Ok(ResponseEvent::MessageResponse(message)) => {
292                                this.update(cx, |this, cx| this.handle_message_response(message, cx))?
293                            }
294                            Ok(ResponseEvent::ReadFileRequest(request)) => {
295                                this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
296                            }
297                            Err(_) => todo!(),
298                        };
299                        pending_event_handlers.push(task);
300                    }
301                    result = next_event_handler_result => {
302                        // Event handlers should only return errors that are
303                        // unrecoverable and should therefore stop this turn of
304                        // the agentic loop.
305                        result.unwrap()?;
306                    }
307                }
308            }
309
310            Ok(())
311        })
312    }
313
314    fn handle_message_response(
315        &mut self,
316        mut message: MessageResponse,
317        cx: &mut Context<Self>,
318    ) -> Task<Result<()>> {
319        let entry_id = self.next_entry_id.post_inc();
320        self.entries.push(ThreadEntry {
321            id: entry_id,
322            entry: AgentThreadEntry::Message(Message {
323                role: message.role,
324                chunks: Vec::new(),
325            }),
326        });
327        cx.notify();
328
329        cx.spawn(async move |this, cx| {
330            while let Some(chunk) = message.chunks.next().await {
331                match chunk {
332                    Ok(chunk) => {
333                        this.update(cx, |this, cx| {
334                            let ix = this
335                                .entries
336                                .binary_search_by_key(&entry_id, |entry| entry.id)
337                                .map_err(|_| anyhow!("message not found"))?;
338                            let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry
339                            else {
340                                unreachable!()
341                            };
342                            message.chunks.push(chunk);
343                            cx.notify();
344                            anyhow::Ok(())
345                        })??;
346                    }
347                    Err(err) => todo!("show error"),
348                }
349            }
350
351            Ok(())
352        })
353    }
354
355    fn handle_read_file_request(
356        &mut self,
357        request: ReadFileRequest,
358        cx: &mut Context<Self>,
359    ) -> Task<Result<()>> {
360        todo!()
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use agentic_coding_protocol::Client;
368    use gpui::{BackgroundExecutor, TestAppContext};
369    use project::FakeFs;
370    use serde_json::json;
371    use settings::SettingsStore;
372    use smol::process::Child;
373    use std::env;
374    use util::path;
375
376    fn init_test(cx: &mut TestAppContext) {
377        cx.update(|cx| {
378            let settings_store = SettingsStore::test(cx);
379            cx.set_global(settings_store);
380            Project::init_settings(cx);
381        });
382    }
383
384    #[gpui::test]
385    async fn test_basic(cx: &mut TestAppContext) {
386        init_test(cx);
387
388        cx.executor().allow_parking();
389
390        let fs = FakeFs::new(cx.executor());
391        fs.insert_tree(
392            path!("/test"),
393            json!({"foo": "foo", "bar": "bar", "baz": "baz"}),
394        )
395        .await;
396        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
397        let agent = GeminiAgent::start(&cx.executor()).await.unwrap();
398        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
399            .await
400            .unwrap();
401    }
402
403    struct TestClient;
404
405    #[async_trait]
406    impl Client for TestClient {
407        async fn read_file(&self, _request: ReadFileParams) -> Result<ReadFileResponse> {
408            Ok(ReadFileResponse {
409                version: FileVersion(0),
410                content: "the content".into(),
411            })
412        }
413    }
414
415    struct GeminiAgent {
416        child: Child,
417        _task: Task<()>,
418    }
419
420    impl GeminiAgent {
421        pub fn start(executor: &BackgroundExecutor) -> Task<Result<Self>> {
422            executor.spawn(async move {
423                // todo!
424                let child = util::command::new_smol_command("node")
425                    .arg("../gemini-cli/packages/cli")
426                    .arg("--acp")
427                    .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
428                    .kill_on_drop(true)
429                    .spawn()
430                    .unwrap();
431
432                Ok(GeminiAgent { child })
433            })
434        }
435    }
436
437    impl Agent for GeminiAgent {
438        type Thread = GeminiAgentThread;
439
440        async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
441            todo!()
442        }
443
444        async fn create_thread(&self) -> Result<Self::Thread> {
445            todo!()
446        }
447
448        async fn open_thread(&self, id: ThreadId) -> Result<Self::Thread> {
449            todo!()
450        }
451    }
452
453    struct GeminiAgentThread {}
454
455    impl AgentThread for GeminiAgentThread {
456        async fn entries(&self) -> Result<Vec<AgentThreadEntry>> {
457            todo!()
458        }
459
460        async fn send(
461            &self,
462            _message: Message,
463        ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>> {
464            todo!()
465        }
466    }
467}