agent_ui: More agent notifications (#35441)

Ben Brandt created

Release Notes:

- N/A

Change summary

crates/acp_thread/src/acp_thread.rs    |  32 +
crates/agent_ui/src/acp/thread_view.rs | 523 +++++++++++++++++++++++++++
crates/agent_ui/src/agent_diff.rs      |   3 
3 files changed, 547 insertions(+), 11 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -580,6 +580,9 @@ pub struct AcpThread {
 pub enum AcpThreadEvent {
     NewEntry,
     EntryUpdated(usize),
+    ToolAuthorizationRequired,
+    Stopped,
+    Error,
 }
 
 impl EventEmitter<AcpThreadEvent> for AcpThread {}
@@ -676,6 +679,18 @@ impl AcpThread {
         false
     }
 
+    pub fn used_tools_since_last_user_message(&self) -> bool {
+        for entry in self.entries.iter().rev() {
+            match entry {
+                AgentThreadEntry::UserMessage(..) => return false,
+                AgentThreadEntry::AssistantMessage(..) => continue,
+                AgentThreadEntry::ToolCall(..) => return true,
+            }
+        }
+
+        false
+    }
+
     pub fn handle_session_update(
         &mut self,
         update: acp::SessionUpdate,
@@ -879,6 +894,7 @@ impl AcpThread {
         };
 
         self.upsert_tool_call_inner(tool_call, status, cx);
+        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
         rx
     }
 
@@ -1018,12 +1034,18 @@ impl AcpThread {
             .log_err();
         }));
 
-        async move {
-            match rx.await {
-                Ok(Err(e)) => Err(e)?,
-                _ => Ok(()),
+        cx.spawn(async move |this, cx| match rx.await {
+            Ok(Err(e)) => {
+                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
+                    .log_err();
+                Err(e)?
             }
-        }
+            _ => {
+                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
+                    .log_err();
+                Ok(())
+            }
+        })
         .boxed()
     }
 

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -1,5 +1,7 @@
 use acp_thread::{AgentConnection, Plan};
 use agent_servers::AgentServer;
+use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
+use audio::{Audio, Sound};
 use std::cell::RefCell;
 use std::collections::BTreeMap;
 use std::path::Path;
@@ -18,10 +20,10 @@ use editor::{
 use file_icons::FileIcons;
 use gpui::{
     Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
-    FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement,
-    Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity,
-    Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*,
-    pulsating_between,
+    FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString,
+    StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
+    UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient,
+    list, percentage, point, prelude::*, pulsating_between,
 };
 use language::language_settings::SoftWrap;
 use language::{Buffer, Language};
@@ -45,7 +47,10 @@ use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSe
 use crate::acp::message_history::MessageHistory;
 use crate::agent_diff::AgentDiff;
 use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES};
-use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll};
+use crate::ui::{AgentNotification, AgentNotificationEvent};
+use crate::{
+    AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll,
+};
 
 const RESPONSE_PADDING_X: Pixels = px(19.);
 
@@ -59,6 +64,8 @@ pub struct AcpThreadView {
     message_set_from_history: bool,
     _message_editor_subscription: Subscription,
     mention_set: Arc<Mutex<MentionSet>>,
+    notifications: Vec<WindowHandle<AgentNotification>>,
+    notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
     last_error: Option<Entity<Markdown>>,
     list_state: ListState,
     auth_task: Option<Task<()>>,
@@ -174,6 +181,8 @@ impl AcpThreadView {
             message_set_from_history: false,
             _message_editor_subscription: message_editor_subscription,
             mention_set,
+            notifications: Vec::new(),
+            notification_subscriptions: HashMap::default(),
             diff_editors: Default::default(),
             list_state: list_state,
             last_error: None,
@@ -381,7 +390,9 @@ impl AcpThreadView {
             return;
         }
 
-        let Some(thread) = self.thread() else { return };
+        let Some(thread) = self.thread() else {
+            return;
+        };
         let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx));
 
         cx.spawn(async move |this, cx| {
@@ -564,6 +575,30 @@ impl AcpThreadView {
                 self.sync_thread_entry_view(index, window, cx);
                 self.list_state.splice(index..index + 1, 1);
             }
+            AcpThreadEvent::ToolAuthorizationRequired => {
+                self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
+            }
+            AcpThreadEvent::Stopped => {
+                let used_tools = thread.read(cx).used_tools_since_last_user_message();
+                self.notify_with_sound(
+                    if used_tools {
+                        "Finished running tools"
+                    } else {
+                        "New message"
+                    },
+                    IconName::ZedAssistant,
+                    window,
+                    cx,
+                );
+            }
+            AcpThreadEvent::Error => {
+                self.notify_with_sound(
+                    "Agent stopped due to an error",
+                    IconName::Warning,
+                    window,
+                    cx,
+                );
+            }
         }
         cx.notify();
     }
