load_project.rs

  1use crate::{
  2    example::{Example, ExampleBuffer, ExampleState},
  3    headless::EpAppState,
  4    paths::{REPOS_DIR, WORKTREES_DIR},
  5    progress::{InfoStyle, Progress, Step, StepProgress},
  6};
  7use anyhow::{Context as _, Result};
  8use collections::HashMap;
  9use edit_prediction::EditPredictionStore;
 10use edit_prediction::udiff::OpenedBuffers;
 11use futures::{
 12    AsyncWriteExt as _,
 13    lock::{Mutex, OwnedMutexGuard},
 14};
 15use gpui::{AsyncApp, Entity};
 16use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
 17use project::Project;
 18use project::buffer_store::BufferStoreEvent;
 19use std::{
 20    cell::RefCell,
 21    fs,
 22    path::{Path, PathBuf},
 23    sync::Arc,
 24};
 25use zeta_prompt::CURSOR_MARKER;
 26
 27pub async fn run_load_project(
 28    example: &mut Example,
 29    app_state: Arc<EpAppState>,
 30    mut cx: AsyncApp,
 31) -> Result<()> {
 32    if example.state.is_some() {
 33        return Ok(());
 34    }
 35
 36    let progress = Progress::global().start(Step::LoadProject, &example.spec.name);
 37
 38    let project = setup_project(example, &app_state, &progress, &mut cx).await?;
 39
 40    let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
 41
 42    let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
 43    let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
 44        let cursor_point = cursor_position.to_point(&buffer);
 45        let language_name = buffer
 46            .language()
 47            .map(|l| l.name().to_string())
 48            .unwrap_or_else(|| "Unknown".to_string());
 49        (
 50            ExampleBuffer {
 51                content: buffer.text(),
 52                cursor_row: cursor_point.row,
 53                cursor_column: cursor_point.column,
 54                cursor_offset: cursor_position.to_offset(&buffer),
 55            },
 56            language_name,
 57        )
 58    })?;
 59
 60    progress.set_info(language_name, InfoStyle::Normal);
 61
 62    example.buffer = Some(example_buffer);
 63    example.state = Some(ExampleState {
 64        buffer,
 65        project,
 66        cursor_position,
 67        _open_buffers,
 68    });
 69    Ok(())
 70}
 71
 72async fn cursor_position(
 73    example: &Example,
 74    project: &Entity<Project>,
 75    cx: &mut AsyncApp,
 76) -> Result<(Entity<Buffer>, Anchor)> {
 77    let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
 78    let result = language_registry
 79        .load_language_for_file_path(&example.spec.cursor_path)
 80        .await;
 81
 82    if let Err(error) = result
 83        && !error.is::<LanguageNotFound>()
 84    {
 85        return Err(error);
 86    }
 87
 88    let cursor_path = project
 89        .read_with(cx, |project, cx| {
 90            project.find_project_path(&example.spec.cursor_path, cx)
 91        })?
 92        .with_context(|| {
 93            format!(
 94                "failed to find cursor path {}",
 95                example.spec.cursor_path.display()
 96            )
 97        })?;
 98    let cursor_buffer = project
 99        .update(cx, |project, cx| project.open_buffer(cursor_path, cx))?
100        .await?;
101    let cursor_offset_within_excerpt = example
102        .spec
103        .cursor_position
104        .find(CURSOR_MARKER)
105        .context("missing cursor marker")?;
106    let mut cursor_excerpt = example.spec.cursor_position.clone();
107    cursor_excerpt.replace_range(
108        cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
109        "",
110    );
111    let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
112        let text = buffer.text();
113
114        let mut matches = text.match_indices(&cursor_excerpt);
115        let (excerpt_offset, _) = matches.next().with_context(|| {
116            format!(
117                "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
118                example.spec.name
119            )
120        })?;
121        anyhow::ensure!(
122            matches.next().is_none(),
123            "More than one cursor position match found for {}",
124            &example.spec.name
125        );
126        Ok(excerpt_offset)
127    })??;
128
129    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
130    let cursor_anchor =
131        cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
132
133    Ok((cursor_buffer, cursor_anchor))
134}
135
136async fn setup_project(
137    example: &mut Example,
138    app_state: &Arc<EpAppState>,
139    step_progress: &StepProgress,
140    cx: &mut AsyncApp,
141) -> Result<Entity<Project>> {
142    let ep_store = cx
143        .update(|cx| EditPredictionStore::try_global(cx))?
144        .context("Store should be initialized at init")?;
145
146    let worktree_path = setup_worktree(example, step_progress).await?;
147
148    if let Some(project) = app_state.project_cache.get(&example.spec.repository_url) {
149        ep_store.update(cx, |ep_store, _| {
150            ep_store.clear_history_for_project(&project);
151        })?;
152        let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
153        let buffers = buffer_store.read_with(cx, |buffer_store, _| {
154            buffer_store.buffers().collect::<Vec<_>>()
155        })?;
156        for buffer in buffers {
157            buffer
158                .update(cx, |buffer, cx| buffer.reload(cx))?
159                .await
160                .ok();
161        }
162        return Ok(project);
163    }
164
165    let project = cx.update(|cx| {
166        Project::local(
167            app_state.client.clone(),
168            app_state.node_runtime.clone(),
169            app_state.user_store.clone(),
170            app_state.languages.clone(),
171            app_state.fs.clone(),
172            None,
173            false,
174            cx,
175        )
176    })?;
177
178    project
179        .update(cx, |project, cx| {
180            project.disable_worktree_scanner(cx);
181            project.create_worktree(&worktree_path, true, cx)
182        })?
183        .await?;
184
185    app_state
186        .project_cache
187        .insert(example.spec.repository_url.clone(), project.clone());
188
189    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
190    cx.subscribe(&buffer_store, {
191        let project = project.clone();
192        move |_, event, cx| match event {
193            BufferStoreEvent::BufferAdded(buffer) => {
194                ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
195            }
196            _ => {}
197        }
198    })?
199    .detach();
200
201    Ok(project)
202}
203
204async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
205    let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
206    let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
207    let worktree_path = WORKTREES_DIR
208        .join(repo_owner.as_ref())
209        .join(repo_name.as_ref());
210    let repo_lock = lock_repo(&repo_dir).await;
211
212    if !repo_dir.is_dir() {
213        step_progress.set_substatus(format!("cloning {}", repo_name));
214        fs::create_dir_all(&repo_dir)?;
215        run_git(&repo_dir, &["init"]).await?;
216        run_git(
217            &repo_dir,
218            &["remote", "add", "origin", &example.spec.repository_url],
219        )
220        .await?;
221    }
222
223    // Resolve the example to a revision, fetching it if needed.
224    let revision = run_git(
225        &repo_dir,
226        &[
227            "rev-parse",
228            &format!("{}^{{commit}}", example.spec.revision),
229        ],
230    )
231    .await;
232    let revision = if let Ok(revision) = revision {
233        revision
234    } else {
235        step_progress.set_substatus("fetching");
236        if run_git(
237            &repo_dir,
238            &["fetch", "--depth", "1", "origin", &example.spec.revision],
239        )
240        .await
241        .is_err()
242        {
243            run_git(&repo_dir, &["fetch", "origin"]).await?;
244        }
245        let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
246        revision
247    };
248
249    // Create the worktree for this example if needed.
250    step_progress.set_substatus("preparing worktree");
251    if worktree_path.is_dir() {
252        run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
253        run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
254        run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
255    } else {
256        let worktree_path_string = worktree_path.to_string_lossy();
257        run_git(
258            &repo_dir,
259            &["branch", "-f", &example.spec.name, revision.as_str()],
260        )
261        .await?;
262        run_git(
263            &repo_dir,
264            &[
265                "worktree",
266                "add",
267                "-f",
268                &worktree_path_string,
269                &example.spec.name,
270            ],
271        )
272        .await?;
273    }
274    drop(repo_lock);
275
276    // Apply the uncommitted diff for this example.
277    if !example.spec.uncommitted_diff.is_empty() {
278        step_progress.set_substatus("applying diff");
279        let mut apply_process = smol::process::Command::new("git")
280            .current_dir(&worktree_path)
281            .args(&["apply", "-"])
282            .stdin(std::process::Stdio::piped())
283            .spawn()?;
284
285        let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
286        stdin
287            .write_all(example.spec.uncommitted_diff.as_bytes())
288            .await?;
289        stdin.close().await?;
290        drop(stdin);
291
292        let apply_result = apply_process.output().await?;
293        anyhow::ensure!(
294            apply_result.status.success(),
295            "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
296            apply_result.status,
297            String::from_utf8_lossy(&apply_result.stderr),
298            String::from_utf8_lossy(&apply_result.stdout),
299        );
300    }
301
302    step_progress.clear_substatus();
303    Ok(worktree_path)
304}
305
306async fn apply_edit_history(
307    example: &Example,
308    project: &Entity<Project>,
309    cx: &mut AsyncApp,
310) -> Result<OpenedBuffers> {
311    edit_prediction::udiff::apply_diff(&example.spec.edit_history, project, cx).await
312}
313
314thread_local! {
315    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
316}
317
318#[must_use]
319pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
320    REPO_LOCKS
321        .with(|cell| {
322            cell.borrow_mut()
323                .entry(path.as_ref().to_path_buf())
324                .or_default()
325                .clone()
326        })
327        .lock_owned()
328        .await
329}
330
331async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
332    let output = smol::process::Command::new("git")
333        .current_dir(repo_path)
334        .args(args)
335        .output()
336        .await?;
337
338    anyhow::ensure!(
339        output.status.success(),
340        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
341        args.join(" "),
342        repo_path.display(),
343        output.status,
344        String::from_utf8_lossy(&output.stderr),
345        String::from_utf8_lossy(&output.stdout),
346    );
347    Ok(String::from_utf8(output.stdout)?.trim().to_string())
348}