1mod acp;
2
3use anyhow::{Result, anyhow};
4use chrono::{DateTime, Utc};
5use futures::{
6 FutureExt, StreamExt,
7 channel::{mpsc, oneshot},
8 select_biased,
9 stream::{BoxStream, FuturesUnordered},
10};
11use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
12use project::Project;
13use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
14
15pub trait Agent: 'static {
16 type Thread: AgentThread;
17
18 fn threads(&self) -> impl Future<Output = Result<Vec<AgentThreadSummary>>>;
19 fn create_thread(&self) -> impl Future<Output = Result<Self::Thread>>;
20 fn open_thread(&self, id: ThreadId) -> impl Future<Output = Result<Self::Thread>>;
21}
22
23pub trait AgentThread: 'static {
24 fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
25 fn send(
26 &self,
27 message: Message,
28 ) -> impl Future<Output = Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>>;
29}
30
31pub enum ResponseEvent {
32 MessageResponse(MessageResponse),
33 ReadFileRequest(ReadFileRequest),
34 // GlobSearchRequest(SearchRequest),
35 // RegexSearchRequest(RegexSearchRequest),
36 // RunCommandRequest(RunCommandRequest),
37 // WebSearchResponse(WebSearchResponse),
38}
39
40pub struct MessageResponse {
41 role: Role,
42 chunks: BoxStream<'static, Result<MessageChunk>>,
43}
44
45#[derive(Debug)]
46pub struct ReadFileRequest {
47 path: PathBuf,
48 range: Range<usize>,
49 response_tx: oneshot::Sender<Result<FileContent>>,
50}
51
52impl ReadFileRequest {
53 pub fn respond(self, content: Result<FileContent>) {
54 self.response_tx.send(content).ok();
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct ThreadId(String);
60
61#[derive(Debug, Clone, Copy)]
62pub struct FileVersion(u64);
63
64#[derive(Debug, Clone)]
65pub struct AgentThreadSummary {
66 pub id: ThreadId,
67 pub title: String,
68 pub created_at: DateTime<Utc>,
69}
70
71#[derive(Debug, Clone)]
72pub struct FileContent {
73 pub path: PathBuf,
74 pub version: FileVersion,
75 pub content: String,
76}
77
78#[derive(Debug, Clone)]
79pub enum Role {
80 User,
81 Assistant,
82}
83
84#[derive(Debug, Clone)]
85pub struct Message {
86 pub role: Role,
87 pub chunks: Vec<MessageChunk>,
88}
89
90#[derive(Debug, Clone)]
91pub enum MessageChunk {
92 Text {
93 chunk: String,
94 },
95 File {
96 content: FileContent,
97 },
98 Directory {
99 path: PathBuf,
100 contents: Vec<FileContent>,
101 },
102 Symbol {
103 path: PathBuf,
104 range: Range<u64>,
105 version: FileVersion,
106 name: String,
107 content: String,
108 },
109 Thread {
110 title: String,
111 content: Vec<AgentThreadEntry>,
112 },
113 Fetch {
114 url: String,
115 content: String,
116 },
117}
118
119#[derive(Debug, Clone)]
120pub enum AgentThreadEntry {
121 Message(Message),
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
125pub struct ThreadEntryId(usize);
126
127impl ThreadEntryId {
128 pub fn post_inc(&mut self) -> Self {
129 let id = *self;
130 self.0 += 1;
131 id
132 }
133}
134
135#[derive(Debug, Clone)]
136pub struct ThreadEntry {
137 pub id: ThreadEntryId,
138 pub entry: AgentThreadEntry,
139}
140
141pub struct ThreadStore<T: Agent> {
142 threads: Vec<AgentThreadSummary>,
143 agent: Arc<T>,
144 project: Entity<Project>,
145}
146
147impl<T: Agent> ThreadStore<T> {
148 pub async fn load(
149 agent: Arc<T>,
150 project: Entity<Project>,
151 cx: &mut AsyncApp,
152 ) -> Result<Entity<Self>> {
153 let threads = agent.threads().await?;
154 cx.new(|cx| Self {
155 threads,
156 agent,
157 project,
158 })
159 }
160
161 /// Returns the threads in reverse chronological order.
162 pub fn threads(&self) -> &[AgentThreadSummary] {
163 &self.threads
164 }
165
166 /// Opens a thread with the given ID.
167 pub fn open_thread(
168 &self,
169 id: ThreadId,
170 cx: &mut Context<Self>,
171 ) -> Task<Result<Entity<Thread<T::Thread>>>> {
172 let agent = self.agent.clone();
173 let project = self.project.clone();
174 cx.spawn(async move |_, cx| {
175 let agent_thread = agent.open_thread(id).await?;
176 Thread::load(Arc::new(agent_thread), project, cx).await
177 })
178 }
179
180 /// Creates a new thread.
181 pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
182 let agent = self.agent.clone();
183 let project = self.project.clone();
184 cx.spawn(async move |_, cx| {
185 let agent_thread = agent.create_thread().await?;
186 Thread::load(Arc::new(agent_thread), project, cx).await
187 })
188 }
189}
190
191pub struct Thread<T: AgentThread> {
192 next_entry_id: ThreadEntryId,
193 entries: Vec<ThreadEntry>,
194 agent_thread: Arc<T>,
195 project: Entity<Project>,
196}
197
198impl<T: AgentThread> Thread<T> {
199 pub async fn load(
200 agent_thread: Arc<T>,
201 project: Entity<Project>,
202 cx: &mut AsyncApp,
203 ) -> Result<Entity<Self>> {
204 let entries = agent_thread.entries().await?;
205 cx.new(|cx| Self::new(agent_thread, entries, project, cx))
206 }
207
208 pub fn new(
209 agent_thread: Arc<T>,
210 entries: Vec<AgentThreadEntry>,
211 project: Entity<Project>,
212 cx: &mut Context<Self>,
213 ) -> Self {
214 let mut next_entry_id = ThreadEntryId(0);
215 Self {
216 entries: entries
217 .into_iter()
218 .map(|entry| ThreadEntry {
219 id: next_entry_id.post_inc(),
220 entry,
221 })
222 .collect(),
223 next_entry_id,
224 agent_thread,
225 project,
226 }
227 }
228
229 async fn handle_message(
230 this: WeakEntity<Self>,
231 role: Role,
232 mut chunks: BoxStream<'static, Result<MessageChunk>>,
233 cx: &mut AsyncApp,
234 ) -> Result<()> {
235 let entry_id = this.update(cx, |this, cx| {
236 let entry_id = this.next_entry_id.post_inc();
237 this.entries.push(ThreadEntry {
238 id: entry_id,
239 entry: AgentThreadEntry::Message(Message {
240 role,
241 chunks: Vec::new(),
242 }),
243 });
244 cx.notify();
245 entry_id
246 })?;
247
248 while let Some(chunk) = chunks.next().await {
249 match chunk {
250 Ok(chunk) => {
251 this.update(cx, |this, cx| {
252 let ix = this
253 .entries
254 .binary_search_by_key(&entry_id, |entry| entry.id)
255 .map_err(|_| anyhow!("message not found"))?;
256 let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry else {
257 unreachable!()
258 };
259 message.chunks.push(chunk);
260 cx.notify();
261 anyhow::Ok(())
262 })??;
263 }
264 Err(err) => todo!("show error"),
265 }
266 }
267
268 Ok(())
269 }
270
271 pub fn entries(&self) -> &[ThreadEntry] {
272 &self.entries
273 }
274
275 pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
276 let agent_thread = self.agent_thread.clone();
277 cx.spawn(async move |this, cx| {
278 let mut events = agent_thread.send(message).await?;
279 let mut pending_event_handlers = FuturesUnordered::new();
280
281 loop {
282 let mut next_event_handler_result = pin!(async {
283 if pending_event_handlers.is_empty() {
284 future::pending::<()>().await;
285 }
286
287 pending_event_handlers.next().await
288 }.fuse());
289
290 select_biased! {
291 event = events.next() => {
292 let Some(event) = event else {
293 while let Some(result) = pending_event_handlers.next().await {
294 result?;
295 }
296
297 break;
298 };
299
300 let task = match event {
301 Ok(ResponseEvent::MessageResponse(message)) => {
302 this.update(cx, |this, cx| this.handle_message_response(message, cx))?
303 }
304 Ok(ResponseEvent::ReadFileRequest(request)) => {
305 this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
306 }
307 Err(_) => todo!(),
308 };
309 pending_event_handlers.push(task);
310 }
311 result = next_event_handler_result => {
312 // Event handlers should only return errors that are
313 // unrecoverable and should therefore stop this turn of
314 // the agentic loop.
315 result.unwrap()?;
316 }
317 }
318 }
319
320 Ok(())
321 })
322 }
323
324 fn handle_message_response(
325 &mut self,
326 mut message: MessageResponse,
327 cx: &mut Context<Self>,
328 ) -> Task<Result<()>> {
329 let entry_id = self.next_entry_id.post_inc();
330 self.entries.push(ThreadEntry {
331 id: entry_id,
332 entry: AgentThreadEntry::Message(Message {
333 role: message.role,
334 chunks: Vec::new(),
335 }),
336 });
337 cx.notify();
338
339 cx.spawn(async move |this, cx| {
340 while let Some(chunk) = message.chunks.next().await {
341 match chunk {
342 Ok(chunk) => {
343 this.update(cx, |this, cx| {
344 let ix = this
345 .entries
346 .binary_search_by_key(&entry_id, |entry| entry.id)
347 .map_err(|_| anyhow!("message not found"))?;
348 let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry
349 else {
350 unreachable!()
351 };
352 message.chunks.push(chunk);
353 cx.notify();
354 anyhow::Ok(())
355 })??;
356 }
357 Err(err) => todo!("show error"),
358 }
359 }
360
361 Ok(())
362 })
363 }
364
365 fn handle_read_file_request(
366 &mut self,
367 request: ReadFileRequest,
368 cx: &mut Context<Self>,
369 ) -> Task<Result<()>> {
370 todo!()
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::acp::AcpAgent;
378 use gpui::TestAppContext;
379 use project::FakeFs;
380 use serde_json::json;
381 use settings::SettingsStore;
382 use std::{env, process::Stdio};
383 use util::path;
384
385 fn init_test(cx: &mut TestAppContext) {
386 env_logger::init();
387 cx.update(|cx| {
388 let settings_store = SettingsStore::test(cx);
389 cx.set_global(settings_store);
390 Project::init_settings(cx);
391 });
392 }
393
394 #[gpui::test]
395 async fn test_basic(cx: &mut TestAppContext) {
396 init_test(cx);
397
398 cx.executor().allow_parking();
399
400 let fs = FakeFs::new(cx.executor());
401 fs.insert_tree(
402 path!("/test"),
403 json!({"foo": "foo", "bar": "bar", "baz": "baz"}),
404 )
405 .await;
406 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
407 let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
408 let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
409 .await
410 .unwrap();
411 }
412
413 pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
414 let child = util::command::new_smol_command("node")
415 .arg("../../../gemini-cli/packages/cli")
416 .arg("--acp")
417 .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
418 .stdin(Stdio::piped())
419 .stdout(Stdio::piped())
420 .stderr(Stdio::inherit())
421 .kill_on_drop(true)
422 .spawn()
423 .unwrap();
424
425 Ok(AcpAgent::stdio(child, project, cx))
426 }
427}