agent2.rs

  1mod acp;
  2mod thread_element;
  3
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use chrono::{DateTime, Utc};
  7use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
  8use project::Project;
  9use std::{ops::Range, path::PathBuf, sync::Arc};
 10
 11pub use acp::AcpAgent;
 12pub use thread_element::ThreadElement;
 13
 14#[async_trait(?Send)]
 15pub trait Agent: 'static {
 16    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
 17    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 18    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 19    async fn thread_entries(
 20        &self,
 21        id: ThreadId,
 22        cx: &mut AsyncApp,
 23    ) -> Result<Vec<AgentThreadEntryContent>>;
 24    async fn send_thread_message(
 25        &self,
 26        thread_id: ThreadId,
 27        message: Message,
 28        cx: &mut AsyncApp,
 29    ) -> Result<()>;
 30}
 31
 32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 33pub struct ThreadId(SharedString);
 34
 35#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 36pub struct FileVersion(u64);
 37
 38#[derive(Debug)]
 39pub struct AgentThreadSummary {
 40    pub id: ThreadId,
 41    pub title: String,
 42    pub created_at: DateTime<Utc>,
 43}
 44
 45#[derive(Clone, Debug, PartialEq, Eq)]
 46pub struct FileContent {
 47    pub path: PathBuf,
 48    pub version: FileVersion,
 49    pub content: SharedString,
 50}
 51
 52#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 53pub enum Role {
 54    User,
 55    Assistant,
 56}
 57
 58#[derive(Clone, Debug, Eq, PartialEq)]
 59pub struct Message {
 60    pub role: Role,
 61    pub chunks: Vec<MessageChunk>,
 62}
 63
 64#[derive(Clone, Debug, Eq, PartialEq)]
 65pub enum MessageChunk {
 66    Text {
 67        // todo! should it be shared string? what about streaming?
 68        chunk: String,
 69    },
 70    File {
 71        content: FileContent,
 72    },
 73    Directory {
 74        path: PathBuf,
 75        contents: Vec<FileContent>,
 76    },
 77    Symbol {
 78        path: PathBuf,
 79        range: Range<u64>,
 80        version: FileVersion,
 81        name: SharedString,
 82        content: SharedString,
 83    },
 84    Fetch {
 85        url: SharedString,
 86        content: SharedString,
 87    },
 88}
 89
 90impl From<&str> for MessageChunk {
 91    fn from(chunk: &str) -> Self {
 92        MessageChunk::Text {
 93            chunk: chunk.to_string().into(),
 94        }
 95    }
 96}
 97
 98#[derive(Clone, Debug, Eq, PartialEq)]
 99pub enum AgentThreadEntryContent {
100    Message(Message),
101    ReadFile { path: PathBuf, content: String },
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
105pub struct ThreadEntryId(usize);
106
107impl ThreadEntryId {
108    pub fn post_inc(&mut self) -> Self {
109        let id = *self;
110        self.0 += 1;
111        id
112    }
113}
114
115#[derive(Debug)]
116pub struct ThreadEntry {
117    pub id: ThreadEntryId,
118    pub content: AgentThreadEntryContent,
119}
120
121pub struct ThreadStore {
122    threads: Vec<AgentThreadSummary>,
123    agent: Arc<dyn Agent>,
124    project: Entity<Project>,
125}
126
127impl ThreadStore {
128    pub async fn load(
129        agent: Arc<dyn Agent>,
130        project: Entity<Project>,
131        cx: &mut AsyncApp,
132    ) -> Result<Entity<Self>> {
133        let threads = agent.threads(cx).await?;
134        cx.new(|_cx| Self {
135            threads,
136            agent,
137            project,
138        })
139    }
140
141    /// Returns the threads in reverse chronological order.
142    pub fn threads(&self) -> &[AgentThreadSummary] {
143        &self.threads
144    }
145
146    /// Opens a thread with the given ID.
147    pub fn open_thread(
148        &self,
149        id: ThreadId,
150        cx: &mut Context<Self>,
151    ) -> Task<Result<Entity<Thread>>> {
152        let agent = self.agent.clone();
153        cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
154    }
155
156    /// Creates a new thread.
157    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
158        let agent = self.agent.clone();
159        cx.spawn(async move |_, cx| agent.create_thread(cx).await)
160    }
161}
162
163pub struct Thread {
164    id: ThreadId,
165    next_entry_id: ThreadEntryId,
166    entries: Vec<ThreadEntry>,
167    agent: Arc<dyn Agent>,
168    title: SharedString,
169    project: Entity<Project>,
170}
171
172impl Thread {
173    pub async fn load(
174        agent: Arc<dyn Agent>,
175        thread_id: ThreadId,
176        project: Entity<Project>,
177        cx: &mut AsyncApp,
178    ) -> Result<Entity<Self>> {
179        let entries = agent.thread_entries(thread_id.clone(), cx).await?;
180        cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
181    }
182
183    pub fn new(
184        agent: Arc<dyn Agent>,
185        thread_id: ThreadId,
186        entries: Vec<AgentThreadEntryContent>,
187        project: Entity<Project>,
188        _: &mut Context<Self>,
189    ) -> Self {
190        let mut next_entry_id = ThreadEntryId(0);
191        Self {
192            title: "A new agent2 thread".into(),
193            entries: entries
194                .into_iter()
195                .map(|entry| ThreadEntry {
196                    id: next_entry_id.post_inc(),
197                    content: entry,
198                })
199                .collect(),
200            agent,
201            id: thread_id,
202            next_entry_id,
203            project,
204        }
205    }
206
207    pub fn title(&self) -> SharedString {
208        self.title.clone()
209    }
210
211    pub fn entries(&self) -> &[ThreadEntry] {
212        &self.entries
213    }
214
215    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
216        self.entries.push(ThreadEntry {
217            id: self.next_entry_id.post_inc(),
218            content: entry,
219        });
220        cx.notify();
221    }
222
223    pub fn push_assistant_chunk(&mut self, chunk: MessageChunk, cx: &mut Context<Self>) {
224        if let Some(last_entry) = self.entries.last_mut() {
225            if let AgentThreadEntryContent::Message(Message {
226                ref mut chunks,
227                role: Role::Assistant,
228            }) = last_entry.content
229            {
230                if let (
231                    Some(MessageChunk::Text { chunk: old_chunk }),
232                    MessageChunk::Text { chunk: new_chunk },
233                ) = (chunks.last_mut(), &chunk)
234                {
235                    old_chunk.push_str(&new_chunk);
236                    return cx.notify();
237                }
238
239                chunks.push(chunk);
240                return cx.notify();
241            }
242        }
243
244        self.entries.push(ThreadEntry {
245            id: self.next_entry_id.post_inc(),
246            content: AgentThreadEntryContent::Message(Message {
247                role: Role::Assistant,
248                chunks: vec![chunk],
249            }),
250        });
251        cx.notify();
252    }
253
254    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
255        let agent = self.agent.clone();
256        let id = self.id.clone();
257        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
258        cx.spawn(async move |_, cx| {
259            agent.send_thread_message(id, message, cx).await?;
260            Ok(())
261        })
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::acp::AcpAgent;
269    use gpui::TestAppContext;
270    use project::FakeFs;
271    use serde_json::json;
272    use settings::SettingsStore;
273    use std::{env, path::Path, process::Stdio};
274    use util::path;
275
276    fn init_test(cx: &mut TestAppContext) {
277        env_logger::init();
278        cx.update(|cx| {
279            let settings_store = SettingsStore::test(cx);
280            cx.set_global(settings_store);
281            Project::init_settings(cx);
282            language::init(cx);
283        });
284    }
285
286    #[gpui::test]
287    async fn test_gemini(cx: &mut TestAppContext) {
288        init_test(cx);
289
290        cx.executor().allow_parking();
291
292        let fs = FakeFs::new(cx.executor());
293        fs.insert_tree(
294            path!("/private/tmp"),
295            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
296        )
297        .await;
298        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
299        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
300        let thread_store = ThreadStore::load(agent, project, &mut cx.to_async())
301            .await
302            .unwrap();
303        let thread = thread_store
304            .update(cx, |thread_store, cx| {
305                assert_eq!(thread_store.threads().len(), 0);
306                thread_store.create_thread(cx)
307            })
308            .await
309            .unwrap();
310        thread
311            .update(cx, |thread, cx| {
312                thread.send(
313                    Message {
314                        role: Role::User,
315                        chunks: vec![
316                            "Read the '/private/tmp/foo' file and output all of its contents."
317                                .into(),
318                        ],
319                    },
320                    cx,
321                )
322            })
323            .await
324            .unwrap();
325
326        thread.read_with(cx, |thread, _| {
327            assert!(matches!(
328                thread.entries[0].content,
329                AgentThreadEntryContent::Message(Message {
330                    role: Role::User,
331                    ..
332                })
333            ));
334            assert!(
335                thread.entries().iter().any(|entry| {
336                    entry.content
337                        == AgentThreadEntryContent::ReadFile {
338                            path: "/private/tmp/foo".into(),
339                            content: "Lorem ipsum dolor".into(),
340                        }
341                }),
342                "Thread does not contain entry. Actual: {:?}",
343                thread.entries()
344            );
345        });
346    }
347
348    pub fn gemini_agent(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpAgent>> {
349        let cli_path =
350            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
351        let mut command = util::command::new_smol_command("node");
352        command
353            .arg(cli_path)
354            .arg("--acp")
355            .args(["--model", "gemini-2.5-flash"])
356            .current_dir("/private/tmp")
357            .stdin(Stdio::piped())
358            .stdout(Stdio::piped())
359            .stderr(Stdio::inherit())
360            .kill_on_drop(true);
361
362        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
363            command.env("GEMINI_API_KEY", gemini_key);
364        }
365
366        let child = command.spawn().unwrap();
367
368        Ok(AcpAgent::stdio(child, project, &mut cx))
369    }
370}