agent2.rs

  1use anyhow::{Result, anyhow};
  2use chrono::{DateTime, Utc};
  3use futures::{StreamExt, channel::oneshot, stream::BoxStream};
  4use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  5use project::Project;
  6use std::{ops::Range, path::PathBuf, sync::Arc};
  7use uuid::Uuid;
  8
  9pub trait Agent: 'static {
 10    type Thread: AgentThread;
 11
 12    fn threads(&self) -> impl Future<Output = Result<Vec<AgentThreadSummary>>>;
 13    fn create_thread(&self) -> impl Future<Output = Result<Self::Thread>>;
 14    fn open_thread(&self, id: ThreadId) -> impl Future<Output = Result<Self::Thread>>;
 15}
 16
 17pub trait AgentThread: 'static {
 18    fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
 19    fn send(
 20        &self,
 21        message: Message,
 22    ) -> impl Future<Output = Result<BoxStream<'static, Result<ResponseEvent>>>>;
 23}
 24
 25pub enum ResponseEvent {
 26    MessageResponse(MessageResponse),
 27    ReadFileRequest(ReadFileRequest),
 28    // GlobSearchRequest(SearchRequest),
 29    // RegexSearchRequest(RegexSearchRequest),
 30    // RunCommandRequest(RunCommandRequest),
 31    // WebSearchResponse(WebSearchResponse),
 32}
 33
 34pub struct MessageResponse {
 35    role: Role,
 36    chunks: BoxStream<'static, Result<MessageChunk>>,
 37}
 38
 39pub struct ReadFileRequest {
 40    path: PathBuf,
 41    range: Range<usize>,
 42    response_tx: oneshot::Sender<Result<FileContent>>,
 43}
 44
 45impl ReadFileRequest {
 46    pub fn respond(self, content: Result<FileContent>) {
 47        self.response_tx.send(content).ok();
 48    }
 49}
 50
 51pub struct ThreadId(Uuid);
 52
 53pub struct FileVersion(u64);
 54
 55pub struct AgentThreadSummary {
 56    pub id: ThreadId,
 57    pub title: String,
 58    pub created_at: DateTime<Utc>,
 59}
 60
 61pub struct FileContent {
 62    pub path: PathBuf,
 63    pub version: FileVersion,
 64    pub content: String,
 65}
 66
 67pub enum Role {
 68    User,
 69    Assistant,
 70}
 71
 72pub struct Message {
 73    pub role: Role,
 74    pub chunks: Vec<MessageChunk>,
 75}
 76
 77pub enum MessageChunk {
 78    Text {
 79        chunk: String,
 80    },
 81    File {
 82        content: FileContent,
 83    },
 84    Directory {
 85        path: PathBuf,
 86        contents: Vec<FileContent>,
 87    },
 88    Symbol {
 89        path: PathBuf,
 90        range: Range<u64>,
 91        version: FileVersion,
 92        name: String,
 93        content: String,
 94    },
 95    Thread {
 96        title: String,
 97        content: Vec<AgentThreadEntry>,
 98    },
 99    Fetch {
100        url: String,
101        content: String,
102    },
103}
104
105pub enum AgentThreadEntry {
106    Message(Message),
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
110pub struct ThreadEntryId(usize);
111
112impl ThreadEntryId {
113    pub fn post_inc(&mut self) -> Self {
114        let id = *self;
115        self.0 += 1;
116        id
117    }
118}
119
120pub struct ThreadEntry {
121    pub id: ThreadEntryId,
122    pub entry: AgentThreadEntry,
123}
124
125pub struct ThreadStore<T: Agent> {
126    threads: Vec<AgentThreadSummary>,
127    agent: Arc<T>,
128    project: Entity<Project>,
129}
130
131impl<T: Agent> ThreadStore<T> {
132    pub async fn load(
133        agent: Arc<T>,
134        project: Entity<Project>,
135        cx: &mut AsyncApp,
136    ) -> Result<Entity<Self>> {
137        let threads = agent.threads().await?;
138        cx.new(|cx| Self {
139            threads,
140            agent,
141            project,
142        })
143    }
144
145    /// Returns the threads in reverse chronological order.
146    pub fn threads(&self) -> &[AgentThreadSummary] {
147        &self.threads
148    }
149
150    /// Opens a thread with the given ID.
151    pub fn open_thread(
152        &self,
153        id: ThreadId,
154        cx: &mut Context<Self>,
155    ) -> Task<Result<Entity<Thread<T::Thread>>>> {
156        let agent = self.agent.clone();
157        let project = self.project.clone();
158        cx.spawn(async move |_, cx| {
159            let agent_thread = agent.open_thread(id).await?;
160            Thread::load(Arc::new(agent_thread), project, cx).await
161        })
162    }
163
164    /// Creates a new thread.
165    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
166        let agent = self.agent.clone();
167        let project = self.project.clone();
168        cx.spawn(async move |_, cx| {
169            let agent_thread = agent.create_thread().await?;
170            Thread::load(Arc::new(agent_thread), project, cx).await
171        })
172    }
173}
174
175pub struct Thread<T: AgentThread> {
176    next_entry_id: ThreadEntryId,
177    entries: Vec<ThreadEntry>,
178    agent_thread: Arc<T>,
179    project: Entity<Project>,
180}
181
182impl<T: AgentThread> Thread<T> {
183    pub async fn load(
184        agent_thread: Arc<T>,
185        project: Entity<Project>,
186        cx: &mut AsyncApp,
187    ) -> Result<Entity<Self>> {
188        let entries = agent_thread.entries().await?;
189        cx.new(|cx| Self::new(agent_thread, entries, project, cx))
190    }
191
192    pub fn new(
193        agent_thread: Arc<T>,
194        entries: Vec<AgentThreadEntry>,
195        project: Entity<Project>,
196        cx: &mut Context<Self>,
197    ) -> Self {
198        let mut next_entry_id = ThreadEntryId(0);
199        Self {
200            entries: entries
201                .into_iter()
202                .map(|entry| ThreadEntry {
203                    id: next_entry_id.post_inc(),
204                    entry,
205                })
206                .collect(),
207            next_entry_id,
208            agent_thread,
209            project,
210        }
211    }
212
213    async fn handle_message(
214        this: WeakEntity<Self>,
215        role: Role,
216        mut chunks: BoxStream<'static, Result<MessageChunk>>,
217        cx: &mut AsyncApp,
218    ) -> Result<()> {
219        let entry_id = this.update(cx, |this, cx| {
220            let entry_id = this.next_entry_id.post_inc();
221            this.entries.push(ThreadEntry {
222                id: entry_id,
223                entry: AgentThreadEntry::Message(Message {
224                    role,
225                    chunks: Vec::new(),
226                }),
227            });
228            cx.notify();
229            entry_id
230        })?;
231
232        while let Some(chunk) = chunks.next().await {
233            match chunk {
234                Ok(chunk) => {
235                    this.update(cx, |this, cx| {
236                        let ix = this
237                            .entries
238                            .binary_search_by_key(&entry_id, |entry| entry.id)
239                            .map_err(|_| anyhow!("message not found"))?;
240                        let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry else {
241                            unreachable!()
242                        };
243                        message.chunks.push(chunk);
244                        cx.notify();
245                        anyhow::Ok(())
246                    })??;
247                }
248                Err(err) => todo!("show error"),
249            }
250        }
251
252        Ok(())
253    }
254
255    pub fn entries(&self) -> &[ThreadEntry] {
256        &self.entries
257    }
258
259    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
260        let agent_thread = self.agent_thread.clone();
261        cx.spawn(async move |this, cx| {
262            let mut events = agent_thread.send(message).await?;
263            while let Some(event) = events.next().await {
264                match event {
265                    Ok(ResponseEvent::MessageResponse(message)) => {
266                        this.update(cx, |this, cx| this.handle_message_response(message, cx))?
267                            .await?;
268                    }
269                    Ok(ResponseEvent::ReadFileRequest(request)) => {
270                        this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
271                            .await?;
272                    }
273                    Err(_) => todo!(),
274                }
275            }
276            Ok(())
277        })
278    }
279
280    fn handle_message_response(
281        &mut self,
282        mut message: MessageResponse,
283        cx: &mut Context<Self>,
284    ) -> Task<Result<()>> {
285        let entry_id = self.next_entry_id.post_inc();
286        self.entries.push(ThreadEntry {
287            id: entry_id,
288            entry: AgentThreadEntry::Message(Message {
289                role: message.role,
290                chunks: Vec::new(),
291            }),
292        });
293        cx.notify();
294
295        cx.spawn(async move |this, cx| {
296            while let Some(chunk) = message.chunks.next().await {
297                match chunk {
298                    Ok(chunk) => {
299                        this.update(cx, |this, cx| {
300                            let ix = this
301                                .entries
302                                .binary_search_by_key(&entry_id, |entry| entry.id)
303                                .map_err(|_| anyhow!("message not found"))?;
304                            let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry
305                            else {
306                                unreachable!()
307                            };
308                            message.chunks.push(chunk);
309                            cx.notify();
310                            anyhow::Ok(())
311                        })??;
312                    }
313                    Err(err) => todo!("show error"),
314                }
315            }
316
317            Ok(())
318        })
319    }
320
321    fn handle_read_file_request(
322        &mut self,
323        request: ReadFileRequest,
324        cx: &mut Context<Self>,
325    ) -> Task<Result<()>> {
326        todo!()
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use gpui::{BackgroundExecutor, TestAppContext};
334    use project::FakeFs;
335    use serde_json::json;
336    use std::path::Path;
337    use util::path;
338
339    #[gpui::test]
340    async fn test_basic(cx: &mut TestAppContext) {
341        cx.executor().allow_parking();
342
343        let fs = FakeFs::new(cx.executor());
344        fs.insert_tree(
345            path!("/test"),
346            json!({"foo": "foo", "bar": "bar", "baz": "baz"}),
347        )
348        .await;
349        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
350        let agent = GeminiAgent::start("~/gemini-cli/change-me.js", &cx.executor())
351            .await
352            .unwrap();
353        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
354            .await
355            .unwrap();
356    }
357
358    struct GeminiAgent {}
359
360    impl GeminiAgent {
361        pub fn start(path: impl AsRef<Path>, executor: &BackgroundExecutor) -> Task<Result<Self>> {
362            executor.spawn(async move { Ok(GeminiAgent {}) })
363        }
364    }
365
366    impl Agent for GeminiAgent {
367        type Thread = GeminiAgentThread;
368
369        async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
370            todo!()
371        }
372
373        async fn create_thread(&self) -> Result<Self::Thread> {
374            todo!()
375        }
376
377        async fn open_thread(&self, id: ThreadId) -> Result<Self::Thread> {
378            todo!()
379        }
380    }
381
382    struct GeminiAgentThread {}
383
384    impl AgentThread for GeminiAgentThread {
385        async fn entries(&self) -> Result<Vec<AgentThreadEntry>> {
386            todo!()
387        }
388
389        async fn send(
390            &self,
391            message: Message,
392        ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
393            todo!()
394        }
395    }
396}