1use anyhow::Result;
2use collections::HashMap;
3use parking_lot::Mutex;
4use serde_derive::{Deserialize, Serialize};
5use std::{
6 cmp::Ordering,
7 ffi::OsStr,
8 os::unix::prelude::OsStrExt,
9 path::{Component, Path, PathBuf},
10 sync::Arc,
11};
12use sum_tree::{MapSeekTarget, TreeMap};
13use util::ResultExt;
14
15pub use git2::Repository as LibGitRepository;
16
17#[async_trait::async_trait]
18pub trait GitRepository: Send {
19 fn reload_index(&self);
20
21 fn load_index_text(&self, relative_file_path: &Path) -> Option<String>;
22
23 fn branch_name(&self) -> Option<String>;
24
25 fn worktree_statuses(&self) -> Option<TreeMap<RepoPath, GitFileStatus>>;
26
27 fn worktree_status(&self, path: &RepoPath) -> Option<GitFileStatus>;
28}
29
30impl std::fmt::Debug for dyn GitRepository {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("dyn GitRepository<...>").finish()
33 }
34}
35
36#[async_trait::async_trait]
37impl GitRepository for LibGitRepository {
38 fn reload_index(&self) {
39 if let Ok(mut index) = self.index() {
40 _ = index.read(false);
41 }
42 }
43
44 fn load_index_text(&self, relative_file_path: &Path) -> Option<String> {
45 fn logic(repo: &LibGitRepository, relative_file_path: &Path) -> Result<Option<String>> {
46 const STAGE_NORMAL: i32 = 0;
47 let index = repo.index()?;
48
49 // This check is required because index.get_path() unwraps internally :(
50 check_path_to_repo_path_errors(relative_file_path)?;
51
52 let oid = match index.get_path(&relative_file_path, STAGE_NORMAL) {
53 Some(entry) => entry.id,
54 None => return Ok(None),
55 };
56
57 let content = repo.find_blob(oid)?.content().to_owned();
58 Ok(Some(String::from_utf8(content)?))
59 }
60
61 match logic(&self, relative_file_path) {
62 Ok(value) => return value,
63 Err(err) => log::error!("Error loading head text: {:?}", err),
64 }
65 None
66 }
67
68 fn branch_name(&self) -> Option<String> {
69 let head = self.head().log_err()?;
70 let branch = String::from_utf8_lossy(head.shorthand_bytes());
71 Some(branch.to_string())
72 }
73
74 fn worktree_statuses(&self) -> Option<TreeMap<RepoPath, GitFileStatus>> {
75 let statuses = self.statuses(None).log_err()?;
76
77 let mut map = TreeMap::default();
78
79 for status in statuses
80 .iter()
81 .filter(|status| !status.status().contains(git2::Status::IGNORED))
82 {
83 let path = RepoPath(PathBuf::from(OsStr::from_bytes(status.path_bytes())));
84 let Some(status) = read_status(status.status()) else {
85 continue
86 };
87
88 map.insert(path, status)
89 }
90
91 Some(map)
92 }
93
94 fn worktree_status(&self, path: &RepoPath) -> Option<GitFileStatus> {
95 let status = self.status_file(path).log_err()?;
96 read_status(status)
97 }
98}
99
100fn read_status(status: git2::Status) -> Option<GitFileStatus> {
101 if status.contains(git2::Status::CONFLICTED) {
102 Some(GitFileStatus::Conflict)
103 } else if status.intersects(git2::Status::WT_MODIFIED | git2::Status::WT_RENAMED) {
104 Some(GitFileStatus::Modified)
105 } else if status.intersects(git2::Status::WT_NEW) {
106 Some(GitFileStatus::Added)
107 } else {
108 None
109 }
110}
111
112#[derive(Debug, Clone, Default)]
113pub struct FakeGitRepository {
114 state: Arc<Mutex<FakeGitRepositoryState>>,
115}
116
117#[derive(Debug, Clone, Default)]
118pub struct FakeGitRepositoryState {
119 pub index_contents: HashMap<PathBuf, String>,
120 pub worktree_statuses: HashMap<RepoPath, GitFileStatus>,
121 pub branch_name: Option<String>,
122}
123
124impl FakeGitRepository {
125 pub fn open(state: Arc<Mutex<FakeGitRepositoryState>>) -> Arc<Mutex<dyn GitRepository>> {
126 Arc::new(Mutex::new(FakeGitRepository { state }))
127 }
128}
129
130#[async_trait::async_trait]
131impl GitRepository for FakeGitRepository {
132 fn reload_index(&self) {}
133
134 fn load_index_text(&self, path: &Path) -> Option<String> {
135 let state = self.state.lock();
136 state.index_contents.get(path).cloned()
137 }
138
139 fn branch_name(&self) -> Option<String> {
140 let state = self.state.lock();
141 state.branch_name.clone()
142 }
143
144 fn worktree_statuses(&self) -> Option<TreeMap<RepoPath, GitFileStatus>> {
145 let state = self.state.lock();
146 let mut map = TreeMap::default();
147 for (repo_path, status) in state.worktree_statuses.iter() {
148 map.insert(repo_path.to_owned(), status.to_owned());
149 }
150 Some(map)
151 }
152
153 fn worktree_status(&self, path: &RepoPath) -> Option<GitFileStatus> {
154 let state = self.state.lock();
155 state.worktree_statuses.get(path).cloned()
156 }
157}
158
159fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
160 match relative_file_path.components().next() {
161 None => anyhow::bail!("repo path should not be empty"),
162 Some(Component::Prefix(_)) => anyhow::bail!(
163 "repo path `{}` should be relative, not a windows prefix",
164 relative_file_path.to_string_lossy()
165 ),
166 Some(Component::RootDir) => {
167 anyhow::bail!(
168 "repo path `{}` should be relative",
169 relative_file_path.to_string_lossy()
170 )
171 }
172 Some(Component::CurDir) => {
173 anyhow::bail!(
174 "repo path `{}` should not start with `.`",
175 relative_file_path.to_string_lossy()
176 )
177 }
178 Some(Component::ParentDir) => {
179 anyhow::bail!(
180 "repo path `{}` should not start with `..`",
181 relative_file_path.to_string_lossy()
182 )
183 }
184 _ => Ok(()),
185 }
186}
187
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
189pub enum GitFileStatus {
190 Added,
191 Modified,
192 Conflict,
193}
194
195#[derive(Clone, Debug, Ord, Hash, PartialOrd, Eq, PartialEq)]
196pub struct RepoPath(PathBuf);
197
198impl RepoPath {
199 pub fn new(path: PathBuf) -> Self {
200 debug_assert!(path.is_relative(), "Repo paths must be relative");
201
202 RepoPath(path)
203 }
204}
205
206impl From<&Path> for RepoPath {
207 fn from(value: &Path) -> Self {
208 RepoPath::new(value.to_path_buf())
209 }
210}
211
212impl From<PathBuf> for RepoPath {
213 fn from(value: PathBuf) -> Self {
214 RepoPath::new(value)
215 }
216}
217
218impl Default for RepoPath {
219 fn default() -> Self {
220 RepoPath(PathBuf::new())
221 }
222}
223
224impl AsRef<Path> for RepoPath {
225 fn as_ref(&self) -> &Path {
226 self.0.as_ref()
227 }
228}
229
230impl std::ops::Deref for RepoPath {
231 type Target = PathBuf;
232
233 fn deref(&self) -> &Self::Target {
234 &self.0
235 }
236}
237
238#[derive(Debug)]
239pub struct RepoPathDescendants<'a>(pub &'a Path);
240
241impl<'a> MapSeekTarget<RepoPath> for RepoPathDescendants<'a> {
242 fn cmp_cursor(&self, key: &RepoPath) -> Ordering {
243 if key.starts_with(&self.0) {
244 Ordering::Greater
245 } else {
246 self.0.cmp(key)
247 }
248 }
249}