1use crate::{
2 example::{Example, ExampleBuffer, ExampleState},
3 headless::EpAppState,
4 paths::{REPOS_DIR, WORKTREES_DIR},
5 progress::{InfoStyle, Progress, Step, StepProgress},
6};
7use anyhow::{Result, anyhow};
8use collections::HashMap;
9use edit_prediction::EditPredictionStore;
10use edit_prediction::udiff::OpenedBuffers;
11use futures::{
12 AsyncWriteExt as _,
13 lock::{Mutex, OwnedMutexGuard},
14};
15use gpui::{AsyncApp, Entity};
16use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
17use project::buffer_store::BufferStoreEvent;
18use project::{Project, ProjectPath};
19use std::{
20 cell::RefCell,
21 fs,
22 path::{Path, PathBuf},
23 sync::Arc,
24};
25use util::{paths::PathStyle, rel_path::RelPath};
26use zeta_prompt::CURSOR_MARKER;
27
28pub async fn run_load_project(
29 example: &mut Example,
30 app_state: Arc<EpAppState>,
31 progress: Arc<Progress>,
32 mut cx: AsyncApp,
33) {
34 if example.state.is_some() {
35 return;
36 }
37
38 let progress = progress.start(Step::LoadProject, &example.name);
39
40 let project = setup_project(example, &app_state, &progress, &mut cx).await;
41
42 let _open_buffers = apply_edit_history(example, &project, &mut cx)
43 .await
44 .unwrap();
45
46 let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
47 let (example_buffer, language_name) = buffer
48 .read_with(&cx, |buffer, _cx| {
49 let cursor_point = cursor_position.to_point(&buffer);
50 let language_name = buffer
51 .language()
52 .map(|l| l.name().to_string())
53 .unwrap_or_else(|| "Unknown".to_string());
54 (
55 ExampleBuffer {
56 content: buffer.text(),
57 cursor_row: cursor_point.row,
58 cursor_column: cursor_point.column,
59 cursor_offset: cursor_position.to_offset(&buffer),
60 },
61 language_name,
62 )
63 })
64 .unwrap();
65
66 progress.set_info(language_name, InfoStyle::Normal);
67
68 example.buffer = Some(example_buffer);
69 example.state = Some(ExampleState {
70 buffer,
71 project,
72 cursor_position,
73 _open_buffers,
74 });
75}
76
77async fn cursor_position(
78 example: &Example,
79 project: &Entity<Project>,
80 cx: &mut AsyncApp,
81) -> (Entity<Buffer>, Anchor) {
82 let language_registry = project
83 .read_with(cx, |project, _| project.languages().clone())
84 .unwrap();
85 let result = language_registry
86 .load_language_for_file_path(&example.cursor_path)
87 .await;
88
89 if let Err(error) = result
90 && !error.is::<LanguageNotFound>()
91 {
92 panic!("Failed to load language for file path: {}", error);
93 }
94
95 let worktree = project
96 .read_with(cx, |project, cx| {
97 project.visible_worktrees(cx).next().unwrap()
98 })
99 .unwrap();
100
101 let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
102 .unwrap()
103 .into_arc();
104 let cursor_buffer = project
105 .update(cx, |project, cx| {
106 project.open_buffer(
107 ProjectPath {
108 worktree_id: worktree.read(cx).id(),
109 path: cursor_path,
110 },
111 cx,
112 )
113 })
114 .unwrap()
115 .await
116 .unwrap();
117 let cursor_offset_within_excerpt = example
118 .cursor_position
119 .find(CURSOR_MARKER)
120 .ok_or_else(|| anyhow!("missing cursor marker"))
121 .unwrap();
122 let mut cursor_excerpt = example.cursor_position.clone();
123 cursor_excerpt.replace_range(
124 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
125 "",
126 );
127 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
128 let text = buffer.text();
129
130 let mut matches = text.match_indices(&cursor_excerpt);
131 let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
132 panic!(
133 "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
134 example.name
135 );
136 });
137 assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
138 excerpt_offset
139 }).unwrap();
140
141 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
142 let cursor_anchor = cursor_buffer
143 .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
144 .unwrap();
145
146 (cursor_buffer, cursor_anchor)
147}
148
149async fn setup_project(
150 example: &mut Example,
151 app_state: &Arc<EpAppState>,
152 step_progress: &Arc<StepProgress>,
153 cx: &mut AsyncApp,
154) -> Entity<Project> {
155 let ep_store = cx
156 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
157 .unwrap();
158
159 let worktree_path = setup_worktree(example, step_progress).await;
160
161 if let Some(project) = app_state.project_cache.get(&example.repository_url) {
162 ep_store
163 .update(cx, |ep_store, _| {
164 ep_store.clear_history_for_project(&project);
165 })
166 .unwrap();
167 let buffer_store = project
168 .read_with(cx, |project, _| project.buffer_store().clone())
169 .unwrap();
170 let buffers = buffer_store
171 .read_with(cx, |buffer_store, _| {
172 buffer_store.buffers().collect::<Vec<_>>()
173 })
174 .unwrap();
175 for buffer in buffers {
176 buffer
177 .update(cx, |buffer, cx| buffer.reload(cx))
178 .unwrap()
179 .await
180 .ok();
181 }
182 return project;
183 }
184
185 let project = cx
186 .update(|cx| {
187 Project::local(
188 app_state.client.clone(),
189 app_state.node_runtime.clone(),
190 app_state.user_store.clone(),
191 app_state.languages.clone(),
192 app_state.fs.clone(),
193 None,
194 cx,
195 )
196 })
197 .unwrap();
198
199 project
200 .update(cx, |project, cx| {
201 project.disable_worktree_scanner(cx);
202 project.create_worktree(&worktree_path, true, cx)
203 })
204 .unwrap()
205 .await
206 .unwrap();
207
208 app_state
209 .project_cache
210 .insert(example.repository_url.clone(), project.clone());
211
212 let buffer_store = project
213 .read_with(cx, |project, _| project.buffer_store().clone())
214 .unwrap();
215 cx.subscribe(&buffer_store, {
216 let project = project.clone();
217 move |_, event, cx| match event {
218 BufferStoreEvent::BufferAdded(buffer) => {
219 ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
220 }
221 _ => {}
222 }
223 })
224 .unwrap()
225 .detach();
226
227 project
228}
229
230async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) -> PathBuf {
231 let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
232 let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
233 let worktree_path = WORKTREES_DIR
234 .join(repo_owner.as_ref())
235 .join(repo_name.as_ref());
236 let repo_lock = lock_repo(&repo_dir).await;
237
238 if !repo_dir.is_dir() {
239 step_progress.set_substatus(format!("cloning {}", repo_name));
240 fs::create_dir_all(&repo_dir).unwrap();
241 run_git(&repo_dir, &["init"]).await.unwrap();
242 run_git(
243 &repo_dir,
244 &["remote", "add", "origin", &example.repository_url],
245 )
246 .await
247 .unwrap();
248 }
249
250 // Resolve the example to a revision, fetching it if needed.
251 let revision = run_git(
252 &repo_dir,
253 &["rev-parse", &format!("{}^{{commit}}", example.revision)],
254 )
255 .await;
256 let revision = if let Ok(revision) = revision {
257 revision
258 } else {
259 step_progress.set_substatus("fetching");
260 if run_git(
261 &repo_dir,
262 &["fetch", "--depth", "1", "origin", &example.revision],
263 )
264 .await
265 .is_err()
266 {
267 run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
268 }
269 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
270 .await
271 .unwrap();
272 revision
273 };
274
275 // Create the worktree for this example if needed.
276 step_progress.set_substatus("preparing worktree");
277 if worktree_path.is_dir() {
278 run_git(&worktree_path, &["clean", "--force", "-d"])
279 .await
280 .unwrap();
281 run_git(&worktree_path, &["reset", "--hard", "HEAD"])
282 .await
283 .unwrap();
284 run_git(&worktree_path, &["checkout", revision.as_str()])
285 .await
286 .unwrap();
287 } else {
288 let worktree_path_string = worktree_path.to_string_lossy();
289 run_git(
290 &repo_dir,
291 &["branch", "-f", &example.name, revision.as_str()],
292 )
293 .await
294 .unwrap();
295 run_git(
296 &repo_dir,
297 &[
298 "worktree",
299 "add",
300 "-f",
301 &worktree_path_string,
302 &example.name,
303 ],
304 )
305 .await
306 .unwrap();
307 }
308 drop(repo_lock);
309
310 // Apply the uncommitted diff for this example.
311 if !example.uncommitted_diff.is_empty() {
312 step_progress.set_substatus("applying diff");
313 let mut apply_process = smol::process::Command::new("git")
314 .current_dir(&worktree_path)
315 .args(&["apply", "-"])
316 .stdin(std::process::Stdio::piped())
317 .spawn()
318 .unwrap();
319
320 let mut stdin = apply_process.stdin.take().unwrap();
321 stdin
322 .write_all(example.uncommitted_diff.as_bytes())
323 .await
324 .unwrap();
325 stdin.close().await.unwrap();
326 drop(stdin);
327
328 let apply_result = apply_process.output().await.unwrap();
329 if !apply_result.status.success() {
330 panic!(
331 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
332 apply_result.status,
333 String::from_utf8_lossy(&apply_result.stderr),
334 String::from_utf8_lossy(&apply_result.stdout),
335 );
336 }
337 }
338
339 step_progress.clear_substatus();
340 worktree_path
341}
342
343async fn apply_edit_history(
344 example: &Example,
345 project: &Entity<Project>,
346 cx: &mut AsyncApp,
347) -> Result<OpenedBuffers> {
348 edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
349}
350
351thread_local! {
352 static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
353}
354
355#[must_use]
356pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
357 REPO_LOCKS
358 .with(|cell| {
359 cell.borrow_mut()
360 .entry(path.as_ref().to_path_buf())
361 .or_default()
362 .clone()
363 })
364 .lock_owned()
365 .await
366}
367
368async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
369 let output = smol::process::Command::new("git")
370 .current_dir(repo_path)
371 .args(args)
372 .output()
373 .await?;
374
375 anyhow::ensure!(
376 output.status.success(),
377 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
378 args.join(" "),
379 repo_path.display(),
380 output.status,
381 String::from_utf8_lossy(&output.stderr),
382 String::from_utf8_lossy(&output.stdout),
383 );
384 Ok(String::from_utf8(output.stdout)?.trim().to_string())
385}