gemini.rs

  1use crate::stdio_agent_server::{StdioAgentServer, find_bin_in_path};
  2use crate::{AgentServerCommand, AgentServerVersion};
  3use anyhow::{Context as _, Result};
  4use gpui::{AsyncApp, Entity};
  5use project::Project;
  6use settings::SettingsStore;
  7
  8use crate::AllAgentServersSettings;
  9
 10#[derive(Clone)]
 11pub struct Gemini;
 12
 13const ACP_ARG: &str = "--experimental-acp";
 14
 15impl StdioAgentServer for Gemini {
 16    fn name(&self) -> &'static str {
 17        "Gemini"
 18    }
 19
 20    fn empty_state_headline(&self) -> &'static str {
 21        "Welcome to Gemini"
 22    }
 23
 24    fn empty_state_message(&self) -> &'static str {
 25        "Ask questions, edit files, run commands.\nBe specific for the best results."
 26    }
 27
 28    fn supports_always_allow(&self) -> bool {
 29        true
 30    }
 31
 32    fn logo(&self) -> ui::IconName {
 33        ui::IconName::AiGemini
 34    }
 35
 36    async fn command(
 37        &self,
 38        project: &Entity<Project>,
 39        cx: &mut AsyncApp,
 40    ) -> Result<AgentServerCommand> {
 41        let custom_command = cx.read_global(|settings: &SettingsStore, _| {
 42            let settings = settings.get::<AllAgentServersSettings>(None);
 43            settings
 44                .gemini
 45                .as_ref()
 46                .map(|gemini_settings| AgentServerCommand {
 47                    path: gemini_settings.command.path.clone(),
 48                    args: gemini_settings
 49                        .command
 50                        .args
 51                        .iter()
 52                        .cloned()
 53                        .chain(std::iter::once(ACP_ARG.into()))
 54                        .collect(),
 55                    env: gemini_settings.command.env.clone(),
 56                })
 57        })?;
 58
 59        if let Some(custom_command) = custom_command {
 60            return Ok(custom_command);
 61        }
 62
 63        if let Some(path) = find_bin_in_path("gemini", project, cx).await {
 64            return Ok(AgentServerCommand {
 65                path,
 66                args: vec![ACP_ARG.into()],
 67                env: None,
 68            });
 69        }
 70
 71        let (fs, node_runtime) = project.update(cx, |project, _| {
 72            (project.fs().clone(), project.node_runtime().cloned())
 73        })?;
 74        let node_runtime = node_runtime.context("gemini not found on path")?;
 75
 76        let directory = ::paths::agent_servers_dir().join("gemini");
 77        fs.create_dir(&directory).await?;
 78        node_runtime
 79            .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
 80            .await?;
 81        let path = directory.join("node_modules/.bin/gemini");
 82
 83        Ok(AgentServerCommand {
 84            path,
 85            args: vec![ACP_ARG.into()],
 86            env: None,
 87        })
 88    }
 89
 90    async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
 91        let version_fut = util::command::new_smol_command(&command.path)
 92            .args(command.args.iter())
 93            .arg("--version")
 94            .kill_on_drop(true)
 95            .output();
 96
 97        let help_fut = util::command::new_smol_command(&command.path)
 98            .args(command.args.iter())
 99            .arg("--help")
