@@ -64,6 +64,7 @@ pub(crate) enum PromptContextType {
Thread,
Rules,
Diagnostics,
+ BranchDiff,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -102,6 +103,7 @@ impl TryFrom<&str> for PromptContextType {
"thread" => Ok(Self::Thread),
"rule" => Ok(Self::Rules),
"diagnostics" => Ok(Self::Diagnostics),
+ "diff" => Ok(Self::BranchDiff),
_ => Err(format!("Invalid context picker mode: {}", value)),
}
}
@@ -116,6 +118,7 @@ impl PromptContextType {
Self::Thread => "thread",
Self::Rules => "rule",
Self::Diagnostics => "diagnostics",
+ Self::BranchDiff => "branch diff",
}
}
@@ -127,6 +130,7 @@ impl PromptContextType {
Self::Thread => "Threads",
Self::Rules => "Rules",
Self::Diagnostics => "Diagnostics",
+ Self::BranchDiff => "Branch Diff",
}
}
@@ -138,6 +142,7 @@ impl PromptContextType {
Self::Thread => IconName::Thread,
Self::Rules => IconName::Reader,
Self::Diagnostics => IconName::Warning,
+ Self::BranchDiff => IconName::GitBranch,
}
}
}
@@ -150,6 +155,12 @@ pub(crate) enum Match {
Fetch(SharedString),
Rules(RulesContextEntry),
Entry(EntryMatch),
+ BranchDiff(BranchDiffMatch),
+}
+
+#[derive(Debug, Clone)]
+pub struct BranchDiffMatch {
+ pub base_ref: SharedString,
}
impl Match {
@@ -162,6 +173,7 @@ impl Match {
Match::Symbol(_) => 1.,
Match::Rules(_) => 1.,
Match::Fetch(_) => 1.,
+ Match::BranchDiff(_) => 1.,
}
}
}
@@ -781,6 +793,47 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
}
}
+ fn build_branch_diff_completion(
+ base_ref: SharedString,
+ source_range: Range<Anchor>,
+ source: Arc<T>,
+ editor: WeakEntity<Editor>,
+ mention_set: WeakEntity<MentionSet>,
+ workspace: Entity<Workspace>,
+ cx: &mut App,
+ ) -> Completion {
+ let uri = MentionUri::GitDiff {
+ base_ref: base_ref.to_string(),
+ };
+ let crease_text: SharedString = format!("Branch Diff (vs {})", base_ref).into();
+ let display_text = format!("@{}", crease_text);
+ let new_text = format!("[{}]({}) ", display_text, uri.to_uri());
+ let new_text_len = new_text.len();
+ let icon_path = uri.icon_path(cx);
+
+ Completion {
+ replace_range: source_range.clone(),
+ new_text,
+ label: CodeLabel::plain(crease_text.to_string(), None),
+ documentation: None,
+ source: project::CompletionSource::Custom,
+ icon_path: Some(icon_path),
+ match_start: None,
+ snippet_deduplication_key: None,
+ insert_text_mode: None,
+ confirm: Some(confirm_completion_callback(
+ crease_text,
+ source_range.start,
+ new_text_len - 1,
+ uri,
+ source,
+ editor,
+ mention_set,
+ workspace,
+ )),
+ }
+ }
+
fn search_slash_commands(&self, query: String, cx: &mut App) -> Task<Vec<AvailableCommand>> {
let commands = self.source.available_commands(cx);
if commands.is_empty() {
@@ -812,6 +865,27 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
})
}
+ fn fetch_branch_diff_match(
+ &self,
+ workspace: &Entity<Workspace>,
+ cx: &mut App,
+ ) -> Option<Task<Option<BranchDiffMatch>>> {
+ let project = workspace.read(cx).project().clone();
+ let repo = project.read(cx).active_repository(cx)?;
+
+ let default_branch_receiver = repo.update(cx, |repo, _| repo.default_branch(false));
+
+ Some(cx.spawn(async move |_cx| {
+ let base_ref = default_branch_receiver
+ .await
+ .ok()
+ .and_then(|r| r.ok())
+ .flatten()?;
+
+ Some(BranchDiffMatch { base_ref })
+ }))
+ }
+
fn search_mentions(
&self,
mode: Option<PromptContextType>,
@@ -892,6 +966,8 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
Some(PromptContextType::Diagnostics) => Task::ready(Vec::new()),
+ Some(PromptContextType::BranchDiff) => Task::ready(Vec::new()),
+
None if query.is_empty() => {
let recent_task = self.recent_context_picker_entries(&workspace, cx);
let entries = self
@@ -905,9 +981,25 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
})
.collect::<Vec<_>>();
+ let branch_diff_task = if self
+ .source
+ .supports_context(PromptContextType::BranchDiff, cx)
+ {
+ self.fetch_branch_diff_match(&workspace, cx)
+ } else {
+ None
+ };
+
cx.spawn(async move |_cx| {
let mut matches = recent_task.await;
matches.extend(entries);
+
+ if let Some(branch_diff_task) = branch_diff_task {
+ if let Some(branch_diff_match) = branch_diff_task.await {
+ matches.push(Match::BranchDiff(branch_diff_match));
+ }
+ }
+
matches
})
}
@@ -924,7 +1016,16 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
.map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword()))
.collect::<Vec<_>>();
- cx.background_spawn(async move {
+ let branch_diff_task = if self
+ .source
+ .supports_context(PromptContextType::BranchDiff, cx)
+ {
+ self.fetch_branch_diff_match(&workspace, cx)
+ } else {
+ None
+ };
+
+ cx.spawn(async move |cx| {
let mut matches = search_files_task
.await
.into_iter()
@@ -949,6 +1050,26 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
})
}));
+ if let Some(branch_diff_task) = branch_diff_task {
+ let branch_diff_keyword = PromptContextType::BranchDiff.keyword();
+ let branch_diff_matches = fuzzy::match_strings(
+ &[StringMatchCandidate::new(0, branch_diff_keyword)],
+ &query,
+ false,
+ true,
+ 1,
+ &Arc::new(AtomicBool::default()),
+ cx.background_executor().clone(),
+ )
+ .await;
+
+ if !branch_diff_matches.is_empty() {
+ if let Some(branch_diff_match) = branch_diff_task.await {
+ matches.push(Match::BranchDiff(branch_diff_match));
+ }
+ }
+ }
+
matches.sort_by(|a, b| {
b.score()
.partial_cmp(&a.score())
@@ -1364,6 +1485,17 @@ impl<T: PromptCompletionProviderDelegate> CompletionProvider for PromptCompletio
cx,
)
}
+ Match::BranchDiff(branch_diff) => {
+ Some(Self::build_branch_diff_completion(
+ branch_diff.base_ref,
+ source_range.clone(),
+ source.clone(),
+ editor.clone(),
+ mention_set.clone(),
+ workspace.clone(),
+ cx,
+ ))
+ }
})
.collect::<Vec<_>>()
});
@@ -147,10 +147,12 @@ impl MentionSet {
include_errors,
include_warnings,
} => self.confirm_mention_for_diagnostics(include_errors, include_warnings, cx),
+ MentionUri::GitDiff { base_ref } => {
+ self.confirm_mention_for_git_diff(base_ref.into(), cx)
+ }
MentionUri::PastedImage
| MentionUri::Selection { .. }
| MentionUri::TerminalSelection { .. }
- | MentionUri::GitDiff { .. }
| MentionUri::MergeConflict { .. } => {
Task::ready(Err(anyhow!("Unsupported mention URI type for paste")))
}
@@ -298,9 +300,8 @@ impl MentionSet {
debug_panic!("unexpected terminal URI");
Task::ready(Err(anyhow!("unexpected terminal URI")))
}
- MentionUri::GitDiff { .. } => {
- debug_panic!("unexpected git diff URI");
- Task::ready(Err(anyhow!("unexpected git diff URI")))
+ MentionUri::GitDiff { base_ref } => {
+ self.confirm_mention_for_git_diff(base_ref.into(), cx)
}
MentionUri::MergeConflict { .. } => {
debug_panic!("unexpected merge conflict URI");
@@ -602,6 +603,42 @@ impl MentionSet {
})
})
}
+
+ fn confirm_mention_for_git_diff(
+ &self,
+ base_ref: SharedString,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Mention>> {
+ let Some(project) = self.project.upgrade() else {
+ return Task::ready(Err(anyhow!("project not found")));
+ };
+
+ let Some(repo) = project.read(cx).active_repository(cx) else {
+ return Task::ready(Err(anyhow!("no active repository")));
+ };
+
+ let diff_receiver = repo.update(cx, |repo, cx| {
+ repo.diff(
+ git::repository::DiffType::MergeBase { base_ref: base_ref },
+ cx,
+ )
+ });
+
+ cx.spawn(async move |_, _| {
+ let diff_text = diff_receiver.await??;
+ if diff_text.is_empty() {
+ Ok(Mention::Text {
+ content: "No changes found in branch diff.".into(),
+ tracked_buffers: Vec::new(),
+ })
+ } else {
+ Ok(Mention::Text {
+ content: diff_text,
+ tracked_buffers: Vec::new(),
+ })
+ }
+ })
+ }
}
#[cfg(test)]