load_project.rs

  1use crate::{
  2    example::{Example, ExampleBuffer, ExampleState},
  3    headless::EpAppState,
  4};
  5use anyhow::{Result, anyhow};
  6use collections::HashMap;
  7use edit_prediction::EditPredictionStore;
  8use edit_prediction::udiff::OpenedBuffers;
  9use futures::{
 10    AsyncWriteExt as _,
 11    lock::{Mutex, OwnedMutexGuard},
 12};
 13use gpui::{AsyncApp, Entity};
 14use language::{Anchor, Buffer, ToOffset, ToPoint};
 15use project::buffer_store::BufferStoreEvent;
 16use project::{Project, ProjectPath};
 17use std::{
 18    cell::RefCell,
 19    fs,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use util::{paths::PathStyle, rel_path::RelPath};
 24use zeta_prompt::CURSOR_MARKER;
 25
 26pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
 27    if example.state.is_some() {
 28        return;
 29    }
 30
 31    let project = setup_project(example, &app_state, &mut cx).await;
 32    let buffer_store = project
 33        .read_with(&cx, |project, _| project.buffer_store().clone())
 34        .unwrap();
 35
 36    let ep_store = cx
 37        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 38        .unwrap();
 39
 40    cx.subscribe(&buffer_store, {
 41        let project = project.clone();
 42        move |_, event, cx| match event {
 43            BufferStoreEvent::BufferAdded(buffer) => {
 44                ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
 45            }
 46            _ => {}
 47        }
 48    })
 49    .unwrap()
 50    .detach();
 51
 52    let _open_buffers = apply_edit_history(example, &project, &mut cx)
 53        .await
 54        .unwrap();
 55    let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
 56    example.buffer = buffer
 57        .read_with(&cx, |buffer, _cx| {
 58            let cursor_point = cursor_position.to_point(&buffer);
 59            Some(ExampleBuffer {
 60                content: buffer.text(),
 61                cursor_row: cursor_point.row,
 62                cursor_column: cursor_point.column,
 63                cursor_offset: cursor_position.to_offset(&buffer),
 64            })
 65        })
 66        .unwrap();
 67    example.state = Some(ExampleState {
 68        buffer,
 69        project,
 70        cursor_position,
 71        _open_buffers,
 72    });
 73}
 74
 75async fn cursor_position(
 76    example: &Example,
 77    project: &Entity<Project>,
 78    cx: &mut AsyncApp,
 79) -> (Entity<Buffer>, Anchor) {
 80    let worktree = project
 81        .read_with(cx, |project, cx| {
 82            project.visible_worktrees(cx).next().unwrap()
 83        })
 84        .unwrap();
 85
 86    let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
 87        .unwrap()
 88        .into_arc();
 89    let cursor_buffer = project
 90        .update(cx, |project, cx| {
 91            project.open_buffer(
 92                ProjectPath {
 93                    worktree_id: worktree.read(cx).id(),
 94                    path: cursor_path,
 95                },
 96                cx,
 97            )
 98        })
 99        .unwrap()
100        .await
101        .unwrap();
102    let cursor_offset_within_excerpt = example
103        .cursor_position
104        .find(CURSOR_MARKER)
105        .ok_or_else(|| anyhow!("missing cursor marker"))
106        .unwrap();
107    let mut cursor_excerpt = example.cursor_position.clone();
108    cursor_excerpt.replace_range(
109        cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
110        "",
111    );
112    let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
113        let text = buffer.text();
114
115        let mut matches = text.match_indices(&cursor_excerpt);
116        let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
117            panic!(
118                "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
119            );
120        });
121        assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
122        excerpt_offset
123    }).unwrap();
124
125    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
126    let cursor_anchor = cursor_buffer
127        .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
128        .unwrap();
129
130    (cursor_buffer, cursor_anchor)
131}
132
133async fn setup_project(
134    example: &mut Example,
135    app_state: &Arc<EpAppState>,
136    cx: &mut AsyncApp,
137) -> Entity<Project> {
138    setup_worktree(example).await;
139
140    let project = cx
141        .update(|cx| {
142            Project::local(
143                app_state.client.clone(),
144                app_state.node_runtime.clone(),
145                app_state.user_store.clone(),
146                app_state.languages.clone(),
147                app_state.fs.clone(),
148                None,
149                cx,
150            )
151        })
152        .unwrap();
153
154    let worktree = project
155        .update(cx, |project, cx| {
156            project.create_worktree(&example.worktree_path(), true, cx)
157        })
158        .unwrap()
159        .await
160        .unwrap();
161    worktree
162        .read_with(cx, |worktree, _cx| {
163            worktree.as_local().unwrap().scan_complete()
164        })
165        .unwrap()
166        .await;
167    project
168}
169
170pub async fn setup_worktree(example: &Example) {
171    let repo_dir = example.repo_path();
172    let repo_lock = lock_repo(&repo_dir).await;
173
174    if !repo_dir.is_dir() {
175        fs::create_dir_all(&repo_dir).unwrap();
176        run_git(&repo_dir, &["init"]).await.unwrap();
177        run_git(
178            &repo_dir,
179            &["remote", "add", "origin", &example.repository_url],
180        )
181        .await
182        .unwrap();
183    }
184
185    // Resolve the example to a revision, fetching it if needed.
186    let revision = run_git(
187        &repo_dir,
188        &["rev-parse", &format!("{}^{{commit}}", example.revision)],
189    )
190    .await;
191    let revision = if let Ok(revision) = revision {
192        revision
193    } else {
194        if run_git(
195            &repo_dir,
196            &["fetch", "--depth", "1", "origin", &example.revision],
197        )
198        .await
199        .is_err()
200        {
201            run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
202        }
203        let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
204            .await
205            .unwrap();
206        if revision != example.revision {
207            run_git(&repo_dir, &["tag", &example.revision, &revision])
208                .await
209                .unwrap();
210        }
211        revision
212    };
213
214    // Create the worktree for this example if needed.
215    let worktree_path = example.worktree_path();
216    if worktree_path.is_dir() {
217        run_git(&worktree_path, &["clean", "--force", "-d"])
218            .await
219            .unwrap();
220        run_git(&worktree_path, &["reset", "--hard", "HEAD"])
221            .await
222            .unwrap();
223        run_git(&worktree_path, &["checkout", revision.as_str()])
224            .await
225            .unwrap();
226    } else {
227        let worktree_path_string = worktree_path.to_string_lossy();
228        run_git(
229            &repo_dir,
230            &["branch", "-f", &example.name, revision.as_str()],
231        )
232        .await
233        .unwrap();
234        run_git(
235            &repo_dir,
236            &[
237                "worktree",
238                "add",
239                "-f",
240                &worktree_path_string,
241                &example.name,
242            ],
243        )
244        .await
245        .unwrap();
246    }
247    drop(repo_lock);
248
249    // Apply the uncommitted diff for this example.
250    if !example.uncommitted_diff.is_empty() {
251        let mut apply_process = smol::process::Command::new("git")
252            .current_dir(&worktree_path)
253            .args(&["apply", "-"])
254            .stdin(std::process::Stdio::piped())
255            .spawn()
256            .unwrap();
257
258        let mut stdin = apply_process.stdin.take().unwrap();
259        stdin
260            .write_all(example.uncommitted_diff.as_bytes())
261            .await
262            .unwrap();
263        stdin.close().await.unwrap();
264        drop(stdin);
265
266        let apply_result = apply_process.output().await.unwrap();
267        if !apply_result.status.success() {
268            panic!(
269                "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
270                apply_result.status,
271                String::from_utf8_lossy(&apply_result.stderr),
272                String::from_utf8_lossy(&apply_result.stdout),
273            );
274        }
275    }
276}
277
278async fn apply_edit_history(
279    example: &Example,
280    project: &Entity<Project>,
281    cx: &mut AsyncApp,
282) -> Result<OpenedBuffers> {
283    edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
284}
285
286thread_local! {
287    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
288}
289
290#[must_use]
291pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
292    REPO_LOCKS
293        .with(|cell| {
294            cell.borrow_mut()
295                .entry(path.as_ref().to_path_buf())
296                .or_default()
297                .clone()
298        })
299        .lock_owned()
300        .await
301}
302
303async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
304    let output = smol::process::Command::new("git")
305        .current_dir(repo_path)
306        .args(args)
307        .output()
308        .await?;
309
310    anyhow::ensure!(
311        output.status.success(),
312        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
313        args.join(" "),
314        repo_path.display(),
315        output.status,
316        String::from_utf8_lossy(&output.stderr),
317        String::from_utf8_lossy(&output.stdout),
318    );
319    Ok(String::from_utf8(output.stdout)?.trim().to_string())
320}