1use anyhow::Result;
  2use buffer_diff::BufferDiff;
  3use collections::HashSet;
  4use futures::StreamExt;
  5use git::{
  6    repository::RepoPath,
  7    status::{DiffTreeType, FileStatus, StatusCode, TrackedStatus, TreeDiff, TreeDiffStatus},
  8};
  9use gpui::{
 10    App, AsyncWindowContext, Context, Entity, EventEmitter, SharedString, Subscription, Task,
 11    WeakEntity, Window,
 12};
 13
 14use language::Buffer;
 15use text::BufferId;
 16use util::ResultExt;
 17
 18use crate::{
 19    Project,
 20    git_store::{GitStoreEvent, Repository, RepositoryEvent},
 21};
 22
 23#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
 24pub enum DiffBase {
 25    Head,
 26    Merge { base_ref: SharedString },
 27}
 28
 29impl DiffBase {
 30    pub fn is_merge_base(&self) -> bool {
 31        matches!(self, DiffBase::Merge { .. })
 32    }
 33}
 34
 35pub struct BranchDiff {
 36    diff_base: DiffBase,
 37    repo: Option<Entity<Repository>>,
 38    project: Entity<Project>,
 39    base_commit: Option<SharedString>,
 40    head_commit: Option<SharedString>,
 41    tree_diff: Option<TreeDiff>,
 42    _subscription: Subscription,
 43    update_needed: postage::watch::Sender<()>,
 44    _task: Task<()>,
 45}
 46
 47pub enum BranchDiffEvent {
 48    FileListChanged,
 49}
 50
 51impl EventEmitter<BranchDiffEvent> for BranchDiff {}
 52
 53impl BranchDiff {
 54    pub fn new(
 55        source: DiffBase,
 56        project: Entity<Project>,
 57        window: &mut Window,
 58        cx: &mut Context<Self>,
 59    ) -> Self {
 60        let git_store = project.read(cx).git_store().clone();
 61        let git_store_subscription = cx.subscribe_in(
 62            &git_store,
 63            window,
 64            move |this, _git_store, event, _window, cx| match event {
 65                GitStoreEvent::ActiveRepositoryChanged(_)
 66                | GitStoreEvent::RepositoryUpdated(
 67                    _,
 68                    RepositoryEvent::StatusesChanged { full_scan: _ },
 69                    true,
 70                )
 71                | GitStoreEvent::ConflictsUpdated => {
 72                    cx.emit(BranchDiffEvent::FileListChanged);
 73                    *this.update_needed.borrow_mut() = ();
 74                }
 75                _ => {}
 76            },
 77        );
 78
 79        let (send, recv) = postage::watch::channel::<()>();
 80        let worker = window.spawn(cx, {
 81            let this = cx.weak_entity();
 82            async |cx| Self::handle_status_updates(this, recv, cx).await
 83        });
 84        let repo = git_store.read(cx).active_repository();
 85
 86        Self {
 87            diff_base: source,
 88            repo,
 89            project,
 90            tree_diff: None,
 91            base_commit: None,
 92            head_commit: None,
 93            _subscription: git_store_subscription,
 94            _task: worker,
 95            update_needed: send,
 96        }
 97    }
 98
 99    pub fn diff_base(&self) -> &DiffBase {
100        &self.diff_base
101    }
102
103    pub async fn handle_status_updates(
104        this: WeakEntity<Self>,
105        mut recv: postage::watch::Receiver<()>,
106        cx: &mut AsyncWindowContext,
107    ) {
108        Self::reload_tree_diff(this.clone(), cx).await.log_err();
109        while recv.next().await.is_some() {
110            let Ok(needs_update) = this.update(cx, |this, cx| {
111                let mut needs_update = false;
112                let active_repo = this
113                    .project
114                    .read(cx)
115                    .git_store()
116                    .read(cx)
117                    .active_repository();
118                if active_repo != this.repo {
119                    needs_update = true;
120                    this.repo = active_repo;
121                } else if let Some(repo) = this.repo.as_ref() {
122                    repo.update(cx, |repo, _| {
123                        if let Some(branch) = &repo.branch
124                            && let DiffBase::Merge { base_ref } = &this.diff_base
125                            && let Some(commit) = branch.most_recent_commit.as_ref()
126                            && &branch.ref_name == base_ref
127                            && this.base_commit.as_ref() != Some(&commit.sha)
128                        {
129                            this.base_commit = Some(commit.sha.clone());
130                            needs_update = true;
131                        }
132
133                        if repo.head_commit.as_ref().map(|c| &c.sha) != this.head_commit.as_ref() {
134                            this.head_commit = repo.head_commit.as_ref().map(|c| c.sha.clone());
135                            needs_update = true;
136                        }
137                    })
138                }
139                needs_update
140            }) else {
141                return;
142            };
143
144            if needs_update {
145                Self::reload_tree_diff(this.clone(), cx).await.log_err();
146            }
147        }
148    }
149
150    pub fn status_for_buffer_id(&self, buffer_id: BufferId, cx: &App) -> Option<FileStatus> {
151        let (repo, path) = self
152            .project
153            .read(cx)
154            .git_store()
155            .read(cx)
156            .repository_and_path_for_buffer_id(buffer_id, cx)?;
157        if self.repo() == Some(&repo) {
158            return self.merge_statuses(
159                repo.read(cx)
160                    .status_for_path(&path)
161                    .map(|status| status.status),
162                self.tree_diff
163                    .as_ref()
164                    .and_then(|diff| diff.entries.get(&path)),
165            );
166        }
167        None
168    }
169
170    pub fn merge_statuses(
171        &self,
172        diff_from_head: Option<FileStatus>,
173        diff_from_merge_base: Option<&TreeDiffStatus>,
174    ) -> Option<FileStatus> {
175        match (diff_from_head, diff_from_merge_base) {
176            (None, None) => None,
177            (Some(diff_from_head), None) => Some(diff_from_head),
178            (Some(diff_from_head @ FileStatus::Unmerged(_)), _) => Some(diff_from_head),
179
180            // file does not exist in HEAD
181            // but *does* exist in work-tree
182            // and *does* exist in merge-base
183            (
184                Some(FileStatus::Untracked)
185                | Some(FileStatus::Tracked(TrackedStatus {
186                    index_status: StatusCode::Added,
187                    worktree_status: _,
188                })),
189                Some(_),
190            ) => Some(FileStatus::Tracked(TrackedStatus {
191                index_status: StatusCode::Modified,
192                worktree_status: StatusCode::Modified,
193            })),
194
195            // file exists in HEAD
196            // but *does not* exist in work-tree
197            (Some(diff_from_head), Some(diff_from_merge_base)) if diff_from_head.is_deleted() => {
198                match diff_from_merge_base {
199                    TreeDiffStatus::Added => None, // unchanged, didn't exist in merge base or worktree
200                    _ => Some(diff_from_head),
201                }
202            }
203
204            // file exists in HEAD
205            // and *does* exist in work-tree
206            (Some(FileStatus::Tracked(_)), Some(tree_status)) => {
207                Some(FileStatus::Tracked(TrackedStatus {
208                    index_status: match tree_status {
209                        TreeDiffStatus::Added { .. } => StatusCode::Added,
210                        _ => StatusCode::Modified,
211                    },
212                    worktree_status: match tree_status {
213                        TreeDiffStatus::Added => StatusCode::Added,
214                        _ => StatusCode::Modified,
215                    },
216                }))
217            }
218
219            (_, Some(diff_from_merge_base)) => {
220                Some(diff_status_to_file_status(diff_from_merge_base))
221            }
222        }
223    }
224
225    pub async fn reload_tree_diff(
226        this: WeakEntity<Self>,
227        cx: &mut AsyncWindowContext,
228    ) -> Result<()> {
229        let task = this.update(cx, |this, cx| {
230            let DiffBase::Merge { base_ref } = this.diff_base.clone() else {
231                return None;
232            };
233            let Some(repo) = this.repo.as_ref() else {
234                this.tree_diff.take();
235                return None;
236            };
237            repo.update(cx, |repo, cx| {
238                Some(repo.diff_tree(
239                    DiffTreeType::MergeBase {
240                        base: base_ref,
241                        head: "HEAD".into(),
242                    },
243                    cx,
244                ))
245            })
246        })?;
247        let Some(task) = task else { return Ok(()) };
248
249        let diff = task.await??;
250        this.update(cx, |this, cx| {
251            this.tree_diff = Some(diff);
252            cx.emit(BranchDiffEvent::FileListChanged);
253            cx.notify();
254        })
255    }
256
257    pub fn repo(&self) -> Option<&Entity<Repository>> {
258        self.repo.as_ref()
259    }
260
261    pub fn load_buffers(&mut self, cx: &mut Context<Self>) -> Vec<DiffBuffer> {
262        let mut output = Vec::default();
263        let Some(repo) = self.repo.clone() else {
264            return output;
265        };
266
267        self.project.update(cx, |_project, cx| {
268            let mut seen = HashSet::default();
269
270            for item in repo.read(cx).cached_status() {
271                seen.insert(item.repo_path.clone());
272                let branch_diff = self
273                    .tree_diff
274                    .as_ref()
275                    .and_then(|t| t.entries.get(&item.repo_path))
276                    .cloned();
277                let status = self
278                    .merge_statuses(Some(item.status), branch_diff.as_ref())
279                    .unwrap();
280                if !status.has_changes() {
281                    continue;
282                }
283
284                let Some(project_path) =
285                    repo.read(cx).repo_path_to_project_path(&item.repo_path, cx)
286                else {
287                    continue;
288                };
289                let task = Self::load_buffer(branch_diff, project_path, repo.clone(), cx);
290
291                output.push(DiffBuffer {
292                    repo_path: item.repo_path.clone(),
293                    load: task,
294                    file_status: item.status,
295                });
296            }
297            let Some(tree_diff) = self.tree_diff.as_ref() else {
298                return;
299            };
300
301            for (path, branch_diff) in tree_diff.entries.iter() {
302                if seen.contains(&path) {
303                    continue;
304                }
305
306                let Some(project_path) = repo.read(cx).repo_path_to_project_path(&path, cx) else {
307                    continue;
308                };
309                let task =
310                    Self::load_buffer(Some(branch_diff.clone()), project_path, repo.clone(), cx);
311
312                let file_status = diff_status_to_file_status(branch_diff);
313
314                output.push(DiffBuffer {
315                    repo_path: path.clone(),
316                    load: task,
317                    file_status,
318                });
319            }
320        });
321        output
322    }
323
324    fn load_buffer(
325        branch_diff: Option<git::status::TreeDiffStatus>,
326        project_path: crate::ProjectPath,
327        repo: Entity<Repository>,
328        cx: &Context<'_, Project>,
329    ) -> Task<Result<(Entity<Buffer>, Entity<BufferDiff>)>> {
330        let task = cx.spawn(async move |project, cx| {
331            let buffer = project
332                .update(cx, |project, cx| project.open_buffer(project_path, cx))?
333                .await?;
334
335            let languages = project.update(cx, |project, _cx| project.languages().clone())?;
336
337            let changes = if let Some(entry) = branch_diff {
338                let oid = match entry {
339                    git::status::TreeDiffStatus::Added { .. } => None,
340                    git::status::TreeDiffStatus::Modified { old, .. }
341                    | git::status::TreeDiffStatus::Deleted { old } => Some(old),
342                };
343                project
344                    .update(cx, |project, cx| {
345                        project.git_store().update(cx, |git_store, cx| {
346                            git_store.open_diff_since(oid, buffer.clone(), repo, languages, cx)
347                        })
348                    })?
349                    .await?
350            } else {
351                project
352                    .update(cx, |project, cx| {
353                        project.open_uncommitted_diff(buffer.clone(), cx)
354                    })?
355                    .await?
356            };
357            Ok((buffer, changes))
358        });
359        task
360    }
361}
362
363fn diff_status_to_file_status(branch_diff: &git::status::TreeDiffStatus) -> FileStatus {
364    let file_status = match branch_diff {
365        git::status::TreeDiffStatus::Added { .. } => FileStatus::Tracked(TrackedStatus {
366            index_status: StatusCode::Added,
367            worktree_status: StatusCode::Added,
368        }),
369        git::status::TreeDiffStatus::Modified { .. } => FileStatus::Tracked(TrackedStatus {
370            index_status: StatusCode::Modified,
371            worktree_status: StatusCode::Modified,
372        }),
373        git::status::TreeDiffStatus::Deleted { .. } => FileStatus::Tracked(TrackedStatus {
374            index_status: StatusCode::Deleted,
375            worktree_status: StatusCode::Deleted,
376        }),
377    };
378    file_status
379}
380
381#[derive(Debug)]
382pub struct DiffBuffer {
383    pub repo_path: RepoPath,
384    pub file_status: FileStatus,
385    pub load: Task<Result<(Entity<Buffer>, Entity<BufferDiff>)>>,
386}