agent2.rs

  1mod acp;
  2
  3use anyhow::Result;
  4use async_trait::async_trait;
  5use chrono::{DateTime, Utc};
  6use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
  7use project::Project;
  8use std::{ops::Range, path::PathBuf, sync::Arc};
  9
 10pub use acp::AcpAgent;
 11
 12#[async_trait(?Send)]
 13pub trait Agent: 'static {
 14    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
 15    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 16    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 17    async fn thread_entries(
 18        &self,
 19        id: ThreadId,
 20        cx: &mut AsyncApp,
 21    ) -> Result<Vec<AgentThreadEntryContent>>;
 22    async fn send_thread_message(
 23        &self,
 24        thread_id: ThreadId,
 25        message: Message,
 26        cx: &mut AsyncApp,
 27    ) -> Result<()>;
 28}
 29
 30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 31pub struct ThreadId(SharedString);
 32
 33#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 34pub struct FileVersion(u64);
 35
 36#[derive(Debug)]
 37pub struct AgentThreadSummary {
 38    pub id: ThreadId,
 39    pub title: String,
 40    pub created_at: DateTime<Utc>,
 41}
 42
 43#[derive(Clone, Debug, PartialEq, Eq)]
 44pub struct FileContent {
 45    pub path: PathBuf,
 46    pub version: FileVersion,
 47    pub content: SharedString,
 48}
 49
 50#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 51pub enum Role {
 52    User,
 53    Assistant,
 54}
 55
 56#[derive(Clone, Debug, Eq, PartialEq)]
 57pub struct Message {
 58    pub role: Role,
 59    pub chunks: Vec<MessageChunk>,
 60}
 61
 62#[derive(Clone, Debug, Eq, PartialEq)]
 63pub enum MessageChunk {
 64    Text {
 65        chunk: SharedString,
 66    },
 67    File {
 68        content: FileContent,
 69    },
 70    Directory {
 71        path: PathBuf,
 72        contents: Vec<FileContent>,
 73    },
 74    Symbol {
 75        path: PathBuf,
 76        range: Range<u64>,
 77        version: FileVersion,
 78        name: SharedString,
 79        content: SharedString,
 80    },
 81    Fetch {
 82        url: SharedString,
 83        content: SharedString,
 84    },
 85}
 86
 87impl From<&str> for MessageChunk {
 88    fn from(chunk: &str) -> Self {
 89        MessageChunk::Text {
 90            chunk: chunk.to_string().into(),
 91        }
 92    }
 93}
 94
 95#[derive(Clone, Debug, Eq, PartialEq)]
 96pub enum AgentThreadEntryContent {
 97    Message(Message),
 98    ReadFile { path: PathBuf, content: String },
 99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
102pub struct ThreadEntryId(usize);
103
104impl ThreadEntryId {
105    pub fn post_inc(&mut self) -> Self {
106        let id = *self;
107        self.0 += 1;
108        id
109    }
110}
111
112#[derive(Debug)]
113pub struct ThreadEntry {
114    pub id: ThreadEntryId,
115    pub content: AgentThreadEntryContent,
116}
117
118pub struct ThreadStore {
119    threads: Vec<AgentThreadSummary>,
120    agent: Arc<dyn Agent>,
121    project: Entity<Project>,
122}
123
124impl ThreadStore {
125    pub async fn load(
126        agent: Arc<dyn Agent>,
127        project: Entity<Project>,
128        cx: &mut AsyncApp,
129    ) -> Result<Entity<Self>> {
130        let threads = agent.threads(cx).await?;
131        cx.new(|_cx| Self {
132            threads,
133            agent,
134            project,
135        })
136    }
137
138    /// Returns the threads in reverse chronological order.
139    pub fn threads(&self) -> &[AgentThreadSummary] {
140        &self.threads
141    }
142
143    /// Opens a thread with the given ID.
144    pub fn open_thread(
145        &self,
146        id: ThreadId,
147        cx: &mut Context<Self>,
148    ) -> Task<Result<Entity<Thread>>> {
149        let agent = self.agent.clone();
150        cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
151    }
152
153    /// Creates a new thread.
154    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
155        let agent = self.agent.clone();
156        cx.spawn(async move |_, cx| agent.create_thread(cx).await)
157    }
158}
159
160pub struct Thread {
161    id: ThreadId,
162    next_entry_id: ThreadEntryId,
163    entries: Vec<ThreadEntry>,
164    agent: Arc<dyn Agent>,
165    title: SharedString,
166    project: Entity<Project>,
167}
168
169impl Thread {
170    pub async fn load(
171        agent: Arc<dyn Agent>,
172        thread_id: ThreadId,
173        project: Entity<Project>,
174        cx: &mut AsyncApp,
175    ) -> Result<Entity<Self>> {
176        let entries = agent.thread_entries(thread_id.clone(), cx).await?;
177        cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
178    }
179
180    pub fn new(
181        agent: Arc<dyn Agent>,
182        thread_id: ThreadId,
183        entries: Vec<AgentThreadEntryContent>,
184        project: Entity<Project>,
185        _: &mut Context<Self>,
186    ) -> Self {
187        let mut next_entry_id = ThreadEntryId(0);
188        Self {
189            title: "A new agent2 thread".into(),
190            entries: entries
191                .into_iter()
192                .map(|entry| ThreadEntry {
193                    id: next_entry_id.post_inc(),
194                    content: entry,
195                })
196                .collect(),
197            agent,
198            id: thread_id,
199            next_entry_id,
200            project,
201        }
202    }
203
204    pub fn title(&self) -> SharedString {
205        self.title.clone()
206    }
207
208    pub fn entries(&self) -> &[ThreadEntry] {
209        &self.entries
210    }
211
212    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
213        self.entries.push(ThreadEntry {
214            id: self.next_entry_id.post_inc(),
215            content: entry,
216        });
217        cx.notify();
218    }
219
220    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
221        let agent = self.agent.clone();
222        let id = self.id.clone();
223        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
224        cx.spawn(async move |_, cx| {
225            agent.send_thread_message(id, message, cx).await?;
226            Ok(())
227        })
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::acp::AcpAgent;
235    use gpui::TestAppContext;
236    use project::FakeFs;
237    use serde_json::json;
238    use settings::SettingsStore;
239    use std::{env, path::Path, process::Stdio};
240    use util::path;
241
242    fn init_test(cx: &mut TestAppContext) {
243        env_logger::init();
244        cx.update(|cx| {
245            let settings_store = SettingsStore::test(cx);
246            cx.set_global(settings_store);
247            Project::init_settings(cx);
248            language::init(cx);
249        });
250    }
251
252    #[gpui::test]
253    async fn test_gemini(cx: &mut TestAppContext) {
254        init_test(cx);
255
256        cx.executor().allow_parking();
257
258        let fs = FakeFs::new(cx.executor());
259        fs.insert_tree(
260            path!("/private/tmp"),
261            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
262        )
263        .await;
264        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
265        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
266        let thread_store = ThreadStore::load(agent, project, &mut cx.to_async())
267            .await
268            .unwrap();
269        let thread = thread_store
270            .update(cx, |thread_store, cx| {
271                assert_eq!(thread_store.threads().len(), 0);
272                thread_store.create_thread(cx)
273            })
274            .await
275            .unwrap();
276        thread
277            .update(cx, |thread, cx| {
278                thread.send(
279                    Message {
280                        role: Role::User,
281                        chunks: vec![
282                            "Read the '/private/tmp/foo' file and output all of its contents."
283                                .into(),
284                        ],
285                    },
286                    cx,
287                )
288            })
289            .await
290            .unwrap();
291
292        thread.read_with(cx, |thread, _| {
293            assert!(matches!(
294                thread.entries[0].content,
295                AgentThreadEntryContent::Message(Message {
296                    role: Role::User,
297                    ..
298                })
299            ));
300            assert!(
301                thread.entries().iter().any(|entry| {
302                    entry.content
303                        == AgentThreadEntryContent::ReadFile {
304                            path: "/private/tmp/foo".into(),
305                            content: "Lorem ipsum dolor".into(),
306                        }
307                }),
308                "Thread does not contain entry. Actual: {:?}",
309                thread.entries()
310            );
311        });
312    }
313
314    pub fn gemini_agent(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpAgent>> {
315        let cli_path =
316            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
317        let mut command = util::command::new_smol_command("node");
318        command
319            .arg(cli_path)
320            .arg("--acp")
321            .args(["--model", "gemini-2.5-flash"])
322            .current_dir("/private/tmp")
323            .stdin(Stdio::piped())
324            .stdout(Stdio::piped())
325            .stderr(Stdio::inherit())
326            .kill_on_drop(true);
327
328        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
329            command.env("GEMINI_API_KEY", gemini_key);
330        }
331
332        let child = command.spawn().unwrap();
333
334        Ok(AcpAgent::stdio(child, project, &mut cx))
335    }
336}