1mod acp;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use futures::{
7 FutureExt, StreamExt,
8 channel::{mpsc, oneshot},
9 select_biased,
10 stream::{BoxStream, FuturesUnordered},
11};
12use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
13use project::Project;
14use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
15
16#[async_trait]
17pub trait Agent: 'static {
18 async fn threads(&self) -> Result<Vec<AgentThreadSummary>>;
19 async fn create_thread(&self) -> Result<Entity<Thread>>;
20 async fn open_thread(&self, id: ThreadId) -> Result<Entity<Thread>>;
21 async fn thread_entries(&self, id: ThreadId) -> Result<Vec<AgentThreadEntryContent>>;
22 async fn send_thread_message(
23 &self,
24 thread_id: ThreadId,
25 message: Message,
26 ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>;
27}
28
29pub enum ResponseEvent {
30 MessageResponse(MessageResponse),
31 ReadFileRequest(ReadFileRequest),
32 // GlobSearchRequest(SearchRequest),
33 // RegexSearchRequest(RegexSearchRequest),
34 // RunCommandRequest(RunCommandRequest),
35 // WebSearchResponse(WebSearchResponse),
36}
37
38pub struct MessageResponse {
39 role: Role,
40 chunks: BoxStream<'static, Result<MessageChunk>>,
41}
42
43#[derive(Debug)]
44pub struct ReadFileRequest {
45 path: PathBuf,
46 range: Range<usize>,
47 response_tx: oneshot::Sender<Result<FileContent>>,
48}
49
50impl ReadFileRequest {
51 pub fn respond(self, content: Result<FileContent>) {
52 self.response_tx.send(content).ok();
53 }
54}
55
56#[derive(Debug, Clone)]
57pub struct ThreadId(SharedString);
58
59#[derive(Copy, Clone, Debug, PartialEq, Eq)]
60pub struct FileVersion(u64);
61
62#[derive(Debug)]
63pub struct AgentThreadSummary {
64 pub id: ThreadId,
65 pub title: String,
66 pub created_at: DateTime<Utc>,
67}
68
69#[derive(Debug, PartialEq, Eq)]
70pub struct FileContent {
71 pub path: PathBuf,
72 pub version: FileVersion,
73 pub content: String,
74}
75
76#[derive(Copy, Clone, Debug, Eq, PartialEq)]
77pub enum Role {
78 User,
79 Assistant,
80}
81
82#[derive(Debug, Eq, PartialEq)]
83pub struct Message {
84 pub role: Role,
85 pub chunks: Vec<MessageChunk>,
86}
87
88#[derive(Debug, Eq, PartialEq)]
89pub enum MessageChunk {
90 Text {
91 chunk: String,
92 },
93 File {
94 content: FileContent,
95 },
96 Directory {
97 path: PathBuf,
98 contents: Vec<FileContent>,
99 },
100 Symbol {
101 path: PathBuf,
102 range: Range<u64>,
103 version: FileVersion,
104 name: String,
105 content: String,
106 },
107 Thread {
108 title: String,
109 content: Vec<AgentThreadEntryContent>,
110 },
111 Fetch {
112 url: String,
113 content: String,
114 },
115}
116
117impl From<&str> for MessageChunk {
118 fn from(chunk: &str) -> Self {
119 MessageChunk::Text {
120 chunk: chunk.to_string(),
121 }
122 }
123}
124
125#[derive(Debug, Eq, PartialEq)]
126pub enum AgentThreadEntryContent {
127 Message(Message),
128 ReadFile { path: PathBuf, content: String },
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
132pub struct ThreadEntryId(usize);
133
134impl ThreadEntryId {
135 pub fn post_inc(&mut self) -> Self {
136 let id = *self;
137 self.0 += 1;
138 id
139 }
140}
141
142#[derive(Debug)]
143pub struct ThreadEntry {
144 pub id: ThreadEntryId,
145 pub content: AgentThreadEntryContent,
146}
147
148pub struct ThreadStore<T: Agent> {
149 threads: Vec<AgentThreadSummary>,
150 agent: Arc<T>,
151 project: Entity<Project>,
152}
153
154impl<T: Agent> ThreadStore<T> {
155 pub async fn load(
156 agent: Arc<T>,
157 project: Entity<Project>,
158 cx: &mut AsyncApp,
159 ) -> Result<Entity<Self>> {
160 let threads = agent.threads().await?;
161 cx.new(|cx| Self {
162 threads,
163 agent,
164 project,
165 })
166 }
167
168 /// Returns the threads in reverse chronological order.
169 pub fn threads(&self) -> &[AgentThreadSummary] {
170 &self.threads
171 }
172
173 /// Opens a thread with the given ID.
174 pub fn open_thread(
175 &self,
176 id: ThreadId,
177 cx: &mut Context<Self>,
178 ) -> Task<Result<Entity<Thread>>> {
179 let agent = self.agent.clone();
180 let project = self.project.clone();
181 cx.spawn(async move |_, cx| {
182 let agent_thread = agent.open_thread(id).await?;
183 Thread::load(agent_thread, project, cx).await
184 })
185 }
186
187 /// Creates a new thread.
188 pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
189 let agent = self.agent.clone();
190 let project = self.project.clone();
191 cx.spawn(async move |_, cx| {
192 let agent_thread = agent.create_thread().await?;
193 Thread::load(agent_thread, project, cx).await
194 })
195 }
196}
197
198pub struct Thread {
199 id: ThreadId,
200 next_entry_id: ThreadEntryId,
201 entries: Vec<ThreadEntry>,
202 agent: Arc<dyn Agent>,
203 project: Entity<Project>,
204}
205
206impl Thread {
207 pub async fn load(
208 agent: Arc<dyn Agent>,
209 thread_id: ThreadId,
210 project: Entity<Project>,
211 cx: &mut AsyncApp,
212 ) -> Result<Entity<Self>> {
213 let entries = agent.thread_entries(thread_id.clone()).await?;
214 cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
215 }
216
217 pub fn new(
218 agent: Arc<dyn Agent>,
219 thread_id: ThreadId,
220 entries: Vec<AgentThreadEntryContent>,
221 project: Entity<Project>,
222 cx: &mut Context<Self>,
223 ) -> Self {
224 let mut next_entry_id = ThreadEntryId(0);
225 Self {
226 entries: entries
227 .into_iter()
228 .map(|entry| ThreadEntry {
229 id: next_entry_id.post_inc(),
230 content: entry,
231 })
232 .collect(),
233 agent,
234 id: thread_id,
235 next_entry_id,
236 project,
237 }
238 }
239
240 pub fn entries(&self) -> &[ThreadEntry] {
241 &self.entries
242 }
243
244 pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
245 let agent = self.agent.clone();
246 let id = self.id;
247 cx.spawn(async move |this, cx| {
248 let mut events = agent.send_thread_message(id, message).await?;
249 let mut pending_event_handlers = FuturesUnordered::new();
250
251 loop {
252 let mut next_event_handler_result = pin!(
253 async {
254 if pending_event_handlers.is_empty() {
255 future::pending::<()>().await;
256 }
257
258 pending_event_handlers.next().await
259 }
260 .fuse()
261 );
262
263 select_biased! {
264 event = events.next() => {
265 let Some(event) = event else {
266 while let Some(result) = pending_event_handlers.next().await {
267 result?;
268 }
269
270 break;
271 };
272
273 let task = match event {
274 Ok(ResponseEvent::MessageResponse(message)) => {
275 this.update(cx, |this, cx| this.handle_message_response(message, cx))?
276 }
277 Ok(ResponseEvent::ReadFileRequest(request)) => {
278 this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
279 }
280 Err(_) => todo!(),
281 };
282 pending_event_handlers.push(task);
283 }
284 result = next_event_handler_result => {
285 // Event handlers should only return errors that are
286 // unrecoverable and should therefore stop this turn of
287 // the agentic loop.
288 result.unwrap()?;
289 }
290 }
291 }
292
293 Ok(())
294 })
295 }
296
297 fn handle_message_response(
298 &mut self,
299 mut message: MessageResponse,
300 cx: &mut Context<Self>,
301 ) -> Task<Result<()>> {
302 let entry_id = self.next_entry_id.post_inc();
303 self.entries.push(ThreadEntry {
304 id: entry_id,
305 content: AgentThreadEntryContent::Message(Message {
306 role: message.role,
307 chunks: Vec::new(),
308 }),
309 });
310 cx.notify();
311
312 cx.spawn(async move |this, cx| {
313 while let Some(chunk) = message.chunks.next().await {
314 match chunk {
315 Ok(chunk) => {
316 this.update(cx, |this, cx| {
317 let ix = this
318 .entries
319 .binary_search_by_key(&entry_id, |entry| entry.id)
320 .map_err(|_| anyhow!("message not found"))?;
321 let AgentThreadEntryContent::Message(message) =
322 &mut this.entries[ix].content
323 else {
324 unreachable!()
325 };
326 message.chunks.push(chunk);
327 cx.notify();
328 anyhow::Ok(())
329 })??;
330 }
331 Err(err) => todo!("show error"),
332 }
333 }
334
335 Ok(())
336 })
337 }
338
339 fn handle_read_file_request(
340 &mut self,
341 request: ReadFileRequest,
342 cx: &mut Context<Self>,
343 ) -> Task<Result<()>> {
344 todo!()
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use crate::acp::AcpAgent;
352 use gpui::TestAppContext;
353 use project::FakeFs;
354 use serde_json::json;
355 use settings::SettingsStore;
356 use std::{env, process::Stdio};
357 use util::path;
358
359 fn init_test(cx: &mut TestAppContext) {
360 env_logger::init();
361 cx.update(|cx| {
362 let settings_store = SettingsStore::test(cx);
363 cx.set_global(settings_store);
364 Project::init_settings(cx);
365 });
366 }
367
368 #[gpui::test]
369 async fn test_gemini(cx: &mut TestAppContext) {
370 init_test(cx);
371
372 cx.executor().allow_parking();
373
374 let fs = FakeFs::new(cx.executor());
375 fs.insert_tree(
376 path!("/test"),
377 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
378 )
379 .await;
380 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
381 let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
382 let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
383 .await
384 .unwrap();
385 let thread = thread_store
386 .update(cx, |thread_store, cx| {
387 assert_eq!(thread_store.threads().len(), 0);
388 thread_store.create_thread(cx)
389 })
390 .await
391 .unwrap();
392 thread
393 .update(cx, |thread, cx| {
394 thread.send(
395 Message {
396 role: Role::User,
397 chunks: vec![
398 "Read the 'test/foo' file and output all of its contents.".into(),
399 ],
400 },
401 cx,
402 )
403 })
404 .await
405 .unwrap();
406 thread.read_with(cx, |thread, _| {
407 assert!(
408 thread.entries().iter().any(|entry| {
409 entry.content
410 == AgentThreadEntryContent::ReadFile {
411 path: "test/foo".into(),
412 content: "Lorem ipsum dolor".into(),
413 }
414 }),
415 "Thread does not contain entry. Actual: {:?}",
416 thread.entries()
417 );
418 });
419 }
420
421 pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
422 let child = util::command::new_smol_command("node")
423 .arg("../../../gemini-cli/packages/cli")
424 .arg("--acp")
425 .args(["--model", "gemini-2.5-flash"])
426 .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
427 .stdin(Stdio::piped())
428 .stdout(Stdio::piped())
429 .stderr(Stdio::inherit())
430 .kill_on_drop(true)
431 .spawn()
432 .unwrap();
433
434 Ok(AcpAgent::stdio(child, project, cx))
435 }
436}