diff --git a/crates/agent_ui/src/completion_provider.rs b/crates/agent_ui/src/completion_provider.rs index 40ad7bc729269d5dae3364ecf3e0de6e5ee5b0ec..d8c45755413ffb14433e3eeb4309e869de195a75 100644 --- a/crates/agent_ui/src/completion_provider.rs +++ b/crates/agent_ui/src/completion_provider.rs @@ -64,6 +64,7 @@ pub(crate) enum PromptContextType { Thread, Rules, Diagnostics, + BranchDiff, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -102,6 +103,7 @@ impl TryFrom<&str> for PromptContextType { "thread" => Ok(Self::Thread), "rule" => Ok(Self::Rules), "diagnostics" => Ok(Self::Diagnostics), + "diff" => Ok(Self::BranchDiff), _ => Err(format!("Invalid context picker mode: {}", value)), } } @@ -116,6 +118,7 @@ impl PromptContextType { Self::Thread => "thread", Self::Rules => "rule", Self::Diagnostics => "diagnostics", + Self::BranchDiff => "branch diff", } } @@ -127,6 +130,7 @@ impl PromptContextType { Self::Thread => "Threads", Self::Rules => "Rules", Self::Diagnostics => "Diagnostics", + Self::BranchDiff => "Branch Diff", } } @@ -138,6 +142,7 @@ impl PromptContextType { Self::Thread => IconName::Thread, Self::Rules => IconName::Reader, Self::Diagnostics => IconName::Warning, + Self::BranchDiff => IconName::GitBranch, } } } @@ -150,6 +155,12 @@ pub(crate) enum Match { Fetch(SharedString), Rules(RulesContextEntry), Entry(EntryMatch), + BranchDiff(BranchDiffMatch), +} + +#[derive(Debug, Clone)] +pub struct BranchDiffMatch { + pub base_ref: SharedString, } impl Match { @@ -162,6 +173,7 @@ impl Match { Match::Symbol(_) => 1., Match::Rules(_) => 1., Match::Fetch(_) => 1., + Match::BranchDiff(_) => 1., } } } @@ -781,6 +793,47 @@ impl PromptCompletionProvider { } } + fn build_branch_diff_completion( + base_ref: SharedString, + source_range: Range, + source: Arc, + editor: WeakEntity, + mention_set: WeakEntity, + workspace: Entity, + cx: &mut App, + ) -> Completion { + let uri = MentionUri::GitDiff { + base_ref: base_ref.to_string(), + }; + let crease_text: SharedString = format!("Branch Diff (vs {})", base_ref).into(); + let display_text = format!("@{}", crease_text); + let new_text = format!("[{}]({}) ", display_text, uri.to_uri()); + let new_text_len = new_text.len(); + let icon_path = uri.icon_path(cx); + + Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(crease_text.to_string(), None), + documentation: None, + source: project::CompletionSource::Custom, + icon_path: Some(icon_path), + match_start: None, + snippet_deduplication_key: None, + insert_text_mode: None, + confirm: Some(confirm_completion_callback( + crease_text, + source_range.start, + new_text_len - 1, + uri, + source, + editor, + mention_set, + workspace, + )), + } + } + fn search_slash_commands(&self, query: String, cx: &mut App) -> Task> { let commands = self.source.available_commands(cx); if commands.is_empty() { @@ -812,6 +865,27 @@ impl PromptCompletionProvider { }) } + fn fetch_branch_diff_match( + &self, + workspace: &Entity, + cx: &mut App, + ) -> Option>> { + let project = workspace.read(cx).project().clone(); + let repo = project.read(cx).active_repository(cx)?; + + let default_branch_receiver = repo.update(cx, |repo, _| repo.default_branch(false)); + + Some(cx.spawn(async move |_cx| { + let base_ref = default_branch_receiver + .await + .ok() + .and_then(|r| r.ok()) + .flatten()?; + + Some(BranchDiffMatch { base_ref }) + })) + } + fn search_mentions( &self, mode: Option, @@ -892,6 +966,8 @@ impl PromptCompletionProvider { Some(PromptContextType::Diagnostics) => Task::ready(Vec::new()), + Some(PromptContextType::BranchDiff) => Task::ready(Vec::new()), + None if query.is_empty() => { let recent_task = self.recent_context_picker_entries(&workspace, cx); let entries = self @@ -905,9 +981,25 @@ impl PromptCompletionProvider { }) .collect::>(); + let branch_diff_task = if self + .source + .supports_context(PromptContextType::BranchDiff, cx) + { + self.fetch_branch_diff_match(&workspace, cx) + } else { + None + }; + cx.spawn(async move |_cx| { let mut matches = recent_task.await; matches.extend(entries); + + if let Some(branch_diff_task) = branch_diff_task { + if let Some(branch_diff_match) = branch_diff_task.await { + matches.push(Match::BranchDiff(branch_diff_match)); + } + } + matches }) } @@ -924,7 +1016,16 @@ impl PromptCompletionProvider { .map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword())) .collect::>(); - cx.background_spawn(async move { + let branch_diff_task = if self + .source + .supports_context(PromptContextType::BranchDiff, cx) + { + self.fetch_branch_diff_match(&workspace, cx) + } else { + None + }; + + cx.spawn(async move |cx| { let mut matches = search_files_task .await .into_iter() @@ -949,6 +1050,26 @@ impl PromptCompletionProvider { }) })); + if let Some(branch_diff_task) = branch_diff_task { + let branch_diff_keyword = PromptContextType::BranchDiff.keyword(); + let branch_diff_matches = fuzzy::match_strings( + &[StringMatchCandidate::new(0, branch_diff_keyword)], + &query, + false, + true, + 1, + &Arc::new(AtomicBool::default()), + cx.background_executor().clone(), + ) + .await; + + if !branch_diff_matches.is_empty() { + if let Some(branch_diff_match) = branch_diff_task.await { + matches.push(Match::BranchDiff(branch_diff_match)); + } + } + } + matches.sort_by(|a, b| { b.score() .partial_cmp(&a.score()) @@ -1364,6 +1485,17 @@ impl CompletionProvider for PromptCompletio cx, ) } + Match::BranchDiff(branch_diff) => { + Some(Self::build_branch_diff_completion( + branch_diff.base_ref, + source_range.clone(), + source.clone(), + editor.clone(), + mention_set.clone(), + workspace.clone(), + cx, + )) + } }) .collect::>() }); diff --git a/crates/agent_ui/src/mention_set.rs b/crates/agent_ui/src/mention_set.rs index e072037f1758e00e648dc46c7ee70599c4363eef..1cb22af6a3fd15df5eeedc5018deaeff77a1dbff 100644 --- a/crates/agent_ui/src/mention_set.rs +++ b/crates/agent_ui/src/mention_set.rs @@ -147,10 +147,12 @@ impl MentionSet { include_errors, include_warnings, } => self.confirm_mention_for_diagnostics(include_errors, include_warnings, cx), + MentionUri::GitDiff { base_ref } => { + self.confirm_mention_for_git_diff(base_ref.into(), cx) + } MentionUri::PastedImage | MentionUri::Selection { .. } | MentionUri::TerminalSelection { .. } - | MentionUri::GitDiff { .. } | MentionUri::MergeConflict { .. } => { Task::ready(Err(anyhow!("Unsupported mention URI type for paste"))) } @@ -298,9 +300,8 @@ impl MentionSet { debug_panic!("unexpected terminal URI"); Task::ready(Err(anyhow!("unexpected terminal URI"))) } - MentionUri::GitDiff { .. } => { - debug_panic!("unexpected git diff URI"); - Task::ready(Err(anyhow!("unexpected git diff URI"))) + MentionUri::GitDiff { base_ref } => { + self.confirm_mention_for_git_diff(base_ref.into(), cx) } MentionUri::MergeConflict { .. } => { debug_panic!("unexpected merge conflict URI"); @@ -602,6 +603,42 @@ impl MentionSet { }) }) } + + fn confirm_mention_for_git_diff( + &self, + base_ref: SharedString, + cx: &mut Context, + ) -> Task> { + let Some(project) = self.project.upgrade() else { + return Task::ready(Err(anyhow!("project not found"))); + }; + + let Some(repo) = project.read(cx).active_repository(cx) else { + return Task::ready(Err(anyhow!("no active repository"))); + }; + + let diff_receiver = repo.update(cx, |repo, cx| { + repo.diff( + git::repository::DiffType::MergeBase { base_ref: base_ref }, + cx, + ) + }); + + cx.spawn(async move |_, _| { + let diff_text = diff_receiver.await??; + if diff_text.is_empty() { + Ok(Mention::Text { + content: "No changes found in branch diff.".into(), + tracked_buffers: Vec::new(), + }) + } else { + Ok(Mention::Text { + content: diff_text, + tracked_buffers: Vec::new(), + }) + } + }) + } } #[cfg(test)] diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 933e24e83c0450dcbdde27d49abebb7fda2fa119..6c2628f9d37efd0531d5663ac4b1d27d9ae5ae0f 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -80,6 +80,7 @@ impl PromptCompletionProviderDelegate for Entity { PromptContextType::Diagnostics, PromptContextType::Fetch, PromptContextType::Rules, + PromptContextType::BranchDiff, ]); } supported