agent2.rs

  1use anyhow::{Result, anyhow};
  2use chrono::{DateTime, Utc};
  3use futures::{StreamExt, stream::BoxStream};
  4use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  5use std::{ops::Range, path::PathBuf, sync::Arc};
  6use uuid::Uuid;
  7
  8pub trait Agent: 'static {
  9    type Thread: AgentThread;
 10
 11    fn threads(&self) -> impl Future<Output = Result<Vec<AgentThreadSummary>>>;
 12    fn create_thread(&self) -> impl Future<Output = Result<Self::Thread>>;
 13    fn open_thread(&self, id: ThreadId) -> impl Future<Output = Result<Self::Thread>>;
 14}
 15
 16pub trait AgentThread: 'static {
 17    fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
 18    fn send(&self, message: Message) -> impl Future<Output = Result<()>>;
 19    fn on_message(
 20        &self,
 21        handler: impl AsyncFn(Role, BoxStream<'static, Result<MessageChunk>>) -> Result<()>,
 22    );
 23}
 24
 25pub struct ThreadId(Uuid);
 26
 27pub struct FileVersion(u64);
 28
 29pub struct AgentThreadSummary {
 30    pub id: ThreadId,
 31    pub title: String,
 32    pub created_at: DateTime<Utc>,
 33}
 34
 35pub struct FileContent {
 36    pub path: PathBuf,
 37    pub version: FileVersion,
 38    pub content: String,
 39}
 40
 41pub enum Role {
 42    User,
 43    Assistant,
 44}
 45
 46pub struct Message {
 47    pub role: Role,
 48    pub chunks: Vec<MessageChunk>,
 49}
 50
 51pub enum MessageChunk {
 52    Text {
 53        chunk: String,
 54    },
 55    File {
 56        content: FileContent,
 57    },
 58    Directory {
 59        path: PathBuf,
 60        contents: Vec<FileContent>,
 61    },
 62    Symbol {
 63        path: PathBuf,
 64        range: Range<u64>,
 65        version: FileVersion,
 66        name: String,
 67        content: String,
 68    },
 69    Thread {
 70        title: String,
 71        content: Vec<AgentThreadEntry>,
 72    },
 73    Fetch {
 74        url: String,
 75        content: String,
 76    },
 77}
 78
 79pub enum AgentThreadEntry {
 80    Message(Message),
 81}
 82
 83#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 84pub struct ThreadEntryId(usize);
 85
 86impl ThreadEntryId {
 87    pub fn post_inc(&mut self) -> Self {
 88        let id = *self;
 89        self.0 += 1;
 90        id
 91    }
 92}
 93
 94pub struct ThreadEntry {
 95    pub id: ThreadEntryId,
 96    pub entry: AgentThreadEntry,
 97}
 98
 99pub struct ThreadStore<T: Agent> {
100    agent: Arc<T>,
101    threads: Vec<AgentThreadSummary>,
102}
103
104impl<T: Agent> ThreadStore<T> {
105    pub async fn load(agent: Arc<T>, cx: &mut AsyncApp) -> Result<Entity<Self>> {
106        let threads = agent.threads().await?;
107        cx.new(|cx| Self { agent, threads })
108    }
109
110    /// Returns the threads in reverse chronological order.
111    pub fn threads(&self) -> &[AgentThreadSummary] {
112        &self.threads
113    }
114
115    /// Opens a thread with the given ID.
116    pub fn open_thread(
117        &self,
118        id: ThreadId,
119        cx: &mut Context<Self>,
120    ) -> Task<Result<Entity<Thread<T::Thread>>>> {
121        let agent = self.agent.clone();
122        cx.spawn(async move |_, cx| {
123            let agent_thread = agent.open_thread(id).await?;
124            Thread::load(Arc::new(agent_thread), cx).await
125        })
126    }
127
128    /// Creates a new thread.
129    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
130        let agent = self.agent.clone();
131        cx.spawn(async move |_, cx| {
132            let agent_thread = agent.create_thread().await?;
133            Thread::load(Arc::new(agent_thread), cx).await
134        })
135    }
136}
137
138pub struct Thread<T: AgentThread> {
139    agent_thread: Arc<T>,
140    entries: Vec<ThreadEntry>,
141    next_entry_id: ThreadEntryId,
142}
143
144impl<T: AgentThread> Thread<T> {
145    pub async fn load(agent_thread: Arc<T>, cx: &mut AsyncApp) -> Result<Entity<Self>> {
146        let entries = agent_thread.entries().await?;
147        cx.new(|cx| Self::new(agent_thread, entries, cx))
148    }
149
150    pub fn new(
151        agent_thread: Arc<T>,
152        entries: Vec<AgentThreadEntry>,
153        cx: &mut Context<Self>,
154    ) -> Self {
155        agent_thread.on_message({
156            let this = cx.weak_entity();
157            let cx = cx.to_async();
158            async move |role, chunks| {
159                Self::handle_message(this.clone(), role, chunks, &mut cx.clone()).await
160            }
161        });
162        let mut next_entry_id = ThreadEntryId(0);
163        Self {
164            agent_thread,
165            entries: entries
166                .into_iter()
167                .map(|entry| ThreadEntry {
168                    id: next_entry_id.post_inc(),
169                    entry,
170                })
171                .collect(),
172            next_entry_id,
173        }
174    }
175
176    async fn handle_message(
177        this: WeakEntity<Self>,
178        role: Role,
179        mut chunks: BoxStream<'static, Result<MessageChunk>>,
180        cx: &mut AsyncApp,
181    ) -> Result<()> {
182        let entry_id = this.update(cx, |this, cx| {
183            let entry_id = this.next_entry_id.post_inc();
184            this.entries.push(ThreadEntry {
185                id: entry_id,
186                entry: AgentThreadEntry::Message(Message {
187                    role,
188                    chunks: Vec::new(),
189                }),
190            });
191            cx.notify();
192            entry_id
193        })?;
194
195        while let Some(chunk) = chunks.next().await {
196            match chunk {
197                Ok(chunk) => {
198                    this.update(cx, |this, cx| {
199                        let ix = this
200                            .entries
201                            .binary_search_by_key(&entry_id, |entry| entry.id)
202                            .map_err(|_| anyhow!("message not found"))?;
203                        let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry else {
204                            unreachable!()
205                        };
206                        message.chunks.push(chunk);
207                        cx.notify();
208                        anyhow::Ok(())
209                    })??;
210                }
211                Err(err) => todo!("show error"),
212            }
213        }
214
215        Ok(())
216    }
217
218    pub fn entries(&self) -> &[ThreadEntry] {
219        &self.entries
220    }
221
222    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
223        let agent_thread = self.agent_thread.clone();
224        cx.spawn(async move |_, cx| agent_thread.send(message).await)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231}