agent2: Port feedback (#36603)

Bennet Bo Fenner and Ben Brandt created

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

crates/acp_thread/src/connection.rs    |  17 +
crates/agent/src/thread.rs             |  53 -----
crates/agent2/src/agent.rs             |  25 ++
crates/agent_ui/src/acp/thread_view.rs | 283 +++++++++++++++++++++++++++
crates/agent_ui/src/active_thread.rs   |   1 
5 files changed, 321 insertions(+), 58 deletions(-)

Detailed changes

crates/acp_thread/src/connection.rs 🔗

@@ -64,6 +64,10 @@ pub trait AgentConnection {
         None
     }
 
+    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
+        None
+    }
+
     fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 }
 
@@ -81,6 +85,19 @@ pub trait AgentSessionResume {
     fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 }
 
+pub trait AgentTelemetry {
+    /// The name of the agent used for telemetry.
+    fn agent_name(&self) -> String;
+
+    /// A representation of the current thread state that can be serialized for
+    /// storage with telemetry events.
+    fn thread_data(
+        &self,
+        session_id: &acp::SessionId,
+        cx: &mut App,
+    ) -> Task<Result<serde_json::Value>>;
+}
+
 #[derive(Debug)]
 pub struct AuthRequired {
     pub description: Option<String>,

crates/agent/src/thread.rs 🔗

@@ -387,7 +387,6 @@ pub struct Thread {
     cumulative_token_usage: TokenUsage,
     exceeded_window_error: Option<ExceededWindowError>,
     tool_use_limit_reached: bool,
-    feedback: Option<ThreadFeedback>,
     retry_state: Option<RetryState>,
     message_feedback: HashMap<MessageId, ThreadFeedback>,
     last_received_chunk_at: Option<Instant>,
@@ -487,7 +486,6 @@ impl Thread {
             cumulative_token_usage: TokenUsage::default(),
             exceeded_window_error: None,
             tool_use_limit_reached: false,
-            feedback: None,
             retry_state: None,
             message_feedback: HashMap::default(),
             last_error_context: None,
@@ -612,7 +610,6 @@ impl Thread {
             cumulative_token_usage: serialized.cumulative_token_usage,
             exceeded_window_error: None,
             tool_use_limit_reached: serialized.tool_use_limit_reached,
-            feedback: None,
             message_feedback: HashMap::default(),
             last_error_context: None,
             last_received_chunk_at: None,
@@ -2787,10 +2784,6 @@ impl Thread {
         cx.emit(ThreadEvent::CancelEditing);
     }
 
-    pub fn feedback(&self) -> Option<ThreadFeedback> {
-        self.feedback
-    }
-
     pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
         self.message_feedback.get(&message_id).copied()
     }
@@ -2852,52 +2845,6 @@ impl Thread {
         })
     }
 
-    pub fn report_feedback(
-        &mut self,
-        feedback: ThreadFeedback,
-        cx: &mut Context<Self>,
-    ) -> Task<Result<()>> {
-        let last_assistant_message_id = self
-            .messages
-            .iter()
-            .rev()
-            .find(|msg| msg.role == Role::Assistant)
-            .map(|msg| msg.id);
-
-        if let Some(message_id) = last_assistant_message_id {
-            self.report_message_feedback(message_id, feedback, cx)
-        } else {
-            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
-            let serialized_thread = self.serialize(cx);
-            let thread_id = self.id().clone();
-            let client = self.project.read(cx).client();
-            self.feedback = Some(feedback);
-            cx.notify();
-
-            cx.background_spawn(async move {
-                let final_project_snapshot = final_project_snapshot.await;
-                let serialized_thread = serialized_thread.await?;
-                let thread_data = serde_json::to_value(serialized_thread)
-                    .unwrap_or_else(|_| serde_json::Value::Null);
-
-                let rating = match feedback {
-                    ThreadFeedback::Positive => "positive",
-                    ThreadFeedback::Negative => "negative",
-                };
-                telemetry::event!(
-                    "Assistant Thread Rated",
-                    rating,
-                    thread_id,
-                    thread_data,
-                    final_project_snapshot
-                );
-                client.telemetry().flush_events().await;
-
-                Ok(())
-            })
-        }
-    }
-
     /// Create a snapshot of the current project state including git information and unsaved buffers.
     fn project_snapshot(
         project: Entity<Project>,

crates/agent2/src/agent.rs 🔗

@@ -948,11 +948,36 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         })
     }
 
