agent_ui: Enable mentioning branch diff with main (#51235)

Danilo Leal created

As a follow up to the possibility of sending the branch diff to an agent
review, this PR enables directly @-mentioning the content of your diff
with main to the agent. Here's a quick video of it:


https://github.com/user-attachments/assets/f27b7287-c9b9-4ccf-875e-4ac6ce4cd8ad

Release Notes:

- Agent: Enabled mentioning the branch diff with main.

Change summary

crates/agent_ui/src/completion_provider.rs | 134 +++++++++++++++++++++++
crates/agent_ui/src/mention_set.rs         |  45 +++++++
crates/agent_ui/src/message_editor.rs      |   1 
3 files changed, 175 insertions(+), 5 deletions(-)

Detailed changes

crates/agent_ui/src/completion_provider.rs 🔗

@@ -64,6 +64,7 @@ pub(crate) enum PromptContextType {
     Thread,
     Rules,
     Diagnostics,
+    BranchDiff,
 }
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -102,6 +103,7 @@ impl TryFrom<&str> for PromptContextType {
             "thread" => Ok(Self::Thread),
             "rule" => Ok(Self::Rules),
             "diagnostics" => Ok(Self::Diagnostics),
+            "diff" => Ok(Self::BranchDiff),
             _ => Err(format!("Invalid context picker mode: {}", value)),
         }
     }
@@ -116,6 +118,7 @@ impl PromptContextType {
             Self::Thread => "thread",
             Self::Rules => "rule",
             Self::Diagnostics => "diagnostics",
+            Self::BranchDiff => "branch diff",
         }
     }
 
@@ -127,6 +130,7 @@ impl PromptContextType {
             Self::Thread => "Threads",
             Self::Rules => "Rules",
             Self::Diagnostics => "Diagnostics",
+            Self::BranchDiff => "Branch Diff",
         }
     }
 
@@ -138,6 +142,7 @@ impl PromptContextType {
             Self::Thread => IconName::Thread,
             Self::Rules => IconName::Reader,
             Self::Diagnostics => IconName::Warning,
+            Self::BranchDiff => IconName::GitBranch,
         }
     }
 }
@@ -150,6 +155,12 @@ pub(crate) enum Match {
     Fetch(SharedString),
     Rules(RulesContextEntry),
     Entry(EntryMatch),
+    BranchDiff(BranchDiffMatch),
+}
+
+#[derive(Debug, Clone)]
+pub struct BranchDiffMatch {
+    pub base_ref: SharedString,
 }
 
 impl Match {
@@ -162,6 +173,7 @@ impl Match {
             Match::Symbol(_) => 1.,
             Match::Rules(_) => 1.,
             Match::Fetch(_) => 1.,
+            Match::BranchDiff(_) => 1.,
         }
     }
 }
@@ -781,6 +793,47 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
         }
     }
 
+    fn build_branch_diff_completion(
+        base_ref: SharedString,
+        source_range: Range<Anchor>,
+        source: Arc<T>,
+        editor: WeakEntity<Editor>,
+        mention_set: WeakEntity<MentionSet>,
+        workspace: Entity<Workspace>,
+        cx: &mut App,
+    ) -> Completion {
+        let uri = MentionUri::GitDiff {
+            base_ref: base_ref.to_string(),
+        };
+        let crease_text: SharedString = format!("Branch Diff (vs {})", base_ref).into();
+        let display_text = format!("@{}", crease_text);
+        let new_text = format!("[{}]({}) ", display_text, uri.to_uri());
+        let new_text_len = new_text.len();
+        let icon_path = uri.icon_path(cx);
+
+        Completion {
+            replace_range: source_range.clone(),
+            new_text,
+            label: CodeLabel::plain(crease_text.to_string(), None),
+            documentation: None,
+            source: project::CompletionSource::Custom,
+            icon_path: Some(icon_path),
+            match_start: None,
+            snippet_deduplication_key: None,
+            insert_text_mode: None,
+            confirm: Some(confirm_completion_callback(
+                crease_text,
+                source_range.start,
+                new_text_len - 1,
+                uri,
+                source,
+                editor,
+                mention_set,
+                workspace,
+            )),
+        }
+    }
+
     fn search_slash_commands(&self, query: String, cx: &mut App) -> Task<Vec<AvailableCommand>> {
         let commands = self.source.available_commands(cx);
         if commands.is_empty() {
@@ -812,6 +865,27 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
         })
     }
 
+    fn fetch_branch_diff_match(
+        &self,
+        workspace: &Entity<Workspace>,
+        cx: &mut App,
+    ) -> Option<Task<Option<BranchDiffMatch>>> {
+        let project = workspace.read(cx).project().clone();
+        let repo = project.read(cx).active_repository(cx)?;
+
+        let default_branch_receiver = repo.update(cx, |repo, _| repo.default_branch(false));
+
+        Some(cx.spawn(async move |_cx| {
+            let base_ref = default_branch_receiver
+                .await
+                .ok()
+                .and_then(|r| r.ok())
+                .flatten()?;
+
+            Some(BranchDiffMatch { base_ref })
+        }))
+    }
+
     fn search_mentions(
         &self,
         mode: Option<PromptContextType>,
@@ -892,6 +966,8 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
 
             Some(PromptContextType::Diagnostics) => Task::ready(Vec::new()),
 
+            Some(PromptContextType::BranchDiff) => Task::ready(Vec::new()),
+
             None if query.is_empty() => {
                 let recent_task = self.recent_context_picker_entries(&workspace, cx);
                 let entries = self
@@ -905,9 +981,25 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
                     })
                     .collect::<Vec<_>>();
 
