agent2.rs

  1mod acp;
  2
  3use anyhow::{Result, anyhow};
  4use async_trait::async_trait;
  5use chrono::{DateTime, Utc};
  6use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
  7use project::Project;
  8use std::{ops::Range, path::PathBuf, sync::Arc};
  9
 10#[async_trait(?Send)]
 11pub trait Agent: 'static {
 12    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
 13    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 14    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
 15    async fn thread_entries(
 16        &self,
 17        id: ThreadId,
 18        cx: &mut AsyncApp,
 19    ) -> Result<Vec<AgentThreadEntryContent>>;
 20    async fn send_thread_message(
 21        &self,
 22        thread_id: ThreadId,
 23        message: Message,
 24        cx: &mut AsyncApp,
 25    ) -> Result<()>;
 26}
 27
 28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 29pub struct ThreadId(SharedString);
 30
 31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 32pub struct FileVersion(u64);
 33
 34#[derive(Debug)]
 35pub struct AgentThreadSummary {
 36    pub id: ThreadId,
 37    pub title: String,
 38    pub created_at: DateTime<Utc>,
 39}
 40
 41#[derive(Debug, PartialEq, Eq)]
 42pub struct FileContent {
 43    pub path: PathBuf,
 44    pub version: FileVersion,
 45    pub content: String,
 46}
 47
 48#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 49pub enum Role {
 50    User,
 51    Assistant,
 52}
 53
 54#[derive(Debug, Eq, PartialEq)]
 55pub struct Message {
 56    pub role: Role,
 57    pub chunks: Vec<MessageChunk>,
 58}
 59
 60#[derive(Debug, Eq, PartialEq)]
 61pub enum MessageChunk {
 62    Text {
 63        chunk: String,
 64    },
 65    File {
 66        content: FileContent,
 67    },
 68    Directory {
 69        path: PathBuf,
 70        contents: Vec<FileContent>,
 71    },
 72    Symbol {
 73        path: PathBuf,
 74        range: Range<u64>,
 75        version: FileVersion,
 76        name: String,
 77        content: String,
 78    },
 79    Thread {
 80        title: String,
 81        content: Vec<AgentThreadEntryContent>,
 82    },
 83    Fetch {
 84        url: String,
 85        content: String,
 86    },
 87}
 88
 89impl From<&str> for MessageChunk {
 90    fn from(chunk: &str) -> Self {
 91        MessageChunk::Text {
 92            chunk: chunk.to_string(),
 93        }
 94    }
 95}
 96
 97#[derive(Debug, Eq, PartialEq)]
 98pub enum AgentThreadEntryContent {
 99    Message(Message),
100    ReadFile { path: PathBuf, content: String },
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
104pub struct ThreadEntryId(usize);
105
106impl ThreadEntryId {
107    pub fn post_inc(&mut self) -> Self {
108        let id = *self;
109        self.0 += 1;
110        id
111    }
112}
113
114#[derive(Debug)]
115pub struct ThreadEntry {
116    pub id: ThreadEntryId,
117    pub content: AgentThreadEntryContent,
118}
119
120pub struct ThreadStore {
121    threads: Vec<AgentThreadSummary>,
122    agent: Arc<dyn Agent>,
123    project: Entity<Project>,
124}
125
126impl ThreadStore {
127    pub async fn load(
128        agent: Arc<dyn Agent>,
129        project: Entity<Project>,
130        cx: &mut AsyncApp,
131    ) -> Result<Entity<Self>> {
132        let threads = agent.threads(cx).await?;
133        cx.new(|_cx| Self {
134            threads,
135            agent,
136            project,
137        })
138    }
139
140    /// Returns the threads in reverse chronological order.
141    pub fn threads(&self) -> &[AgentThreadSummary] {
142        &self.threads
143    }
144
145    /// Opens a thread with the given ID.
146    pub fn open_thread(
147        &self,
148        id: ThreadId,
149        cx: &mut Context<Self>,
150    ) -> Task<Result<Entity<Thread>>> {
151        let agent = self.agent.clone();
152        cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
153    }
154
155    /// Creates a new thread.
156    pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
157        let agent = self.agent.clone();
158        cx.spawn(async move |_, cx| agent.create_thread(cx).await)
159    }
160}
161
162pub struct Thread {
163    id: ThreadId,
164    next_entry_id: ThreadEntryId,
165    entries: Vec<ThreadEntry>,
166    agent: Arc<dyn Agent>,
167    project: Entity<Project>,
168}
169
170impl Thread {
171    pub async fn load(
172        agent: Arc<dyn Agent>,
173        thread_id: ThreadId,
174        project: Entity<Project>,
175        cx: &mut AsyncApp,
176    ) -> Result<Entity<Self>> {
177        let entries = agent.thread_entries(thread_id.clone(), cx).await?;
178        cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
179    }
180
181    pub fn new(
182        agent: Arc<dyn Agent>,
183        thread_id: ThreadId,
184        entries: Vec<AgentThreadEntryContent>,
185        project: Entity<Project>,
186        cx: &mut Context<Self>,
187    ) -> Self {
188        let mut next_entry_id = ThreadEntryId(0);
189        Self {
190            entries: entries
191                .into_iter()
192                .map(|entry| ThreadEntry {
193                    id: next_entry_id.post_inc(),
194                    content: entry,
195                })
196                .collect(),
197            agent,
198            id: thread_id,
199            next_entry_id,
200            project,
201        }
202    }
203
204    pub fn entries(&self) -> &[ThreadEntry] {
205        &self.entries
206    }
207
208    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
209        self.entries.push(ThreadEntry {
210            id: self.next_entry_id.post_inc(),
211            content: entry,
212        });
213        cx.notify();
214    }
215
216    pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
217        let agent = self.agent.clone();
218        let id = self.id.clone();
219        cx.spawn(async move |this, cx| {
220            agent.send_thread_message(id, message, cx).await?;
221            Ok(())
222        })
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::acp::AcpAgent;
230    use gpui::TestAppContext;
231    use project::FakeFs;
232    use serde_json::json;
233    use settings::SettingsStore;
234    use std::{env, path::Path, process::Stdio};
235    use util::path;
236
237    fn init_test(cx: &mut TestAppContext) {
238        env_logger::init();
239        cx.update(|cx| {
240            let settings_store = SettingsStore::test(cx);
241            cx.set_global(settings_store);
242            Project::init_settings(cx);
243            language::init(cx);
244        });
245    }
246
247    #[gpui::test]
248    async fn test_gemini(cx: &mut TestAppContext) {
249        init_test(cx);
250
251        cx.executor().allow_parking();
252
253        let fs = FakeFs::new(cx.executor());
254        fs.insert_tree(
255            path!("/private/tmp"),
256            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
257        )
258        .await;
259        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
260        let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
261        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
262            .await
263            .unwrap();
264        let thread = thread_store
265            .update(cx, |thread_store, cx| {
266                assert_eq!(thread_store.threads().len(), 0);
267                thread_store.create_thread(cx)
268            })
269            .await
270            .unwrap();
271        thread
272            .update(cx, |thread, cx| {
273                thread.send(
274                    Message {
275                        role: Role::User,
276                        chunks: vec![
277                            "Read the '/private/tmp/foo' file and output all of its contents."
278                                .into(),
279                        ],
280                    },
281                    cx,
282                )
283            })
284            .await
285            .unwrap();
286        thread.read_with(cx, |thread, _| {
287            assert!(
288                thread.entries().iter().any(|entry| {
289                    entry.content
290                        == AgentThreadEntryContent::ReadFile {
291                            path: "/private/tmp/foo".into(),
292                            content: "Lorem ipsum dolor".into(),
293                        }
294                }),
295                "Thread does not contain entry. Actual: {:?}",
296                thread.entries()
297            );
298        });
299    }
300
301    pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
302        let cli_path =
303            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
304        let child = util::command::new_smol_command("node")
305            .arg(cli_path)
306            .arg("--acp")
307            .args(["--model", "gemini-2.5-flash"])
308            .current_dir("/private/tmp")
309            .stdin(Stdio::piped())
310            .stdout(Stdio::piped())
311            .stderr(Stdio::inherit())
312            .kill_on_drop(true)
313            .spawn()
314            .unwrap();
315
316        Ok(AcpAgent::stdio(child, project, cx))
317    }
318}