agent2.rs

  1mod acp;
  2mod thread_element;
  3
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use chrono::{DateTime, Utc};
  7use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
  8use project::Project;
  9use std::{ops::Range, path::PathBuf, sync::Arc};
 10
 11pub use acp::AcpAgent;
 12pub use thread_element::ThreadElement;
 13
 14#[async_trait(?Send)]
 15pub trait Agent: 'static {
 16    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
 17    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 18    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 19    async fn thread_entries(
 20        &self,
 21        id: ThreadId,
 22        cx: &mut AsyncApp,
 23    ) -> Result<Vec<AgentThreadEntryContent>>;
 24    async fn send_thread_message(
 25        &self,
 26        thread_id: ThreadId,
 27        message: Message,
 28        cx: &mut AsyncApp,
 29    ) -> Result<()>;
 30}
 31
 32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 33pub struct ThreadId(SharedString);
 34
 35#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 36pub struct FileVersion(u64);
 37
 38#[derive(Debug)]
 39pub struct AgentThreadSummary {
 40    pub id: ThreadId,
 41    pub title: String,
 42    pub created_at: DateTime<Utc>,
 43}
 44
 45#[derive(Clone, Debug, PartialEq, Eq)]
 46pub struct FileContent {
 47    pub path: PathBuf,
 48    pub version: FileVersion,
 49    pub content: SharedString,
 50}
 51
 52#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 53pub enum Role {
 54    User,
 55    Assistant,
 56}
 57
 58#[derive(Clone, Debug, Eq, PartialEq)]
 59pub struct Message {
 60    pub role: Role,
 61    pub chunks: Vec<MessageChunk>,
 62}
 63
 64#[derive(Clone, Debug, Eq, PartialEq)]
 65pub enum MessageChunk {
 66    Text {
 67        chunk: SharedString,
 68    },
 69    File {
 70        content: FileContent,
 71    },
 72    Directory {
 73        path: PathBuf,
 74        contents: Vec<FileContent>,
 75    },
 76    Symbol {
 77        path: PathBuf,
 78        range: Range<u64>,
 79        version: FileVersion,
 80        name: SharedString,
 81        content: SharedString,
 82    },
 83    Fetch {
 84        url: SharedString,
 85        content: SharedString,
 86    },
 87}
 88
 89impl From<&str> for MessageChunk {
 90    fn from(chunk: &str) -> Self {
 91        MessageChunk::Text {
 92            chunk: chunk.to_string().into(),
 93        }
 94    }
 95}
 96
 97#[derive(Clone, Debug, Eq, PartialEq)]
 98pub enum AgentThreadEntryContent {
 99    Message(Message),
100    ReadFile { path: PathBuf, content: String },
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
104pub struct ThreadEntryId(usize);
105
106impl ThreadEntryId {
107    pub fn post_inc(&mut self) -> Self {
108        let id = *self;
109        self.0 += 1;
110        id
111    }
112}
113
114#[derive(Debug)]
115pub struct ThreadEntry {
116    pub id: ThreadEntryId,
117    pub content: AgentThreadEntryContent,
118}
119
120pub struct ThreadStore {
121    threads: Vec<AgentThreadSummary>,
122    agent: Arc<dyn Agent>,
123    project: Entity<Project>,
124}
125
126impl ThreadStore {
127    pub async fn load(
128        agent: Arc<dyn Agent>,
129        project: Entity<Project>,
130        cx: &mut AsyncApp,
131    ) -> Result<Entity<Self>> {
132        let threads = agent.threads(cx).await?;
133        cx.new(|_cx| Self {
134            threads,
135            agent,
136            project,
137        })
138    }
139
140    /// Returns the threads in reverse chronological order.
141    pub fn threads(&self) -> &[AgentThreadSummary] {
142        &self.threads
143    }
144
145    /// Opens a thread with the given ID.
146    pub fn open_thread(
147        &self,
148        id: ThreadId,
149        cx: &mut Context<Self>,
150    ) -> Task<Result<Entity<Thread>>> {
151        let agent = self.agent.clone();
152        cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
153    }
154
155    /// Creates a new thread.
156    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
157        let agent = self.agent.clone();
158        cx.spawn(async move |_, cx| agent.create_thread(cx).await)
159    }
160}
161
162pub struct Thread {
163    id: ThreadId,
164    next_entry_id: ThreadEntryId,
165    entries: Vec<ThreadEntry>,
166    agent: Arc<dyn Agent>,
167    title: SharedString,
168    project: Entity<Project>,
169}
170
171impl Thread {
172    pub async fn load(
173        agent: Arc<dyn Agent>,
174        thread_id: ThreadId,
175        project: Entity<Project>,
176        cx: &mut AsyncApp,
177    ) -> Result<Entity<Self>> {
178        let entries = agent.thread_entries(thread_id.clone(), cx).await?;
179        cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
180    }
181
182    pub fn new(
183        agent: Arc<dyn Agent>,
184        thread_id: ThreadId,
185        entries: Vec<AgentThreadEntryContent>,
186        project: Entity<Project>,
187        _: &mut Context<Self>,
188    ) -> Self {
189        let mut next_entry_id = ThreadEntryId(0);
190        Self {
191            title: "A new agent2 thread".into(),
192            entries: entries
193                .into_iter()
194                .map(|entry| ThreadEntry {
195                    id: next_entry_id.post_inc(),
196                    content: entry,
197                })
198                .collect(),
199            agent,
200            id: thread_id,
201            next_entry_id,
202            project,
203        }
204    }
205
206    pub fn title(&self) -> SharedString {
207        self.title.clone()
208    }
209
210    pub fn entries(&self) -> &[ThreadEntry] {
211        &self.entries
212    }
213
214    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
215        self.entries.push(ThreadEntry {
216            id: self.next_entry_id.post_inc(),
217            content: entry,
218        });
219        cx.notify();
220    }
221
222    pub fn push_assistant_chunk(&mut self, chunk: MessageChunk, cx: &mut Context<Self>) {
223        if let Some(last_entry) = self.entries.last_mut() {
224            if let AgentThreadEntryContent::Message(Message {
225                ref mut chunks,
226                role: Role::Assistant,
227            }) = last_entry.content
228            {
229                // todo! merge with last chunk if same type
230                chunks.push(chunk);
231                cx.notify();
232                return;
233            }
234        }
235
236        self.entries.push(ThreadEntry {
237            id: self.next_entry_id.post_inc(),
238            content: AgentThreadEntryContent::Message(Message {
239                role: Role::Assistant,
240                chunks: vec![chunk],
241            }),
242        });
243        cx.notify();
244    }
245
246    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
247        let agent = self.agent.clone();
248        let id = self.id.clone();
249        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
250        cx.spawn(async move |_, cx| {
251            agent.send_thread_message(id, message, cx).await?;
252            Ok(())
253        })
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::acp::AcpAgent;
261    use gpui::TestAppContext;
262    use project::FakeFs;
263    use serde_json::json;
264    use settings::SettingsStore;
265    use std::{env, path::Path, process::Stdio};
266    use util::path;
267
268    fn init_test(cx: &mut TestAppContext) {
269        env_logger::init();
270        cx.update(|cx| {
271            let settings_store = SettingsStore::test(cx);
272            cx.set_global(settings_store);
273            Project::init_settings(cx);
274            language::init(cx);
275        });
276    }
277
278    #[gpui::test]
279    async fn test_gemini(cx: &mut TestAppContext) {
280        init_test(cx);
281
282        cx.executor().allow_parking();
283
284        let fs = FakeFs::new(cx.executor());
285        fs.insert_tree(
286            path!("/private/tmp"),
287            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
288        )
289        .await;
290        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
291        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
292        let thread_store = ThreadStore::load(agent, project, &mut cx.to_async())
293            .await
294            .unwrap();
295        let thread = thread_store
296            .update(cx, |thread_store, cx| {
297                assert_eq!(thread_store.threads().len(), 0);
298                thread_store.create_thread(cx)
299            })
300            .await
301            .unwrap();
302        thread
303            .update(cx, |thread, cx| {
304                thread.send(
305                    Message {
306                        role: Role::User,
307                        chunks: vec![
308                            "Read the '/private/tmp/foo' file and output all of its contents."
309                                .into(),
310                        ],
311                    },
312                    cx,
313                )
314            })
315            .await
316            .unwrap();
317
318        thread.read_with(cx, |thread, _| {
319            assert!(matches!(
320                thread.entries[0].content,
321                AgentThreadEntryContent::Message(Message {
322                    role: Role::User,
323                    ..
324                })
325            ));
326            assert!(
327                thread.entries().iter().any(|entry| {
328                    entry.content
329                        == AgentThreadEntryContent::ReadFile {
330                            path: "/private/tmp/foo".into(),
331                            content: "Lorem ipsum dolor".into(),
332                        }
333                }),
334                "Thread does not contain entry. Actual: {:?}",
335                thread.entries()
336            );
337        });
338    }
339
340    pub fn gemini_agent(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpAgent>> {
341        let cli_path =
342            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
343        let mut command = util::command::new_smol_command("node");
344        command
345            .arg(cli_path)
346            .arg("--acp")
347            .args(["--model", "gemini-2.5-flash"])
348            .current_dir("/private/tmp")
349            .stdin(Stdio::piped())
350            .stdout(Stdio::piped())
351            .stderr(Stdio::inherit())
352            .kill_on_drop(true);
353
354        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
355            command.env("GEMINI_API_KEY", gemini_key);
356        }
357
358        let child = command.spawn().unwrap();
359
360        Ok(AcpAgent::stdio(child, project, &mut cx))
361    }
362}