agent2.rs

  1mod acp;
  2
  3use anyhow::Result;
  4use async_trait::async_trait;
  5use chrono::{DateTime, Utc};
  6use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
  7use project::Project;
  8use std::{ops::Range, path::PathBuf, sync::Arc};
  9
 10pub use acp::AcpAgent;
 11
 12#[async_trait(?Send)]
 13pub trait Agent: 'static {
 14    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
 15    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 16    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 17    async fn thread_entries(
 18        &self,
 19        id: ThreadId,
 20        cx: &mut AsyncApp,
 21    ) -> Result<Vec<AgentThreadEntryContent>>;
 22    async fn send_thread_message(
 23        &self,
 24        thread_id: ThreadId,
 25        message: Message,
 26        cx: &mut AsyncApp,
 27    ) -> Result<()>;
 28}
 29
 30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 31pub struct ThreadId(SharedString);
 32
 33#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 34pub struct FileVersion(u64);
 35
 36#[derive(Debug)]
 37pub struct AgentThreadSummary {
 38    pub id: ThreadId,
 39    pub title: String,
 40    pub created_at: DateTime<Utc>,
 41}
 42
 43#[derive(Debug, PartialEq, Eq)]
 44pub struct FileContent {
 45    pub path: PathBuf,
 46    pub version: FileVersion,
 47    pub content: String,
 48}
 49
 50#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 51pub enum Role {
 52    User,
 53    Assistant,
 54}
 55
 56#[derive(Debug, Eq, PartialEq)]
 57pub struct Message {
 58    pub role: Role,
 59    pub chunks: Vec<MessageChunk>,
 60}
 61
 62#[derive(Debug, Eq, PartialEq)]
 63pub enum MessageChunk {
 64    Text {
 65        chunk: String,
 66    },
 67    File {
 68        content: FileContent,
 69    },
 70    Directory {
 71        path: PathBuf,
 72        contents: Vec<FileContent>,
 73    },
 74    Symbol {
 75        path: PathBuf,
 76        range: Range<u64>,
 77        version: FileVersion,
 78        name: String,
 79        content: String,
 80    },
 81    Thread {
 82        title: String,
 83        content: Vec<AgentThreadEntryContent>,
 84    },
 85    Fetch {
 86        url: String,
 87        content: String,
 88    },
 89}
 90
 91impl From<&str> for MessageChunk {
 92    fn from(chunk: &str) -> Self {
 93        MessageChunk::Text {
 94            chunk: chunk.to_string(),
 95        }
 96    }
 97}
 98
 99#[derive(Debug, Eq, PartialEq)]
