Merge commit 'd6c76d8d33' into acp

Mikayla Maki created

Change summary

Cargo.lock                    |   1 
crates/acp/Cargo.toml         |   1 
crates/acp/src/acp.rs         |   4 
crates/acp/src/server.rs      |  42 ++++++++++-
crates/acp/src/thread_view.rs | 129 ++++++++++++++++++++++++++++--------
5 files changed, 141 insertions(+), 36 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -22,6 +22,7 @@ dependencies = [
  "markdown",
  "parking_lot",
  "project",
+ "proto",
  "serde_json",
  "settings",
  "smol",

crates/acp/Cargo.toml 🔗

@@ -31,6 +31,7 @@ log.workspace = true
 markdown.workspace = true
 parking_lot.workspace = true
 project.workspace = true
+proto.workspace = true
 settings.workspace = true
 smol.workspace = true
 theme.workspace = true

crates/acp/src/acp.rs 🔗

@@ -817,7 +817,7 @@ mod tests {
         }
     }
 
-    pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
+    pub fn gemini_acp_server(project: Entity<Project>, cx: AsyncApp) -> Result<Arc<AcpServer>> {
         let cli_path =
             Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
         let mut command = util::command::new_smol_command("node");
@@ -836,6 +836,6 @@ mod tests {
 
         let child = command.spawn().unwrap();
 
-        Ok(AcpServer::stdio(child, project, &mut cx))
+        cx.update(|cx| AcpServer::stdio(child, project, cx))
     }
 }

crates/acp/src/server.rs 🔗

@@ -238,13 +238,13 @@ impl acp::Client for AcpClientDelegate {
 }
 
 impl AcpServer {
-    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
+    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut App) -> Arc<Self> {
         let stdin = process.stdin.take().expect("process didn't have stdin");
         let stdout = process.stdout.take().expect("process didn't have stdout");
 
         let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
         let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
-            AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
+            AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
             stdin,
             stdout,
         );
@@ -269,14 +269,35 @@ impl AcpServer {
         })
     }
 
+    pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
+        self.connection
+            .request(acp::InitializeParams)
+            .await
+            .map_err(to_anyhow)
+    }
+
+    pub async fn authenticate(&self) -> Result<()> {
+        self.connection
+            .request(acp::AuthenticateParams)
+            .await
+            .map_err(to_anyhow)?;
+
+        Ok(())
+    }
+
     pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
-        let response = self.connection.request(acp::CreateThreadParams).await?;
+        let response = self
+            .connection
+            .request(acp::CreateThreadParams)
+            .await
+            .map_err(to_anyhow)?;
+
         let thread_id: ThreadId = response.thread_id.into();
         let server = self.clone();
         let thread = cx.new(|_| AcpThread {
             // todo!
             title: "ACP Thread".into(),
-            id: thread_id.clone(),
+            id: thread_id.clone(), // Either<ErrorState, Id>
             next_entry_id: ThreadEntryId(0),
             entries: Vec::default(),
             project: self.project.clone(),
@@ -297,7 +318,8 @@ impl AcpServer {
                 thread_id: thread_id.clone().into(),
                 message,
             })
-            .await?;
+            .await
+            .map_err(to_anyhow)?;
         Ok(())
     }
 
@@ -306,6 +328,16 @@ impl AcpServer {
     }
 }
 
