1use crate::{
2 example::{Example, ExampleBuffer, ExampleState},
3 headless::EpAppState,
4};
5use anyhow::{Result, anyhow};
6use collections::HashMap;
7use edit_prediction::EditPredictionStore;
8use edit_prediction::udiff::OpenedBuffers;
9use futures::{
10 AsyncWriteExt as _,
11 lock::{Mutex, OwnedMutexGuard},
12};
13use gpui::{AsyncApp, Entity};
14use language::{Anchor, Buffer, ToOffset, ToPoint};
15use project::buffer_store::BufferStoreEvent;
16use project::{Project, ProjectPath};
17use std::{
18 cell::RefCell,
19 fs,
20 path::{Path, PathBuf},
21 sync::Arc,
22};
23use util::{paths::PathStyle, rel_path::RelPath};
24use zeta_prompt::CURSOR_MARKER;
25
26pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
27 if example.state.is_some() {
28 return;
29 }
30
31 let project = setup_project(example, &app_state, &mut cx).await;
32 let buffer_store = project
33 .read_with(&cx, |project, _| project.buffer_store().clone())
34 .unwrap();
35
36 let ep_store = cx
37 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
38 .unwrap();
39
40 cx.subscribe(&buffer_store, {
41 let project = project.clone();
42 move |_, event, cx| match event {
43 BufferStoreEvent::BufferAdded(buffer) => {
44 ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
45 }
46 _ => {}
47 }
48 })
49 .unwrap()
50 .detach();
51
52 let _open_buffers = apply_edit_history(example, &project, &mut cx)
53 .await
54 .unwrap();
55 let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
56 example.buffer = buffer
57 .read_with(&cx, |buffer, _cx| {
58 let cursor_point = cursor_position.to_point(&buffer);
59 Some(ExampleBuffer {
60 content: buffer.text(),
61 cursor_row: cursor_point.row,
62 cursor_column: cursor_point.column,
63 cursor_offset: cursor_position.to_offset(&buffer),
64 })
65 })
66 .unwrap();
67 example.state = Some(ExampleState {
68 buffer,
69 project,
70 cursor_position,
71 _open_buffers,
72 });
73}
74
75async fn cursor_position(
76 example: &Example,
77 project: &Entity<Project>,
78 cx: &mut AsyncApp,
79) -> (Entity<Buffer>, Anchor) {
80 let worktree = project
81 .read_with(cx, |project, cx| {
82 project.visible_worktrees(cx).next().unwrap()
83 })
84 .unwrap();
85
86 let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
87 .unwrap()
88 .into_arc();
89 let cursor_buffer = project
90 .update(cx, |project, cx| {
91 project.open_buffer(
92 ProjectPath {
93 worktree_id: worktree.read(cx).id(),
94 path: cursor_path,
95 },
96 cx,
97 )
98 })
99 .unwrap()
100 .await
101 .unwrap();
102 let cursor_offset_within_excerpt = example
103 .cursor_position
104 .find(CURSOR_MARKER)
105 .ok_or_else(|| anyhow!("missing cursor marker"))
106 .unwrap();
107 let mut cursor_excerpt = example.cursor_position.clone();
108 cursor_excerpt.replace_range(
109 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
110 "",
111 );
112 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
113 let text = buffer.text();
114
115 let mut matches = text.match_indices(&cursor_excerpt);
116 let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
117 panic!(
118 "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
119 );
120 });
121 assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
122 excerpt_offset
123 }).unwrap();
124
125 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
126 let cursor_anchor = cursor_buffer
127 .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
128 .unwrap();
129
130 (cursor_buffer, cursor_anchor)
131}
132
133async fn setup_project(
134 example: &mut Example,
135 app_state: &Arc<EpAppState>,
136 cx: &mut AsyncApp,
137) -> Entity<Project> {
138 setup_worktree(example).await;
139
140 let project = cx
141 .update(|cx| {
142 Project::local(
143 app_state.client.clone(),
144 app_state.node_runtime.clone(),
145 app_state.user_store.clone(),
146 app_state.languages.clone(),
147 app_state.fs.clone(),
148 None,
149 cx,
150 )
151 })
152 .unwrap();
153
154 let worktree = project
155 .update(cx, |project, cx| {
156 project.create_worktree(&example.worktree_path(), true, cx)
157 })
158 .unwrap()
159 .await
160 .unwrap();
161 worktree
162 .read_with(cx, |worktree, _cx| {
163 worktree.as_local().unwrap().scan_complete()
164 })
165 .unwrap()
166 .await;
167 project
168}
169
170pub async fn setup_worktree(example: &Example) {
171 let repo_dir = example.repo_path();
172 let repo_lock = lock_repo(&repo_dir).await;
173
174 if !repo_dir.is_dir() {
175 fs::create_dir_all(&repo_dir).unwrap();
176 run_git(&repo_dir, &["init"]).await.unwrap();
177 run_git(
178 &repo_dir,
179 &["remote", "add", "origin", &example.repository_url],
180 )
181 .await
182 .unwrap();
183 }
184
185 // Resolve the example to a revision, fetching it if needed.
186 let revision = run_git(
187 &repo_dir,
188 &["rev-parse", &format!("{}^{{commit}}", example.revision)],
189 )
190 .await;
191 let revision = if let Ok(revision) = revision {
192 revision
193 } else {
194 if run_git(
195 &repo_dir,
196 &["fetch", "--depth", "1", "origin", &example.revision],
197 )
198 .await
199 .is_err()
200 {
201 run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
202 }
203 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
204 .await
205 .unwrap();
206 if revision != example.revision {
207 run_git(&repo_dir, &["tag", &example.revision, &revision])
208 .await
209 .unwrap();
210 }
211 revision
212 };
213
214 // Create the worktree for this example if needed.
215 let worktree_path = example.worktree_path();
216 if worktree_path.is_dir() {
217 run_git(&worktree_path, &["clean", "--force", "-d"])
218 .await
219 .unwrap();
220 run_git(&worktree_path, &["reset", "--hard", "HEAD"])
221 .await
222 .unwrap();
223 run_git(&worktree_path, &["checkout", revision.as_str()])
224 .await
225 .unwrap();
226 } else {
227 let worktree_path_string = worktree_path.to_string_lossy();
228 run_git(
229 &repo_dir,
230 &["branch", "-f", &example.name, revision.as_str()],
231 )
232 .await
233 .unwrap();
234 run_git(
235 &repo_dir,
236 &[
237 "worktree",
238 "add",
239 "-f",
240 &worktree_path_string,
241 &example.name,
242 ],
243 )
244 .await
245 .unwrap();
246 }
247 drop(repo_lock);
248
249 // Apply the uncommitted diff for this example.
250 if !example.uncommitted_diff.is_empty() {
251 let mut apply_process = smol::process::Command::new("git")
252 .current_dir(&worktree_path)
253 .args(&["apply", "-"])
254 .stdin(std::process::Stdio::piped())
255 .spawn()
256 .unwrap();
257
258 let mut stdin = apply_process.stdin.take().unwrap();
259 stdin
260 .write_all(example.uncommitted_diff.as_bytes())
261 .await
262 .unwrap();
263 stdin.close().await.unwrap();
264 drop(stdin);
265
266 let apply_result = apply_process.output().await.unwrap();
267 if !apply_result.status.success() {
268 panic!(
269 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
270 apply_result.status,
271 String::from_utf8_lossy(&apply_result.stderr),
272 String::from_utf8_lossy(&apply_result.stdout),
273 );
274 }
275 }
276}
277
278async fn apply_edit_history(
279 example: &Example,
280 project: &Entity<Project>,
281 cx: &mut AsyncApp,
282) -> Result<OpenedBuffers> {
283 edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
284}
285
286thread_local! {
287 static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
288}
289
290#[must_use]
291pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
292 REPO_LOCKS
293 .with(|cell| {
294 cell.borrow_mut()
295 .entry(path.as_ref().to_path_buf())
296 .or_default()
297 .clone()
298 })
299 .lock_owned()
300 .await
301}
302
303async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
304 let output = smol::process::Command::new("git")
305 .current_dir(repo_path)
306 .args(args)
307 .output()
308 .await?;
309
310 anyhow::ensure!(
311 output.status.success(),
312 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
313 args.join(" "),
314 repo_path.display(),
315 output.status,
316 String::from_utf8_lossy(&output.stderr),
317 String::from_utf8_lossy(&output.stdout),
318 );
319 Ok(String::from_utf8(output.stdout)?.trim().to_string())
320}