100pub enum AgentThreadEntryContent {
101    Message(Message),
102    ReadFile { path: PathBuf, content: String },
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
106pub struct ThreadEntryId(usize);
107
108impl ThreadEntryId {
109    pub fn post_inc(&mut self) -> Self {
110        let id = *self;
111        self.0 += 1;
112        id
113    }
114}
115
116#[derive(Debug)]
117pub struct ThreadEntry {
118    pub id: ThreadEntryId,
119    pub content: AgentThreadEntryContent,
120}
121
122pub struct ThreadStore {
123    threads: Vec<AgentThreadSummary>,
124    agent: Arc<dyn Agent>,
125    project: Entity<Project>,
126}
127
128impl ThreadStore {
129    pub async fn load(
130        agent: Arc<dyn Agent>,
131        project: Entity<Project>,
132        cx: &mut AsyncApp,
133    ) -> Result<Entity<Self>> {
134        let threads = agent.threads(cx).await?;
135        cx.new(|_cx| Self {
136            threads,
137            agent,
138            project,
139        })
140    }
141
142    /// Returns the threads in reverse chronological order.
143    pub fn threads(&self) -> &[AgentThreadSummary] {
144        &self.threads
145    }
146
147    /// Opens a thread with the given ID.
148    pub fn open_thread(
149        &self,
150        id: ThreadId,
151        cx: &mut Context<Self>,
152    ) -> Task<Result<Entity<Thread>>> {
153        let agent = self.agent.clone();
154        cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
155    }
156
157    /// Creates a new thread.
158    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
159        let agent = self.agent.clone();
160        cx.spawn(async move |_, cx| agent.create_thread(cx).await)
161    }
162}
163
164pub struct Thread {
165    id: ThreadId,
166    next_entry_id: ThreadEntryId,
167    entries: Vec<ThreadEntry>,
168    agent: Arc<dyn Agent>,
169    title: SharedString,
170    project: Entity<Project>,
171}
172
173impl Thread {
174    pub async fn load(
175        agent: Arc<dyn Agent>,
176        thread_id: ThreadId,
177        project: Entity<Project>,
178        cx: &mut AsyncApp,
179    ) -> Result<Entity<Self>> {
180        let entries = agent.thread_entries(thread_id.clone(), cx).await?;
181        cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
182    }
183
184    pub fn new(
185        agent: Arc<dyn Agent>,
186        thread_id: ThreadId,
187        entries: Vec<AgentThreadEntryContent>,
188        project: Entity<Project>,
189        _: &mut Context<Self>,
190    ) -> Self {
191        let mut next_entry_id = ThreadEntryId(0);
192        Self {
193            title: "A new agent2 thread".into(),
194            entries: entries
195                .into_iter()
196                .map(|entry| ThreadEntry {
197                    id: next_entry_id.post_inc(),
198                    content: entry,
199                })
200                .collect(),
201            agent,
202            id: thread_id,
203            next_entry_id,
204            project,
205        }
206    }
207
208    pub fn title(&self) -> SharedString {
209        self.title.clone()
210    }
211
212    pub fn entries(&self) -> &[ThreadEntry] {
213        &self.entries
214    }
215
216    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
217        self.entries.push(ThreadEntry {
218            id: self.next_entry_id.post_inc(),
219            content: entry,
220        });
221        cx.notify();
222    }
223
224    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
225        let agent = self.agent.clone();
226        let id = self.id.clone();
227        cx.spawn(async move |_, cx| {
228            agent.send_thread_message(id, message, cx).await?;
229            Ok(())
230        })
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::acp::AcpAgent;
238    use gpui::TestAppContext;
239    use project::FakeFs;
240    use serde_json::json;
241    use settings::SettingsStore;
242    use std::{env, path::Path, process::Stdio};
243    use util::path;
244
245    fn init_test(cx: &mut TestAppContext) {
246        env_logger::init();
247        cx.update(|cx| {
248            let settings_store = SettingsStore::test(cx);
249            cx.set_global(settings_store);
250            Project::init_settings(cx);
251            language::init(cx);
252        });
253    }
254
255    #[gpui::test]
256    async fn test_gemini(cx: &mut TestAppContext) {
257        init_test(cx);
258
259        cx.executor().allow_parking();
260
261        let fs = FakeFs::new(cx.executor());
262        fs.insert_tree(
263            path!("/private/tmp"),
264            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
265        )
266        .await;
267        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
268        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
269        let thread_store = ThreadStore::load(agent, project, &mut cx.to_async())
270            .await
271            .unwrap();
272        let thread = thread_store
273            .update(cx, |thread_store, cx| {
274                assert_eq!(thread_store.threads().len(), 0);
275                thread_store.create_thread(cx)
276            })
277            .await
278            .unwrap();
279        thread
280            .update(cx, |thread, cx| {
281                thread.send(
282                    Message {
283                        role: Role::User,
284                        chunks: vec![
285                            "Read the '/private/tmp/foo' file and output all of its contents."
286                                .into(),
287                        ],
288                    },
289                    cx,
290                )
291            })
292            .await
293            .unwrap();
294        thread.read_with(cx, |thread, _| {
295            assert!(
296                thread.entries().iter().any(|entry| {
297                    entry.content
298                        == AgentThreadEntryContent::ReadFile {
299                            path: "/private/tmp/foo".into(),
300                            content: "Lorem ipsum dolor".into(),
301                        }
302                }),
303                "Thread does not contain entry. Actual: {:?}",
304                thread.entries()
305            );
306        });
307    }
308
309    pub fn gemini_agent(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpAgent>> {
310        let cli_path =
311            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
312        let mut command = util::command::new_smol_command("node");
313        command
314            .arg(cli_path)
315            .arg("--acp")
316            .args(["--model", "gemini-2.5-flash"])
317            .current_dir("/private/tmp")
318            .stdin(Stdio::piped())
319            .stdout(Stdio::piped())
320            .stderr(Stdio::inherit())
321            .kill_on_drop(true);
322
323        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
324            command.env("GEMINI_API_KEY", gemini_key);
325        }
326
327        let child = command.spawn().unwrap();
328
329        Ok(AcpAgent::stdio(child, project, &mut cx))
330    }
331}