1use anyhow::Result;
2use collections::HashMap;
3use parking_lot::Mutex;
4use serde_derive::{Deserialize, Serialize};
5use std::{
6 cmp::Ordering,
7 ffi::OsStr,
8 os::unix::prelude::OsStrExt,
9 path::{Component, Path, PathBuf},
10 sync::Arc,
11};
12use sum_tree::{MapSeekTarget, TreeMap};
13use util::ResultExt;
14
15pub use git2::Repository as LibGitRepository;
16
17#[async_trait::async_trait]
18pub trait GitRepository: Send {
19 fn reload_index(&self);
20
21 fn load_index_text(&self, relative_file_path: &Path) -> Option<String>;
22
23 fn branch_name(&self) -> Option<String>;
24
25 fn statuses(&self) -> Option<TreeMap<RepoPath, GitFileStatus>>;
26
27 fn status(&self, path: &RepoPath) -> Option<GitFileStatus>;
28}
29
30impl std::fmt::Debug for dyn GitRepository {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("dyn GitRepository<...>").finish()
33 }
34}
35
36#[async_trait::async_trait]
37impl GitRepository for LibGitRepository {
38 fn reload_index(&self) {
39 if let Ok(mut index) = self.index() {
40 _ = index.read(false);
41 }
42 }
43
44 fn load_index_text(&self, relative_file_path: &Path) -> Option<String> {
45 fn logic(repo: &LibGitRepository, relative_file_path: &Path) -> Result<Option<String>> {
46 const STAGE_NORMAL: i32 = 0;
47 let index = repo.index()?;
48
49 // This check is required because index.get_path() unwraps internally :(
50 check_path_to_repo_path_errors(relative_file_path)?;
51
52 let oid = match index.get_path(&relative_file_path, STAGE_NORMAL) {
53 Some(entry) => entry.id,
54 None => return Ok(None),
55 };
56
57 let content = repo.find_blob(oid)?.content().to_owned();
58 Ok(Some(String::from_utf8(content)?))
59 }
60
61 match logic(&self, relative_file_path) {
62 Ok(value) => return value,
63 Err(err) => log::error!("Error loading head text: {:?}", err),
64 }
65 None
66 }
67
68 fn branch_name(&self) -> Option<String> {
69 let head = self.head().log_err()?;
70 let branch = String::from_utf8_lossy(head.shorthand_bytes());
71 Some(branch.to_string())
72 }
73
74 fn statuses(&self) -> Option<TreeMap<RepoPath, GitFileStatus>> {
75 let statuses = self.statuses(None).log_err()?;
76
77 let mut map = TreeMap::default();
78
79 for status in statuses
80 .iter()
81 .filter(|status| !status.status().contains(git2::Status::IGNORED))
82 {
83 let path = RepoPath(PathBuf::from(OsStr::from_bytes(status.path_bytes())));
84 let Some(status) = read_status(status.status()) else {
85 continue
86 };
87
88 map.insert(path, status)
89 }
90
91 Some(map)
92 }
93
94 fn status(&self, path: &RepoPath) -> Option<GitFileStatus> {
95 let status = self.status_file(path).log_err()?;
96 read_status(status)
97 }
98}
99
100fn read_status(status: git2::Status) -> Option<GitFileStatus> {
101 if status.contains(git2::Status::CONFLICTED) {
102 Some(GitFileStatus::Conflict)
103 } else if status.intersects(
104 git2::Status::WT_MODIFIED
105 | git2::Status::WT_RENAMED
106 | git2::Status::INDEX_MODIFIED
107 | git2::Status::INDEX_RENAMED,
108 ) {
109 Some(GitFileStatus::Modified)
110 } else if status.intersects(git2::Status::WT_NEW | git2::Status::INDEX_NEW) {
111 Some(GitFileStatus::Added)
112 } else {
113 None
114 }
115}
116
117#[derive(Debug, Clone, Default)]
118pub struct FakeGitRepository {
119 state: Arc<Mutex<FakeGitRepositoryState>>,
120}
121
122#[derive(Debug, Clone, Default)]
123pub struct FakeGitRepositoryState {
124 pub index_contents: HashMap<PathBuf, String>,
125 pub worktree_statuses: HashMap<RepoPath, GitFileStatus>,
126 pub branch_name: Option<String>,
127}
128
129impl FakeGitRepository {
130 pub fn open(state: Arc<Mutex<FakeGitRepositoryState>>) -> Arc<Mutex<dyn GitRepository>> {
131 Arc::new(Mutex::new(FakeGitRepository { state }))
132 }
133}
134
135#[async_trait::async_trait]
136impl GitRepository for FakeGitRepository {
137 fn reload_index(&self) {}
138
139 fn load_index_text(&self, path: &Path) -> Option<String> {
140 let state = self.state.lock();
141 state.index_contents.get(path).cloned()
142 }
143
144 fn branch_name(&self) -> Option<String> {
145 let state = self.state.lock();
146 state.branch_name.clone()
147 }
148
149 fn statuses(&self) -> Option<TreeMap<RepoPath, GitFileStatus>> {
150 let state = self.state.lock();
151 let mut map = TreeMap::default();
152 for (repo_path, status) in state.worktree_statuses.iter() {
153 map.insert(repo_path.to_owned(), status.to_owned());
154 }
155 Some(map)
156 }
157
158 fn status(&self, path: &RepoPath) -> Option<GitFileStatus> {
159 let state = self.state.lock();
160 state.worktree_statuses.get(path).cloned()
161 }
162}
163
164fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
165 match relative_file_path.components().next() {
166 None => anyhow::bail!("repo path should not be empty"),
167 Some(Component::Prefix(_)) => anyhow::bail!(
168 "repo path `{}` should be relative, not a windows prefix",
169 relative_file_path.to_string_lossy()
170 ),
171 Some(Component::RootDir) => {
172 anyhow::bail!(
173 "repo path `{}` should be relative",
174 relative_file_path.to_string_lossy()
175 )
176 }
177 Some(Component::CurDir) => {
178 anyhow::bail!(
179 "repo path `{}` should not start with `.`",
180 relative_file_path.to_string_lossy()
181 )
182 }
183 Some(Component::ParentDir) => {
184 anyhow::bail!(
185 "repo path `{}` should not start with `..`",
186 relative_file_path.to_string_lossy()
187 )
188 }
189 _ => Ok(()),
190 }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194pub enum GitFileStatus {
195 Added,
196 Modified,
197 Conflict,
198}
199
200#[derive(Clone, Debug, Ord, Hash, PartialOrd, Eq, PartialEq)]
201pub struct RepoPath(PathBuf);
202
203impl RepoPath {
204 pub fn new(path: PathBuf) -> Self {
205 debug_assert!(path.is_relative(), "Repo paths must be relative");
206
207 RepoPath(path)
208 }
209}
210
211impl From<&Path> for RepoPath {
212 fn from(value: &Path) -> Self {
213 RepoPath::new(value.to_path_buf())
214 }
215}
216
217impl From<PathBuf> for RepoPath {
218 fn from(value: PathBuf) -> Self {
219 RepoPath::new(value)
220 }
221}
222
223impl Default for RepoPath {
224 fn default() -> Self {
225 RepoPath(PathBuf::new())
226 }
227}
228
229impl AsRef<Path> for RepoPath {
230 fn as_ref(&self) -> &Path {
231 self.0.as_ref()
232 }
233}
234
235impl std::ops::Deref for RepoPath {
236 type Target = PathBuf;
237
238 fn deref(&self) -> &Self::Target {
239 &self.0
240 }
241}
242
243#[derive(Debug)]
244pub struct RepoPathDescendants<'a>(pub &'a Path);
245
246impl<'a> MapSeekTarget<RepoPath> for RepoPathDescendants<'a> {
247 fn cmp_cursor(&self, key: &RepoPath) -> Ordering {
248 if key.starts_with(&self.0) {
249 Ordering::Greater
250 } else {
251 self.0.cmp(key)
252 }
253 }
254}