100            .kill_on_drop(true)
101            .output();
102
103        let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
104
105        let current_version = String::from_utf8(version_output?.stdout)?;
106        let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
107
108        if supported {
109            Ok(AgentServerVersion::Supported)
110        } else {
111            Ok(AgentServerVersion::Unsupported {
112                error_message: format!(
113                    "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
114                    current_version
115                ).into(),
116                upgrade_message: "Upgrade Gemini to Latest".into(),
117                upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
118            })
119        }
120    }
121}
122
123#[cfg(test)]
124mod test {
125    use std::{path::Path, time::Duration};
126
127    use acp_thread::{
128        AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent,
129        ToolCallStatus,
130    };
131    use agentic_coding_protocol as acp;
132    use anyhow::Result;
133    use futures::{FutureExt, StreamExt, channel::mpsc, select};
134    use gpui::{AsyncApp, Entity, TestAppContext};
135    use indoc::indoc;
136    use project::{FakeFs, Project};
137    use serde_json::json;
138    use settings::SettingsStore;
139    use util::path;
140
141    use crate::{AgentServer, AgentServerCommand, AgentServerVersion, StdioAgentServer};
142
143    pub async fn gemini_acp_thread(
144        project: Entity<Project>,
145        current_dir: impl AsRef<Path>,
146        cx: &mut TestAppContext,
147    ) -> Entity<AcpThread> {
148        #[derive(Clone)]
149        struct DevGemini;
150
151        impl StdioAgentServer for DevGemini {
152            async fn command(
153                &self,
154                _project: &Entity<Project>,
155                _cx: &mut AsyncApp,
156            ) -> Result<AgentServerCommand> {
157                let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
158                    .join("../../../gemini-cli/packages/cli")
159                    .to_string_lossy()
160                    .to_string();
161
162                Ok(AgentServerCommand {
163                    path: "node".into(),
164                    args: vec![cli_path, "--experimental-acp".into()],
165                    env: None,
166                })
167            }
168
169            async fn version(&self, _command: &AgentServerCommand) -> Result<AgentServerVersion> {
170                Ok(AgentServerVersion::Supported)
171            }
172
173            fn logo(&self) -> ui::IconName {
174                ui::IconName::AiGemini
175            }
176
177            fn name(&self) -> &'static str {
178                "test"
179            }
180
181            fn empty_state_headline(&self) -> &'static str {
182                "test"
183            }
184
185            fn empty_state_message(&self) -> &'static str {
186                "test"
187            }
188
189            fn supports_always_allow(&self) -> bool {
190                true
191            }
192        }
193
194        let thread = cx
195            .update(|cx| AgentServer::new_thread(&DevGemini, current_dir.as_ref(), &project, cx))
196            .await
197            .unwrap();
198
199        thread
200            .update(cx, |thread, _| thread.initialize())
201            .await
202            .unwrap();
203        thread
204    }
205
206    fn init_test(cx: &mut TestAppContext) {
207        env_logger::try_init().ok();
208        cx.update(|cx| {
209            let settings_store = SettingsStore::test(cx);
210            cx.set_global(settings_store);
211            Project::init_settings(cx);
212            language::init(cx);
213        });
214    }
215
216    #[gpui::test]
217    #[cfg_attr(not(feature = "gemini"), ignore)]
218    async fn test_gemini_basic(cx: &mut TestAppContext) {
219        init_test(cx);
220
221        cx.executor().allow_parking();
222
223        let fs = FakeFs::new(cx.executor());
224        let project = Project::test(fs, [], cx).await;
225        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
226        thread
227            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
228            .await
229            .unwrap();
230
231        thread.read_with(cx, |thread, _| {
232            assert_eq!(thread.entries().len(), 2);
233            assert!(matches!(
234                thread.entries()[0],
235                AgentThreadEntry::UserMessage(_)
236            ));
237            assert!(matches!(
238                thread.entries()[1],
239                AgentThreadEntry::AssistantMessage(_)
240            ));
241        });
242    }
243
244    #[gpui::test]
245    #[cfg_attr(not(feature = "gemini"), ignore)]
246    async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
247        init_test(cx);
248
249        cx.executor().allow_parking();
250        let tempdir = tempfile::tempdir().unwrap();
251        std::fs::write(
252            tempdir.path().join("foo.rs"),
253            indoc! {"
254                fn main() {
255                    println!(\"Hello, world!\");
256                }
257            "},
258        )
259        .expect("failed to write file");
260        let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
261        let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
262        thread
263            .update(cx, |thread, cx| {
264                thread.send(
265                    acp::SendUserMessageParams {
266                        chunks: vec![
267                            acp::UserMessageChunk::Text {
268                                text: "Read the file ".into(),
269                            },
270                            acp::UserMessageChunk::Path {
271                                path: Path::new("foo.rs").into(),
272                            },
273                            acp::UserMessageChunk::Text {
274                                text: " and tell me what the content of the println! is".into(),
275                            },
276                        ],
277                    },
278                    cx,
279                )
280            })
281            .await
282            .unwrap();
283
284        thread.read_with(cx, |thread, cx| {
285            assert_eq!(thread.entries().len(), 3);
286            assert!(matches!(
287                thread.entries()[0],
288                AgentThreadEntry::UserMessage(_)
289            ));
290            assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
291            let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
292                panic!("Expected AssistantMessage")
293            };
294            assert!(
295                assistant_message.to_markdown(cx).contains("Hello, world!"),
296                "unexpected assistant message: {:?}",
297                assistant_message.to_markdown(cx)
298            );
299        });
300    }
301
302    #[gpui::test]
303    #[cfg_attr(not(feature = "gemini"), ignore)]
304    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
305        init_test(cx);
306
307        cx.executor().allow_parking();
308
309        let fs = FakeFs::new(cx.executor());
310        fs.insert_tree(
311            path!("/private/tmp"),
312            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
313        )
314        .await;
315        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
316        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
317        thread
318            .update(cx, |thread, cx| {
319                thread.send_raw(
320                    "Read the '/private/tmp/foo' file and tell me what you see.",
321                    cx,
322                )
323            })
324            .await
325            .unwrap();
326        thread.read_with(cx, |thread, _cx| {
327            assert!(matches!(
328                &thread.entries()[2],
329                AgentThreadEntry::ToolCall(ToolCall {
330                    status: ToolCallStatus::Allowed { .. },
331                    ..
332                })
333            ));
334
335            assert!(matches!(
336                thread.entries()[3],
337                AgentThreadEntry::AssistantMessage(_)
338            ));
339        });
340    }
341
342    #[gpui::test]
343    #[cfg_attr(not(feature = "gemini"), ignore)]
344    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
345        init_test(cx);
346
347        cx.executor().allow_parking();
348
349        let fs = FakeFs::new(cx.executor());
350        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
351        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
352        let full_turn = thread.update(cx, |thread, cx| {
353            thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
354        });
355
356        run_until_first_tool_call(&thread, cx).await;
357
358        let tool_call_id = thread.read_with(cx, |thread, _cx| {
359            let AgentThreadEntry::ToolCall(ToolCall {
360                id,
361                status:
362                    ToolCallStatus::WaitingForConfirmation {
363                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
364                        ..
365                    },
366                ..
367            }) = &thread.entries()[2]
368            else {
369                panic!();
370            };
371
372            assert_eq!(root_command, "echo");
373
374            *id
375        });
376
377        thread.update(cx, |thread, cx| {
378            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
379
380            assert!(matches!(
381                &thread.entries()[2],
382                AgentThreadEntry::ToolCall(ToolCall {
383                    status: ToolCallStatus::Allowed { .. },
384                    ..
385                })
386            ));
387        });
388
389        full_turn.await.unwrap();
390
391        thread.read_with(cx, |thread, cx| {
392            let AgentThreadEntry::ToolCall(ToolCall {
393                content: Some(ToolCallContent::Markdown { markdown }),
394                status: ToolCallStatus::Allowed { .. },
395                ..
396            }) = &thread.entries()[2]
397            else {
398                panic!();
399            };
400
401            markdown.read_with(cx, |md, _cx| {
402                assert!(
403                    md.source().contains("Hello, world!"),
404                    r#"Expected '{}' to contain "Hello, world!""#,
405                    md.source()
406                );
407            });
408        });
409    }
410
411    #[gpui::test]
412    #[cfg_attr(not(feature = "gemini"), ignore)]
413    async fn test_gemini_cancel(cx: &mut TestAppContext) {
414        init_test(cx);
415
416        cx.executor().allow_parking();
417
418        let fs = FakeFs::new(cx.executor());
419        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
420        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
421        let full_turn = thread.update(cx, |thread, cx| {
422            thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
423        });
424
425        let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
426
427        thread.read_with(cx, |thread, _cx| {
428            let AgentThreadEntry::ToolCall(ToolCall {
429                id,
430                status:
431                    ToolCallStatus::WaitingForConfirmation {
432                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
433                        ..
434                    },
435                ..
436            }) = &thread.entries()[first_tool_call_ix]
437            else {
438                panic!("{:?}", thread.entries()[1]);
439            };
440
441            assert_eq!(root_command, "echo");
442
443            *id
444        });
445
446        thread
447            .update(cx, |thread, cx| thread.cancel(cx))
448            .await
449            .unwrap();
450        full_turn.await.unwrap();
451        thread.read_with(cx, |thread, _| {
452            let AgentThreadEntry::ToolCall(ToolCall {
453                status: ToolCallStatus::Canceled,
454                ..
455            }) = &thread.entries()[first_tool_call_ix]
456            else {
457                panic!();
458            };
459        });
460
461        thread
462            .update(cx, |thread, cx| {
463                thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
464            })
465            .await
466            .unwrap();
467        thread.read_with(cx, |thread, _| {
468            assert!(matches!(
469                &thread.entries().last().unwrap(),
470                AgentThreadEntry::AssistantMessage(..),
471            ))
472        });
473    }
474
475    async fn run_until_first_tool_call(
476        thread: &Entity<AcpThread>,
477        cx: &mut TestAppContext,
478    ) -> usize {
479        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
480
481        let subscription = cx.update(|cx| {
482            cx.subscribe(thread, move |thread, _, cx| {
483                for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
484                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
485                        return tx.try_send(ix).unwrap();
486                    }
487                }
488            })
489        });
490
491        select! {
492            _ =  cx.executor().timer(Duration::from_secs(10)).fuse() => {
493                panic!("Timeout waiting for tool call")
494            }
495            ix = rx.next().fuse() => {
496                drop(subscription);
497                ix.unwrap()
498            }
499        }
500    }
501}