example.rs

  1use std::{
  2    borrow::Cow,
  3    cell::RefCell,
  4    fmt::{self, Display},
  5    fs,
  6    io::Write,
  7    mem,
  8    path::{Path, PathBuf},
  9    sync::Arc,
 10};
 11
 12use anyhow::{Context as _, Result, anyhow};
 13use clap::ValueEnum;
 14use cloud_zeta2_prompt::CURSOR_MARKER;
 15use collections::HashMap;
 16use edit_prediction_context::Line;
 17use futures::{
 18    AsyncWriteExt as _,
 19    lock::{Mutex, OwnedMutexGuard},
 20};
 21use gpui::{AsyncApp, Entity, http_client::Url};
 22use language::{Anchor, Buffer};
 23use project::{Project, ProjectPath};
 24use pulldown_cmark::CowStr;
 25use serde::{Deserialize, Serialize};
 26use util::{paths::PathStyle, rel_path::RelPath};
 27use zeta2::udiff::OpenedBuffers;
 28
 29use crate::paths::{REPOS_DIR, WORKTREES_DIR};
 30
 31const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 32const EDIT_HISTORY_HEADING: &str = "Edit History";
 33const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 34const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
 35const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
 36const REPOSITORY_URL_FIELD: &str = "repository_url";
 37const REVISION_FIELD: &str = "revision";
 38
 39#[derive(Debug, Clone)]
 40pub struct NamedExample {
 41    pub name: String,
 42    pub example: Example,
 43}
 44
 45#[derive(Clone, Debug, Serialize, Deserialize)]
 46pub struct Example {
 47    pub repository_url: String,
 48    pub revision: String,
 49    pub uncommitted_diff: String,
 50    pub cursor_path: PathBuf,
 51    pub cursor_position: String,
 52    pub edit_history: String,
 53    pub expected_patch: String,
 54    pub expected_context: Vec<ExpectedContextEntry>,
 55}
 56
 57pub type ActualExcerpt = Excerpt;
 58
 59#[derive(Clone, Debug, Serialize, Deserialize)]
 60pub struct Excerpt {
 61    pub path: PathBuf,
 62    pub text: String,
 63}
 64
 65#[derive(Default, Clone, Debug, Serialize, Deserialize)]
 66pub struct ExpectedContextEntry {
 67    pub heading: String,
 68    pub alternatives: Vec<ExpectedExcerptSet>,
 69}
 70
 71#[derive(Default, Clone, Debug, Serialize, Deserialize)]
 72pub struct ExpectedExcerptSet {
 73    pub heading: String,
 74    pub excerpts: Vec<ExpectedExcerpt>,
 75}
 76
 77#[derive(Clone, Debug, Serialize, Deserialize)]
 78pub struct ExpectedExcerpt {
 79    pub path: PathBuf,
 80    pub text: String,
 81    pub required_lines: Vec<Line>,
 82}
 83
 84#[derive(ValueEnum, Debug, Clone)]
 85pub enum ExampleFormat {
 86    Json,
 87    Toml,
 88    Md,
 89}
 90
 91impl NamedExample {
 92    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
 93        let path = path.as_ref();
 94        let content = std::fs::read_to_string(path)?;
 95        let ext = path.extension();
 96
 97        match ext.and_then(|s| s.to_str()) {
 98            Some("json") => Ok(Self {
 99                name: path.file_stem().unwrap_or_default().display().to_string(),
100                example: serde_json::from_str(&content)?,
101            }),
102            Some("toml") => Ok(Self {
103                name: path.file_stem().unwrap_or_default().display().to_string(),
104                example: toml::from_str(&content)?,
105            }),
106            Some("md") => Self::parse_md(&content),
107            Some(_) => {
108                anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
109            }
110            None => {
111                anyhow::bail!(
112                    "Failed to determine example type since the file does not have an extension."
113                );
114            }
115        }
116    }
117
118    pub fn parse_md(input: &str) -> Result<Self> {
119        use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
120
121        let parser = Parser::new(input);
122
123        let mut named = NamedExample {
124            name: String::new(),
125            example: Example {
126                repository_url: String::new(),
127                revision: String::new(),
128                uncommitted_diff: String::new(),
129                cursor_path: PathBuf::new(),
130                cursor_position: String::new(),
131                edit_history: String::new(),
132                expected_patch: String::new(),
133                expected_context: Vec::new(),
134            },
135        };
136
137        let mut text = String::new();
138        let mut block_info: CowStr = "".into();
139
140        #[derive(PartialEq)]
141        enum Section {
142            UncommittedDiff,
143            EditHistory,
144            CursorPosition,
145            ExpectedExcerpts,
146            ExpectedPatch,
147            Other,
148        }
149
150        let mut current_section = Section::Other;
151
152        for event in parser {
153            match event {
154                Event::Text(line) => {
155                    text.push_str(&line);
156
157                    if !named.name.is_empty()
158                        && current_section == Section::Other
159                        // in h1 section
160                        && let Some((field, value)) = line.split_once('=')
161                    {
162                        match field.trim() {
163                            REPOSITORY_URL_FIELD => {
164                                named.example.repository_url = value.trim().to_string();
165                            }
166                            REVISION_FIELD => {
167                                named.example.revision = value.trim().to_string();
168                            }
169                            _ => {}
170                        }
171                    }
172                }
173                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
174                    if !named.name.is_empty() {
175                        anyhow::bail!(
176                            "Found multiple H1 headings. There should only be one with the name of the example."
177                        );
178                    }
179                    named.name = mem::take(&mut text);
180                }
181                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
182                    let title = mem::take(&mut text);
183                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
184                        Section::UncommittedDiff
185                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
186                        Section::EditHistory
187                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
188                        Section::CursorPosition
189                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
190                        Section::ExpectedPatch
191                    } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
192                        Section::ExpectedExcerpts
193                    } else {
194                        Section::Other
195                    };
196                }
197                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
198                    let heading = mem::take(&mut text);
199                    match current_section {
200                        Section::ExpectedExcerpts => {
201                            named.example.expected_context.push(ExpectedContextEntry {
202                                heading,
203                                alternatives: Vec::new(),
204                            });
205                        }
206                        _ => {}
207                    }
208                }
209                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
210                    let heading = mem::take(&mut text);
211                    match current_section {
212                        Section::ExpectedExcerpts => {
213                            let expected_context = &mut named.example.expected_context;
214                            let last_entry = expected_context.last_mut().unwrap();
215                            last_entry.alternatives.push(ExpectedExcerptSet {
216                                heading,
217                                excerpts: Vec::new(),
218                            })
219                        }
220                        _ => {}
221                    }
222                }
223                Event::End(TagEnd::Heading(level)) => {
224                    anyhow::bail!("Unexpected heading level: {level}");
225                }
226                Event::Start(Tag::CodeBlock(kind)) => {
227                    match kind {
228                        CodeBlockKind::Fenced(info) => {
229                            block_info = info;
230                        }
231                        CodeBlockKind::Indented => {
232                            anyhow::bail!("Unexpected indented codeblock");
233                        }
234                    };
235                }
236                Event::Start(_) => {
237                    text.clear();
238                    block_info = "".into();
239                }
240                Event::End(TagEnd::CodeBlock) => {
241                    let block_info = block_info.trim();
242                    match current_section {
243                        Section::UncommittedDiff => {
244                            named.example.uncommitted_diff = mem::take(&mut text);
245                        }
246                        Section::EditHistory => {
247                            named.example.edit_history.push_str(&mem::take(&mut text));
248                        }
249                        Section::CursorPosition => {
250                            named.example.cursor_path = block_info.into();
251                            named.example.cursor_position = mem::take(&mut text);
252                        }
253                        Section::ExpectedExcerpts => {
254                            let text = mem::take(&mut text);
255                            for excerpt in text.split("\n\n") {
256                                let (mut text, required_lines) = extract_required_lines(&excerpt);
257                                if !text.ends_with('\n') {
258                                    text.push('\n');
259                                }
260                                let alternatives = &mut named
261                                    .example
262                                    .expected_context
263                                    .last_mut()
264                                    .unwrap()
265                                    .alternatives;
266
267                                if alternatives.is_empty() {
268                                    alternatives.push(ExpectedExcerptSet {
269                                        heading: String::new(),
270                                        excerpts: vec![],
271                                    });
272                                }
273
274                                alternatives
275                                    .last_mut()
276                                    .unwrap()
277                                    .excerpts
278                                    .push(ExpectedExcerpt {
279                                        path: block_info.into(),
280                                        text,
281                                        required_lines,
282                                    });
283                            }
284                        }
285                        Section::ExpectedPatch => {
286                            named.example.expected_patch = mem::take(&mut text);
287                        }
288                        Section::Other => {}
289                    }
290                }
291                _ => {}
292            }
293        }
294
295        if named.example.cursor_path.as_path() == Path::new("")
296            || named.example.cursor_position.is_empty()
297        {
298            anyhow::bail!("Missing cursor position codeblock");
299        }
300
301        Ok(named)
302    }
303
304    pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
305        match format {
306            ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
307            ExampleFormat::Toml => {
308                Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
309            }
310            ExampleFormat::Md => Ok(write!(out, "{}", self)?),
311        }
312    }
313
314    pub async fn setup_worktree(&self) -> Result<PathBuf> {
315        let (repo_owner, repo_name) = self.repo_name()?;
316        let file_name = self.file_name();
317
318        fs::create_dir_all(&*REPOS_DIR)?;
319        fs::create_dir_all(&*WORKTREES_DIR)?;
320
321        let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
322        let repo_lock = lock_repo(&repo_dir).await;
323
324        if !repo_dir.is_dir() {
325            fs::create_dir_all(&repo_dir)?;
326            run_git(&repo_dir, &["init"]).await?;
327            run_git(
328                &repo_dir,
329                &["remote", "add", "origin", &self.example.repository_url],
330            )
331            .await?;
332        }
333
334        // Resolve the example to a revision, fetching it if needed.
335        let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
336        let revision = if let Ok(revision) = revision {
337            revision
338        } else {
339            run_git(
340                &repo_dir,
341                &["fetch", "--depth", "1", "origin", &self.example.revision],
342            )
343            .await?;
344            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
345            if revision != self.example.revision {
346                run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
347            }
348            revision
349        };
350
351        // Create the worktree for this example if needed.
352        let worktree_path = WORKTREES_DIR.join(&file_name);
353        if worktree_path.is_dir() {
354            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
355            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
356            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
357        } else {
358            let worktree_path_string = worktree_path.to_string_lossy();
359            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
360            run_git(
361                &repo_dir,
362                &["worktree", "add", "-f", &worktree_path_string, &file_name],
363            )
364            .await?;
365        }
366        drop(repo_lock);
367
368        // Apply the uncommitted diff for this example.
369        if !self.example.uncommitted_diff.is_empty() {
370            let mut apply_process = smol::process::Command::new("git")
371                .current_dir(&worktree_path)
372                .args(&["apply", "-"])
373                .stdin(std::process::Stdio::piped())
374                .spawn()?;
375
376            let mut stdin = apply_process.stdin.take().unwrap();
377            stdin
378                .write_all(self.example.uncommitted_diff.as_bytes())
379                .await?;
380            stdin.close().await?;
381            drop(stdin);
382
383            let apply_result = apply_process.output().await?;
384            if !apply_result.status.success() {
385                anyhow::bail!(
386                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
387                    apply_result.status,
388                    String::from_utf8_lossy(&apply_result.stderr),
389                    String::from_utf8_lossy(&apply_result.stdout),
390                );
391            }
392        }
393
394        Ok(worktree_path)
395    }
396
397    fn file_name(&self) -> String {
398        self.name
399            .chars()
400            .map(|c| {
401                if c.is_whitespace() {
402                    '-'
403                } else {
404                    c.to_ascii_lowercase()
405                }
406            })
407            .collect()
408    }
409
410    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
411        // git@github.com:owner/repo.git
412        if self.example.repository_url.contains('@') {
413            let (owner, repo) = self
414                .example
415                .repository_url
416                .split_once(':')
417                .context("expected : in git url")?
418                .1
419                .split_once('/')
420                .context("expected / in git url")?;
421            Ok((
422                Cow::Borrowed(owner),
423                Cow::Borrowed(repo.trim_end_matches(".git")),
424            ))
425        // http://github.com/owner/repo.git
426        } else {
427            let url = Url::parse(&self.example.repository_url)?;
428            let mut segments = url.path_segments().context("empty http url")?;
429            let owner = segments
430                .next()
431                .context("expected owner path segment")?
432                .to_string();
433            let repo = segments
434                .next()
435                .context("expected repo path segment")?
436                .trim_end_matches(".git")
437                .to_string();
438            assert!(segments.next().is_none());
439
440            Ok((owner.into(), repo.into()))
441        }
442    }
443
444    pub async fn cursor_position(
445        &self,
446        project: &Entity<Project>,
447        cx: &mut AsyncApp,
448    ) -> Result<(Entity<Buffer>, Anchor)> {
449        let worktree = project.read_with(cx, |project, cx| {
450            project.visible_worktrees(cx).next().unwrap()
451        })?;
452        let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
453        let cursor_buffer = project
454            .update(cx, |project, cx| {
455                project.open_buffer(
456                    ProjectPath {
457                        worktree_id: worktree.read(cx).id(),
458                        path: cursor_path,
459                    },
460                    cx,
461                )
462            })?
463            .await?;
464        let cursor_offset_within_excerpt = self
465            .example
466            .cursor_position
467            .find(CURSOR_MARKER)
468            .ok_or_else(|| anyhow!("missing cursor marker"))?;
469        let mut cursor_excerpt = self.example.cursor_position.clone();
470        cursor_excerpt.replace_range(
471            cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
472            "",
473        );
474        let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
475            let text = buffer.text();
476
477            let mut matches = text.match_indices(&cursor_excerpt);
478            let Some((excerpt_offset, _)) = matches.next() else {
479                anyhow::bail!(
480                    "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
481                );
482            };
483            assert!(matches.next().is_none());
484
485            Ok(excerpt_offset)
486        })??;
487
488        let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
489        let cursor_anchor =
490            cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
491        Ok((cursor_buffer, cursor_anchor))
492    }
493
494    #[must_use]
495    pub async fn apply_edit_history(
496        &self,
497        project: &Entity<Project>,
498        cx: &mut AsyncApp,
499    ) -> Result<OpenedBuffers<'_>> {
500        zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
501    }
502}
503
504fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
505    const MARKER: &str = "[ZETA]";
506    let mut new_text = String::new();
507    let mut required_lines = Vec::new();
508    let mut skipped_lines = 0_u32;
509
510    for (row, mut line) in text.split('\n').enumerate() {
511        if let Some(marker_column) = line.find(MARKER) {
512            let mut strip_column = marker_column;
513
514            while strip_column > 0 {
515                let prev_char = line[strip_column - 1..].chars().next().unwrap();
516                if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
517                    strip_column -= 1;
518                } else {
519                    break;
520                }
521            }
522
523            let metadata = &line[marker_column + MARKER.len()..];
524            if metadata.contains("required") {
525                required_lines.push(Line(row as u32 - skipped_lines));
526            }
527
528            if strip_column == 0 {
529                skipped_lines += 1;
530                continue;
531            }
532
533            line = &line[..strip_column];
534        }
535
536        new_text.push_str(line);
537        new_text.push('\n');
538    }
539
540    new_text.pop();
541
542    (new_text, required_lines)
543}
544
545async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
546    let output = smol::process::Command::new("git")
547        .current_dir(repo_path)
548        .args(args)
549        .output()
550        .await?;
551
552    anyhow::ensure!(
553        output.status.success(),
554        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
555        args.join(" "),
556        repo_path.display(),
557        output.status,
558        String::from_utf8_lossy(&output.stderr),
559        String::from_utf8_lossy(&output.stdout),
560    );
561    Ok(String::from_utf8(output.stdout)?.trim().to_string())
562}
563
564impl Display for NamedExample {
565    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566        write!(f, "# {}\n\n", self.name)?;
567        write!(
568            f,
569            "{REPOSITORY_URL_FIELD} = {}\n",
570            self.example.repository_url
571        )?;
572        write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
573
574        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
575        write!(f, "`````diff\n")?;
576        write!(f, "{}", self.example.uncommitted_diff)?;
577        write!(f, "`````\n")?;
578
579        if !self.example.edit_history.is_empty() {
580            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
581        }
582
583        write!(
584            f,
585            "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
586            self.example.cursor_path.display(),
587            self.example.cursor_position
588        )?;
589        write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
590
591        if !self.example.expected_patch.is_empty() {
592            write!(
593                f,
594                "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
595                self.example.expected_patch
596            )?;
597        }
598
599        if !self.example.expected_context.is_empty() {
600            write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
601
602            for entry in &self.example.expected_context {
603                write!(f, "\n### {}\n\n", entry.heading)?;
604
605                let skip_h4 =
606                    entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
607
608                for excerpt_set in &entry.alternatives {
609                    if !skip_h4 {
610                        write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
611                    }
612
613                    for excerpt in &excerpt_set.excerpts {
614                        write!(
615                            f,
616                            "`````{}{}\n{}`````\n\n",
617                            excerpt
618                                .path
619                                .extension()
620                                .map(|ext| format!("{} ", ext.to_string_lossy()))
621                                .unwrap_or_default(),
622                            excerpt.path.display(),
623                            excerpt.text
624                        )?;
625                    }
626                }
627            }
628        }
629
630        Ok(())
631    }
632}
633
634thread_local! {
635    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
636}
637
638#[must_use]
639pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
640    REPO_LOCKS
641        .with(|cell| {
642            cell.borrow_mut()
643                .entry(path.as_ref().to_path_buf())
644                .or_default()
645                .clone()
646        })
647        .lock_owned()
648        .await
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654    use indoc::indoc;
655    use pretty_assertions::assert_eq;
656
657    #[test]
658    fn test_extract_required_lines() {
659        let input = indoc! {"
660            zero
661            one // [ZETA] required
662            two
663            // [ZETA] something
664            three
665            four # [ZETA] required
666            five
667        "};
668
669        let expected_updated_input = indoc! {"
670            zero
671            one
672            two
673            three
674            four
675            five
676        "};
677
678        let expected_required_lines = vec![Line(1), Line(4)];
679
680        let (updated_input, required_lines) = extract_required_lines(input);
681        assert_eq!(updated_input, expected_updated_input);
682        assert_eq!(required_lines, expected_required_lines);
683    }
684}