1use crate::{
2 example::{Example, ExampleBuffer, ExampleState},
3 headless::EpAppState,
4 paths::{REPOS_DIR, WORKTREES_DIR},
5};
6use anyhow::{Result, anyhow};
7use collections::HashMap;
8use edit_prediction::EditPredictionStore;
9use edit_prediction::udiff::OpenedBuffers;
10use futures::{
11 AsyncWriteExt as _,
12 lock::{Mutex, OwnedMutexGuard},
13};
14use gpui::{AsyncApp, Entity};
15use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
16use project::buffer_store::BufferStoreEvent;
17use project::{Project, ProjectPath};
18use std::{
19 cell::RefCell,
20 fs,
21 path::{Path, PathBuf},
22 sync::Arc,
23};
24use util::{paths::PathStyle, rel_path::RelPath};
25use zeta_prompt::CURSOR_MARKER;
26
27pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
28 if example.state.is_some() {
29 return;
30 }
31
32 let project = setup_project(example, &app_state, &mut cx).await;
33
34 let _open_buffers = apply_edit_history(example, &project, &mut cx)
35 .await
36 .unwrap();
37
38 let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
39 example.buffer = buffer
40 .read_with(&cx, |buffer, _cx| {
41 let cursor_point = cursor_position.to_point(&buffer);
42 Some(ExampleBuffer {
43 content: buffer.text(),
44 cursor_row: cursor_point.row,
45 cursor_column: cursor_point.column,
46 cursor_offset: cursor_position.to_offset(&buffer),
47 })
48 })
49 .unwrap();
50
51 example.state = Some(ExampleState {
52 buffer,
53 project,
54 cursor_position,
55 _open_buffers,
56 });
57}
58
59async fn cursor_position(
60 example: &Example,
61 project: &Entity<Project>,
62 cx: &mut AsyncApp,
63) -> (Entity<Buffer>, Anchor) {
64 let language_registry = project
65 .read_with(cx, |project, _| project.languages().clone())
66 .unwrap();
67 let result = language_registry
68 .load_language_for_file_path(&example.cursor_path)
69 .await;
70
71 if let Err(error) = result
72 && !error.is::<LanguageNotFound>()
73 {
74 panic!("Failed to load language for file path: {}", error);
75 }
76
77 let worktree = project
78 .read_with(cx, |project, cx| {
79 project.visible_worktrees(cx).next().unwrap()
80 })
81 .unwrap();
82
83 let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
84 .unwrap()
85 .into_arc();
86 let cursor_buffer = project
87 .update(cx, |project, cx| {
88 project.open_buffer(
89 ProjectPath {
90 worktree_id: worktree.read(cx).id(),
91 path: cursor_path,
92 },
93 cx,
94 )
95 })
96 .unwrap()
97 .await
98 .unwrap();
99 let cursor_offset_within_excerpt = example
100 .cursor_position
101 .find(CURSOR_MARKER)
102 .ok_or_else(|| anyhow!("missing cursor marker"))
103 .unwrap();
104 let mut cursor_excerpt = example.cursor_position.clone();
105 cursor_excerpt.replace_range(
106 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
107 "",
108 );
109 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
110 let text = buffer.text();
111
112 let mut matches = text.match_indices(&cursor_excerpt);
113 let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
114 panic!(
115 "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
116 example.name
117 );
118 });
119 assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
120 excerpt_offset
121 }).unwrap();
122
123 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
124 let cursor_anchor = cursor_buffer
125 .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
126 .unwrap();
127
128 (cursor_buffer, cursor_anchor)
129}
130
131async fn setup_project(
132 example: &mut Example,
133 app_state: &Arc<EpAppState>,
134 cx: &mut AsyncApp,
135) -> Entity<Project> {
136 let ep_store = cx
137 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
138 .unwrap();
139
140 let worktree_path = setup_worktree(example).await;
141
142 if let Some(project) = app_state.project_cache.get(&example.repository_url) {
143 ep_store
144 .update(cx, |ep_store, _| {
145 ep_store.clear_history_for_project(&project);
146 })
147 .unwrap();
148 let buffer_store = project
149 .read_with(cx, |project, _| project.buffer_store().clone())
150 .unwrap();
151 let buffers = buffer_store
152 .read_with(cx, |buffer_store, _| {
153 buffer_store.buffers().collect::<Vec<_>>()
154 })
155 .unwrap();
156 for buffer in buffers {
157 buffer
158 .update(cx, |buffer, cx| buffer.reload(cx))
159 .unwrap()
160 .await
161 .unwrap();
162 }
163 return project;
164 }
165
166 let project = cx
167 .update(|cx| {
168 Project::local(
169 app_state.client.clone(),
170 app_state.node_runtime.clone(),
171 app_state.user_store.clone(),
172 app_state.languages.clone(),
173 app_state.fs.clone(),
174 None,
175 cx,
176 )
177 })
178 .unwrap();
179
180 project
181 .update(cx, |project, cx| {
182 project.disable_worktree_scanner(cx);
183 project.create_worktree(&worktree_path, true, cx)
184 })
185 .unwrap()
186 .await
187 .unwrap();
188
189 app_state
190 .project_cache
191 .insert(example.repository_url.clone(), project.clone());
192
193 let buffer_store = project
194 .read_with(cx, |project, _| project.buffer_store().clone())
195 .unwrap();
196 cx.subscribe(&buffer_store, {
197 let project = project.clone();
198 move |_, event, cx| match event {
199 BufferStoreEvent::BufferAdded(buffer) => {
200 ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
201 }
202 _ => {}
203 }
204 })
205 .unwrap()
206 .detach();
207
208 project
209}
210
211pub async fn setup_worktree(example: &Example) -> PathBuf {
212 let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
213 let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
214 let worktree_path = WORKTREES_DIR
215 .join(repo_owner.as_ref())
216 .join(repo_name.as_ref());
217 let repo_lock = lock_repo(&repo_dir).await;
218
219 if !repo_dir.is_dir() {
220 eprintln!("Cloning repository {}", example.repository_url);
221 fs::create_dir_all(&repo_dir).unwrap();
222 run_git(&repo_dir, &["init"]).await.unwrap();
223 run_git(
224 &repo_dir,
225 &["remote", "add", "origin", &example.repository_url],
226 )
227 .await
228 .unwrap();
229 }
230
231 // Resolve the example to a revision, fetching it if needed.
232 let revision = run_git(
233 &repo_dir,
234 &["rev-parse", &format!("{}^{{commit}}", example.revision)],
235 )
236 .await;
237 let revision = if let Ok(revision) = revision {
238 revision
239 } else {
240 if run_git(
241 &repo_dir,
242 &["fetch", "--depth", "1", "origin", &example.revision],
243 )
244 .await
245 .is_err()
246 {
247 run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
248 }
249 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
250 .await
251 .unwrap();
252 revision
253 };
254
255 // Create the worktree for this example if needed.
256 if worktree_path.is_dir() {
257 run_git(&worktree_path, &["clean", "--force", "-d"])
258 .await
259 .unwrap();
260 run_git(&worktree_path, &["reset", "--hard", "HEAD"])
261 .await
262 .unwrap();
263 run_git(&worktree_path, &["checkout", revision.as_str()])
264 .await
265 .unwrap();
266 } else {
267 let worktree_path_string = worktree_path.to_string_lossy();
268 run_git(
269 &repo_dir,
270 &["branch", "-f", &example.name, revision.as_str()],
271 )
272 .await
273 .unwrap();
274 run_git(
275 &repo_dir,
276 &[
277 "worktree",
278 "add",
279 "-f",
280 &worktree_path_string,
281 &example.name,
282 ],
283 )
284 .await
285 .unwrap();
286 }
287 drop(repo_lock);
288
289 // Apply the uncommitted diff for this example.
290 if !example.uncommitted_diff.is_empty() {
291 let mut apply_process = smol::process::Command::new("git")
292 .current_dir(&worktree_path)
293 .args(&["apply", "-"])
294 .stdin(std::process::Stdio::piped())
295 .spawn()
296 .unwrap();
297
298 let mut stdin = apply_process.stdin.take().unwrap();
299 stdin
300 .write_all(example.uncommitted_diff.as_bytes())
301 .await
302 .unwrap();
303 stdin.close().await.unwrap();
304 drop(stdin);
305
306 let apply_result = apply_process.output().await.unwrap();
307 if !apply_result.status.success() {
308 panic!(
309 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
310 apply_result.status,
311 String::from_utf8_lossy(&apply_result.stderr),
312 String::from_utf8_lossy(&apply_result.stdout),
313 );
314 }
315 }
316
317 worktree_path
318}
319
320async fn apply_edit_history(
321 example: &Example,
322 project: &Entity<Project>,
323 cx: &mut AsyncApp,
324) -> Result<OpenedBuffers> {
325 edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
326}
327
328thread_local! {
329 static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
330}
331
332#[must_use]
333pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
334 REPO_LOCKS
335 .with(|cell| {
336 cell.borrow_mut()
337 .entry(path.as_ref().to_path_buf())
338 .or_default()
339 .clone()
340 })
341 .lock_owned()
342 .await
343}
344
345async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
346 let output = smol::process::Command::new("git")
347 .current_dir(repo_path)
348 .args(args)
349 .output()
350 .await?;
351
352 anyhow::ensure!(
353 output.status.success(),
354 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
355 args.join(" "),
356 repo_path.display(),
357 output.status,
358 String::from_utf8_lossy(&output.stderr),
359 String::from_utf8_lossy(&output.stdout),
360 );
361 Ok(String::from_utf8(output.stdout)?.trim().to_string())
362}