agent2.rs

  1mod acp;
  2mod thread_element;
  3
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use chrono::{DateTime, Utc};
  7use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
  8use project::Project;
  9use std::{ops::Range, path::PathBuf, sync::Arc};
 10
 11pub use acp::AcpAgent;
 12pub use thread_element::ThreadElement;
 13
 14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 15pub struct ThreadId(SharedString);
 16
 17#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 18pub struct FileVersion(u64);
 19
 20#[derive(Debug)]
 21pub struct AgentThreadSummary {
 22    pub id: ThreadId,
 23    pub title: String,
 24    pub created_at: DateTime<Utc>,
 25}
 26
 27#[derive(Clone, Debug, PartialEq, Eq)]
 28pub struct FileContent {
 29    pub path: PathBuf,
 30    pub version: FileVersion,
 31    pub content: SharedString,
 32}
 33
 34#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 35pub enum Role {
 36    User,
 37    Assistant,
 38}
 39
 40#[derive(Clone, Debug, Eq, PartialEq)]
 41pub struct Message {
 42    pub role: Role,
 43    pub chunks: Vec<MessageChunk>,
 44}
 45
 46#[derive(Clone, Debug, Eq, PartialEq)]
 47pub enum MessageChunk {
 48    Text {
 49        // todo! should it be shared string? what about streaming?
 50        chunk: String,
 51    },
 52    File {
 53        content: FileContent,
 54    },
 55    Directory {
 56        path: PathBuf,
 57        contents: Vec<FileContent>,
 58    },
 59    Symbol {
 60        path: PathBuf,
 61        range: Range<u64>,
 62        version: FileVersion,
 63        name: SharedString,
 64        content: SharedString,
 65    },
 66    Fetch {
 67        url: SharedString,
 68        content: SharedString,
 69    },
 70}
 71
 72impl From<&str> for MessageChunk {
 73    fn from(chunk: &str) -> Self {
 74        MessageChunk::Text {
 75            chunk: chunk.to_string().into(),
 76        }
 77    }
 78}
 79
 80#[derive(Clone, Debug, Eq, PartialEq)]
 81pub enum AgentThreadEntryContent {
 82    Message(Message),
 83    ReadFile { path: PathBuf, content: String },
 84}
 85
 86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 87pub struct ThreadEntryId(usize);
 88
 89impl ThreadEntryId {
 90    pub fn post_inc(&mut self) -> Self {
 91        let id = *self;
 92        self.0 += 1;
 93        id
 94    }
 95}
 96
 97#[derive(Debug)]
 98pub struct ThreadEntry {
 99    pub id: ThreadEntryId,
100    pub content: AgentThreadEntryContent,
101}
102
103pub struct ThreadStore {
104    threads: Vec<AgentThreadSummary>,
105    agent: Arc<dyn Agent>,
106    project: Entity<Project>,
107}
108
109impl ThreadStore {
110    pub async fn load(
111        agent: Arc<dyn Agent>,
112        project: Entity<Project>,
113        cx: &mut AsyncApp,
114    ) -> Result<Entity<Self>> {
115        let threads = agent.threads(cx).await?;
116        cx.new(|_cx| Self {
117            threads,
118            agent,
119            project,
120        })
121    }
122
123    /// Returns the threads in reverse chronological order.
124    pub fn threads(&self) -> &[AgentThreadSummary] {
125        &self.threads
126    }
127
128    /// Opens a thread with the given ID.
129    pub fn open_thread(
130        &self,
131        id: ThreadId,
132        cx: &mut Context<Self>,
133    ) -> Task<Result<Entity<Thread>>> {
134        let agent = self.agent.clone();
135        cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
136    }
137
138    /// Creates a new thread.
139    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
140        let agent = self.agent.clone();
141        cx.spawn(async move |_, cx| agent.create_thread(cx).await)
142    }
143}
144
145pub struct Thread {
146    id: ThreadId,
147    next_entry_id: ThreadEntryId,
148    entries: Vec<ThreadEntry>,
149    agent: Arc<dyn Agent>,
150    title: SharedString,
151    project: Entity<Project>,
152}
153
154impl Thread {
155    pub async fn load(
156        agent: Arc<dyn Agent>,
157        thread_id: ThreadId,
158        project: Entity<Project>,
159        cx: &mut AsyncApp,
160    ) -> Result<Entity<Self>> {
161        let entries = agent.thread_entries(thread_id.clone(), cx).await?;
162        cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
163    }
164
165    pub fn new(
166        agent: Arc<dyn Agent>,
167        thread_id: ThreadId,
168        entries: Vec<AgentThreadEntryContent>,
169        project: Entity<Project>,
170        _: &mut Context<Self>,
171    ) -> Self {
172        let mut next_entry_id = ThreadEntryId(0);
173        Self {
174            title: "A new agent2 thread".into(),
175            entries: entries
176                .into_iter()
177                .map(|entry| ThreadEntry {
178                    id: next_entry_id.post_inc(),
179                    content: entry,
180                })
181                .collect(),
182            agent,
183            id: thread_id,
184            next_entry_id,
185            project,
186        }
187    }
188
189    pub fn title(&self) -> SharedString {
190        self.title.clone()
191    }
192
193    pub fn entries(&self) -> &[ThreadEntry] {
194        &self.entries
195    }
196
197    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
198        self.entries.push(ThreadEntry {
199            id: self.next_entry_id.post_inc(),
200            content: entry,
201        });
202        cx.notify();
203    }
204
205    pub fn push_assistant_chunk(&mut self, chunk: MessageChunk, cx: &mut Context<Self>) {
206        if let Some(last_entry) = self.entries.last_mut() {
207            if let AgentThreadEntryContent::Message(Message {
208                ref mut chunks,
209                role: Role::Assistant,
210            }) = last_entry.content
211            {
212                if let (
213                    Some(MessageChunk::Text { chunk: old_chunk }),
214                    MessageChunk::Text { chunk: new_chunk },
215                ) = (chunks.last_mut(), &chunk)
216                {
217                    old_chunk.push_str(&new_chunk);
218                    return cx.notify();
219                }
220
221                chunks.push(chunk);
222                return cx.notify();
223            }
224        }
225
226        self.entries.push(ThreadEntry {
227            id: self.next_entry_id.post_inc(),
228            content: AgentThreadEntryContent::Message(Message {
229                role: Role::Assistant,
230                chunks: vec![chunk],
231            }),
232        });
233        cx.notify();
234    }
235
236    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
237        let agent = self.agent.clone();
238        let id = self.id.clone();
239        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
240        cx.spawn(async move |_, cx| {
241            agent.send_thread_message(id, message, cx).await?;
242            Ok(())
243        })
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::acp::AcpAgent;
251    use gpui::TestAppContext;
252    use project::FakeFs;
253    use serde_json::json;
254    use settings::SettingsStore;
255    use std::{env, path::Path, process::Stdio};
256    use util::path;
257
258    fn init_test(cx: &mut TestAppContext) {
259        env_logger::init();
260        cx.update(|cx| {
261            let settings_store = SettingsStore::test(cx);
262            cx.set_global(settings_store);
263            Project::init_settings(cx);
264            language::init(cx);
265        });
266    }
267
268    #[gpui::test]
269    async fn test_gemini(cx: &mut TestAppContext) {
270        init_test(cx);
271
272        cx.executor().allow_parking();
273
274        let fs = FakeFs::new(cx.executor());
275        fs.insert_tree(
276            path!("/private/tmp"),
277            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
278        )
279        .await;
280        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
281        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
282        let thread_store = ThreadStore::load(agent, project, &mut cx.to_async())
283            .await
284            .unwrap();
285        let thread = thread_store
286            .update(cx, |thread_store, cx| {
287                assert_eq!(thread_store.threads().len(), 0);
288                thread_store.create_thread(cx)
289            })
290            .await
291            .unwrap();
292        thread
293            .update(cx, |thread, cx| {
294                thread.send(
295                    Message {
296                        role: Role::User,
297                        chunks: vec![
298                            "Read the '/private/tmp/foo' file and output all of its contents."
299                                .into(),
300                        ],
301                    },
302                    cx,
303                )
304            })
305            .await
306            .unwrap();
307
308        thread.read_with(cx, |thread, _| {
309            assert!(matches!(
310                thread.entries[0].content,
311                AgentThreadEntryContent::Message(Message {
312                    role: Role::User,
313                    ..
314                })
315            ));
316            assert!(
317                thread.entries().iter().any(|entry| {
318                    entry.content
319                        == AgentThreadEntryContent::ReadFile {
320                            path: "/private/tmp/foo".into(),
321                            content: "Lorem ipsum dolor".into(),
322                        }
323                }),
324                "Thread does not contain entry. Actual: {:?}",
325                thread.entries()
326            );
327        });
328    }
329
330    pub fn gemini_agent(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpAgent>> {
331        let cli_path =
332            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
333        let mut command = util::command::new_smol_command("node");
334        command
335            .arg(cli_path)
336            .arg("--acp")
337            .args(["--model", "gemini-2.5-flash"])
338            .current_dir("/private/tmp")
339            .stdin(Stdio::piped())
340            .stdout(Stdio::piped())
341            .stderr(Stdio::inherit())
342            .kill_on_drop(true);
343
344        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
345            command.env("GEMINI_API_KEY", gemini_key);
346        }
347
348        let child = command.spawn().unwrap();
349
350        Ok(AcpAgent::stdio(child, project, &mut cx))
351    }
352}