agent2.rs

  1use anyhow::{Result, anyhow};
  2use chrono::{DateTime, Utc};
  3use futures::{
  4    FutureExt, StreamExt,
  5    channel::{mpsc, oneshot},
  6    select_biased,
  7    stream::{BoxStream, FuturesUnordered},
  8};
  9use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
 10use project::Project;
 11use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
 12use uuid::Uuid;
 13
 14pub trait Agent: 'static {
 15    type Thread: AgentThread;
 16
 17    fn threads(&self) -> impl Future<Output = Result<Vec<AgentThreadSummary>>>;
 18    fn create_thread(&self) -> impl Future<Output = Result<Self::Thread>>;
 19    fn open_thread(&self, id: ThreadId) -> impl Future<Output = Result<Self::Thread>>;
 20}
 21
 22pub trait AgentThread: 'static {
 23    fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
 24    fn send(
 25        &self,
 26        message: Message,
 27    ) -> impl Future<Output = Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>>;
 28}
 29
 30pub enum ResponseEvent {
 31    MessageResponse(MessageResponse),
 32    ReadFileRequest(ReadFileRequest),
 33    // GlobSearchRequest(SearchRequest),
 34    // RegexSearchRequest(RegexSearchRequest),
 35    // RunCommandRequest(RunCommandRequest),
 36    // WebSearchResponse(WebSearchResponse),
 37}
 38
 39pub struct MessageResponse {
 40    role: Role,
 41    chunks: BoxStream<'static, Result<MessageChunk>>,
 42}
 43
 44pub struct ReadFileRequest {
 45    path: PathBuf,
 46    range: Range<usize>,
 47    response_tx: oneshot::Sender<Result<FileContent>>,
 48}
 49
 50impl ReadFileRequest {
 51    pub fn respond(self, content: Result<FileContent>) {
 52        self.response_tx.send(content).ok();
 53    }
 54}
 55
 56pub struct ThreadId(Uuid);
 57
 58pub struct FileVersion(u64);
 59
 60pub struct AgentThreadSummary {
 61    pub id: ThreadId,
 62    pub title: String,
 63    pub created_at: DateTime<Utc>,
 64}
 65
 66pub struct FileContent {
 67    pub path: PathBuf,
 68    pub version: FileVersion,
 69    pub content: String,
 70}
 71
 72pub enum Role {
 73    User,
 74    Assistant,
 75}
 76
 77pub struct Message {
 78    pub role: Role,
 79    pub chunks: Vec<MessageChunk>,
 80}
 81
 82pub enum MessageChunk {
 83    Text {
 84        chunk: String,
 85    },
 86    File {
 87        content: FileContent,
 88    },
 89    Directory {
 90        path: PathBuf,
 91        contents: Vec<FileContent>,
 92    },
 93    Symbol {
 94        path: PathBuf,
 95        range: Range<u64>,
 96        version: FileVersion,
 97        name: String,
 98        content: String,
 99    },
100    Thread {
101        title: String,
102        content: Vec<AgentThreadEntry>,
103    },
104    Fetch {
105        url: String,
106        content: String,
107    },
108}
109
110pub enum AgentThreadEntry {
111    Message(Message),
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
115pub struct ThreadEntryId(usize);
116
117impl ThreadEntryId {
118    pub fn post_inc(&mut self) -> Self {
119        let ed = *self;
120        self.0 += 1;
121        id
122    }
123}
124
125pub struct ThreadEntry {
126    pub id: ThreadEntryId,
127    pub entry: AgentThreadEntry,
128}
129
130pub struct ThreadStore<T: Agent> {
131    threads: Vec<AgentThreadSummary>,
132    agent: Arc<T>,
133    project: Entity<Project>,
134}
135
136impl<T: Agent> ThreadStore<T> {
137    pub async fn load(
138        agent: Arc<T>,
139        project: Entity<Project>,
140        cx: &mut AsyncApp,
141    ) -> Result<Entity<Self>> {
142        let threads = agent.threads().await?;
143        cx.new(|cx| Self {
144            threads,
145            agent,
146            project,
147        })
148    }
149
150    /// Returns the threads in reverse chronological order.
151    pub fn threads(&self) -> &[AgentThreadSummary] {
152        &self.threads
153    }
154
155    /// Opens a thread with the given ID.
156    pub fn open_thread(
157        &self,
158        id: ThreadId,
159        cx: &mut Context<Self>,
160    ) -> Task<Result<Entity<Thread<T::Thread>>>> {
161        let agent = self.agent.clone();
162        let project = self.project.clone();
163        cx.spawn(async move |_, cx| {
164            let agent_thread = agent.open_thread(id).await?;
165            Thread::load(Arc::new(agent_thread), project, cx).await
166        })
167    }
168
169    /// Creates a new thread.
170    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
171        let agent = self.agent.clone();
172        let project = self.project.clone();
173        cx.spawn(async move |_, cx| {
174            let agent_thread = agent.create_thread().await?;
175            Thread::load(Arc::new(agent_thread), project, cx).await
176        })
177    }
178}
179
180pub struct Thread<T: AgentThread> {
181    next_entry_id: ThreadEntryId,
182    entries: Vec<ThreadEntry>,
183    agent_thread: Arc<T>,
184    project: Entity<Project>,
185}
186
187impl<T: AgentThread> Thread<T> {
188    pub async fn load(
189        agent_thread: Arc<T>,
190        project: Entity<Project>,
191        cx: &mut AsyncApp,
192    ) -> Result<Entity<Self>> {
193        let entries = agent_thread.entries().await?;
194        cx.new(|cx| Self::new(agent_thread, entries, project, cx))
195    }
196
197    pub fn new(
198        agent_thread: Arc<T>,
199        entries: Vec<AgentThreadEntry>,
200        project: Entity<Project>,
201        cx: &mut Context<Self>,
202    ) -> Self {
203        let mut next_entry_id = ThreadEntryId(0);
204        Self {
205            entries: entries
206                .into_iter()
207                .map(|entry| ThreadEntry {
208                    id: next_entry_id.post_inc(),
209                    entry,
210                })
211                .collect(),
212            next_entry_id,
213            agent_thread,
214            project,
215        }
216    }
217
218    async fn handle_message(
219        this: WeakEntity<Self>,
220        role: Role,
221        mut chunks: BoxStream<'static, Result<MessageChunk>>,
222        cx: &mut AsyncApp,
223    ) -> Result<()> {
224        let entry_id = this.update(cx, |this, cx| {
225            let entry_id = this.next_entry_id.post_inc();
226            this.entries.push(ThreadEntry {
227                id: entry_id,
228                entry: AgentThreadEntry::Message(Message {
229                    role,
230                    chunks: Vec::new(),
231                }),
232            });
233            cx.notify();
234            entry_id
235        })?;
236
237        while let Some(chunk) = chunks.next().await {
238            match chunk {
239                Ok(chunk) => {
240                    this.update(cx, |this, cx| {
241                        let ix = this
242                            .entries
243                            .binary_search_by_key(&entry_id, |entry| entry.id)
244                            .map_err(|_| anyhow!("message not found"))?;
245                        let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry else {
246                            unreachable!()
247                        };
248                        message.chunks.push(chunk);
249                        cx.notify();
250                        anyhow::Ok(())
251                    })??;
252                }
253                Err(err) => todo!("show error"),
254            }
255        }
256
257        Ok(())
258    }
259
260    pub fn entries(&self) -> &[ThreadEntry] {
261        &self.entries
262    }
263
264    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
265        let agent_thread = self.agent_thread.clone();
266        cx.spawn(async move |this, cx| {
267            let mut events = agent_thread.send(message).await?;
268            let mut pending_event_handlers = FuturesUnordered::new();
269
270            loop {
271                let mut next_event_handler_result = pin!(async {
272                    if pending_event_handlers.is_empty() {
273                        future::pending::<()>().await;
274                    }
275
276                    pending_event_handlers.next().await
277                }.fuse());
278
279                select_biased! {
280                    event = events.next() => {
281                        let Some(event) = event else {
282                            while let Some(result) = pending_event_handlers.next().await {
283                                result?;
284                            }
285
286                            break;
287                        };
288
289                        let task = match event {
290                            Ok(ResponseEvent::MessageResponse(message)) => {
291                                this.update(cx, |this, cx| this.handle_message_response(message, cx))?
292                            }
293                            Ok(ResponseEvent::ReadFileRequest(request)) => {
294                                this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
295                            }
296                            Err(_) => todo!(),
297                        };
298                        pending_event_handlers.push(task);
299                    }
300                    result = next_event_handler_result => {
301                        // Event handlers should only return errors that are
302                        // unrecoverable and should therefore stop this turn of
303                        // the agentic loop.
304                        result.unwrap()?;
305                    }
306                }
307            }
308
309            Ok(())
310        })
311    }
312
313    fn handle_message_response(
314        &mut self,
315        mut message: MessageResponse,
316        cx: &mut Context<Self>,
317    ) -> Task<Result<()>> {
318        let entry_id = self.next_entry_id.post_inc();
319        self.entries.push(ThreadEntry {
320            id: entry_id,
321            entry: AgentThreadEntry::Message(Message {
322                role: message.role,
323                chunks: Vec::new(),
324            }),
325        });
326        cx.notify();
327
328        cx.spawn(async move |this, cx| {
329            while let Some(chunk) = message.chunks.next().await {
330                match chunk {
331                    Ok(chunk) => {
332                        this.update(cx, |this, cx| {
333                            let ix = this
334                                .entries
335                                .binary_search_by_key(&entry_id, |entry| entry.id)
336                                .map_err(|_| anyhow!("message not found"))?;
337                            let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry
338                            else {
339                                unreachable!()
340                            };
341                            message.chunks.push(chunk);
342                            cx.notify();
343                            anyhow::Ok(())
344                        })??;
345                    }
346                    Err(err) => todo!("show error"),
347                }
348            }
349
350            Ok(())
351        })
352    }
353
354    fn handle_read_file_request(
355        &mut self,
356        request: ReadFileRequest,
357        cx: &mut Context<Self>,
358    ) -> Task<Result<()>> {
359        todo!()
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use gpui::{BackgroundExecutor, TestAppContext};
367    use project::FakeFs;
368    use serde_json::json;
369    use std::path::Path;
370    use util::path;
371
372    #[gpui::test]
373    async fn test_basic(cx: &mut TestAppContext) {
374        cx.executor().allow_parking();
375
376        let fs = FakeFs::new(cx.executor());
377        fs.insert_tree(
378            path!("/test"),
379            json!({"foo": "foo", "bar": "bar", "baz": "baz"}),
380        )
381        .await;
382        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
383        let agent = GeminiAgent::start("~/gemini-cli/change-me.js", &cx.executor())
384            .await
385            .unwrap();
386        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
387            .await
388            .unwrap();
389    }
390
391    struct GeminiAgent {}
392
393    impl GeminiAgent {
394        pub fn start(path: impl AsRef<Path>, executor: &BackgroundExecutor) -> Task<Result<Self>> {
395            executor.spawn(async move { Ok(GeminiAgent {}) })
396        }
397    }
398
399    impl Agent for GeminiAgent {
400        type Thread = GeminiAgentThread;
401
402        async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
403            todo!()
404        }
405
406        async fn create_thread(&self) -> Result<Self::Thread> {
407            todo!()
408        }
409
410        async fn open_thread(&self, id: ThreadId) -> Result<Self::Thread> {
411            todo!()
412        }
413    }
414
415    struct GeminiAgentThread {}
416
417    impl AgentThread for GeminiAgentThread {
418        async fn entries(&self) -> Result<Vec<AgentThreadEntry>> {
419            todo!()
420        }
421
422        async fn send(
423            &self,
424            _message: Message,
425        ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>> {
426            todo!()
427        }
428    }
429}