example.rs

  1use std::{
  2    borrow::Cow,
  3    cell::RefCell,
  4    fmt::{self, Display},
  5    fs,
  6    hash::Hash,
  7    hash::Hasher,
  8    io::Write,
  9    mem,
 10    path::{Path, PathBuf},
 11    sync::{Arc, OnceLock},
 12};
 13
 14use crate::headless::ZetaCliAppState;
 15use anyhow::{Context as _, Result, anyhow};
 16use clap::ValueEnum;
 17use cloud_zeta2_prompt::CURSOR_MARKER;
 18use collections::HashMap;
 19use edit_prediction::udiff::OpenedBuffers;
 20use futures::{
 21    AsyncWriteExt as _,
 22    lock::{Mutex, OwnedMutexGuard},
 23};
 24use futures::{FutureExt as _, future::Shared};
 25use gpui::{AsyncApp, Entity, Task, http_client::Url};
 26use language::{Anchor, Buffer};
 27use project::{Project, ProjectPath};
 28use pulldown_cmark::CowStr;
 29use serde::{Deserialize, Serialize};
 30use util::{paths::PathStyle, rel_path::RelPath};
 31
 32use crate::paths::{REPOS_DIR, WORKTREES_DIR};
 33
 34const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 35const EDIT_HISTORY_HEADING: &str = "Edit History";
 36const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 37const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
 38const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
 39const REPOSITORY_URL_FIELD: &str = "repository_url";
 40const REVISION_FIELD: &str = "revision";
 41
 42#[derive(Debug, Clone)]
 43pub struct NamedExample {
 44    pub name: String,
 45    pub example: Example,
 46}
 47
 48#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
 49pub struct Example {
 50    pub repository_url: String,
 51    pub revision: String,
 52    pub uncommitted_diff: String,
 53    pub cursor_path: PathBuf,
 54    pub cursor_position: String,
 55    pub edit_history: String,
 56    pub expected_patch: String,
 57}
 58
 59impl Example {
 60    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
 61        // git@github.com:owner/repo.git
 62        if self.repository_url.contains('@') {
 63            let (owner, repo) = self
 64                .repository_url
 65                .split_once(':')
 66                .context("expected : in git url")?
 67                .1
 68                .split_once('/')
 69                .context("expected / in git url")?;
 70            Ok((
 71                Cow::Borrowed(owner),
 72                Cow::Borrowed(repo.trim_end_matches(".git")),
 73            ))
 74        // http://github.com/owner/repo.git
 75        } else {
 76            let url = Url::parse(&self.repository_url)?;
 77            let mut segments = url.path_segments().context("empty http url")?;
 78            let owner = segments
 79                .next()
 80                .context("expected owner path segment")?
 81                .to_string();
 82            let repo = segments
 83                .next()
 84                .context("expected repo path segment")?
 85                .trim_end_matches(".git")
 86                .to_string();
 87            assert!(segments.next().is_none());
 88
 89            Ok((owner.into(), repo.into()))
 90        }
 91    }
 92
 93    pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
 94        let (repo_owner, repo_name) = self.repo_name()?;
 95
 96        let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
 97        let repo_lock = lock_repo(&repo_dir).await;
 98
 99        if !repo_dir.is_dir() {
100            fs::create_dir_all(&repo_dir)?;
101            run_git(&repo_dir, &["init"]).await?;
102            run_git(
103                &repo_dir,
104                &["remote", "add", "origin", &self.repository_url],
105            )
106            .await?;
107        }
108
109        // Resolve the example to a revision, fetching it if needed.
110        let revision = run_git(
111            &repo_dir,
112            &["rev-parse", &format!("{}^{{commit}}", self.revision)],
113        )
114        .await;
115        let revision = if let Ok(revision) = revision {
116            revision
117        } else {
118            if run_git(
119                &repo_dir,
120                &["fetch", "--depth", "1", "origin", &self.revision],
121            )
122            .await
123            .is_err()
124            {
125                run_git(&repo_dir, &["fetch", "origin"]).await?;
126            }
127            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
128            if revision != self.revision {
129                run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
130            }
131            revision
132        };
133
134        // Create the worktree for this example if needed.
135        let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
136        if worktree_path.is_dir() {
137            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
138            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
139            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
140        } else {
141            let worktree_path_string = worktree_path.to_string_lossy();
142            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
143            run_git(
144                &repo_dir,
145                &["worktree", "add", "-f", &worktree_path_string, &file_name],
146            )
147            .await?;
148        }
149        drop(repo_lock);
150
151        // Apply the uncommitted diff for this example.
152        if !self.uncommitted_diff.is_empty() {
153            let mut apply_process = smol::process::Command::new("git")
154                .current_dir(&worktree_path)
155                .args(&["apply", "-"])
156                .stdin(std::process::Stdio::piped())
157                .spawn()?;
158
159            let mut stdin = apply_process.stdin.take().unwrap();
160            stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
161            stdin.close().await?;
162            drop(stdin);
163
164            let apply_result = apply_process.output().await?;
165            if !apply_result.status.success() {
166                anyhow::bail!(
167                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
168                    apply_result.status,
169                    String::from_utf8_lossy(&apply_result.stderr),
170                    String::from_utf8_lossy(&apply_result.stdout),
171                );
172            }
173        }
174
175        Ok(worktree_path)
176    }
177
178    pub fn unique_name(&self) -> String {
179        let mut hasher = std::hash::DefaultHasher::new();
180        self.hash(&mut hasher);
181        let disambiguator = hasher.finish();
182        let hash = format!("{:04x}", disambiguator);
183        format!("{}_{}", &self.revision[..8], &hash[..4])
184    }
185}
186
187pub type ActualExcerpt = Excerpt;
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190pub struct Excerpt {
191    pub path: PathBuf,
192    pub text: String,
193}
194
195#[derive(ValueEnum, Debug, Clone)]
196pub enum ExampleFormat {
197    Json,
198    Toml,
199    Md,
200}
201
202impl NamedExample {
203    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
204        let path = path.as_ref();
205        let content = std::fs::read_to_string(path)?;
206        let ext = path.extension();
207
208        match ext.and_then(|s| s.to_str()) {
209            Some("json") => Ok(Self {
210                name: path.file_stem().unwrap_or_default().display().to_string(),
211                example: serde_json::from_str(&content)?,
212            }),
213            Some("toml") => Ok(Self {
214                name: path.file_stem().unwrap_or_default().display().to_string(),
215                example: toml::from_str(&content)?,
216            }),
217            Some("md") => Self::parse_md(&content),
218            Some(_) => {
219                anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
220            }
221            None => {
222                anyhow::bail!(
223                    "Failed to determine example type since the file does not have an extension."
224                );
225            }
226        }
227    }
228
229    pub fn parse_md(input: &str) -> Result<Self> {
230        use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
231
232        let parser = Parser::new(input);
233
234        let mut named = NamedExample {
235            name: String::new(),
236            example: Example {
237                repository_url: String::new(),
238                revision: String::new(),
239                uncommitted_diff: String::new(),
240                cursor_path: PathBuf::new(),
241                cursor_position: String::new(),
242                edit_history: String::new(),
243                expected_patch: String::new(),
244            },
245        };
246
247        let mut text = String::new();
248        let mut block_info: CowStr = "".into();
249
250        #[derive(PartialEq)]
251        enum Section {
252            UncommittedDiff,
253            EditHistory,
254            CursorPosition,
255            ExpectedExcerpts,
256            ExpectedPatch,
257            Other,
258        }
259
260        let mut current_section = Section::Other;
261
262        for event in parser {
263            match event {
264                Event::Text(line) => {
265                    text.push_str(&line);
266
267                    if !named.name.is_empty()
268                        && current_section == Section::Other
269                        // in h1 section
270                        && let Some((field, value)) = line.split_once('=')
271                    {
272                        match field.trim() {
273                            REPOSITORY_URL_FIELD => {
274                                named.example.repository_url = value.trim().to_string();
275                            }
276                            REVISION_FIELD => {
277                                named.example.revision = value.trim().to_string();
278                            }
279                            _ => {}
280                        }
281                    }
282                }
283                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
284                    if !named.name.is_empty() {
285                        anyhow::bail!(
286                            "Found multiple H1 headings. There should only be one with the name of the example."
287                        );
288                    }
289                    named.name = mem::take(&mut text);
290                }
291                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
292                    let title = mem::take(&mut text);
293                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
294                        Section::UncommittedDiff
295                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
296                        Section::EditHistory
297                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
298                        Section::CursorPosition
299                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
300                        Section::ExpectedPatch
301                    } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
302                        Section::ExpectedExcerpts
303                    } else {
304                        Section::Other
305                    };
306                }
307                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
308                    mem::take(&mut text);
309                }
310                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
311                    mem::take(&mut text);
312                }
313                Event::End(TagEnd::Heading(level)) => {
314                    anyhow::bail!("Unexpected heading level: {level}");
315                }
316                Event::Start(Tag::CodeBlock(kind)) => {
317                    match kind {
318                        CodeBlockKind::Fenced(info) => {
319                            block_info = info;
320                        }
321                        CodeBlockKind::Indented => {
322                            anyhow::bail!("Unexpected indented codeblock");
323                        }
324                    };
325                }
326                Event::Start(_) => {
327                    text.clear();
328                    block_info = "".into();
329                }
330                Event::End(TagEnd::CodeBlock) => {
331                    let block_info = block_info.trim();
332                    match current_section {
333                        Section::UncommittedDiff => {
334                            named.example.uncommitted_diff = mem::take(&mut text);
335                        }
336                        Section::EditHistory => {
337                            named.example.edit_history.push_str(&mem::take(&mut text));
338                        }
339                        Section::CursorPosition => {
340                            named.example.cursor_path = block_info.into();
341                            named.example.cursor_position = mem::take(&mut text);
342                        }
343                        Section::ExpectedExcerpts => {
344                            mem::take(&mut text);
345                        }
346                        Section::ExpectedPatch => {
347                            named.example.expected_patch = mem::take(&mut text);
348                        }
349                        Section::Other => {}
350                    }
351                }
352                _ => {}
353            }
354        }
355
356        if named.example.cursor_path.as_path() == Path::new("")
357            || named.example.cursor_position.is_empty()
358        {
359            anyhow::bail!("Missing cursor position codeblock");
360        }
361
362        Ok(named)
363    }
364
365    pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
366        match format {
367            ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
368            ExampleFormat::Toml => {
369                Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
370            }
371            ExampleFormat::Md => Ok(write!(out, "{}", self)?),
372        }
373    }
374
375    pub async fn setup_project(
376        &self,
377        app_state: &Arc<ZetaCliAppState>,
378        cx: &mut AsyncApp,
379    ) -> Result<Entity<Project>> {
380        let worktree_path = self.setup_worktree().await?;
381
382        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
383
384        AUTHENTICATED
385            .get_or_init(|| {
386                let client = app_state.client.clone();
387                cx.spawn(async move |cx| {
388                    client
389                        .sign_in_with_optional_connect(true, cx)
390                        .await
391                        .unwrap();
392                })
393                .shared()
394            })
395            .clone()
396            .await;
397
398        let project = cx.update(|cx| {
399            Project::local(
400                app_state.client.clone(),
401                app_state.node_runtime.clone(),
402                app_state.user_store.clone(),
403                app_state.languages.clone(),
404                app_state.fs.clone(),
405                None,
406                cx,
407            )
408        })?;
409
410        let worktree = project
411            .update(cx, |project, cx| {
412                project.create_worktree(&worktree_path, true, cx)
413            })?
414            .await?;
415        worktree
416            .read_with(cx, |worktree, _cx| {
417                worktree.as_local().unwrap().scan_complete()
418            })?
419            .await;
420
421        anyhow::Ok(project)
422    }
423
424    pub async fn setup_worktree(&self) -> Result<PathBuf> {
425        self.example.setup_worktree(self.file_name()).await
426    }
427
428    pub fn file_name(&self) -> String {
429        self.name
430            .chars()
431            .map(|c| {
432                if c.is_whitespace() {
433                    '-'
434                } else {
435                    c.to_ascii_lowercase()
436                }
437            })
438            .collect()
439    }
440
441    pub async fn cursor_position(
442        &self,
443        project: &Entity<Project>,
444        cx: &mut AsyncApp,
445    ) -> Result<(Entity<Buffer>, Anchor)> {
446        let worktree = project.read_with(cx, |project, cx| {
447            project.visible_worktrees(cx).next().unwrap()
448        })?;
449        let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
450        let cursor_buffer = project
451            .update(cx, |project, cx| {
452                project.open_buffer(
453                    ProjectPath {
454                        worktree_id: worktree.read(cx).id(),
455                        path: cursor_path,
456                    },
457                    cx,
458                )
459            })?
460            .await?;
461        let cursor_offset_within_excerpt = self
462            .example
463            .cursor_position
464            .find(CURSOR_MARKER)
465            .ok_or_else(|| anyhow!("missing cursor marker"))?;
466        let mut cursor_excerpt = self.example.cursor_position.clone();
467        cursor_excerpt.replace_range(
468            cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
469            "",
470        );
471        let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
472            let text = buffer.text();
473
474            let mut matches = text.match_indices(&cursor_excerpt);
475            let Some((excerpt_offset, _)) = matches.next() else {
476                anyhow::bail!(
477                    "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
478                );
479            };
480            assert!(matches.next().is_none());
481
482            Ok(excerpt_offset)
483        })??;
484
485        let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
486        let cursor_anchor =
487            cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
488        Ok((cursor_buffer, cursor_anchor))
489    }
490
491    #[must_use]
492    pub async fn apply_edit_history(
493        &self,
494        project: &Entity<Project>,
495        cx: &mut AsyncApp,
496    ) -> Result<OpenedBuffers<'_>> {
497        edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
498    }
499}
500
501async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
502    let output = smol::process::Command::new("git")
503        .current_dir(repo_path)
504        .args(args)
505        .output()
506        .await?;
507
508    anyhow::ensure!(
509        output.status.success(),
510        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
511        args.join(" "),
512        repo_path.display(),
513        output.status,
514        String::from_utf8_lossy(&output.stderr),
515        String::from_utf8_lossy(&output.stdout),
516    );
517    Ok(String::from_utf8(output.stdout)?.trim().to_string())
518}
519
520impl Display for NamedExample {
521    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522        write!(f, "# {}\n\n", self.name)?;
523        write!(
524            f,
525            "{REPOSITORY_URL_FIELD} = {}\n",
526            self.example.repository_url
527        )?;
528        write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
529
530        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
531        write!(f, "`````diff\n")?;
532        write!(f, "{}", self.example.uncommitted_diff)?;
533        write!(f, "`````\n")?;
534
535        if !self.example.edit_history.is_empty() {
536            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
537        }
538
539        write!(
540            f,
541            "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
542            self.example.cursor_path.display(),
543            self.example.cursor_position
544        )?;
545        write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
546
547        if !self.example.expected_patch.is_empty() {
548            write!(
549                f,
550                "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
551                self.example.expected_patch
552            )?;
553        }
554
555        Ok(())
556    }
557}
558
559thread_local! {
560    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
561}
562
563#[must_use]
564pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
565    REPO_LOCKS
566        .with(|cell| {
567            cell.borrow_mut()
568                .entry(path.as_ref().to_path_buf())
569                .or_default()
570                .clone()
571        })
572        .lock_owned()
573        .await
574}