+    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
+        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
+    }
+
     fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
         self
     }
 }
 
+impl acp_thread::AgentTelemetry for NativeAgentConnection {
+    fn agent_name(&self) -> String {
+        "Zed".into()
+    }
+
+    fn thread_data(
+        &self,
+        session_id: &acp::SessionId,
+        cx: &mut App,
+    ) -> Task<Result<serde_json::Value>> {
+        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
+            return Task::ready(Err(anyhow!("Session not found")));
+        };
+
+        let task = session.thread.read(cx).to_db(cx);
+        cx.background_spawn(async move {
+            serde_json::to_value(task.await).context("Failed to serialize thread")
+        })
+    }
+}
+
 struct NativeAgentSessionEditor {
     thread: Entity<Thread>,
     acp_thread: WeakEntity<AcpThread>,

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -65,6 +65,12 @@ const RESPONSE_PADDING_X: Pixels = px(19.);
 pub const MIN_EDITOR_LINES: usize = 4;
 pub const MAX_EDITOR_LINES: usize = 8;
 
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+enum ThreadFeedback {
+    Positive,
+    Negative,
+}
+
 enum ThreadError {
     PaymentRequired,
     ModelRequestLimitReached(cloud_llm_client::Plan),
@@ -106,6 +112,128 @@ impl ProfileProvider for Entity<agent2::Thread> {
     }
 }
 
+#[derive(Default)]
+struct ThreadFeedbackState {
+    feedback: Option<ThreadFeedback>,
+    comments_editor: Option<Entity<Editor>>,
+}
+
+impl ThreadFeedbackState {
+    pub fn submit(
+        &mut self,
+        thread: Entity<AcpThread>,
+        feedback: ThreadFeedback,
+        window: &mut Window,
+        cx: &mut App,
+    ) {
+        let Some(telemetry) = thread.read(cx).connection().telemetry() else {
+            return;
+        };
+
+        if self.feedback == Some(feedback) {
+            return;
+        }
+
+        self.feedback = Some(feedback);
+        match feedback {
+            ThreadFeedback::Positive => {
+                self.comments_editor = None;
+            }
+            ThreadFeedback::Negative => {
+                self.comments_editor = Some(Self::build_feedback_comments_editor(window, cx));
+            }
+        }
+        let session_id = thread.read(cx).session_id().clone();
+        let agent_name = telemetry.agent_name();
+        let task = telemetry.thread_data(&session_id, cx);
+        let rating = match feedback {
+            ThreadFeedback::Positive => "positive",
+            ThreadFeedback::Negative => "negative",
+        };
+        cx.background_spawn(async move {
+            let thread = task.await?;
+            telemetry::event!(
+                "Agent Thread Rated",
+                session_id = session_id,
+                rating = rating,
+                agent = agent_name,
+                thread = thread
+            );
+            anyhow::Ok(())
+        })
+        .detach_and_log_err(cx);
+    }
+
+    pub fn submit_comments(&mut self, thread: Entity<AcpThread>, cx: &mut App) {
+        let Some(telemetry) = thread.read(cx).connection().telemetry() else {
+            return;
+        };
+
+        let Some(comments) = self
+            .comments_editor
+            .as_ref()
+            .map(|editor| editor.read(cx).text(cx))
+            .filter(|text| !text.trim().is_empty())
+        else {
+            return;
+        };
+
+        self.comments_editor.take();
+
+        let session_id = thread.read(cx).session_id().clone();
+        let agent_name = telemetry.agent_name();
+        let task = telemetry.thread_data(&session_id, cx);
+        cx.background_spawn(async move {
+            let thread = task.await?;
+            telemetry::event!(
+                "Agent Thread Feedback Comments",
+                session_id = session_id,
+                comments = comments,
+                agent = agent_name,
+                thread = thread
+            );
+            anyhow::Ok(())
+        })
+        .detach_and_log_err(cx);
+    }
+
+    pub fn clear(&mut self) {
+        *self = Self::default()
+    }
+
+    pub fn dismiss_comments(&mut self) {
+        self.comments_editor.take();
+    }
+
+    fn build_feedback_comments_editor(window: &mut Window, cx: &mut App) -> Entity<Editor> {
+        let buffer = cx.new(|cx| {
+            let empty_string = String::new();
+            MultiBuffer::singleton(cx.new(|cx| Buffer::local(empty_string, cx)), cx)
+        });
+
+        let editor = cx.new(|cx| {
+            let mut editor = Editor::new(
+                editor::EditorMode::AutoHeight {
+                    min_lines: 1,
+                    max_lines: Some(4),
+                },
+                buffer,
+                None,
+                window,
+                cx,
+            );
+            editor.set_placeholder_text(
+                "What went wrong? Share your feedback so we can improve.",
+                cx,
+            );
+            editor
+        });
+
+        editor.read(cx).focus_handle(cx).focus(window);
+        editor
+    }
+}
+
 pub struct AcpThreadView {
     agent: Rc<dyn AgentServer>,
     workspace: WeakEntity<Workspace>,
@@ -120,6 +248,7 @@ pub struct AcpThreadView {
     notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
     thread_retry_status: Option<RetryStatus>,
     thread_error: Option<ThreadError>,
+    thread_feedback: ThreadFeedbackState,
     list_state: ListState,
     scrollbar_state: ScrollbarState,
     auth_task: Option<Task<()>>,
@@ -218,6 +347,7 @@ impl AcpThreadView {
             scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
             thread_retry_status: None,
             thread_error: None,
+            thread_feedback: Default::default(),
             auth_task: None,
             expanded_tool_calls: HashSet::default(),
             expanded_thinking_blocks: HashSet::default(),
@@ -615,6 +745,7 @@ impl AcpThreadView {
     ) {
         self.thread_error.take();
         self.editing_message.take();
+        self.thread_feedback.clear();
 
         let Some(thread) = self.thread().cloned() else {
             return;
@@ -1087,6 +1218,12 @@ impl AcpThreadView {
                 .w_full()
                 .child(primary)
                 .child(self.render_thread_controls(cx))
+                .when_some(
+                    self.thread_feedback.comments_editor.clone(),
+                    |this, editor| {
+                        this.child(Self::render_feedback_feedback_editor(editor, window, cx))
+                    },
+                )
                 .into_any_element()
         } else {
             primary
@@ -3556,7 +3693,9 @@ impl AcpThreadView {
                 this.scroll_to_top(cx);
             }));
 
-        h_flex()
+        let mut container = h_flex()
+            .id("thread-controls-container")
+            .group("thread-controls-container")
             .w_full()
             .mr_1()
             .pb_2()
@@ -3564,9 +3703,145 @@ impl AcpThreadView {
             .opacity(0.4)
             .hover(|style| style.opacity(1.))
             .flex_wrap()
-            .justify_end()
-            .child(open_as_markdown)
-            .child(scroll_to_top)
+            .justify_end();
+
+        if AgentSettings::get_global(cx).enable_feedback {
+            let feedback = self.thread_feedback.feedback;
+            container = container.child(
+                div().visible_on_hover("thread-controls-container").child(
+                    Label::new(
+                        match feedback {
+                            Some(ThreadFeedback::Positive) => "Thanks for your feedback!",
+                            Some(ThreadFeedback::Negative) => "We appreciate your feedback and will use it to improve.",
+                            None => "Rating the thread sends all of your current conversation to the Zed team.",
+                        }
+                    )
+                    .color(Color::Muted)
+                    .size(LabelSize::XSmall)
+                    .truncate(),
+                ),
+            ).child(
+                h_flex()
+                    .child(
+                        IconButton::new("feedback-thumbs-up", IconName::ThumbsUp)
+                            .shape(ui::IconButtonShape::Square)
+                            .icon_size(IconSize::Small)
+                            .icon_color(match feedback {
+                                Some(ThreadFeedback::Positive) => Color::Accent,
+                                _ => Color::Ignored,
+                            })
+                            .tooltip(Tooltip::text("Helpful Response"))
+                            .on_click(cx.listener(move |this, _, window, cx| {
+                                this.handle_feedback_click(
+                                    ThreadFeedback::Positive,
+                                    window,
+                                    cx,
+                                );
+                            })),
+                    )
+                    .child(
+                        IconButton::new("feedback-thumbs-down", IconName::ThumbsDown)
+                            .shape(ui::IconButtonShape::Square)
+                            .icon_size(IconSize::Small)
+                            .icon_color(match feedback {
+                                Some(ThreadFeedback::Negative) => Color::Accent,
+                                _ => Color::Ignored,
+                            })
+                            .tooltip(Tooltip::text("Not Helpful"))
+                            .on_click(cx.listener(move |this, _, window, cx| {
+                                this.handle_feedback_click(
+                                    ThreadFeedback::Negative,
+                                    window,
+                                    cx,
+                                );
+                            })),
+                    )
+            )
+        }
+
+        container.child(open_as_markdown).child(scroll_to_top)
+    }
+
+    fn render_feedback_feedback_editor(
+        editor: Entity<Editor>,
+        window: &mut Window,
+        cx: &Context<Self>,
+    ) -> Div {
+        let focus_handle = editor.focus_handle(cx);
+        v_flex()
+            .key_context("AgentFeedbackMessageEditor")
+            .on_action(cx.listener(move |this, _: &menu::Cancel, _, cx| {
+                this.thread_feedback.dismiss_comments();
+                cx.notify();
+            }))
+            .on_action(cx.listener(move |this, _: &menu::Confirm, _window, cx| {
+                this.submit_feedback_message(cx);
+            }))
+            .mb_2()
+            .mx_4()
+            .p_2()
+            .rounded_md()
+            .border_1()
+            .border_color(cx.theme().colors().border)
+            .bg(cx.theme().colors().editor_background)
+            .child(editor)
+            .child(
+                h_flex()
+                    .gap_1()
+                    .justify_end()
+                    .child(
+                        Button::new("dismiss-feedback-message", "Cancel")
+                            .label_size(LabelSize::Small)
+                            .key_binding(
+                                KeyBinding::for_action_in(&menu::Cancel, &focus_handle, window, cx)
+                                    .map(|kb| kb.size(rems_from_px(10.))),
+                            )
+                            .on_click(cx.listener(move |this, _, _window, cx| {
+                                this.thread_feedback.dismiss_comments();
+                                cx.notify();
+                            })),
+                    )
+                    .child(
+                        Button::new("submit-feedback-message", "Share Feedback")
+                            .style(ButtonStyle::Tinted(ui::TintColor::Accent))
+                            .label_size(LabelSize::Small)
+                            .key_binding(
+                                KeyBinding::for_action_in(
+                                    &menu::Confirm,
+                                    &focus_handle,
+                                    window,
+                                    cx,
+                                )
+                                .map(|kb| kb.size(rems_from_px(10.))),
+                            )
+                            .on_click(cx.listener(move |this, _, _window, cx| {
+                                this.submit_feedback_message(cx);
+                            })),
+                    ),
+            )
+    }
+
+    fn handle_feedback_click(
+        &mut self,
+        feedback: ThreadFeedback,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        let Some(thread) = self.thread().cloned() else {
+            return;
+        };
+
+        self.thread_feedback.submit(thread, feedback, window, cx);
+        cx.notify();
+    }
+
+    fn submit_feedback_message(&mut self, cx: &mut Context<Self>) {
+        let Some(thread) = self.thread().cloned() else {
+            return;
+        };
+
+        self.thread_feedback.submit_comments(thread, cx);
+        cx.notify();
     }
 
     fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {

crates/agent_ui/src/active_thread.rs 🔗

@@ -2349,7 +2349,6 @@ impl ActiveThread {
                                     this.submit_feedback_message(message_id, cx);
                                     cx.notify();
                                 }))
-                                .on_action(cx.listener(Self::confirm_editing_message))
                                 .mb_2()
                                 .mx_4()
                                 .p_2()