example.rs

  1use std::{
  2    borrow::Cow,
  3    cell::RefCell,
  4    fmt::{self, Display},
  5    fs,
  6    io::Write,
  7    mem,
  8    path::{Path, PathBuf},
  9    sync::{Arc, OnceLock},
 10};
 11
 12use crate::headless::ZetaCliAppState;
 13use anyhow::{Context as _, Result, anyhow};
 14use clap::ValueEnum;
 15use cloud_zeta2_prompt::CURSOR_MARKER;
 16use collections::HashMap;
 17use edit_prediction_context::Line;
 18use futures::{
 19    AsyncWriteExt as _,
 20    lock::{Mutex, OwnedMutexGuard},
 21};
 22use futures::{FutureExt as _, future::Shared};
 23use gpui::{AsyncApp, Entity, Task, http_client::Url};
 24use language::{Anchor, Buffer};
 25use project::{Project, ProjectPath};
 26use pulldown_cmark::CowStr;
 27use serde::{Deserialize, Serialize};
 28use util::{paths::PathStyle, rel_path::RelPath};
 29use zeta2::udiff::OpenedBuffers;
 30
 31use crate::paths::{REPOS_DIR, WORKTREES_DIR};
 32
 33const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 34const EDIT_HISTORY_HEADING: &str = "Edit History";
 35const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 36const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
 37const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
 38const REPOSITORY_URL_FIELD: &str = "repository_url";
 39const REVISION_FIELD: &str = "revision";
 40
 41#[derive(Debug, Clone)]
 42pub struct NamedExample {
 43    pub name: String,
 44    pub example: Example,
 45}
 46
 47#[derive(Clone, Debug, Serialize, Deserialize)]
 48pub struct Example {
 49    pub repository_url: String,
 50    pub revision: String,
 51    pub uncommitted_diff: String,
 52    pub cursor_path: PathBuf,
 53    pub cursor_position: String,
 54    pub edit_history: String,
 55    pub expected_patch: String,
 56    pub expected_context: Vec<ExpectedContextEntry>,
 57}
 58
 59pub type ActualExcerpt = Excerpt;
 60
 61#[derive(Clone, Debug, Serialize, Deserialize)]
 62pub struct Excerpt {
 63    pub path: PathBuf,
 64    pub text: String,
 65}
 66
 67#[derive(Default, Clone, Debug, Serialize, Deserialize)]
 68pub struct ExpectedContextEntry {
 69    pub heading: String,
 70    pub alternatives: Vec<ExpectedExcerptSet>,
 71}
 72
 73#[derive(Default, Clone, Debug, Serialize, Deserialize)]
 74pub struct ExpectedExcerptSet {
 75    pub heading: String,
 76    pub excerpts: Vec<ExpectedExcerpt>,
 77}
 78
 79#[derive(Clone, Debug, Serialize, Deserialize)]
 80pub struct ExpectedExcerpt {
 81    pub path: PathBuf,
 82    pub text: String,
 83    pub required_lines: Vec<Line>,
 84}
 85
 86#[derive(ValueEnum, Debug, Clone)]
 87pub enum ExampleFormat {
 88    Json,
 89    Toml,
 90    Md,
 91}
 92
 93impl NamedExample {
 94    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
 95        let path = path.as_ref();
 96        let content = std::fs::read_to_string(path)?;
 97        let ext = path.extension();
 98
 99        match ext.and_then(|s| s.to_str()) {
100            Some("json") => Ok(Self {
101                name: path.file_stem().unwrap_or_default().display().to_string(),
102                example: serde_json::from_str(&content)?,
103            }),
104            Some("toml") => Ok(Self {
105                name: path.file_stem().unwrap_or_default().display().to_string(),
106                example: toml::from_str(&content)?,
107            }),
108            Some("md") => Self::parse_md(&content),
109            Some(_) => {
110                anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
111            }
112            None => {
113                anyhow::bail!(
114                    "Failed to determine example type since the file does not have an extension."
115                );
116            }
117        }
118    }
119
120    pub fn parse_md(input: &str) -> Result<Self> {
121        use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
122
123        let parser = Parser::new(input);
124
125        let mut named = NamedExample {
126            name: String::new(),
127            example: Example {
128                repository_url: String::new(),
129                revision: String::new(),
130                uncommitted_diff: String::new(),
131                cursor_path: PathBuf::new(),
132                cursor_position: String::new(),
133                edit_history: String::new(),
134                expected_patch: String::new(),
135                expected_context: Vec::new(),
136            },
137        };
138
139        let mut text = String::new();
140        let mut block_info: CowStr = "".into();
141
142        #[derive(PartialEq)]
143        enum Section {
144            UncommittedDiff,
145            EditHistory,
146            CursorPosition,
147            ExpectedExcerpts,
148            ExpectedPatch,
149            Other,
150        }
151
152        let mut current_section = Section::Other;
153
154        for event in parser {
155            match event {
156                Event::Text(line) => {
157                    text.push_str(&line);
158
159                    if !named.name.is_empty()
160                        && current_section == Section::Other
161                        // in h1 section
162                        && let Some((field, value)) = line.split_once('=')
163                    {
164                        match field.trim() {
165                            REPOSITORY_URL_FIELD => {
166                                named.example.repository_url = value.trim().to_string();
167                            }
168                            REVISION_FIELD => {
169                                named.example.revision = value.trim().to_string();
170                            }
171                            _ => {}
172                        }
173                    }
174                }
175                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
176                    if !named.name.is_empty() {
177                        anyhow::bail!(
178                            "Found multiple H1 headings. There should only be one with the name of the example."
179                        );
180                    }
181                    named.name = mem::take(&mut text);
182                }
183                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
184                    let title = mem::take(&mut text);
185                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
186                        Section::UncommittedDiff
187                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
188                        Section::EditHistory
189                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
190                        Section::CursorPosition
191                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
192                        Section::ExpectedPatch
193                    } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
194                        Section::ExpectedExcerpts
195                    } else {
196                        Section::Other
197                    };
198                }
199                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
200                    let heading = mem::take(&mut text);
201                    match current_section {
202                        Section::ExpectedExcerpts => {
203                            named.example.expected_context.push(ExpectedContextEntry {
204                                heading,
205                                alternatives: Vec::new(),
206                            });
207                        }
208                        _ => {}
209                    }
210                }
211                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
212                    let heading = mem::take(&mut text);
213                    match current_section {
214                        Section::ExpectedExcerpts => {
215                            let expected_context = &mut named.example.expected_context;
216                            let last_entry = expected_context.last_mut().unwrap();
217                            last_entry.alternatives.push(ExpectedExcerptSet {
218                                heading,
219                                excerpts: Vec::new(),
220                            })
221                        }
222                        _ => {}
223                    }
224                }
225                Event::End(TagEnd::Heading(level)) => {
226                    anyhow::bail!("Unexpected heading level: {level}");
227                }
228                Event::Start(Tag::CodeBlock(kind)) => {
229                    match kind {
230                        CodeBlockKind::Fenced(info) => {
231                            block_info = info;
232                        }
233                        CodeBlockKind::Indented => {
234                            anyhow::bail!("Unexpected indented codeblock");
235                        }
236                    };
237                }
238                Event::Start(_) => {
239                    text.clear();
240                    block_info = "".into();
241                }
242                Event::End(TagEnd::CodeBlock) => {
243                    let block_info = block_info.trim();
244                    match current_section {
245                        Section::UncommittedDiff => {
246                            named.example.uncommitted_diff = mem::take(&mut text);
247                        }
248                        Section::EditHistory => {
249                            named.example.edit_history.push_str(&mem::take(&mut text));
250                        }
251                        Section::CursorPosition => {
252                            named.example.cursor_path = block_info.into();
253                            named.example.cursor_position = mem::take(&mut text);
254                        }
255                        Section::ExpectedExcerpts => {
256                            let text = mem::take(&mut text);
257                            for excerpt in text.split("\n\n") {
258                                let (mut text, required_lines) = extract_required_lines(&excerpt);
259                                if !text.ends_with('\n') {
260                                    text.push('\n');
261                                }
262
263                                if named.example.expected_context.is_empty() {
264                                    named.example.expected_context.push(Default::default());
265                                }
266
267                                let alternatives = &mut named
268                                    .example
269                                    .expected_context
270                                    .last_mut()
271                                    .unwrap()
272                                    .alternatives;
273
274                                if alternatives.is_empty() {
275                                    alternatives.push(ExpectedExcerptSet {
276                                        heading: String::new(),
277                                        excerpts: vec![],
278                                    });
279                                }
280
281                                alternatives
282                                    .last_mut()
283                                    .unwrap()
284                                    .excerpts
285                                    .push(ExpectedExcerpt {
286                                        path: block_info.into(),
287                                        text,
288                                        required_lines,
289                                    });
290                            }
291                        }
292                        Section::ExpectedPatch => {
293                            named.example.expected_patch = mem::take(&mut text);
294                        }
295                        Section::Other => {}
296                    }
297                }
298                _ => {}
299            }
300        }
301
302        if named.example.cursor_path.as_path() == Path::new("")
303            || named.example.cursor_position.is_empty()
304        {
305            anyhow::bail!("Missing cursor position codeblock");
306        }
307
308        Ok(named)
309    }
310
311    pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
312        match format {
313            ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
314            ExampleFormat::Toml => {
315                Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
316            }
317            ExampleFormat::Md => Ok(write!(out, "{}", self)?),
318        }
319    }
320
321    pub async fn setup_project(
322        &self,
323        app_state: &Arc<ZetaCliAppState>,
324        cx: &mut AsyncApp,
325    ) -> Result<Entity<Project>> {
326        let worktree_path = self.setup_worktree().await?;
327
328        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
329
330        AUTHENTICATED
331            .get_or_init(|| {
332                let client = app_state.client.clone();
333                cx.spawn(async move |cx| {
334                    client
335                        .sign_in_with_optional_connect(true, cx)
336                        .await
337                        .unwrap();
338                })
339                .shared()
340            })
341            .clone()
342            .await;
343
344        let project = cx.update(|cx| {
345            Project::local(
346                app_state.client.clone(),
347                app_state.node_runtime.clone(),
348                app_state.user_store.clone(),
349                app_state.languages.clone(),
350                app_state.fs.clone(),
351                None,
352                cx,
353            )
354        })?;
355
356        let worktree = project
357            .update(cx, |project, cx| {
358                project.create_worktree(&worktree_path, true, cx)
359            })?
360            .await?;
361        worktree
362            .read_with(cx, |worktree, _cx| {
363                worktree.as_local().unwrap().scan_complete()
364            })?
365            .await;
366
367        anyhow::Ok(project)
368    }
369
370    pub async fn setup_worktree(&self) -> Result<PathBuf> {
371        let (repo_owner, repo_name) = self.repo_name()?;
372        let file_name = self.file_name();
373
374        let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
375        let repo_lock = lock_repo(&repo_dir).await;
376
377        if !repo_dir.is_dir() {
378            fs::create_dir_all(&repo_dir)?;
379            run_git(&repo_dir, &["init"]).await?;
380            run_git(
381                &repo_dir,
382                &["remote", "add", "origin", &self.example.repository_url],
383            )
384            .await?;
385        }
386
387        // Resolve the example to a revision, fetching it if needed.
388        let revision = run_git(
389            &repo_dir,
390            &[
391                "rev-parse",
392                &format!("{}^{{commit}}", self.example.revision),
393            ],
394        )
395        .await;
396        let revision = if let Ok(revision) = revision {
397            revision
398        } else {
399            run_git(
400                &repo_dir,
401                &["fetch", "--depth", "1", "origin", &self.example.revision],
402            )
403            .await?;
404            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
405            if revision != self.example.revision {
406                run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
407            }
408            revision
409        };
410
411        // Create the worktree for this example if needed.
412        let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
413        if worktree_path.is_dir() {
414            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
415            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
416            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
417        } else {
418            let worktree_path_string = worktree_path.to_string_lossy();
419            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
420            run_git(
421                &repo_dir,
422                &["worktree", "add", "-f", &worktree_path_string, &file_name],
423            )
424            .await?;
425        }
426        drop(repo_lock);
427
428        // Apply the uncommitted diff for this example.
429        if !self.example.uncommitted_diff.is_empty() {
430            let mut apply_process = smol::process::Command::new("git")
431                .current_dir(&worktree_path)
432                .args(&["apply", "-"])
433                .stdin(std::process::Stdio::piped())
434                .spawn()?;
435
436            let mut stdin = apply_process.stdin.take().unwrap();
437            stdin
438                .write_all(self.example.uncommitted_diff.as_bytes())
439                .await?;
440            stdin.close().await?;
441            drop(stdin);
442
443            let apply_result = apply_process.output().await?;
444            if !apply_result.status.success() {
445                anyhow::bail!(
446                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
447                    apply_result.status,
448                    String::from_utf8_lossy(&apply_result.stderr),
449                    String::from_utf8_lossy(&apply_result.stdout),
450                );
451            }
452        }
453
454        Ok(worktree_path)
455    }
456
457    pub fn file_name(&self) -> String {
458        self.name
459            .chars()
460            .map(|c| {
461                if c.is_whitespace() {
462                    '-'
463                } else {
464                    c.to_ascii_lowercase()
465                }
466            })
467            .collect()
468    }
469
470    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
471        // git@github.com:owner/repo.git
472        if self.example.repository_url.contains('@') {
473            let (owner, repo) = self
474                .example
475                .repository_url
476                .split_once(':')
477                .context("expected : in git url")?
478                .1
479                .split_once('/')
480                .context("expected / in git url")?;
481            Ok((
482                Cow::Borrowed(owner),
483                Cow::Borrowed(repo.trim_end_matches(".git")),
484            ))
485        // http://github.com/owner/repo.git
486        } else {
487            let url = Url::parse(&self.example.repository_url)?;
488            let mut segments = url.path_segments().context("empty http url")?;
489            let owner = segments
490                .next()
491                .context("expected owner path segment")?
492                .to_string();
493            let repo = segments
494                .next()
495                .context("expected repo path segment")?
496                .trim_end_matches(".git")
497                .to_string();
498            assert!(segments.next().is_none());
499
500            Ok((owner.into(), repo.into()))
501        }
502    }
503
504    pub async fn cursor_position(
505        &self,
506        project: &Entity<Project>,
507        cx: &mut AsyncApp,
508    ) -> Result<(Entity<Buffer>, Anchor)> {
509        let worktree = project.read_with(cx, |project, cx| {
510            project.visible_worktrees(cx).next().unwrap()
511        })?;
512        let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
513        let cursor_buffer = project
514            .update(cx, |project, cx| {
515                project.open_buffer(
516                    ProjectPath {
517                        worktree_id: worktree.read(cx).id(),
518                        path: cursor_path,
519                    },
520                    cx,
521                )
522            })?
523            .await?;
524        let cursor_offset_within_excerpt = self
525            .example
526            .cursor_position
527            .find(CURSOR_MARKER)
528            .ok_or_else(|| anyhow!("missing cursor marker"))?;
529        let mut cursor_excerpt = self.example.cursor_position.clone();
530        cursor_excerpt.replace_range(
531            cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
532            "",
533        );
534        let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
535            let text = buffer.text();
536
537            let mut matches = text.match_indices(&cursor_excerpt);
538            let Some((excerpt_offset, _)) = matches.next() else {
539                anyhow::bail!(
540                    "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
541                );
542            };
543            assert!(matches.next().is_none());
544
545            Ok(excerpt_offset)
546        })??;
547
548        let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
549        let cursor_anchor =
550            cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
551        Ok((cursor_buffer, cursor_anchor))
552    }
553
554    #[must_use]
555    pub async fn apply_edit_history(
556        &self,
557        project: &Entity<Project>,
558        cx: &mut AsyncApp,
559    ) -> Result<OpenedBuffers<'_>> {
560        zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
561    }
562}
563
564fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
565    const MARKER: &str = "[ZETA]";
566    let mut new_text = String::new();
567    let mut required_lines = Vec::new();
568    let mut skipped_lines = 0_u32;
569
570    for (row, mut line) in text.split('\n').enumerate() {
571        if let Some(marker_column) = line.find(MARKER) {
572            let mut strip_column = marker_column;
573
574            while strip_column > 0 {
575                let prev_char = line[strip_column - 1..].chars().next().unwrap();
576                if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
577                    strip_column -= 1;
578                } else {
579                    break;
580                }
581            }
582
583            let metadata = &line[marker_column + MARKER.len()..];
584            if metadata.contains("required") {
585                required_lines.push(Line(row as u32 - skipped_lines));
586            }
587
588            if strip_column == 0 {
589                skipped_lines += 1;
590                continue;
591            }
592
593            line = &line[..strip_column];
594        }
595
596        new_text.push_str(line);
597        new_text.push('\n');
598    }
599
600    new_text.pop();
601
602    (new_text, required_lines)
603}
604
605async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
606    let output = smol::process::Command::new("git")
607        .current_dir(repo_path)
608        .args(args)
609        .output()
610        .await?;
611
612    anyhow::ensure!(
613        output.status.success(),
614        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
615        args.join(" "),
616        repo_path.display(),
617        output.status,
618        String::from_utf8_lossy(&output.stderr),
619        String::from_utf8_lossy(&output.stdout),
620    );
621    Ok(String::from_utf8(output.stdout)?.trim().to_string())
622}
623
624impl Display for NamedExample {
625    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
626        write!(f, "# {}\n\n", self.name)?;
627        write!(
628            f,
629            "{REPOSITORY_URL_FIELD} = {}\n",
630            self.example.repository_url
631        )?;
632        write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
633
634        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
635        write!(f, "`````diff\n")?;
636        write!(f, "{}", self.example.uncommitted_diff)?;
637        write!(f, "`````\n")?;
638
639        if !self.example.edit_history.is_empty() {
640            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
641        }
642
643        write!(
644            f,
645            "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
646            self.example.cursor_path.display(),
647            self.example.cursor_position
648        )?;
649        write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
650
651        if !self.example.expected_patch.is_empty() {
652            write!(
653                f,
654                "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
655                self.example.expected_patch
656            )?;
657        }
658
659        if !self.example.expected_context.is_empty() {
660            write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
661
662            for entry in &self.example.expected_context {
663                write!(f, "\n### {}\n\n", entry.heading)?;
664
665                let skip_h4 =
666                    entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
667
668                for excerpt_set in &entry.alternatives {
669                    if !skip_h4 {
670                        write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
671                    }
672
673                    for excerpt in &excerpt_set.excerpts {
674                        write!(
675                            f,
676                            "`````{}{}\n{}`````\n\n",
677                            excerpt
678                                .path
679                                .extension()
680                                .map(|ext| format!("{} ", ext.to_string_lossy()))
681                                .unwrap_or_default(),
682                            excerpt.path.display(),
683                            excerpt.text
684                        )?;
685                    }
686                }
687            }
688        }
689
690        Ok(())
691    }
692}
693
694thread_local! {
695    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
696}
697
698#[must_use]
699pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
700    REPO_LOCKS
701        .with(|cell| {
702            cell.borrow_mut()
703                .entry(path.as_ref().to_path_buf())
704                .or_default()
705                .clone()
706        })
707        .lock_owned()
708        .await
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714    use indoc::indoc;
715    use pretty_assertions::assert_eq;
716
717    #[test]
718    fn test_extract_required_lines() {
719        let input = indoc! {"
720            zero
721            one // [ZETA] required
722            two
723            // [ZETA] something
724            three
725            four # [ZETA] required
726            five
727        "};
728
729        let expected_updated_input = indoc! {"
730            zero
731            one
732            two
733            three
734            four
735            five
736        "};
737
738        let expected_required_lines = vec![Line(1), Line(4)];
739
740        let (updated_input, required_lines) = extract_required_lines(input);
741        assert_eq!(updated_input, expected_updated_input);
742        assert_eq!(required_lines, expected_required_lines);
743    }
744}