example.rs

  1use std::{
  2    borrow::Cow,
  3    cell::RefCell,
  4    fmt::{self, Display},
  5    fs,
  6    io::Write,
  7    mem,
  8    path::{Path, PathBuf},
  9    sync::Arc,
 10};
 11
 12use anyhow::{Context as _, Result, anyhow};
 13use clap::ValueEnum;
 14use cloud_zeta2_prompt::CURSOR_MARKER;
 15use collections::HashMap;
 16use edit_prediction_context::Line;
 17use futures::{
 18    AsyncWriteExt as _,
 19    lock::{Mutex, OwnedMutexGuard},
 20};
 21use gpui::{AsyncApp, Entity, http_client::Url};
 22use language::{Anchor, Buffer};
 23use project::{Project, ProjectPath};
 24use pulldown_cmark::CowStr;
 25use serde::{Deserialize, Serialize};
 26use util::{paths::PathStyle, rel_path::RelPath};
 27use zeta2::udiff::OpenedBuffers;
 28
 29use crate::paths::{REPOS_DIR, WORKTREES_DIR};
 30
 31const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 32const EDIT_HISTORY_HEADING: &str = "Edit History";
 33const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 34const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
 35const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
 36const REPOSITORY_URL_FIELD: &str = "repository_url";
 37const REVISION_FIELD: &str = "revision";
 38
 39#[derive(Debug, Clone)]
 40pub struct NamedExample {
 41    pub name: String,
 42    pub example: Example,
 43}
 44
 45#[derive(Clone, Debug, Serialize, Deserialize)]
 46pub struct Example {
 47    pub repository_url: String,
 48    pub revision: String,
 49    pub uncommitted_diff: String,
 50    pub cursor_path: PathBuf,
 51    pub cursor_position: String,
 52    pub edit_history: String,
 53    pub expected_patch: String,
 54    pub expected_context: Vec<ExpectedContextEntry>,
 55}
 56
 57pub type ActualExcerpt = Excerpt;
 58
 59#[derive(Clone, Debug, Serialize, Deserialize)]
 60pub struct Excerpt {
 61    pub path: PathBuf,
 62    pub text: String,
 63}
 64
 65#[derive(Default, Clone, Debug, Serialize, Deserialize)]
 66pub struct ExpectedContextEntry {
 67    pub heading: String,
 68    pub alternatives: Vec<ExpectedExcerptSet>,
 69}
 70
 71#[derive(Default, Clone, Debug, Serialize, Deserialize)]
 72pub struct ExpectedExcerptSet {
 73    pub heading: String,
 74    pub excerpts: Vec<ExpectedExcerpt>,
 75}
 76
 77#[derive(Clone, Debug, Serialize, Deserialize)]
 78pub struct ExpectedExcerpt {
 79    pub path: PathBuf,
 80    pub text: String,
 81    pub required_lines: Vec<Line>,
 82}
 83
 84#[derive(ValueEnum, Debug, Clone)]
 85pub enum ExampleFormat {
 86    Json,
 87    Toml,
 88    Md,
 89}
 90
 91impl NamedExample {
 92    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
 93        let path = path.as_ref();
 94        let content = std::fs::read_to_string(path)?;
 95        let ext = path.extension();
 96
 97        match ext.and_then(|s| s.to_str()) {
 98            Some("json") => Ok(Self {
 99                name: path.file_stem().unwrap_or_default().display().to_string(),
100                example: serde_json::from_str(&content)?,
101            }),
102            Some("toml") => Ok(Self {
103                name: path.file_stem().unwrap_or_default().display().to_string(),
104                example: toml::from_str(&content)?,
105            }),
106            Some("md") => Self::parse_md(&content),
107            Some(_) => {
108                anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
109            }
110            None => {
111                anyhow::bail!(
112                    "Failed to determine example type since the file does not have an extension."
113                );
114            }
115        }
116    }
117
118    pub fn parse_md(input: &str) -> Result<Self> {
119        use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
120
121        let parser = Parser::new(input);
122
123        let mut named = NamedExample {
124            name: String::new(),
125            example: Example {
126                repository_url: String::new(),
127                revision: String::new(),
128                uncommitted_diff: String::new(),
129                cursor_path: PathBuf::new(),
130                cursor_position: String::new(),
131                edit_history: String::new(),
132                expected_patch: String::new(),
133                expected_context: Vec::new(),
134            },
135        };
136
137        let mut text = String::new();
138        let mut block_info: CowStr = "".into();
139
140        #[derive(PartialEq)]
141        enum Section {
142            UncommittedDiff,
143            EditHistory,
144            CursorPosition,
145            ExpectedExcerpts,
146            ExpectedPatch,
147            Other,
148        }
149
150        let mut current_section = Section::Other;
151
152        for event in parser {
153            match event {
154                Event::Text(line) => {
155                    text.push_str(&line);
156
157                    if !named.name.is_empty()
158                        && current_section == Section::Other
159                        // in h1 section
160                        && let Some((field, value)) = line.split_once('=')
161                    {
162                        match field.trim() {
163                            REPOSITORY_URL_FIELD => {
164                                named.example.repository_url = value.trim().to_string();
165                            }
166                            REVISION_FIELD => {
167                                named.example.revision = value.trim().to_string();
168                            }
169                            _ => {}
170                        }
171                    }
172                }
173                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
174                    if !named.name.is_empty() {
175                        anyhow::bail!(
176                            "Found multiple H1 headings. There should only be one with the name of the example."
177                        );
178                    }
179                    named.name = mem::take(&mut text);
180                }
181                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
182                    let title = mem::take(&mut text);
183                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
184                        Section::UncommittedDiff
185                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
186                        Section::EditHistory
187                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
188                        Section::CursorPosition
189                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
190                        Section::ExpectedPatch
191                    } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
192                        Section::ExpectedExcerpts
193                    } else {
194                        Section::Other
195                    };
196                }
197                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
198                    let heading = mem::take(&mut text);
199                    match current_section {
200                        Section::ExpectedExcerpts => {
201                            named.example.expected_context.push(ExpectedContextEntry {
202                                heading,
203                                alternatives: Vec::new(),
204                            });
205                        }
206                        _ => {}
207                    }
208                }
209                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
210                    let heading = mem::take(&mut text);
211                    match current_section {
212                        Section::ExpectedExcerpts => {
213                            let expected_context = &mut named.example.expected_context;
214                            let last_entry = expected_context.last_mut().unwrap();
215                            last_entry.alternatives.push(ExpectedExcerptSet {
216                                heading,
217                                excerpts: Vec::new(),
218                            })
219                        }
220                        _ => {}
221                    }
222                }
223                Event::End(TagEnd::Heading(level)) => {
224                    anyhow::bail!("Unexpected heading level: {level}");
225                }
226                Event::Start(Tag::CodeBlock(kind)) => {
227                    match kind {
228                        CodeBlockKind::Fenced(info) => {
229                            block_info = info;
230                        }
231                        CodeBlockKind::Indented => {
232                            anyhow::bail!("Unexpected indented codeblock");
233                        }
234                    };
235                }
236                Event::Start(_) => {
237                    text.clear();
238                    block_info = "".into();
239                }
240                Event::End(TagEnd::CodeBlock) => {
241                    let block_info = block_info.trim();
242                    match current_section {
243                        Section::UncommittedDiff => {
244                            named.example.uncommitted_diff = mem::take(&mut text);
245                        }
246                        Section::EditHistory => {
247                            named.example.edit_history.push_str(&mem::take(&mut text));
248                        }
249                        Section::CursorPosition => {
250                            named.example.cursor_path = block_info.into();
251                            named.example.cursor_position = mem::take(&mut text);
252                        }
253                        Section::ExpectedExcerpts => {
254                            let text = mem::take(&mut text);
255                            for excerpt in text.split("\n\n") {
256                                let (mut text, required_lines) = extract_required_lines(&excerpt);
257                                if !text.ends_with('\n') {
258                                    text.push('\n');
259                                }
260                                let alternatives = &mut named
261                                    .example
262                                    .expected_context
263                                    .last_mut()
264                                    .unwrap()
265                                    .alternatives;
266
267                                if alternatives.is_empty() {
268                                    alternatives.push(ExpectedExcerptSet {
269                                        heading: String::new(),
270                                        excerpts: vec![],
271                                    });
272                                }
273
274                                alternatives
275                                    .last_mut()
276                                    .unwrap()
277                                    .excerpts
278                                    .push(ExpectedExcerpt {
279                                        path: block_info.into(),
280                                        text,
281                                        required_lines,
282                                    });
283                            }
284                        }
285                        Section::ExpectedPatch => {
286                            named.example.expected_patch = mem::take(&mut text);
287                        }
288                        Section::Other => {}
289                    }
290                }
291                _ => {}
292            }
293        }
294
295        if named.example.cursor_path.as_path() == Path::new("")
296            || named.example.cursor_position.is_empty()
297        {
298            anyhow::bail!("Missing cursor position codeblock");
299        }
300
301        Ok(named)
302    }
303
304    pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
305        match format {
306            ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
307            ExampleFormat::Toml => {
308                Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
309            }
310            ExampleFormat::Md => Ok(write!(out, "{}", self)?),
311        }
312    }
313
314    pub async fn setup_worktree(&self) -> Result<PathBuf> {
315        let (repo_owner, repo_name) = self.repo_name()?;
316        let file_name = self.file_name();
317
318        let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
319        let repo_lock = lock_repo(&repo_dir).await;
320
321        if !repo_dir.is_dir() {
322            fs::create_dir_all(&repo_dir)?;
323            run_git(&repo_dir, &["init"]).await?;
324            run_git(
325                &repo_dir,
326                &["remote", "add", "origin", &self.example.repository_url],
327            )
328            .await?;
329        }
330
331        // Resolve the example to a revision, fetching it if needed.
332        let revision = run_git(
333            &repo_dir,
334            &[
335                "rev-parse",
336                &format!("{}^{{commit}}", self.example.revision),
337            ],
338        )
339        .await;
340        let revision = if let Ok(revision) = revision {
341            revision
342        } else {
343            run_git(
344                &repo_dir,
345                &["fetch", "--depth", "1", "origin", &self.example.revision],
346            )
347            .await?;
348            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
349            if revision != self.example.revision {
350                run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
351            }
352            revision
353        };
354
355        // Create the worktree for this example if needed.
356        let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
357        if worktree_path.is_dir() {
358            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
359            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
360            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
361        } else {
362            let worktree_path_string = worktree_path.to_string_lossy();
363            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
364            run_git(
365                &repo_dir,
366                &["worktree", "add", "-f", &worktree_path_string, &file_name],
367            )
368            .await?;
369        }
370        drop(repo_lock);
371
372        // Apply the uncommitted diff for this example.
373        if !self.example.uncommitted_diff.is_empty() {
374            let mut apply_process = smol::process::Command::new("git")
375                .current_dir(&worktree_path)
376                .args(&["apply", "-"])
377                .stdin(std::process::Stdio::piped())
378                .spawn()?;
379
380            let mut stdin = apply_process.stdin.take().unwrap();
381            stdin
382                .write_all(self.example.uncommitted_diff.as_bytes())
383                .await?;
384            stdin.close().await?;
385            drop(stdin);
386
387            let apply_result = apply_process.output().await?;
388            if !apply_result.status.success() {
389                anyhow::bail!(
390                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
391                    apply_result.status,
392                    String::from_utf8_lossy(&apply_result.stderr),
393                    String::from_utf8_lossy(&apply_result.stdout),
394                );
395            }
396        }
397
398        Ok(worktree_path)
399    }
400
401    pub fn file_name(&self) -> String {
402        self.name
403            .chars()
404            .map(|c| {
405                if c.is_whitespace() {
406                    '-'
407                } else {
408                    c.to_ascii_lowercase()
409                }
410            })
411            .collect()
412    }
413
414    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
415        // git@github.com:owner/repo.git
416        if self.example.repository_url.contains('@') {
417            let (owner, repo) = self
418                .example
419                .repository_url
420                .split_once(':')
421                .context("expected : in git url")?
422                .1
423                .split_once('/')
424                .context("expected / in git url")?;
425            Ok((
426                Cow::Borrowed(owner),
427                Cow::Borrowed(repo.trim_end_matches(".git")),
428            ))
429        // http://github.com/owner/repo.git
430        } else {
431            let url = Url::parse(&self.example.repository_url)?;
432            let mut segments = url.path_segments().context("empty http url")?;
433            let owner = segments
434                .next()
435                .context("expected owner path segment")?
436                .to_string();
437            let repo = segments
438                .next()
439                .context("expected repo path segment")?
440                .trim_end_matches(".git")
441                .to_string();
442            assert!(segments.next().is_none());
443
444            Ok((owner.into(), repo.into()))
445        }
446    }
447
448    pub async fn cursor_position(
449        &self,
450        project: &Entity<Project>,
451        cx: &mut AsyncApp,
452    ) -> Result<(Entity<Buffer>, Anchor)> {
453        let worktree = project.read_with(cx, |project, cx| {
454            project.visible_worktrees(cx).next().unwrap()
455        })?;
456        let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
457        let cursor_buffer = project
458            .update(cx, |project, cx| {
459                project.open_buffer(
460                    ProjectPath {
461                        worktree_id: worktree.read(cx).id(),
462                        path: cursor_path,
463                    },
464                    cx,
465                )
466            })?
467            .await?;
468        let cursor_offset_within_excerpt = self
469            .example
470            .cursor_position
471            .find(CURSOR_MARKER)
472            .ok_or_else(|| anyhow!("missing cursor marker"))?;
473        let mut cursor_excerpt = self.example.cursor_position.clone();
474        cursor_excerpt.replace_range(
475            cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
476            "",
477        );
478        let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
479            let text = buffer.text();
480
481            let mut matches = text.match_indices(&cursor_excerpt);
482            let Some((excerpt_offset, _)) = matches.next() else {
483                anyhow::bail!(
484                    "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
485                );
486            };
487            assert!(matches.next().is_none());
488
489            Ok(excerpt_offset)
490        })??;
491
492        let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
493        let cursor_anchor =
494            cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
495        Ok((cursor_buffer, cursor_anchor))
496    }
497
498    #[must_use]
499    pub async fn apply_edit_history(
500        &self,
501        project: &Entity<Project>,
502        cx: &mut AsyncApp,
503    ) -> Result<OpenedBuffers<'_>> {
504        zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
505    }
506}
507
508fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
509    const MARKER: &str = "[ZETA]";
510    let mut new_text = String::new();
511    let mut required_lines = Vec::new();
512    let mut skipped_lines = 0_u32;
513
514    for (row, mut line) in text.split('\n').enumerate() {
515        if let Some(marker_column) = line.find(MARKER) {
516            let mut strip_column = marker_column;
517
518            while strip_column > 0 {
519                let prev_char = line[strip_column - 1..].chars().next().unwrap();
520                if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
521                    strip_column -= 1;
522                } else {
523                    break;
524                }
525            }
526
527            let metadata = &line[marker_column + MARKER.len()..];
528            if metadata.contains("required") {
529                required_lines.push(Line(row as u32 - skipped_lines));
530            }
531
532            if strip_column == 0 {
533                skipped_lines += 1;
534                continue;
535            }
536
537            line = &line[..strip_column];
538        }
539
540        new_text.push_str(line);
541        new_text.push('\n');
542    }
543
544    new_text.pop();
545
546    (new_text, required_lines)
547}
548
549async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
550    let output = smol::process::Command::new("git")
551        .current_dir(repo_path)
552        .args(args)
553        .output()
554        .await?;
555
556    anyhow::ensure!(
557        output.status.success(),
558        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
559        args.join(" "),
560        repo_path.display(),
561        output.status,
562        String::from_utf8_lossy(&output.stderr),
563        String::from_utf8_lossy(&output.stdout),
564    );
565    Ok(String::from_utf8(output.stdout)?.trim().to_string())
566}
567
568impl Display for NamedExample {
569    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
570        write!(f, "# {}\n\n", self.name)?;
571        write!(
572            f,
573            "{REPOSITORY_URL_FIELD} = {}\n",
574            self.example.repository_url
575        )?;
576        write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
577
578        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
579        write!(f, "`````diff\n")?;
580        write!(f, "{}", self.example.uncommitted_diff)?;
581        write!(f, "`````\n")?;
582
583        if !self.example.edit_history.is_empty() {
584            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
585        }
586
587        write!(
588            f,
589            "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
590            self.example.cursor_path.display(),
591            self.example.cursor_position
592        )?;
593        write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
594
595        if !self.example.expected_patch.is_empty() {
596            write!(
597                f,
598                "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
599                self.example.expected_patch
600            )?;
601        }
602
603        if !self.example.expected_context.is_empty() {
604            write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
605
606            for entry in &self.example.expected_context {
607                write!(f, "\n### {}\n\n", entry.heading)?;
608
609                let skip_h4 =
610                    entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
611
612                for excerpt_set in &entry.alternatives {
613                    if !skip_h4 {
614                        write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
615                    }
616
617                    for excerpt in &excerpt_set.excerpts {
618                        write!(
619                            f,
620                            "`````{}{}\n{}`````\n\n",
621                            excerpt
622                                .path
623                                .extension()
624                                .map(|ext| format!("{} ", ext.to_string_lossy()))
625                                .unwrap_or_default(),
626                            excerpt.path.display(),
627                            excerpt.text
628                        )?;
629                    }
630                }
631            }
632        }
633
634        Ok(())
635    }
636}
637
638thread_local! {
639    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
640}
641
642#[must_use]
643pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
644    REPO_LOCKS
645        .with(|cell| {
646            cell.borrow_mut()
647                .entry(path.as_ref().to_path_buf())
648                .or_default()
649                .clone()
650        })
651        .lock_owned()
652        .await
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658    use indoc::indoc;
659    use pretty_assertions::assert_eq;
660
661    #[test]
662    fn test_extract_required_lines() {
663        let input = indoc! {"
664            zero
665            one // [ZETA] required
666            two
667            // [ZETA] something
668            three
669            four # [ZETA] required
670            five
671        "};
672
673        let expected_updated_input = indoc! {"
674            zero
675            one
676            two
677            three
678            four
679            five
680        "};
681
682        let expected_required_lines = vec![Line(1), Line(4)];
683
684        let (updated_input, required_lines) = extract_required_lines(input);
685        assert_eq!(updated_input, expected_updated_input);
686        assert_eq!(required_lines, expected_required_lines);
687    }
688}