example.rs

  1use std::{
  2    borrow::Cow,
  3    env,
  4    fmt::{self, Display},
  5    fs,
  6    io::Write,
  7    mem,
  8    path::{Path, PathBuf},
  9};
 10
 11use anyhow::{Context as _, Result};
 12use clap::ValueEnum;
 13use collections::HashSet;
 14use futures::AsyncWriteExt as _;
 15use gpui::{AsyncApp, Entity, http_client::Url};
 16use language::Buffer;
 17use project::{Project, ProjectPath};
 18use pulldown_cmark::CowStr;
 19use serde::{Deserialize, Serialize};
 20
 21const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 22const EDIT_HISTORY_HEADING: &str = "Edit History";
 23const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 24const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
 25const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
 26const REPOSITORY_URL_FIELD: &str = "repository_url";
 27const REVISION_FIELD: &str = "revision";
 28
 29#[derive(Debug)]
 30pub struct NamedExample {
 31    pub name: String,
 32    pub example: Example,
 33}
 34
 35#[derive(Debug, Serialize, Deserialize)]
 36pub struct Example {
 37    pub repository_url: String,
 38    pub revision: String,
 39    pub uncommitted_diff: String,
 40    pub cursor_path: PathBuf,
 41    pub cursor_position: String,
 42    pub edit_history: String,
 43    pub expected_patch: String,
 44    pub expected_excerpts: Vec<ExpectedExcerpt>,
 45}
 46
 47#[derive(Debug, Serialize, Deserialize)]
 48pub struct ExpectedExcerpt {
 49    path: PathBuf,
 50    text: String,
 51}
 52
 53#[derive(ValueEnum, Debug, Clone)]
 54pub enum ExampleFormat {
 55    Json,
 56    Toml,
 57    Md,
 58}
 59
 60impl NamedExample {
 61    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
 62        let path = path.as_ref();
 63        let content = std::fs::read_to_string(path)?;
 64        let ext = path.extension();
 65
 66        match ext.and_then(|s| s.to_str()) {
 67            Some("json") => Ok(Self {
 68                name: path.file_stem().unwrap_or_default().display().to_string(),
 69                example: serde_json::from_str(&content)?,
 70            }),
 71            Some("toml") => Ok(Self {
 72                name: path.file_stem().unwrap_or_default().display().to_string(),
 73                example: toml::from_str(&content)?,
 74            }),
 75            Some("md") => Self::parse_md(&content),
 76            Some(_) => {
 77                anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
 78            }
 79            None => {
 80                anyhow::bail!(
 81                    "Failed to determine example type since the file does not have an extension."
 82                );
 83            }
 84        }
 85    }
 86
 87    pub fn parse_md(input: &str) -> Result<Self> {
 88        use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
 89
 90        let parser = Parser::new(input);
 91
 92        let mut named = NamedExample {
 93            name: String::new(),
 94            example: Example {
 95                repository_url: String::new(),
 96                revision: String::new(),
 97                uncommitted_diff: String::new(),
 98                cursor_path: PathBuf::new(),
 99                cursor_position: String::new(),
100                edit_history: String::new(),
101                expected_patch: String::new(),
102                expected_excerpts: Vec::new(),
103            },
104        };
105
106        let mut text = String::new();
107        let mut current_section = String::new();
108        let mut block_info: CowStr = "".into();
109
110        for event in parser {
111            match event {
112                Event::Text(line) => {
113                    text.push_str(&line);
114
115                    if !named.name.is_empty()
116                        && current_section.is_empty()
117                        // in h1 section
118                        && let Some((field, value)) = line.split_once('=')
119                    {
120                        match field.trim() {
121                            REPOSITORY_URL_FIELD => {
122                                named.example.repository_url = value.trim().to_string();
123                            }
124                            REVISION_FIELD => {
125                                named.example.revision = value.trim().to_string();
126                            }
127                            _ => {
128                                eprintln!("Warning: Unrecognized field `{field}`");
129                            }
130                        }
131                    }
132                }
133                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
134                    if !named.name.is_empty() {
135                        anyhow::bail!(
136                            "Found multiple H1 headings. There should only be one with the name of the example."
137                        );
138                    }
139                    named.name = mem::take(&mut text);
140                }
141                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
142                    current_section = mem::take(&mut text);
143                }
144                Event::End(TagEnd::Heading(level)) => {
145                    anyhow::bail!("Unexpected heading level: {level}");
146                }
147                Event::Start(Tag::CodeBlock(kind)) => {
148                    match kind {
149                        CodeBlockKind::Fenced(info) => {
150                            block_info = info;
151                        }
152                        CodeBlockKind::Indented => {
153                            anyhow::bail!("Unexpected indented codeblock");
154                        }
155                    };
156                }
157                Event::Start(_) => {
158                    text.clear();
159                    block_info = "".into();
160                }
161                Event::End(TagEnd::CodeBlock) => {
162                    let block_info = block_info.trim();
163                    if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
164                        named.example.uncommitted_diff = mem::take(&mut text);
165                    } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
166                        named.example.edit_history.push_str(&mem::take(&mut text));
167                    } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
168                        named.example.cursor_path = block_info.into();
169                        named.example.cursor_position = mem::take(&mut text);
170                    } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
171                        named.example.expected_patch = mem::take(&mut text);
172                    } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
173                        named.example.expected_excerpts.push(ExpectedExcerpt {
174                            path: block_info.into(),
175                            text: mem::take(&mut text),
176                        });
177                    } else {
178                        eprintln!("Warning: Unrecognized section `{current_section:?}`")
179                    }
180                }
181                _ => {}
182            }
183        }
184
185        if named.example.cursor_path.as_path() == Path::new("")
186            || named.example.cursor_position.is_empty()
187        {
188            anyhow::bail!("Missing cursor position codeblock");
189        }
190
191        Ok(named)
192    }
193
194    pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
195        match format {
196            ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
197            ExampleFormat::Toml => {
198                Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
199            }
200            ExampleFormat::Md => Ok(write!(out, "{}", self)?),
201        }
202    }
203
204    #[allow(unused)]
205    pub async fn setup_worktree(&self) -> Result<PathBuf> {
206        let (repo_owner, repo_name) = self.repo_name()?;
207        let file_name = self.file_name();
208
209        let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
210        let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
211        fs::create_dir_all(&repos_dir)?;
212        fs::create_dir_all(&worktrees_dir)?;
213
214        let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
215        if !repo_dir.is_dir() {
216            fs::create_dir_all(&repo_dir)?;
217            run_git(&repo_dir, &["init"]).await?;
218            run_git(
219                &repo_dir,
220                &["remote", "add", "origin", &self.example.repository_url],
221            )
222            .await?;
223        }
224
225        // Resolve the example to a revision, fetching it if needed.
226        let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
227        let revision = if let Ok(revision) = revision {
228            revision
229        } else {
230            run_git(
231                &repo_dir,
232                &["fetch", "--depth", "1", "origin", &self.example.revision],
233            )
234            .await?;
235            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
236            if revision != self.example.revision {
237                run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
238            }
239            revision
240        };
241
242        // Create the worktree for this example if needed.
243        let worktree_path = worktrees_dir.join(&file_name);
244        if worktree_path.is_dir() {
245            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
246            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
247            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
248        } else {
249            let worktree_path_string = worktree_path.to_string_lossy();
250            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
251            run_git(
252                &repo_dir,
253                &["worktree", "add", "-f", &worktree_path_string, &file_name],
254            )
255            .await?;
256        }
257
258        // Apply the uncommitted diff for this example.
259        if !self.example.uncommitted_diff.is_empty() {
260            let mut apply_process = smol::process::Command::new("git")
261                .current_dir(&worktree_path)
262                .args(&["apply", "-"])
263                .stdin(std::process::Stdio::piped())
264                .spawn()?;
265
266            let mut stdin = apply_process.stdin.take().unwrap();
267            stdin
268                .write_all(self.example.uncommitted_diff.as_bytes())
269                .await?;
270            stdin.close().await?;
271            drop(stdin);
272
273            let apply_result = apply_process.output().await?;
274            if !apply_result.status.success() {
275                anyhow::bail!(
276                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
277                    apply_result.status,
278                    String::from_utf8_lossy(&apply_result.stderr),
279                    String::from_utf8_lossy(&apply_result.stdout),
280                );
281            }
282        }
283
284        Ok(worktree_path)
285    }
286
287    fn file_name(&self) -> String {
288        self.name
289            .chars()
290            .map(|c| {
291                if c.is_whitespace() {
292                    '-'
293                } else {
294                    c.to_ascii_lowercase()
295                }
296            })
297            .collect()
298    }
299
300    #[allow(unused)]
301    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
302        // git@github.com:owner/repo.git
303        if self.example.repository_url.contains('@') {
304            let (owner, repo) = self
305                .example
306                .repository_url
307                .split_once(':')
308                .context("expected : in git url")?
309                .1
310                .split_once('/')
311                .context("expected / in git url")?;
312            Ok((
313                Cow::Borrowed(owner),
314                Cow::Borrowed(repo.trim_end_matches(".git")),
315            ))
316        // http://github.com/owner/repo.git
317        } else {
318            let url = Url::parse(&self.example.repository_url)?;
319            let mut segments = url.path_segments().context("empty http url")?;
320            let owner = segments
321                .next()
322                .context("expected owner path segment")?
323                .to_string();
324            let repo = segments
325                .next()
326                .context("expected repo path segment")?
327                .trim_end_matches(".git")
328                .to_string();
329            assert!(segments.next().is_none());
330
331            Ok((owner.into(), repo.into()))
332        }
333    }
334
335    #[must_use]
336    pub async fn apply_edit_history(
337        &self,
338        project: &Entity<Project>,
339        cx: &mut AsyncApp,
340    ) -> Result<HashSet<Entity<Buffer>>> {
341        apply_diff(&self.example.edit_history, project, cx).await
342    }
343}
344
345async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
346    let output = smol::process::Command::new("git")
347        .current_dir(repo_path)
348        .args(args)
349        .output()
350        .await?;
351
352    anyhow::ensure!(
353        output.status.success(),
354        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
355        args.join(" "),
356        repo_path.display(),
357        output.status,
358        String::from_utf8_lossy(&output.stderr),
359        String::from_utf8_lossy(&output.stdout),
360    );
361    Ok(String::from_utf8(output.stdout)?.trim().to_string())
362}
363
364impl Display for NamedExample {
365    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366        write!(f, "# {}\n\n", self.name)?;
367        write!(
368            f,
369            "{REPOSITORY_URL_FIELD} = {}\n",
370            self.example.repository_url
371        )?;
372        write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
373
374        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
375        write!(f, "`````diff\n")?;
376        write!(f, "{}", self.example.uncommitted_diff)?;
377        write!(f, "`````\n")?;
378
379        if !self.example.edit_history.is_empty() {
380            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
381        }
382
383        write!(
384            f,
385            "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
386            self.example.cursor_path.display(),
387            self.example.cursor_position
388        )?;
389        write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
390
391        if !self.example.expected_patch.is_empty() {
392            write!(
393                f,
394                "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
395                self.example.expected_patch
396            )?;
397        }
398
399        if !self.example.expected_excerpts.is_empty() {
400            write!(f, "\n## {EXPECTED_EXCERPTS_HEADING}\n\n")?;
401
402            for excerpt in &self.example.expected_excerpts {
403                write!(
404                    f,
405                    "`````{}{}\n{}`````\n\n",
406                    excerpt
407                        .path
408                        .extension()
409                        .map(|ext| format!("{} ", ext.to_string_lossy()))
410                        .unwrap_or_default(),
411                    excerpt.path.display(),
412                    excerpt.text
413                )?;
414            }
415        }
416
417        Ok(())
418    }
419}
420
421#[must_use]
422pub async fn apply_diff(
423    diff: &str,
424    project: &Entity<Project>,
425    cx: &mut AsyncApp,
426) -> Result<HashSet<Entity<Buffer>>> {
427    use cloud_llm_client::udiff::DiffLine;
428    use std::fmt::Write;
429
430    #[derive(Debug, Default)]
431    struct Edit {
432        context: String,
433        deletion_start: Option<usize>,
434        addition: String,
435    }
436
437    let mut old_path = None;
438    let mut new_path = None;
439    let mut pending = Edit::default();
440    let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
441    let mut open_buffers = HashSet::default();
442
443    while let Some(diff_line) = diff_lines.next() {
444        match diff_line {
445            DiffLine::OldPath { path } => {
446                mem::take(&mut pending);
447                old_path = Some(path)
448            }
449            DiffLine::HunkHeader(_) => {
450                mem::take(&mut pending);
451            }
452            DiffLine::NewPath { path } => {
453                if old_path.is_none() {
454                    anyhow::bail!(
455                        "Found a new path header (`+++`) before an (`---`) old path header"
456                    );
457                }
458                new_path = Some(path)
459            }
460            DiffLine::Context(ctx) => {
461                writeln!(&mut pending.context, "{ctx}")?;
462            }
463            DiffLine::Deletion(del) => {
464                pending.deletion_start.get_or_insert(pending.context.len());
465                writeln!(&mut pending.context, "{del}")?;
466            }
467            DiffLine::Addition(add) => {
468                if pending.context.is_empty() {
469                    anyhow::bail!("Found an addition before any context or deletion lines");
470                }
471
472                writeln!(&mut pending.addition, "{add}")?;
473            }
474            DiffLine::Garbage => {}
475        }
476
477        let commit_pending = match diff_lines.peek() {
478            Some(DiffLine::OldPath { .. })
479            | Some(DiffLine::HunkHeader(_))
480            | Some(DiffLine::Context(_))
481            | None => {
482                // commit pending edit cluster
483                !pending.addition.is_empty() || pending.deletion_start.is_some()
484            }
485            Some(DiffLine::Deletion(_)) => {
486                // start a new cluster if we have any additions specifically
487                // if we only have deletions, we continue to aggregate them
488                !pending.addition.is_empty()
489            }
490            _ => false,
491        };
492
493        if commit_pending {
494            let edit = mem::take(&mut pending);
495
496            let Some(old_path) = old_path.as_deref() else {
497                anyhow::bail!("Missing old path (`---`) header")
498            };
499
500            let Some(new_path) = new_path.as_deref() else {
501                anyhow::bail!("Missing new path (`+++`) header")
502            };
503
504            let buffer = project
505                .update(cx, |project, cx| {
506                    let project_path = project
507                        .find_project_path(old_path, cx)
508                        .context("Failed to find old_path in project")?;
509
510                    anyhow::Ok(project.open_buffer(project_path, cx))
511                })??
512                .await?;
513            open_buffers.insert(buffer.clone());
514
515            if old_path != new_path {
516                project
517                    .update(cx, |project, cx| {
518                        let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
519                        let new_path = ProjectPath {
520                            worktree_id: project_file.worktree_id(cx),
521                            path: project_file.path.clone(),
522                        };
523                        project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
524                    })?
525                    .await?;
526            }
527
528            // TODO is it worth using project search?
529            buffer.update(cx, |buffer, cx| {
530                let text = buffer.text();
531                // todo! check there's only one
532                if let Some(context_offset) = text.find(&edit.context) {
533                    let end = context_offset + edit.context.len();
534                    let start = if let Some(deletion_start) = edit.deletion_start {
535                        context_offset + deletion_start
536                    } else {
537                        end
538                    };
539
540                    buffer.edit([(start..end, edit.addition)], None, cx);
541
542                    anyhow::Ok(())
543                } else {
544                    anyhow::bail!("Failed to match context:\n{}", edit.context);
545                }
546            })??;
547        }
548    }
549
550    anyhow::Ok(open_buffers)
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556    use fs::FakeFs;
557    use gpui::TestAppContext;
558    use indoc::indoc;
559    use project::Project;
560    use serde_json::json;
561
562    #[gpui::test]
563    async fn test_apply_diff(cx: &mut TestAppContext) {
564        let buffer_1_text = indoc! {r#"
565            one
566            two
567            three
568            four
569            five
570        "# };
571
572        let buffer_2_text = indoc! {r#"
573            six
574            seven
575            eight
576            nine
577            ten
578        "# };
579
580        let fs = FakeFs::new(cx.background_executor().clone());
581        fs.insert_tree(
582            "/root",
583            json!({
584                "file1": buffer_1_text,
585                "file2": buffer_2_text,
586            }),
587        )
588        .await;
589
590        let project = Project::test(fs, ["/root".as_ref()], cx).await;
591
592        let diff = indoc! {r#"
593            --- a/root/file1
594            +++ b/root/file1
595             one
596             two
597            -three
598            +3
599             four
600             five
601        "#};
602
603        let _buffers = apply_diff(diff, &project, &mut cx.to_async())
604            .await
605            .unwrap();
606        let buffer_1 = project
607            .update(cx, |project, cx| {
608                let project_path = project.find_project_path("/root/file1", cx).unwrap();
609                project.open_buffer(project_path, cx)
610            })?
611            .await?;
612
613        buffer_1.read_with(cx, |buffer, cx| {
614            pretty_assertions::assert_eq!(buffer.text())
615        })
616    }
617}