+#[track_caller]
+fn to_anyhow(e: acp::Error) -> anyhow::Error {
+    log::error!(
+        "failed to send message: {code}: {message}",
+        code = e.code,
+        message = e.message
+    );
+    anyhow::anyhow!(e.message)
+}
+
 impl From<acp::ThreadId> for ThreadId {
     fn from(thread_id: acp::ThreadId) -> Self {
         Self(thread_id.0.into())

crates/acp/src/thread_view.rs 🔗

@@ -1,5 +1,6 @@
 use std::path::Path;
 use std::rc::Rc;
+use std::sync::Arc;
 use std::time::Duration;
 
 use agentic_coding_protocol::{self as acp};
@@ -13,13 +14,13 @@ use gpui::{
 use gpui::{FocusHandle, Task};
 use language::Buffer;
 use language::language_settings::SoftWrap;
-use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle};
+use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
 use project::Project;
 use settings::Settings as _;
 use theme::ThemeSettings;
 use ui::prelude::*;
 use ui::{Button, Tooltip};
-use util::ResultExt;
+use util::{ResultExt, paths};
 use zed_actions::agent::Chat;
 
 use crate::{
@@ -28,12 +29,15 @@ use crate::{
 };
 
 pub struct AcpThreadView {
+    agent: Arc<AcpServer>,
     thread_state: ThreadState,
     // todo! reconsider structure. currently pretty sparse, but easy to clean up if we need to delete entries.
     thread_entry_views: Vec<Option<ThreadEntryView>>,
     message_editor: Entity<Editor>,
+    last_error: Option<Entity<Markdown>>,
     list_state: ListState,
     send_task: Option<Task<Result<()>>>,
+    auth_task: Option<Task<()>>,
 }
 
 #[derive(Debug)]
@@ -50,6 +54,7 @@ enum ThreadState {
         _subscription: Subscription,
     },
     LoadError(SharedString),
+    Unauthenticated,
 }
 
 impl AcpThreadView {
@@ -90,30 +95,12 @@ impl AcpThreadView {
             }),
         );
 
-        Self {
-            thread_state: Self::initial_state(project, window, cx),
-            thread_entry_views: Vec::new(),
-            message_editor,
-            send_task: None,
-            list_state: list_state,
-        }
-    }
-
-    fn initial_state(
-        project: Entity<Project>,
-        window: &mut Window,
-        cx: &mut Context<Self>,
-    ) -> ThreadState {
-        let Some(root_dir) = project
+        let root_dir = project
             .read(cx)
             .visible_worktrees(cx)
             .next()
             .map(|worktree| worktree.read(cx).abs_path())
-        else {
-            return ThreadState::LoadError(
-                "Gemini threads must be created within a project".into(),
-            );
-        };
+            .unwrap_or_else(|| paths::home_dir().as_path().into());
 
         let cli_path =
             Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
@@ -129,10 +116,39 @@ impl AcpThreadView {
             .spawn()
             .unwrap();
 
-        let project = project.clone();
+        let agent = AcpServer::stdio(child, project, cx);
+
+        Self {
+            thread_state: Self::initial_state(agent.clone(), window, cx),
+            agent,
+            message_editor,
+            thread_entry_views: Vec::new(),
+            send_task: None,
+            list_state: list_state,
+            last_error: None,
+            auth_task: None,
+        }
+    }
+
+    fn initial_state(
+        agent: Arc<AcpServer>,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> ThreadState {
         let load_task = cx.spawn_in(window, async move |this, cx| {
-            let agent = AcpServer::stdio(child, project, cx);
-            let result = agent.clone().create_thread(cx).await;
+            let result = match agent.initialize().await {
+                Err(e) => Err(e),
+                Ok(response) => {
+                    if !response.is_authenticated {
+                        this.update(cx, |this, _| {
+                            this.thread_state = ThreadState::Unauthenticated;
+                        })
+                        .ok();
+                        return;
+                    }
+                    agent.clone().create_thread(cx).await
+                }
+            };
 
             this.update_in(cx, |this, window, cx| {
                 match result {
@@ -148,6 +164,7 @@ impl AcpThreadView {
                         };
                     }
                     Err(e) => {
+                        dbg!(&e);
                         if let Some(exit_status) = agent.exit_status() {
                             this.thread_state = ThreadState::LoadError(
                                 format!(
@@ -172,7 +189,9 @@ impl AcpThreadView {
     fn thread(&self) -> Option<&Entity<AcpThread>> {
         match &self.thread_state {
             ThreadState::Ready { thread, .. } => Some(thread),
-            ThreadState::Loading { .. } | ThreadState::LoadError(..) => None,
+            ThreadState::Loading { .. }
+            | ThreadState::LoadError(..)
+            | ThreadState::Unauthenticated => None,
         }
     }
 
@@ -181,6 +200,7 @@ impl AcpThreadView {
             ThreadState::Ready { thread, .. } => thread.read(cx).title(),
             ThreadState::Loading { .. } => "Loading...".into(),
             ThreadState::LoadError(_) => "Failed to load".into(),
+            ThreadState::Unauthenticated => "Not authenticated".into(),
         }
     }
 
@@ -189,6 +209,7 @@ impl AcpThreadView {
     }
 
     fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context<Self>) {
+        self.last_error.take();
         let text = self.message_editor.read(cx).text(cx);
         if text.is_empty() {
             return;
@@ -198,9 +219,15 @@ impl AcpThreadView {
         let task = thread.update(cx, |thread, cx| thread.send(&text, cx));
 
         self.send_task = Some(cx.spawn(async move |this, cx| {
-            task.await?;
+            let result = task.await;
 
-            this.update(cx, |this, _cx| {
+            this.update(cx, |this, cx| {
+                if let Err(err) = result {
+                    this.last_error =
+                        Some(cx.new(|cx| {
+                            Markdown::new(format!("Error: {err}").into(), None, None, cx)
+                        }))
+                }
                 this.send_task.take();
             })
         }));
@@ -329,6 +356,27 @@ impl AcpThreadView {
         }
     }
 
+    fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+        let agent = self.agent.clone();
+
+        self.auth_task = Some(cx.spawn_in(window, async move |this, cx| {
+            let result = agent.authenticate().await;
+
+            this.update_in(cx, |this, window, cx| {
+                if let Err(err) = result {
+                    this.last_error =
+                        Some(cx.new(|cx| {
+                            Markdown::new(format!("Error: {err}").into(), None, None, cx)
+                        }))
+                } else {
+                    this.thread_state = Self::initial_state(agent, window, cx)
+                }
+                this.auth_task.take()
+            })
+            .ok();
+        }));
+    }
+
     fn authorize_tool_call(
         &mut self,
         id: ToolCallId,
@@ -899,7 +947,7 @@ impl Focusable for AcpThreadView {
 }
 
 impl Render for AcpThreadView {
-    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
         let text = self.message_editor.read(cx).text(cx);
         let is_editor_empty = text.is_empty();
         let focus_handle = self.message_editor.focus_handle(cx);
@@ -909,6 +957,14 @@ impl Render for AcpThreadView {
             .on_action(cx.listener(Self::chat))
             .h_full()
             .child(match &self.thread_state {
+                ThreadState::Unauthenticated => v_flex()
+                    .p_2()
+                    .flex_1()
+                    .justify_end()
+                    .child(Label::new("Not authenticated"))
+                    .child(Button::new("sign-in", "Sign in via Gemini CLI").on_click(
+                        cx.listener(|this, _, window, cx| this.authenticate(window, cx)),
+                    )),
                 ThreadState::Loading { .. } => v_flex()
                     .p_2()
                     .flex_1()
@@ -941,6 +997,21 @@ impl Render for AcpThreadView {
                         .into()
                     })),
             })
+            .when_some(self.last_error.clone(), |el, error| {
+                el.child(
+                    div()
+                        .text_xs()
+                        .p_2()
+                        .gap_2()
+                        .border_t_1()
+                        .border_color(cx.theme().status().error_border)
+                        .bg(cx.theme().status().error_background)
+                        .child(MarkdownElement::new(
+                            error,
+                            default_markdown_style(window, cx),
+                        )),
+                )
+            })
             .child(
                 v_flex()
                     .bg(cx.theme().colors().editor_background)