example.rs

  1use std::{
  2    borrow::Cow,
  3    cell::RefCell,
  4    fmt::{self, Display},
  5    fs,
  6    io::Write,
  7    mem,
  8    path::{Path, PathBuf},
  9    sync::Arc,
 10};
 11
 12use anyhow::{Context as _, Result, anyhow};
 13use clap::ValueEnum;
 14use cloud_zeta2_prompt::CURSOR_MARKER;
 15use collections::HashMap;
 16use edit_prediction_context::Line;
 17use futures::{
 18    AsyncWriteExt as _,
 19    lock::{Mutex, OwnedMutexGuard},
 20};
 21use gpui::{AsyncApp, Entity, http_client::Url};
 22use language::{Anchor, Buffer};
 23use project::{Project, ProjectPath};
 24use pulldown_cmark::CowStr;
 25use serde::{Deserialize, Serialize};
 26use util::{paths::PathStyle, rel_path::RelPath};
 27use zeta2::udiff::OpenedBuffers;
 28
 29use crate::paths::{REPOS_DIR, WORKTREES_DIR};
 30
 31const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 32const EDIT_HISTORY_HEADING: &str = "Edit History";
 33const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 34const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
 35const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
 36const REPOSITORY_URL_FIELD: &str = "repository_url";
 37const REVISION_FIELD: &str = "revision";
 38
 39#[derive(Debug, Clone)]
 40pub struct NamedExample {
 41    pub name: String,
 42    pub example: Example,
 43}
 44
 45#[derive(Clone, Debug, Serialize, Deserialize)]
 46pub struct Example {
 47    pub repository_url: String,
 48    pub revision: String,
 49    pub uncommitted_diff: String,
 50    pub cursor_path: PathBuf,
 51    pub cursor_position: String,
 52    pub edit_history: String,
 53    pub expected_patch: String,
 54    pub expected_context: Vec<ExpectedContextEntry>,
 55}
 56
 57pub type ActualExcerpt = Excerpt;
 58
 59#[derive(Clone, Debug, Serialize, Deserialize)]
 60pub struct Excerpt {
 61    pub path: PathBuf,
 62    pub text: String,
 63}
 64
 65#[derive(Default, Clone, Debug, Serialize, Deserialize)]
 66pub struct ExpectedContextEntry {
 67    pub heading: String,
 68    pub alternatives: Vec<ExpectedExcerptSet>,
 69}
 70
 71#[derive(Default, Clone, Debug, Serialize, Deserialize)]
 72pub struct ExpectedExcerptSet {
 73    pub heading: String,
 74    pub excerpts: Vec<ExpectedExcerpt>,
 75}
 76
 77#[derive(Clone, Debug, Serialize, Deserialize)]
 78pub struct ExpectedExcerpt {
 79    pub path: PathBuf,
 80    pub text: String,
 81    pub required_lines: Vec<Line>,
 82}
 83
 84#[derive(ValueEnum, Debug, Clone)]
 85pub enum ExampleFormat {
 86    Json,
 87    Toml,
 88    Md,
 89}
 90
 91impl NamedExample {
 92    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
 93        let path = path.as_ref();
 94        let content = std::fs::read_to_string(path)?;
 95        let ext = path.extension();
 96
 97        match ext.and_then(|s| s.to_str()) {
 98            Some("json") => Ok(Self {
 99                name: path.file_stem().unwrap_or_default().display().to_string(),
100                example: serde_json::from_str(&content)?,
101            }),
102            Some("toml") => Ok(Self {
103                name: path.file_stem().unwrap_or_default().display().to_string(),
104                example: toml::from_str(&content)?,
105            }),
106            Some("md") => Self::parse_md(&content),
107            Some(_) => {
108                anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
109            }
110            None => {
111                anyhow::bail!(
112                    "Failed to determine example type since the file does not have an extension."
113                );
114            }
115        }
116    }
117
118    pub fn parse_md(input: &str) -> Result<Self> {
119        use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
120
121        let parser = Parser::new(input);
122
123        let mut named = NamedExample {
124            name: String::new(),
125            example: Example {
126                repository_url: String::new(),
127                revision: String::new(),
128                uncommitted_diff: String::new(),
129                cursor_path: PathBuf::new(),
130                cursor_position: String::new(),
131                edit_history: String::new(),
132                expected_patch: String::new(),
133                expected_context: Vec::new(),
134            },
135        };
136
137        let mut text = String::new();
138        let mut block_info: CowStr = "".into();
139
140        #[derive(PartialEq)]
141        enum Section {
142            UncommittedDiff,
143            EditHistory,
144            CursorPosition,
145            ExpectedExcerpts,
146            ExpectedPatch,
147            Other,
148        }
149
150        let mut current_section = Section::Other;
151
152        for event in parser {
153            match event {
154                Event::Text(line) => {
155                    text.push_str(&line);
156
157                    if !named.name.is_empty()
158                        && current_section == Section::Other
159                        // in h1 section
160                        && let Some((field, value)) = line.split_once('=')
161                    {
162                        match field.trim() {
163                            REPOSITORY_URL_FIELD => {
164                                named.example.repository_url = value.trim().to_string();
165                            }
166                            REVISION_FIELD => {
167                                named.example.revision = value.trim().to_string();
168                            }
169                            _ => {
170                                eprintln!("Warning: Unrecognized field `{field}`");
171                            }
172                        }
173                    }
174                }
175                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
176                    if !named.name.is_empty() {
177                        anyhow::bail!(
178                            "Found multiple H1 headings. There should only be one with the name of the example."
179                        );
180                    }
181                    named.name = mem::take(&mut text);
182                }
183                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
184                    let title = mem::take(&mut text);
185                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
186                        Section::UncommittedDiff
187                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
188                        Section::EditHistory
189                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
190                        Section::CursorPosition
191                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
192                        Section::ExpectedPatch
193                    } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
194                        Section::ExpectedExcerpts
195                    } else {
196                        eprintln!("Warning: Unrecognized section `{title:?}`");
197                        Section::Other
198                    };
199                }
200                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
201                    let heading = mem::take(&mut text);
202                    match current_section {
203                        Section::ExpectedExcerpts => {
204                            named.example.expected_context.push(ExpectedContextEntry {
205                                heading,
206                                alternatives: Vec::new(),
207                            });
208                        }
209                        _ => {}
210                    }
211                }
212                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
213                    let heading = mem::take(&mut text);
214                    match current_section {
215                        Section::ExpectedExcerpts => {
216                            let expected_context = &mut named.example.expected_context;
217                            let last_entry = expected_context.last_mut().unwrap();
218                            last_entry.alternatives.push(ExpectedExcerptSet {
219                                heading,
220                                excerpts: Vec::new(),
221                            })
222                        }
223                        _ => {}
224                    }
225                }
226                Event::End(TagEnd::Heading(level)) => {
227                    anyhow::bail!("Unexpected heading level: {level}");
228                }
229                Event::Start(Tag::CodeBlock(kind)) => {
230                    match kind {
231                        CodeBlockKind::Fenced(info) => {
232                            block_info = info;
233                        }
234                        CodeBlockKind::Indented => {
235                            anyhow::bail!("Unexpected indented codeblock");
236                        }
237                    };
238                }
239                Event::Start(_) => {
240                    text.clear();
241                    block_info = "".into();
242                }
243                Event::End(TagEnd::CodeBlock) => {
244                    let block_info = block_info.trim();
245                    match current_section {
246                        Section::UncommittedDiff => {
247                            named.example.uncommitted_diff = mem::take(&mut text);
248                        }
249                        Section::EditHistory => {
250                            named.example.edit_history.push_str(&mem::take(&mut text));
251                        }
252                        Section::CursorPosition => {
253                            named.example.cursor_path = block_info.into();
254                            named.example.cursor_position = mem::take(&mut text);
255                        }
256                        Section::ExpectedExcerpts => {
257                            let text = mem::take(&mut text);
258                            for excerpt in text.split("\n\n") {
259                                let (mut text, required_lines) = extract_required_lines(&excerpt);
260                                if !text.ends_with('\n') {
261                                    text.push('\n');
262                                }
263                                let alternatives = &mut named
264                                    .example
265                                    .expected_context
266                                    .last_mut()
267                                    .unwrap()
268                                    .alternatives;
269
270                                if alternatives.is_empty() {
271                                    alternatives.push(ExpectedExcerptSet {
272                                        heading: String::new(),
273                                        excerpts: vec![],
274                                    });
275                                }
276
277                                alternatives
278                                    .last_mut()
279                                    .unwrap()
280                                    .excerpts
281                                    .push(ExpectedExcerpt {
282                                        path: block_info.into(),
283                                        text,
284                                        required_lines,
285                                    });
286                            }
287                        }
288                        Section::ExpectedPatch => {
289                            named.example.expected_patch = mem::take(&mut text);
290                        }
291                        Section::Other => {}
292                    }
293                }
294                _ => {}
295            }
296        }
297
298        if named.example.cursor_path.as_path() == Path::new("")
299            || named.example.cursor_position.is_empty()
300        {
301            anyhow::bail!("Missing cursor position codeblock");
302        }
303
304        Ok(named)
305    }
306
307    pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
308        match format {
309            ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
310            ExampleFormat::Toml => {
311                Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
312            }
313            ExampleFormat::Md => Ok(write!(out, "{}", self)?),
314        }
315    }
316
317    pub async fn setup_worktree(&self) -> Result<PathBuf> {
318        let (repo_owner, repo_name) = self.repo_name()?;
319        let file_name = self.file_name();
320
321        fs::create_dir_all(&*REPOS_DIR)?;
322        fs::create_dir_all(&*WORKTREES_DIR)?;
323
324        let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
325        let repo_lock = lock_repo(&repo_dir).await;
326
327        if !repo_dir.is_dir() {
328            fs::create_dir_all(&repo_dir)?;
329            run_git(&repo_dir, &["init"]).await?;
330            run_git(
331                &repo_dir,
332                &["remote", "add", "origin", &self.example.repository_url],
333            )
334            .await?;
335        }
336
337        // Resolve the example to a revision, fetching it if needed.
338        let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
339        let revision = if let Ok(revision) = revision {
340            revision
341        } else {
342            run_git(
343                &repo_dir,
344                &["fetch", "--depth", "1", "origin", &self.example.revision],
345            )
346            .await?;
347            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
348            if revision != self.example.revision {
349                run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
350            }
351            revision
352        };
353
354        // Create the worktree for this example if needed.
355        let worktree_path = WORKTREES_DIR.join(&file_name);
356        if worktree_path.is_dir() {
357            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
358            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
359            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
360        } else {
361            let worktree_path_string = worktree_path.to_string_lossy();
362            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
363            run_git(
364                &repo_dir,
365                &["worktree", "add", "-f", &worktree_path_string, &file_name],
366            )
367            .await?;
368        }
369        drop(repo_lock);
370
371        // Apply the uncommitted diff for this example.
372        if !self.example.uncommitted_diff.is_empty() {
373            let mut apply_process = smol::process::Command::new("git")
374                .current_dir(&worktree_path)
375                .args(&["apply", "-"])
376                .stdin(std::process::Stdio::piped())
377                .spawn()?;
378
379            let mut stdin = apply_process.stdin.take().unwrap();
380            stdin
381                .write_all(self.example.uncommitted_diff.as_bytes())
382                .await?;
383            stdin.close().await?;
384            drop(stdin);
385
386            let apply_result = apply_process.output().await?;
387            if !apply_result.status.success() {
388                anyhow::bail!(
389                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
390                    apply_result.status,
391                    String::from_utf8_lossy(&apply_result.stderr),
392                    String::from_utf8_lossy(&apply_result.stdout),
393                );
394            }
395        }
396
397        Ok(worktree_path)
398    }
399
400    fn file_name(&self) -> String {
401        self.name
402            .chars()
403            .map(|c| {
404                if c.is_whitespace() {
405                    '-'
406                } else {
407                    c.to_ascii_lowercase()
408                }
409            })
410            .collect()
411    }
412
413    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
414        // git@github.com:owner/repo.git
415        if self.example.repository_url.contains('@') {
416            let (owner, repo) = self
417                .example
418                .repository_url
419                .split_once(':')
420                .context("expected : in git url")?
421                .1
422                .split_once('/')
423                .context("expected / in git url")?;
424            Ok((
425                Cow::Borrowed(owner),
426                Cow::Borrowed(repo.trim_end_matches(".git")),
427            ))
428        // http://github.com/owner/repo.git
429        } else {
430            let url = Url::parse(&self.example.repository_url)?;
431            let mut segments = url.path_segments().context("empty http url")?;
432            let owner = segments
433                .next()
434                .context("expected owner path segment")?
435                .to_string();
436            let repo = segments
437                .next()
438                .context("expected repo path segment")?
439                .trim_end_matches(".git")
440                .to_string();
441            assert!(segments.next().is_none());
442
443            Ok((owner.into(), repo.into()))
444        }
445    }
446
447    pub async fn cursor_position(
448        &self,
449        project: &Entity<Project>,
450        cx: &mut AsyncApp,
451    ) -> Result<(Entity<Buffer>, Anchor)> {
452        let worktree = project.read_with(cx, |project, cx| {
453            project.visible_worktrees(cx).next().unwrap()
454        })?;
455        let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
456        let cursor_buffer = project
457            .update(cx, |project, cx| {
458                project.open_buffer(
459                    ProjectPath {
460                        worktree_id: worktree.read(cx).id(),
461                        path: cursor_path,
462                    },
463                    cx,
464                )
465            })?
466            .await?;
467        let cursor_offset_within_excerpt = self
468            .example
469            .cursor_position
470            .find(CURSOR_MARKER)
471            .ok_or_else(|| anyhow!("missing cursor marker"))?;
472        let mut cursor_excerpt = self.example.cursor_position.clone();
473        cursor_excerpt.replace_range(
474            cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
475            "",
476        );
477        let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
478            let text = buffer.text();
479
480            let mut matches = text.match_indices(&cursor_excerpt);
481            let Some((excerpt_offset, _)) = matches.next() else {
482                anyhow::bail!(
483                    "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
484                );
485            };
486            assert!(matches.next().is_none());
487
488            Ok(excerpt_offset)
489        })??;
490
491        let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
492        let cursor_anchor =
493            cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
494        Ok((cursor_buffer, cursor_anchor))
495    }
496
497    #[must_use]
498    pub async fn apply_edit_history(
499        &self,
500        project: &Entity<Project>,
501        cx: &mut AsyncApp,
502    ) -> Result<OpenedBuffers<'_>> {
503        zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
504    }
505}
506
507fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
508    const MARKER: &str = "[ZETA]";
509    let mut new_text = String::new();
510    let mut required_lines = Vec::new();
511    let mut skipped_lines = 0_u32;
512
513    for (row, mut line) in text.split('\n').enumerate() {
514        if let Some(marker_column) = line.find(MARKER) {
515            let mut strip_column = marker_column;
516
517            while strip_column > 0 {
518                let prev_char = line[strip_column - 1..].chars().next().unwrap();
519                if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
520                    strip_column -= 1;
521                } else {
522                    break;
523                }
524            }
525
526            let metadata = &line[marker_column + MARKER.len()..];
527            if metadata.contains("required") {
528                required_lines.push(Line(row as u32 - skipped_lines));
529            }
530
531            if strip_column == 0 {
532                skipped_lines += 1;
533                continue;
534            }
535
536            line = &line[..strip_column];
537        }
538
539        new_text.push_str(line);
540        new_text.push('\n');
541    }
542
543    new_text.pop();
544
545    (new_text, required_lines)
546}
547
548async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
549    let output = smol::process::Command::new("git")
550        .current_dir(repo_path)
551        .args(args)
552        .output()
553        .await?;
554
555    anyhow::ensure!(
556        output.status.success(),
557        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
558        args.join(" "),
559        repo_path.display(),
560        output.status,
561        String::from_utf8_lossy(&output.stderr),
562        String::from_utf8_lossy(&output.stdout),
563    );
564    Ok(String::from_utf8(output.stdout)?.trim().to_string())
565}
566
567impl Display for NamedExample {
568    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
569        write!(f, "# {}\n\n", self.name)?;
570        write!(
571            f,
572            "{REPOSITORY_URL_FIELD} = {}\n",
573            self.example.repository_url
574        )?;
575        write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
576
577        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
578        write!(f, "`````diff\n")?;
579        write!(f, "{}", self.example.uncommitted_diff)?;
580        write!(f, "`````\n")?;
581
582        if !self.example.edit_history.is_empty() {
583            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
584        }
585
586        write!(
587            f,
588            "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
589            self.example.cursor_path.display(),
590            self.example.cursor_position
591        )?;
592        write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
593
594        if !self.example.expected_patch.is_empty() {
595            write!(
596                f,
597                "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
598                self.example.expected_patch
599            )?;
600        }
601
602        if !self.example.expected_context.is_empty() {
603            write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
604
605            for entry in &self.example.expected_context {
606                write!(f, "\n### {}\n\n", entry.heading)?;
607
608                let skip_h4 =
609                    entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
610
611                for excerpt_set in &entry.alternatives {
612                    if !skip_h4 {
613                        write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
614                    }
615
616                    for excerpt in &excerpt_set.excerpts {
617                        write!(
618                            f,
619                            "`````{}{}\n{}`````\n\n",
620                            excerpt
621                                .path
622                                .extension()
623                                .map(|ext| format!("{} ", ext.to_string_lossy()))
624                                .unwrap_or_default(),
625                            excerpt.path.display(),
626                            excerpt.text
627                        )?;
628                    }
629                }
630            }
631        }
632
633        Ok(())
634    }
635}
636
637thread_local! {
638    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
639}
640
641#[must_use]
642pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
643    REPO_LOCKS
644        .with(|cell| {
645            cell.borrow_mut()
646                .entry(path.as_ref().to_path_buf())
647                .or_default()
648                .clone()
649        })
650        .lock_owned()
651        .await
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use indoc::indoc;
658    use pretty_assertions::assert_eq;
659
660    #[test]
661    fn test_extract_required_lines() {
662        let input = indoc! {"
663            zero
664            one // [ZETA] required
665            two
666            // [ZETA] something
667            three
668            four # [ZETA] required
669            five
670        "};
671
672        let expected_updated_input = indoc! {"
673            zero
674            one
675            two
676            three
677            four
678            five
679        "};
680
681        let expected_required_lines = vec![Line(1), Line(4)];
682
683        let (updated_input, required_lines) = extract_required_lines(input);
684        assert_eq!(updated_input, expected_updated_input);
685        assert_eq!(required_lines, expected_required_lines);
686    }
687}