@@ -2160,6 +2195,154 @@ impl AcpThreadView {
         self.list_state.scroll_to(ListOffset::default());
         cx.notify();
     }
+
+    fn notify_with_sound(
+        &mut self,
+        caption: impl Into<SharedString>,
+        icon: IconName,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        self.play_notification_sound(window, cx);
+        self.show_notification(caption, icon, window, cx);
+    }
+
+    fn play_notification_sound(&self, window: &Window, cx: &mut App) {
+        let settings = AgentSettings::get_global(cx);
+        if settings.play_sound_when_agent_done && !window.is_window_active() {
+            Audio::play_sound(Sound::AgentDone, cx);
+        }
+    }
+
+    fn show_notification(
+        &mut self,
+        caption: impl Into<SharedString>,
+        icon: IconName,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        if window.is_window_active() || !self.notifications.is_empty() {
+            return;
+        }
+
+        let title = self.title(cx);
+
+        match AgentSettings::get_global(cx).notify_when_agent_waiting {
+            NotifyWhenAgentWaiting::PrimaryScreen => {
+                if let Some(primary) = cx.primary_display() {
+                    self.pop_up(icon, caption.into(), title, window, primary, cx);
+                }
+            }
+            NotifyWhenAgentWaiting::AllScreens => {
+                let caption = caption.into();
+                for screen in cx.displays() {
+                    self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx);
+                }
+            }
+            NotifyWhenAgentWaiting::Never => {
+                // Don't show anything
+            }
+        }
+    }
+
+    fn pop_up(
+        &mut self,
+        icon: IconName,
+        caption: SharedString,
+        title: SharedString,
+        window: &mut Window,
+        screen: Rc<dyn PlatformDisplay>,
+        cx: &mut Context<Self>,
+    ) {
+        let options = AgentNotification::window_options(screen, cx);
+
+        let project_name = self.workspace.upgrade().and_then(|workspace| {
+            workspace
+                .read(cx)
+                .project()
+                .read(cx)
+                .visible_worktrees(cx)
+                .next()
+                .map(|worktree| worktree.read(cx).root_name().to_string())
+        });
+
+        if let Some(screen_window) = cx
+            .open_window(options, |_, cx| {
+                cx.new(|_| {
+                    AgentNotification::new(title.clone(), caption.clone(), icon, project_name)
+                })
+            })
+            .log_err()
+        {
+            if let Some(pop_up) = screen_window.entity(cx).log_err() {
+                self.notification_subscriptions
+                    .entry(screen_window)
+                    .or_insert_with(Vec::new)
+                    .push(cx.subscribe_in(&pop_up, window, {
+                        |this, _, event, window, cx| match event {
+                            AgentNotificationEvent::Accepted => {
+                                let handle = window.window_handle();
+                                cx.activate(true);
+
+                                let workspace_handle = this.workspace.clone();
+
+                                // If there are multiple Zed windows, activate the correct one.
+                                cx.defer(move |cx| {
+                                    handle
+                                        .update(cx, |_view, window, _cx| {
+                                            window.activate_window();
+
+                                            if let Some(workspace) = workspace_handle.upgrade() {
+                                                workspace.update(_cx, |workspace, cx| {
+                                                    workspace.focus_panel::<AgentPanel>(window, cx);
+                                                });
+                                            }
+                                        })
+                                        .log_err();
+                                });
+
+                                this.dismiss_notifications(cx);
+                            }
+                            AgentNotificationEvent::Dismissed => {
+                                this.dismiss_notifications(cx);
+                            }
+                        }
+                    }));
+
+                self.notifications.push(screen_window);
+
+                // If the user manually refocuses the original window, dismiss the popup.
+                self.notification_subscriptions
+                    .entry(screen_window)
+                    .or_insert_with(Vec::new)
+                    .push({
+                        let pop_up_weak = pop_up.downgrade();
+
+                        cx.observe_window_activation(window, move |_, window, cx| {
+                            if window.is_window_active() {
+                                if let Some(pop_up) = pop_up_weak.upgrade() {
+                                    pop_up.update(cx, |_, cx| {
+                                        cx.emit(AgentNotificationEvent::Dismissed);
+                                    });
+                                }
+                            }
+                        })
+                    });
+            }
+        }
+    }
+
+    fn dismiss_notifications(&mut self, cx: &mut Context<Self>) {
+        for window in self.notifications.drain(..) {
+            window
+                .update(cx, |_, window, _| {
+                    window.remove_window();
+                })
+                .ok();
+
+            self.notification_subscriptions.remove(&window);
+        }
+    }
 }
 
 impl Focusable for AcpThreadView {
@@ -2441,3 +2624,331 @@ fn plan_label_markdown_style(
         ..default_md_style
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use agent_client_protocol::SessionId;
+    use editor::EditorSettings;
+    use fs::FakeFs;
+    use futures::future::try_join_all;
+    use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
+    use rand::Rng;
+    use settings::SettingsStore;
+
+    use super::*;
+
+    #[gpui::test]
+    async fn test_notification_for_stop_event(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await;
+
+        let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+        message_editor.update_in(cx, |editor, window, cx| {
+            editor.set_text("Hello", window, cx);
+        });
+
+        cx.deactivate_window();
+
+        thread_view.update_in(cx, |thread_view, window, cx| {
+            thread_view.chat(&Chat, window, cx);
+        });
+
+        cx.run_until_parked();
+
+        assert!(
+            cx.windows()
+                .iter()
+                .any(|window| window.downcast::<AgentNotification>().is_some())
+        );
+    }
+
+    #[gpui::test]
+    async fn test_notification_for_error(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let (thread_view, cx) =
+            setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await;
+
+        let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+        message_editor.update_in(cx, |editor, window, cx| {
+            editor.set_text("Hello", window, cx);
+        });
+
+        cx.deactivate_window();
+
+        thread_view.update_in(cx, |thread_view, window, cx| {
+            thread_view.chat(&Chat, window, cx);
+        });
+
+        cx.run_until_parked();
+
+        assert!(
+            cx.windows()
+                .iter()
+                .any(|window| window.downcast::<AgentNotification>().is_some())
+        );
+    }
+
+    #[gpui::test]
+    async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let tool_call_id = acp::ToolCallId("1".into());
+        let tool_call = acp::ToolCall {
+            id: tool_call_id.clone(),
+            label: "Label".into(),
+            kind: acp::ToolKind::Edit,
+            status: acp::ToolCallStatus::Pending,
+            content: vec!["hi".into()],
+            locations: vec![],
+            raw_input: None,
+        };
+        let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
+            .with_permission_requests(HashMap::from_iter([(
+                tool_call_id,
+                vec![acp::PermissionOption {
+                    id: acp::PermissionOptionId("1".into()),
+                    label: "Allow".into(),
+                    kind: acp::PermissionOptionKind::AllowOnce,
+                }],
+            )]));
+        let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
+
+        let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+        message_editor.update_in(cx, |editor, window, cx| {
+            editor.set_text("Hello", window, cx);
+        });
+
+        cx.deactivate_window();
+
+        thread_view.update_in(cx, |thread_view, window, cx| {
+            thread_view.chat(&Chat, window, cx);
+        });
+
+        cx.run_until_parked();
+
+        assert!(
+            cx.windows()
+                .iter()
+                .any(|window| window.downcast::<AgentNotification>().is_some())
+        );
+    }
+
+    async fn setup_thread_view(
+        agent: impl AgentServer + 'static,
+        cx: &mut TestAppContext,
+    ) -> (Entity<AcpThreadView>, &mut VisualTestContext) {
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+        let (workspace, cx) =
+            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+        let thread_view = cx.update(|window, cx| {
+            cx.new(|cx| {
+                AcpThreadView::new(
+                    Rc::new(agent),
+                    workspace.downgrade(),
+                    project,
+                    Rc::new(RefCell::new(MessageHistory::default())),
+                    1,
+                    None,
+                    window,
+                    cx,
+                )
+            })
+        });
+        cx.run_until_parked();
+        (thread_view, cx)
+    }
+
+    struct StubAgentServer<C> {
+        connection: C,
+    }
+
+    impl<C> StubAgentServer<C> {
+        fn new(connection: C) -> Self {
+            Self { connection }
+        }
+    }
+
+    impl StubAgentServer<StubAgentConnection> {
+        fn default() -> Self {
+            Self::new(StubAgentConnection::default())
+        }
+    }
+
+    impl<C> AgentServer for StubAgentServer<C>
+    where
+        C: 'static + AgentConnection + Send + Clone,
+    {
+        fn logo(&self) -> ui::IconName {
+            unimplemented!()
+        }
+
+        fn name(&self) -> &'static str {
+            unimplemented!()
+        }
+
+        fn empty_state_headline(&self) -> &'static str {
+            unimplemented!()
+        }
+
+        fn empty_state_message(&self) -> &'static str {
+            unimplemented!()
+        }
+
+        fn connect(
+            &self,
+            _root_dir: &Path,
+            _project: &Entity<Project>,
+            _cx: &mut App,
+        ) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
+            Task::ready(Ok(Rc::new(self.connection.clone())))
+        }
+    }
+
+    #[derive(Clone, Default)]
+    struct StubAgentConnection {
+        sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
+        updates: Vec<acp::SessionUpdate>,
+    }
+
+    impl StubAgentConnection {
+        fn new(updates: Vec<acp::SessionUpdate>) -> Self {
+            Self {
+                updates,
+                permission_requests: HashMap::default(),
+                sessions: Arc::default(),
+            }
+        }
+
+        fn with_permission_requests(
+            mut self,
+            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
+        ) -> Self {
+            self.permission_requests = permission_requests;
+            self
+        }
+    }
+
+    impl AgentConnection for StubAgentConnection {
+        fn name(&self) -> &'static str {
+            "StubAgentConnection"
+        }
+
+        fn new_thread(
+            self: Rc<Self>,
+            project: Entity<Project>,
+            _cwd: &Path,
+            cx: &mut gpui::AsyncApp,
+        ) -> Task<gpui::Result<Entity<AcpThread>>> {
+            let session_id = SessionId(
+                rand::thread_rng()
+                    .sample_iter(&rand::distributions::Alphanumeric)
+                    .take(7)
+                    .map(char::from)
+                    .collect::<String>()
+                    .into(),
+            );
+            let thread = cx
+                .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
+                .unwrap();
+            self.sessions.lock().insert(session_id, thread.downgrade());
+            Task::ready(Ok(thread))
+        }
+
+        fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
+            unimplemented!()
+        }
+
+        fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
+            let sessions = self.sessions.lock();
+            let thread = sessions.get(&params.session_id).unwrap();
+            let mut tasks = vec![];
+            for update in &self.updates {
+                let thread = thread.clone();
+                let update = update.clone();
+                let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
+                    && let Some(options) = self.permission_requests.get(&tool_call.id)
+                {
+                    Some((tool_call.clone(), options.clone()))
+                } else {
+                    None
+                };
+                let task = cx.spawn(async move |cx| {
+                    if let Some((tool_call, options)) = permission_request {
+                        let permission = thread.update(cx, |thread, cx| {
+                            thread.request_tool_call_permission(
+                                tool_call.clone(),
+                                options.clone(),
+                                cx,
+                            )
+                        })?;
+                        permission.await?;
+                    }
+                    thread.update(cx, |thread, cx| {
+                        thread.handle_session_update(update.clone(), cx).unwrap();
+                    })?;
+                    anyhow::Ok(())
+                });
+                tasks.push(task);
+            }
+            cx.spawn(async move |_| {
+                try_join_all(tasks).await?;
+                Ok(())
+            })
+        }
+
+        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
+            unimplemented!()
+        }
+    }
+
+    #[derive(Clone)]
+    struct SaboteurAgentConnection;
+
+    impl AgentConnection for SaboteurAgentConnection {
+        fn name(&self) -> &'static str {
+            "SaboteurAgentConnection"
+        }
+
+        fn new_thread(
+            self: Rc<Self>,
+            project: Entity<Project>,
+            _cwd: &Path,
+            cx: &mut gpui::AsyncApp,
+        ) -> Task<gpui::Result<Entity<AcpThread>>> {
+            Task::ready(Ok(cx
+                .new(|cx| AcpThread::new(self, project, SessionId("test".into()), cx))
+                .unwrap()))
+        }
+
+        fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
+            unimplemented!()
+        }
+
+        fn prompt(&self, _params: acp::PromptArguments, _cx: &mut App) -> Task<gpui::Result<()>> {
+            Task::ready(Err(anyhow::anyhow!("Error prompting")))
+        }
+
+        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
+            unimplemented!()
+        }
+    }
+
+    fn init_test(cx: &mut TestAppContext) {
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            language::init(cx);
+            Project::init_settings(cx);
+            AgentSettings::register(cx);
+            workspace::init_settings(cx);
+            ThemeSettings::register(cx);
+            release_channel::init(SemanticVersion::default(), cx);
+            EditorSettings::register(cx);
+        });
+    }
+}

crates/agent_ui/src/agent_diff.rs 🔗

@@ -1521,6 +1521,9 @@ impl AgentDiff {
                     self.update_reviewing_editors(workspace, window, cx);
                 }
             }
+            AcpThreadEvent::Stopped
+            | AcpThreadEvent::ToolAuthorizationRequired
+            | AcpThreadEvent::Error => {}
         }
     }