+                let branch_diff_task = if self
+                    .source
+                    .supports_context(PromptContextType::BranchDiff, cx)
+                {
+                    self.fetch_branch_diff_match(&workspace, cx)
+                } else {
+                    None
+                };
+
                 cx.spawn(async move |_cx| {
                     let mut matches = recent_task.await;
                     matches.extend(entries);
+
+                    if let Some(branch_diff_task) = branch_diff_task {
+                        if let Some(branch_diff_match) = branch_diff_task.await {
+                            matches.push(Match::BranchDiff(branch_diff_match));
+                        }
+                    }
+
                     matches
                 })
             }
@@ -924,7 +1016,16 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
                     .map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword()))
                     .collect::<Vec<_>>();
 
-                cx.background_spawn(async move {
+                let branch_diff_task = if self
+                    .source
+                    .supports_context(PromptContextType::BranchDiff, cx)
+                {
+                    self.fetch_branch_diff_match(&workspace, cx)
+                } else {
+                    None
+                };
+
+                cx.spawn(async move |cx| {
                     let mut matches = search_files_task
                         .await
                         .into_iter()
@@ -949,6 +1050,26 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
                         })
                     }));
 
+                    if let Some(branch_diff_task) = branch_diff_task {
+                        let branch_diff_keyword = PromptContextType::BranchDiff.keyword();
+                        let branch_diff_matches = fuzzy::match_strings(
+                            &[StringMatchCandidate::new(0, branch_diff_keyword)],
+                            &query,
+                            false,
+                            true,
+                            1,
+                            &Arc::new(AtomicBool::default()),
+                            cx.background_executor().clone(),
+                        )
+                        .await;
+
+                        if !branch_diff_matches.is_empty() {
+                            if let Some(branch_diff_match) = branch_diff_task.await {
+                                matches.push(Match::BranchDiff(branch_diff_match));
+                            }
+                        }
+                    }
+
                     matches.sort_by(|a, b| {
                         b.score()
                             .partial_cmp(&a.score())
@@ -1364,6 +1485,17 @@ impl<T: PromptCompletionProviderDelegate> CompletionProvider for PromptCompletio
                                         cx,
                                     )
                                 }
+                                Match::BranchDiff(branch_diff) => {
+                                    Some(Self::build_branch_diff_completion(
+                                        branch_diff.base_ref,
+                                        source_range.clone(),
+                                        source.clone(),
+                                        editor.clone(),
+                                        mention_set.clone(),
+                                        workspace.clone(),
+                                        cx,
+                                    ))
+                                }
                             })
                             .collect::<Vec<_>>()
                     });

crates/agent_ui/src/mention_set.rs 🔗

@@ -147,10 +147,12 @@ impl MentionSet {
                 include_errors,
                 include_warnings,
             } => self.confirm_mention_for_diagnostics(include_errors, include_warnings, cx),
+            MentionUri::GitDiff { base_ref } => {
+                self.confirm_mention_for_git_diff(base_ref.into(), cx)
+            }
             MentionUri::PastedImage
             | MentionUri::Selection { .. }
             | MentionUri::TerminalSelection { .. }
-            | MentionUri::GitDiff { .. }
             | MentionUri::MergeConflict { .. } => {
                 Task::ready(Err(anyhow!("Unsupported mention URI type for paste")))
             }
@@ -298,9 +300,8 @@ impl MentionSet {
                 debug_panic!("unexpected terminal URI");
                 Task::ready(Err(anyhow!("unexpected terminal URI")))
             }
-            MentionUri::GitDiff { .. } => {
-                debug_panic!("unexpected git diff URI");
-                Task::ready(Err(anyhow!("unexpected git diff URI")))
+            MentionUri::GitDiff { base_ref } => {
+                self.confirm_mention_for_git_diff(base_ref.into(), cx)
             }
             MentionUri::MergeConflict { .. } => {
                 debug_panic!("unexpected merge conflict URI");
@@ -602,6 +603,42 @@ impl MentionSet {
             })
         })
     }
+
+    fn confirm_mention_for_git_diff(
+        &self,
+        base_ref: SharedString,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<Mention>> {
+        let Some(project) = self.project.upgrade() else {
+            return Task::ready(Err(anyhow!("project not found")));
+        };
+
+        let Some(repo) = project.read(cx).active_repository(cx) else {
+            return Task::ready(Err(anyhow!("no active repository")));
+        };
+
+        let diff_receiver = repo.update(cx, |repo, cx| {
+            repo.diff(
+                git::repository::DiffType::MergeBase { base_ref: base_ref },
+                cx,
+            )
+        });
+
+        cx.spawn(async move |_, _| {
+            let diff_text = diff_receiver.await??;
+            if diff_text.is_empty() {
+                Ok(Mention::Text {
+                    content: "No changes found in branch diff.".into(),
+                    tracked_buffers: Vec::new(),
+                })
+            } else {
+                Ok(Mention::Text {
+                    content: diff_text,
+                    tracked_buffers: Vec::new(),
+                })
+            }
+        })
+    }
 }
 
 #[cfg(test)]

crates/agent_ui/src/message_editor.rs 🔗

@@ -80,6 +80,7 @@ impl PromptCompletionProviderDelegate for Entity<MessageEditor> {
                 PromptContextType::Diagnostics,
                 PromptContextType::Fetch,
                 PromptContextType::Rules,
+                PromptContextType::BranchDiff,
             ]);
         }
         supported