1mod acp;
2
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use futures::{
6 FutureExt, StreamExt,
7 channel::{mpsc, oneshot},
8 select_biased,
9 stream::{BoxStream, FuturesUnordered},
10};
11use gpui::{AppContext, AsyncApp, Context, Entity, Task};
12use project::Project;
13use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
14
15pub trait Agent: 'static {
16 type Thread: AgentThread;
17
18 fn threads(&self) -> impl Future<Output = Result<Vec<AgentThreadSummary>>>;
19 fn create_thread(&self) -> impl Future<Output = Result<Self::Thread>>;
20 fn open_thread(&self, id: ThreadId) -> impl Future<Output = Result<Self::Thread>>;
21}
22
23pub trait AgentThread: 'static {
24 fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntryContent>>>;
25 fn send(
26 &self,
27 message: Message,
28 ) -> impl Future<Output = Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>>;
29}
30
31pub enum ResponseEvent {
32 MessageResponse(MessageResponse),
33 ReadFileRequest(ReadFileRequest),
34 // GlobSearchRequest(SearchRequest),
35 // RegexSearchRequest(RegexSearchRequest),
36 // RunCommandRequest(RunCommandRequest),
37 // WebSearchResponse(WebSearchResponse),
38}
39
40pub struct MessageResponse {
41 role: Role,
42 chunks: BoxStream<'static, Result<MessageChunk>>,
43}
44
45#[derive(Debug)]
46pub struct ReadFileRequest {
47 path: PathBuf,
48 range: Range<usize>,
49 response_tx: oneshot::Sender<Result<FileContent>>,
50}
51
52impl ReadFileRequest {
53 pub fn respond(self, content: Result<FileContent>) {
54 self.response_tx.send(content).ok();
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct ThreadId(String);
60
61#[derive(Copy, Clone, Debug, PartialEq, Eq)]
62pub struct FileVersion(u64);
63
64#[derive(Debug)]
65pub struct AgentThreadSummary {
66 pub id: ThreadId,
67 pub title: String,
68 pub created_at: DateTime<Utc>,
69}
70
71#[derive(Debug, PartialEq, Eq)]
72pub struct FileContent {
73 pub path: PathBuf,
74 pub version: FileVersion,
75 pub content: String,
76}
77
78#[derive(Copy, Clone, Debug, Eq, PartialEq)]
79pub enum Role {
80 User,
81 Assistant,
82}
83
84#[derive(Debug, Eq, PartialEq)]
85pub struct Message {
86 pub role: Role,
87 pub chunks: Vec<MessageChunk>,
88}
89
90#[derive(Debug, Eq, PartialEq)]
91pub enum MessageChunk {
92 Text {
93 chunk: String,
94 },
95 File {
96 content: FileContent,
97 },
98 Directory {
99 path: PathBuf,
100 contents: Vec<FileContent>,
101 },
102 Symbol {
103 path: PathBuf,
104 range: Range<u64>,
105 version: FileVersion,
106 name: String,
107 content: String,
108 },
109 Thread {
110 title: String,
111 content: Vec<AgentThreadEntryContent>,
112 },
113 Fetch {
114 url: String,
115 content: String,
116 },
117}
118
119impl From<&str> for MessageChunk {
120 fn from(chunk: &str) -> Self {
121 MessageChunk::Text {
122 chunk: chunk.to_string(),
123 }
124 }
125}
126
127#[derive(Debug, Eq, PartialEq)]
128pub enum AgentThreadEntryContent {
129 Message(Message),
130 ReadFile { path: PathBuf, content: String },
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
134pub struct ThreadEntryId(usize);
135
136impl ThreadEntryId {
137 pub fn post_inc(&mut self) -> Self {
138 let id = *self;
139 self.0 += 1;
140 id
141 }
142}
143
144#[derive(Debug)]
145pub struct ThreadEntry {
146 pub id: ThreadEntryId,
147 pub content: AgentThreadEntryContent,
148}
149
150pub struct ThreadStore<T: Agent> {
151 threads: Vec<AgentThreadSummary>,
152 agent: Arc<T>,
153 project: Entity<Project>,
154}
155
156impl<T: Agent> ThreadStore<T> {
157 pub async fn load(
158 agent: Arc<T>,
159 project: Entity<Project>,
160 cx: &mut AsyncApp,
161 ) -> Result<Entity<Self>> {
162 let threads = agent.threads().await?;
163 cx.new(|cx| Self {
164 threads,
165 agent,
166 project,
167 })
168 }
169
170 /// Returns the threads in reverse chronological order.
171 pub fn threads(&self) -> &[AgentThreadSummary] {
172 &self.threads
173 }
174
175 /// Opens a thread with the given ID.
176 pub fn open_thread(
177 &self,
178 id: ThreadId,
179 cx: &mut Context<Self>,
180 ) -> Task<Result<Entity<Thread<T::Thread>>>> {
181 let agent = self.agent.clone();
182 let project = self.project.clone();
183 cx.spawn(async move |_, cx| {
184 let agent_thread = agent.open_thread(id).await?;
185 Thread::load(Arc::new(agent_thread), project, cx).await
186 })
187 }
188
189 /// Creates a new thread.
190 pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
191 let agent = self.agent.clone();
192 let project = self.project.clone();
193 cx.spawn(async move |_, cx| {
194 let agent_thread = agent.create_thread().await?;
195 Thread::load(Arc::new(agent_thread), project, cx).await
196 })
197 }
198}
199
200pub struct Thread<T: AgentThread> {
201 next_entry_id: ThreadEntryId,
202 entries: Vec<ThreadEntry>,
203 agent_thread: Arc<T>,
204 project: Entity<Project>,
205}
206
207impl<T: AgentThread> Thread<T> {
208 pub async fn load(
209 agent_thread: Arc<T>,
210 project: Entity<Project>,
211 cx: &mut AsyncApp,
212 ) -> Result<Entity<Self>> {
213 let entries = agent_thread.entries().await?;
214 cx.new(|cx| Self::new(agent_thread, entries, project, cx))
215 }
216
217 pub fn new(
218 agent_thread: Arc<T>,
219 entries: Vec<AgentThreadEntryContent>,
220 project: Entity<Project>,
221 cx: &mut Context<Self>,
222 ) -> Self {
223 let mut next_entry_id = ThreadEntryId(0);
224 Self {
225 entries: entries
226 .into_iter()
227 .map(|entry| ThreadEntry {
228 id: next_entry_id.post_inc(),
229 content: entry,
230 })
231 .collect(),
232 next_entry_id,
233 agent_thread,
234 project,
235 }
236 }
237
238 pub fn entries(&self) -> &[ThreadEntry] {
239 &self.entries
240 }
241
242 pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
243 let agent_thread = self.agent_thread.clone();
244 cx.spawn(async move |this, cx| {
245 let mut events = agent_thread.send(message).await?;
246 let mut pending_event_handlers = FuturesUnordered::new();
247
248 loop {
249 let mut next_event_handler_result = pin!(
250 async {
251 if pending_event_handlers.is_empty() {
252 future::pending::<()>().await;
253 }
254
255 pending_event_handlers.next().await
256 }
257 .fuse()
258 );
259
260 select_biased! {
261 event = events.next() => {
262 let Some(event) = event else {
263 while let Some(result) = pending_event_handlers.next().await {
264 result?;
265 }
266
267 break;
268 };
269
270 let task = match event {
271 Ok(ResponseEvent::MessageResponse(message)) => {
272 this.update(cx, |this, cx| this.handle_message_response(message, cx))?
273 }
274 Ok(ResponseEvent::ReadFileRequest(request)) => {
275 this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
276 }
277 Err(_) => todo!(),
278 };
279 pending_event_handlers.push(task);
280 }
281 result = next_event_handler_result => {
282 // Event handlers should only return errors that are
283 // unrecoverable and should therefore stop this turn of
284 // the agentic loop.
285 result.unwrap()?;
286 }
287 }
288 }
289
290 Ok(())
291 })
292 }
293
294 fn handle_message_response(
295 &mut self,
296 mut message: MessageResponse,
297 cx: &mut Context<Self>,
298 ) -> Task<Result<()>> {
299 let entry_id = self.next_entry_id.post_inc();
300 self.entries.push(ThreadEntry {
301 id: entry_id,
302 content: AgentThreadEntryContent::Message(Message {
303 role: message.role,
304 chunks: Vec::new(),
305 }),
306 });
307 cx.notify();
308
309 cx.spawn(async move |this, cx| {
310 while let Some(chunk) = message.chunks.next().await {
311 match chunk {
312 Ok(chunk) => {
313 this.update(cx, |this, cx| {
314 let ix = this
315 .entries
316 .binary_search_by_key(&entry_id, |entry| entry.id)
317 .map_err(|_| anyhow!("message not found"))?;
318 let AgentThreadEntryContent::Message(message) =
319 &mut this.entries[ix].content
320 else {
321 unreachable!()
322 };
323 message.chunks.push(chunk);
324 cx.notify();
325 anyhow::Ok(())
326 })??;
327 }
328 Err(err) => todo!("show error"),
329 }
330 }
331
332 Ok(())
333 })
334 }
335
336 fn handle_read_file_request(
337 &mut self,
338 request: ReadFileRequest,
339 cx: &mut Context<Self>,
340 ) -> Task<Result<()>> {
341 todo!()
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::acp::AcpAgent;
349 use gpui::TestAppContext;
350 use project::FakeFs;
351 use serde_json::json;
352 use settings::SettingsStore;
353 use std::{env, process::Stdio};
354 use util::path;
355
356 fn init_test(cx: &mut TestAppContext) {
357 env_logger::init();
358 cx.update(|cx| {
359 let settings_store = SettingsStore::test(cx);
360 cx.set_global(settings_store);
361 Project::init_settings(cx);
362 });
363 }
364
365 #[gpui::test]
366 async fn test_gemini(cx: &mut TestAppContext) {
367 init_test(cx);
368
369 cx.executor().allow_parking();
370
371 let fs = FakeFs::new(cx.executor());
372 fs.insert_tree(
373 path!("/test"),
374 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
375 )
376 .await;
377 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
378 let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
379 let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
380 .await
381 .unwrap();
382 let thread = thread_store
383 .update(cx, |thread_store, cx| {
384 assert_eq!(thread_store.threads().len(), 0);
385 thread_store.create_thread(cx)
386 })
387 .await
388 .unwrap();
389 thread
390 .update(cx, |thread, cx| {
391 thread.send(
392 Message {
393 role: Role::User,
394 chunks: vec![
395 "Read the 'test/foo' file and output all of its contents.".into(),
396 ],
397 },
398 cx,
399 )
400 })
401 .await
402 .unwrap();
403 thread.read_with(cx, |thread, cx| {
404 assert!(
405 thread.entries().iter().any(|entry| {
406 entry.content
407 == AgentThreadEntryContent::ReadFile {
408 path: "test/foo".into(),
409 content: "Lorem ipsum dolor".into(),
410 }
411 }),
412 "Thread does not contain entry. Actual: {:?}",
413 thread.entries()
414 );
415 });
416 }
417
418 pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
419 let child = util::command::new_smol_command("node")
420 .arg("../../../gemini-cli/packages/cli")
421 .arg("--acp")
422 // .args(["--model", "gemini-2.5-flash"])
423 .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
424 .stdin(Stdio::piped())
425 .stdout(Stdio::piped())
426 .stderr(Stdio::inherit())
427 .kill_on_drop(true)
428 .spawn()
429 .unwrap();
430
431 Ok(AcpAgent::stdio(child, project, cx))
432 }
433}