1use crate::{
2 example::{Example, ExampleBuffer, ExampleState},
3 headless::EpAppState,
4 paths::{REPOS_DIR, WORKTREES_DIR},
5 progress::{InfoStyle, Progress, Step, StepProgress},
6};
7use anyhow::{Result, anyhow};
8use collections::HashMap;
9use edit_prediction::EditPredictionStore;
10use edit_prediction::udiff::OpenedBuffers;
11use futures::{
12 AsyncWriteExt as _,
13 lock::{Mutex, OwnedMutexGuard},
14};
15use gpui::{AsyncApp, Entity};
16use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
17use project::buffer_store::BufferStoreEvent;
18use project::{Project, ProjectPath};
19use std::{
20 cell::RefCell,
21 fs,
22 path::{Path, PathBuf},
23 sync::Arc,
24};
25use util::{paths::PathStyle, rel_path::RelPath};
26use zeta_prompt::CURSOR_MARKER;
27
28pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
29 if example.state.is_some() {
30 return;
31 }
32
33 let progress = Progress::global().start(Step::LoadProject, &example.name);
34
35 let project = setup_project(example, &app_state, &progress, &mut cx).await;
36
37 let _open_buffers = apply_edit_history(example, &project, &mut cx)
38 .await
39 .unwrap();
40
41 let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
42 let (example_buffer, language_name) = buffer
43 .read_with(&cx, |buffer, _cx| {
44 let cursor_point = cursor_position.to_point(&buffer);
45 let language_name = buffer
46 .language()
47 .map(|l| l.name().to_string())
48 .unwrap_or_else(|| "Unknown".to_string());
49 (
50 ExampleBuffer {
51 content: buffer.text(),
52 cursor_row: cursor_point.row,
53 cursor_column: cursor_point.column,
54 cursor_offset: cursor_position.to_offset(&buffer),
55 },
56 language_name,
57 )
58 })
59 .unwrap();
60
61 progress.set_info(language_name, InfoStyle::Normal);
62
63 example.buffer = Some(example_buffer);
64 example.state = Some(ExampleState {
65 buffer,
66 project,
67 cursor_position,
68 _open_buffers,
69 });
70}
71
72async fn cursor_position(
73 example: &Example,
74 project: &Entity<Project>,
75 cx: &mut AsyncApp,
76) -> (Entity<Buffer>, Anchor) {
77 let language_registry = project
78 .read_with(cx, |project, _| project.languages().clone())
79 .unwrap();
80 let result = language_registry
81 .load_language_for_file_path(&example.cursor_path)
82 .await;
83
84 if let Err(error) = result
85 && !error.is::<LanguageNotFound>()
86 {
87 panic!("Failed to load language for file path: {}", error);
88 }
89
90 let worktree = project
91 .read_with(cx, |project, cx| {
92 project.visible_worktrees(cx).next().unwrap()
93 })
94 .unwrap();
95
96 let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
97 .unwrap()
98 .into_arc();
99 let cursor_buffer = project
100 .update(cx, |project, cx| {
101 project.open_buffer(
102 ProjectPath {
103 worktree_id: worktree.read(cx).id(),
104 path: cursor_path,
105 },
106 cx,
107 )
108 })
109 .unwrap()
110 .await
111 .unwrap();
112 let cursor_offset_within_excerpt = example
113 .cursor_position
114 .find(CURSOR_MARKER)
115 .ok_or_else(|| anyhow!("missing cursor marker"))
116 .unwrap();
117 let mut cursor_excerpt = example.cursor_position.clone();
118 cursor_excerpt.replace_range(
119 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
120 "",
121 );
122 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
123 let text = buffer.text();
124
125 let mut matches = text.match_indices(&cursor_excerpt);
126 let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
127 panic!(
128 "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
129 example.name
130 );
131 });
132 assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
133 excerpt_offset
134 }).unwrap();
135
136 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
137 let cursor_anchor = cursor_buffer
138 .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
139 .unwrap();
140
141 (cursor_buffer, cursor_anchor)
142}
143
144async fn setup_project(
145 example: &mut Example,
146 app_state: &Arc<EpAppState>,
147 step_progress: &StepProgress,
148 cx: &mut AsyncApp,
149) -> Entity<Project> {
150 let ep_store = cx
151 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
152 .unwrap();
153
154 let worktree_path = setup_worktree(example, step_progress).await;
155
156 if let Some(project) = app_state.project_cache.get(&example.repository_url) {
157 ep_store
158 .update(cx, |ep_store, _| {
159 ep_store.clear_history_for_project(&project);
160 })
161 .unwrap();
162 let buffer_store = project
163 .read_with(cx, |project, _| project.buffer_store().clone())
164 .unwrap();
165 let buffers = buffer_store
166 .read_with(cx, |buffer_store, _| {
167 buffer_store.buffers().collect::<Vec<_>>()
168 })
169 .unwrap();
170 for buffer in buffers {
171 buffer
172 .update(cx, |buffer, cx| buffer.reload(cx))
173 .unwrap()
174 .await
175 .ok();
176 }
177 return project;
178 }
179
180 let project = cx
181 .update(|cx| {
182 Project::local(
183 app_state.client.clone(),
184 app_state.node_runtime.clone(),
185 app_state.user_store.clone(),
186 app_state.languages.clone(),
187 app_state.fs.clone(),
188 None,
189 cx,
190 )
191 })
192 .unwrap();
193
194 project
195 .update(cx, |project, cx| {
196 project.disable_worktree_scanner(cx);
197 project.create_worktree(&worktree_path, true, cx)
198 })
199 .unwrap()
200 .await
201 .unwrap();
202
203 app_state
204 .project_cache
205 .insert(example.repository_url.clone(), project.clone());
206
207 let buffer_store = project
208 .read_with(cx, |project, _| project.buffer_store().clone())
209 .unwrap();
210 cx.subscribe(&buffer_store, {
211 let project = project.clone();
212 move |_, event, cx| match event {
213 BufferStoreEvent::BufferAdded(buffer) => {
214 ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
215 }
216 _ => {}
217 }
218 })
219 .unwrap()
220 .detach();
221
222 project
223}
224
225async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> PathBuf {
226 let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
227 let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
228 let worktree_path = WORKTREES_DIR
229 .join(repo_owner.as_ref())
230 .join(repo_name.as_ref());
231 let repo_lock = lock_repo(&repo_dir).await;
232
233 if !repo_dir.is_dir() {
234 step_progress.set_substatus(format!("cloning {}", repo_name));
235 fs::create_dir_all(&repo_dir).unwrap();
236 run_git(&repo_dir, &["init"]).await.unwrap();
237 run_git(
238 &repo_dir,
239 &["remote", "add", "origin", &example.repository_url],
240 )
241 .await
242 .unwrap();
243 }
244
245 // Resolve the example to a revision, fetching it if needed.
246 let revision = run_git(
247 &repo_dir,
248 &["rev-parse", &format!("{}^{{commit}}", example.revision)],
249 )
250 .await;
251 let revision = if let Ok(revision) = revision {
252 revision
253 } else {
254 step_progress.set_substatus("fetching");
255 if run_git(
256 &repo_dir,
257 &["fetch", "--depth", "1", "origin", &example.revision],
258 )
259 .await
260 .is_err()
261 {
262 run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
263 }
264 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
265 .await
266 .unwrap();
267 revision
268 };
269
270 // Create the worktree for this example if needed.
271 step_progress.set_substatus("preparing worktree");
272 if worktree_path.is_dir() {
273 run_git(&worktree_path, &["clean", "--force", "-d"])
274 .await
275 .unwrap();
276 run_git(&worktree_path, &["reset", "--hard", "HEAD"])
277 .await
278 .unwrap();
279 run_git(&worktree_path, &["checkout", revision.as_str()])
280 .await
281 .unwrap();
282 } else {
283 let worktree_path_string = worktree_path.to_string_lossy();
284 run_git(
285 &repo_dir,
286 &["branch", "-f", &example.name, revision.as_str()],
287 )
288 .await
289 .unwrap();
290 run_git(
291 &repo_dir,
292 &[
293 "worktree",
294 "add",
295 "-f",
296 &worktree_path_string,
297 &example.name,
298 ],
299 )
300 .await
301 .unwrap();
302 }
303 drop(repo_lock);
304
305 // Apply the uncommitted diff for this example.
306 if !example.uncommitted_diff.is_empty() {
307 step_progress.set_substatus("applying diff");
308 let mut apply_process = smol::process::Command::new("git")
309 .current_dir(&worktree_path)
310 .args(&["apply", "-"])
311 .stdin(std::process::Stdio::piped())
312 .spawn()
313 .unwrap();
314
315 let mut stdin = apply_process.stdin.take().unwrap();
316 stdin
317 .write_all(example.uncommitted_diff.as_bytes())
318 .await
319 .unwrap();
320 stdin.close().await.unwrap();
321 drop(stdin);
322
323 let apply_result = apply_process.output().await.unwrap();
324 if !apply_result.status.success() {
325 panic!(
326 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
327 apply_result.status,
328 String::from_utf8_lossy(&apply_result.stderr),
329 String::from_utf8_lossy(&apply_result.stdout),
330 );
331 }
332 }
333
334 step_progress.clear_substatus();
335 worktree_path
336}
337
338async fn apply_edit_history(
339 example: &Example,
340 project: &Entity<Project>,
341 cx: &mut AsyncApp,
342) -> Result<OpenedBuffers> {
343 edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
344}
345
346thread_local! {
347 static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
348}
349
350#[must_use]
351pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
352 REPO_LOCKS
353 .with(|cell| {
354 cell.borrow_mut()
355 .entry(path.as_ref().to_path_buf())
356 .or_default()
357 .clone()
358 })
359 .lock_owned()
360 .await
361}
362
363async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
364 let output = smol::process::Command::new("git")
365 .current_dir(repo_path)
366 .args(args)
367 .output()
368 .await?;
369
370 anyhow::ensure!(
371 output.status.success(),
372 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
373 args.join(" "),
374 repo_path.display(),
375 output.status,
376 String::from_utf8_lossy(&output.stderr),
377 String::from_utf8_lossy(&output.stdout),
378 );
379 Ok(String::from_utf8(output.stdout)?.trim().to_string())
380}