1use std::{
  2    borrow::Cow,
  3    env,
  4    fmt::{self, Display},
  5    fs,
  6    io::Write,
  7    mem,
  8    ops::Range,
  9    path::{Path, PathBuf},
 10};
 11
 12use anyhow::{Context as _, Result};
 13use clap::ValueEnum;
 14use collections::HashSet;
 15use futures::AsyncWriteExt as _;
 16use gpui::{AsyncApp, Entity, http_client::Url};
 17use language::Buffer;
 18use project::{Project, ProjectPath};
 19use pulldown_cmark::CowStr;
 20use serde::{Deserialize, Serialize};
 21
 22const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 23const EDIT_HISTORY_HEADING: &str = "Edit History";
 24const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 25const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
 26const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
 27const REPOSITORY_URL_FIELD: &str = "repository_url";
 28const REVISION_FIELD: &str = "revision";
 29
 30#[derive(Debug)]
 31pub struct NamedExample {
 32    pub name: String,
 33    pub example: Example,
 34}
 35
 36#[derive(Debug, Serialize, Deserialize)]
 37pub struct Example {
 38    pub repository_url: String,
 39    pub revision: String,
 40    pub uncommitted_diff: String,
 41    pub cursor_path: PathBuf,
 42    pub cursor_position: String,
 43    pub edit_history: String,
 44    pub expected_patch: String,
 45    pub expected_excerpts: Vec<ExpectedExcerpt>,
 46}
 47
 48#[derive(Debug, Serialize, Deserialize)]
 49pub struct ExpectedExcerpt {
 50    path: PathBuf,
 51    text: String,
 52}
 53
 54#[derive(ValueEnum, Debug, Clone)]
 55pub enum ExampleFormat {
 56    Json,
 57    Toml,
 58    Md,
 59}
 60
 61impl NamedExample {
 62    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
 63        let path = path.as_ref();
 64        let content = std::fs::read_to_string(path)?;
 65        let ext = path.extension();
 66
 67        match ext.and_then(|s| s.to_str()) {
 68            Some("json") => Ok(Self {
 69                name: path.file_stem().unwrap_or_default().display().to_string(),
 70                example: serde_json::from_str(&content)?,
 71            }),
 72            Some("toml") => Ok(Self {
 73                name: path.file_stem().unwrap_or_default().display().to_string(),
 74                example: toml::from_str(&content)?,
 75            }),
 76            Some("md") => Self::parse_md(&content),
 77            Some(_) => {
 78                anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
 79            }
 80            None => {
 81                anyhow::bail!(
 82                    "Failed to determine example type since the file does not have an extension."
 83                );
 84            }
 85        }
 86    }
 87
 88    pub fn parse_md(input: &str) -> Result<Self> {
 89        use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
 90
 91        let parser = Parser::new(input);
 92
 93        let mut named = NamedExample {
 94            name: String::new(),
 95            example: Example {
 96                repository_url: String::new(),
 97                revision: String::new(),
 98                uncommitted_diff: String::new(),
 99                cursor_path: PathBuf::new(),
100                cursor_position: String::new(),
101                edit_history: String::new(),
102                expected_patch: String::new(),
103                expected_excerpts: Vec::new(),
104            },
105        };
106
107        let mut text = String::new();
108        let mut current_section = String::new();
109        let mut block_info: CowStr = "".into();
110
111        for event in parser {
112            match event {
113                Event::Text(line) => {
114                    text.push_str(&line);
115
116                    if !named.name.is_empty()
117                        && current_section.is_empty()
118                        // in h1 section
119                        && let Some((field, value)) = line.split_once('=')
120                    {
121                        match field.trim() {
122                            REPOSITORY_URL_FIELD => {
123                                named.example.repository_url = value.trim().to_string();
124                            }
125                            REVISION_FIELD => {
126                                named.example.revision = value.trim().to_string();
127                            }
128                            _ => {
129                                eprintln!("Warning: Unrecognized field `{field}`");
130                            }
131                        }
132                    }
133                }
134                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
135                    if !named.name.is_empty() {
136                        anyhow::bail!(
137                            "Found multiple H1 headings. There should only be one with the name of the example."
138                        );
139                    }
140                    named.name = mem::take(&mut text);
141                }
142                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
143                    current_section = mem::take(&mut text);
144                }
145                Event::End(TagEnd::Heading(level)) => {
146                    anyhow::bail!("Unexpected heading level: {level}");
147                }
148                Event::Start(Tag::CodeBlock(kind)) => {
149                    match kind {
150                        CodeBlockKind::Fenced(info) => {
151                            block_info = info;
152                        }
153                        CodeBlockKind::Indented => {
154                            anyhow::bail!("Unexpected indented codeblock");
155                        }
156                    };
157                }
158                Event::Start(_) => {
159                    text.clear();
160                    block_info = "".into();
161                }
162                Event::End(TagEnd::CodeBlock) => {
163                    let block_info = block_info.trim();
164                    if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
165                        named.example.uncommitted_diff = mem::take(&mut text);
166                    } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
167                        named.example.edit_history.push_str(&mem::take(&mut text));
168                    } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
169                        named.example.cursor_path = block_info.into();
170                        named.example.cursor_position = mem::take(&mut text);
171                    } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
172                        named.example.expected_patch = mem::take(&mut text);
173                    } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
174                        named.example.expected_excerpts.push(ExpectedExcerpt {
175                            path: block_info.into(),
176                            text: mem::take(&mut text),
177                        });
178                    } else {
179                        eprintln!("Warning: Unrecognized section `{current_section:?}`")
180                    }
181                }
182                _ => {}
183            }
184        }
185
186        if named.example.cursor_path.as_path() == Path::new("")
187            || named.example.cursor_position.is_empty()
188        {
189            anyhow::bail!("Missing cursor position codeblock");
190        }
191
192        Ok(named)
193    }
194
195    pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
196        match format {
197            ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
198            ExampleFormat::Toml => {
199                Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
200            }
201            ExampleFormat::Md => Ok(write!(out, "{}", self)?),
202        }
203    }
204
205    #[allow(unused)]
206    pub async fn setup_worktree(&self) -> Result<PathBuf> {
207        let (repo_owner, repo_name) = self.repo_name()?;
208        let file_name = self.file_name();
209
210        let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
211        let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
212        fs::create_dir_all(&repos_dir)?;
213        fs::create_dir_all(&worktrees_dir)?;
214
215        let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
216        if !repo_dir.is_dir() {
217            fs::create_dir_all(&repo_dir)?;
218            run_git(&repo_dir, &["init"]).await?;
219            run_git(
220                &repo_dir,
221                &["remote", "add", "origin", &self.example.repository_url],
222            )
223            .await?;
224        }
225
226        // Resolve the example to a revision, fetching it if needed.
227        let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
228        let revision = if let Ok(revision) = revision {
229            revision
230        } else {
231            run_git(
232                &repo_dir,
233                &["fetch", "--depth", "1", "origin", &self.example.revision],
234            )
235            .await?;
236            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
237            if revision != self.example.revision {
238                run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
239            }
240            revision
241        };
242
243        // Create the worktree for this example if needed.
244        let worktree_path = worktrees_dir.join(&file_name);
245        if worktree_path.is_dir() {
246            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
247            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
248            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
249        } else {
250            let worktree_path_string = worktree_path.to_string_lossy();
251            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
252            run_git(
253                &repo_dir,
254                &["worktree", "add", "-f", &worktree_path_string, &file_name],
255            )
256            .await?;
257        }
258
259        // Apply the uncommitted diff for this example.
260        if !self.example.uncommitted_diff.is_empty() {
261            let mut apply_process = smol::process::Command::new("git")
262                .current_dir(&worktree_path)
263                .args(&["apply", "-"])
264                .stdin(std::process::Stdio::piped())
265                .spawn()?;
266
267            let mut stdin = apply_process.stdin.take().unwrap();
268            stdin
269                .write_all(self.example.uncommitted_diff.as_bytes())
270                .await?;
271            stdin.close().await?;
272            drop(stdin);
273
274            let apply_result = apply_process.output().await?;
275            if !apply_result.status.success() {
276                anyhow::bail!(
277                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
278                    apply_result.status,
279                    String::from_utf8_lossy(&apply_result.stderr),
280                    String::from_utf8_lossy(&apply_result.stdout),
281                );
282            }
283        }
284
285        Ok(worktree_path)
286    }
287
288    fn file_name(&self) -> String {
289        self.name
290            .chars()
291            .map(|c| {
292                if c.is_whitespace() {
293                    '-'
294                } else {
295                    c.to_ascii_lowercase()
296                }
297            })
298            .collect()
299    }
300
301    #[allow(unused)]
302    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
303        // git@github.com:owner/repo.git
304        if self.example.repository_url.contains('@') {
305            let (owner, repo) = self
306                .example
307                .repository_url
308                .split_once(':')
309                .context("expected : in git url")?
310                .1
311                .split_once('/')
312                .context("expected / in git url")?;
313            Ok((
314                Cow::Borrowed(owner),
315                Cow::Borrowed(repo.trim_end_matches(".git")),
316            ))
317        // http://github.com/owner/repo.git
318        } else {
319            let url = Url::parse(&self.example.repository_url)?;
320            let mut segments = url.path_segments().context("empty http url")?;
321            let owner = segments
322                .next()
323                .context("expected owner path segment")?
324                .to_string();
325            let repo = segments
326                .next()
327                .context("expected repo path segment")?
328                .trim_end_matches(".git")
329                .to_string();
330            assert!(segments.next().is_none());
331
332            Ok((owner.into(), repo.into()))
333        }
334    }
335
336    #[must_use]
337    pub async fn apply_edit_history(
338        &self,
339        project: &Entity<Project>,
340        cx: &mut AsyncApp,
341    ) -> Result<HashSet<Entity<Buffer>>> {
342        apply_diff(&self.example.edit_history, project, cx).await
343    }
344}
345
346async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
347    let output = smol::process::Command::new("git")
348        .current_dir(repo_path)
349        .args(args)
350        .output()
351        .await?;
352
353    anyhow::ensure!(
354        output.status.success(),
355        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
356        args.join(" "),
357        repo_path.display(),
358        output.status,
359        String::from_utf8_lossy(&output.stderr),
360        String::from_utf8_lossy(&output.stdout),
361    );
362    Ok(String::from_utf8(output.stdout)?.trim().to_string())
363}
364
365impl Display for NamedExample {
366    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367        write!(f, "# {}\n\n", self.name)?;
368        write!(
369            f,
370            "{REPOSITORY_URL_FIELD} = {}\n",
371            self.example.repository_url
372        )?;
373        write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
374
375        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
376        write!(f, "`````diff\n")?;
377        write!(f, "{}", self.example.uncommitted_diff)?;
378        write!(f, "`````\n")?;
379
380        if !self.example.edit_history.is_empty() {
381            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
382        }
383
384        write!(
385            f,
386            "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
387            self.example.cursor_path.display(),
388            self.example.cursor_position
389        )?;
390        write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
391
392        if !self.example.expected_patch.is_empty() {
393            write!(
394                f,
395                "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
396                self.example.expected_patch
397            )?;
398        }
399
400        if !self.example.expected_excerpts.is_empty() {
401            write!(f, "\n## {EXPECTED_EXCERPTS_HEADING}\n\n")?;
402
403            for excerpt in &self.example.expected_excerpts {
404                write!(
405                    f,
406                    "`````{}{}\n{}`````\n\n",
407                    excerpt
408                        .path
409                        .extension()
410                        .map(|ext| format!("{} ", ext.to_string_lossy()))
411                        .unwrap_or_default(),
412                    excerpt.path.display(),
413                    excerpt.text
414                )?;
415            }
416        }
417
418        Ok(())
419    }
420}
421
422#[must_use]
423pub async fn apply_diff(
424    diff: &str,
425    project: &Entity<Project>,
426    cx: &mut AsyncApp,
427) -> Result<HashSet<Entity<Buffer>>> {
428    use cloud_llm_client::udiff::DiffLine;
429    use std::fmt::Write;
430
431    #[derive(Debug, Default)]
432    struct HunkState {
433        context: String,
434        edits: Vec<Edit>,
435    }
436
437    // #[derive(Debug, Default)]
438    // struct Edit {
439    //     deletion_start: Option<usize>,
440    //     addition: String,
441    // }
442
443    #[derive(Debug)]
444    struct Edit {
445        range: Range<usize>,
446        text: String,
447    }
448
449    let mut old_path = None;
450    let mut new_path = None;
451    let mut hunk = HunkState::default();
452    let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
453    let mut open_buffers = HashSet::default();
454
455    while let Some(diff_line) = diff_lines.next() {
456        match diff_line {
457            DiffLine::OldPath { path } => old_path = Some(path),
458            DiffLine::NewPath { path } => {
459                if old_path.is_none() {
460                    anyhow::bail!(
461                        "Found a new path header (`+++`) before an (`---`) old path header"
462                    );
463                }
464                new_path = Some(path)
465            }
466            DiffLine::Context(ctx) => {
467                writeln!(&mut hunk.context, "{ctx}")?;
468            }
469            DiffLine::Deletion(del) => {
470                let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8();
471                if let Some(last_edit) = hunk.edits.last_mut()
472                    && last_edit.range.end == range.start
473                {
474                    last_edit.range.end = range.end;
475                } else {
476                    hunk.edits.push(Edit {
477                        range,
478                        text: String::new(),
479                    });
480                }
481                writeln!(&mut hunk.context, "{del}")?;
482            }
483            DiffLine::Addition(add) => {
484                let range = hunk.context.len()..hunk.context.len();
485                if let Some(last_edit) = hunk.edits.last_mut()
486                    && last_edit.range.end == range.start
487                {
488                    writeln!(&mut last_edit.text, "{add}").unwrap();
489                } else {
490                    hunk.edits.push(Edit {
491                        range,
492                        text: format!("{add}\n"),
493                    });
494                }
495            }
496            DiffLine::HunkHeader(_) | DiffLine::Garbage => {}
497        }
498
499        let at_hunk_end = match diff_lines.peek() {
500            Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true,
501            _ => false,
502        };
503
504        if at_hunk_end {
505            let hunk = mem::take(&mut hunk);
506
507            let Some(old_path) = old_path.as_deref() else {
508                anyhow::bail!("Missing old path (`---`) header")
509            };
510
511            let Some(new_path) = new_path.as_deref() else {
512                anyhow::bail!("Missing new path (`+++`) header")
513            };
514
515            let buffer = project
516                .update(cx, |project, cx| {
517                    let project_path = project
518                        .find_project_path(old_path, cx)
519                        .context("Failed to find old_path in project")?;
520
521                    anyhow::Ok(project.open_buffer(project_path, cx))
522                })??
523                .await?;
524            open_buffers.insert(buffer.clone());
525
526            if old_path != new_path {
527                project
528                    .update(cx, |project, cx| {
529                        let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
530                        let new_path = ProjectPath {
531                            worktree_id: project_file.worktree_id(cx),
532                            path: project_file.path.clone(),
533                        };
534                        project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
535                    })?
536                    .await?;
537            }
538
539            // TODO is it worth using project search?
540            buffer.update(cx, |buffer, cx| {
541                let context_offset = if hunk.context.is_empty() {
542                    0
543                } else {
544                    let text = buffer.text();
545                    if let Some(offset) = text.find(&hunk.context) {
546                        if text[offset + 1..].find(&hunk.context).is_some() {
547                            anyhow::bail!("Context is not unique enough:\n{}", hunk.context);
548                        }
549                        offset
550                    } else {
551                        anyhow::bail!(
552                            "Failed to match context:\n{}\n\nBuffer:\n{}",
553                            hunk.context,
554                            text
555                        );
556                    }
557                };
558
559                buffer.edit(
560                    hunk.edits.into_iter().map(|edit| {
561                        (
562                            context_offset + edit.range.start..context_offset + edit.range.end,
563                            edit.text,
564                        )
565                    }),
566                    None,
567                    cx,
568                );
569
570                anyhow::Ok(())
571            })??;
572        }
573    }
574
575    anyhow::Ok(open_buffers)
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use ::fs::FakeFs;
582    use gpui::TestAppContext;
583    use indoc::indoc;
584    use pretty_assertions::assert_eq;
585    use project::Project;
586    use serde_json::json;
587    use settings::SettingsStore;
588
589    #[gpui::test]
590    async fn test_apply_diff_successful(cx: &mut TestAppContext) {
591        let buffer_1_text = indoc! {r#"
592            one
593            two
594            three
595            four
596            five
597        "# };
598
599        let buffer_1_text_final = indoc! {r#"
600            3
601            4
602            5
603        "# };
604
605        let buffer_2_text = indoc! {r#"
606            six
607            seven
608            eight
609            nine
610            ten
611        "# };
612
613        let buffer_2_text_final = indoc! {r#"
614            5
615            six
616            seven
617            7.5
618            eight
619            nine
620            ten
621            11
622        "# };
623
624        cx.update(|cx| {
625            let settings_store = SettingsStore::test(cx);
626            cx.set_global(settings_store);
627            Project::init_settings(cx);
628            language::init(cx);
629        });
630
631        let fs = FakeFs::new(cx.background_executor().clone());
632        fs.insert_tree(
633            "/root",
634            json!({
635                "file1": buffer_1_text,
636                "file2": buffer_2_text,
637            }),
638        )
639        .await;
640
641        let project = Project::test(fs, ["/root".as_ref()], cx).await;
642
643        let diff = indoc! {r#"
644            --- a/root/file1
645            +++ b/root/file1
646             one
647             two
648            -three
649            +3
650             four
651             five
652            --- a/root/file1
653            +++ b/root/file1
654             3
655            -four
656            -five
657            +4
658            +5
659            --- a/root/file1
660            +++ b/root/file1
661            -one
662            -two
663             3
664             4
665            --- a/root/file2
666            +++ b/root/file2
667            +5
668             six
669            --- a/root/file2
670            +++ b/root/file2
671             seven
672            +7.5
673             eight
674            --- a/root/file2
675            +++ b/root/file2
676             ten
677            +11
678        "#};
679
680        let _buffers = apply_diff(diff, &project, &mut cx.to_async())
681            .await
682            .unwrap();
683        let buffer_1 = project
684            .update(cx, |project, cx| {
685                let project_path = project.find_project_path("/root/file1", cx).unwrap();
686                project.open_buffer(project_path, cx)
687            })
688            .await
689            .unwrap();
690
691        buffer_1.read_with(cx, |buffer, _cx| {
692            assert_eq!(buffer.text(), buffer_1_text_final);
693        });
694        let buffer_2 = project
695            .update(cx, |project, cx| {
696                let project_path = project.find_project_path("/root/file2", cx).unwrap();
697                project.open_buffer(project_path, cx)
698            })
699            .await
700            .unwrap();
701
702        buffer_2.read_with(cx, |buffer, _cx| {
703            assert_eq!(buffer.text(), buffer_2_text_final);
704        });
705    }
706
707    #[gpui::test]
708    async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
709        let buffer_1_text = indoc! {r#"
710            one
711            two
712            three
713            four
714            five
715            one
716            two
717            three
718            four
719            five
720        "# };
721
722        cx.update(|cx| {
723            let settings_store = SettingsStore::test(cx);
724            cx.set_global(settings_store);
725            Project::init_settings(cx);
726            language::init(cx);
727        });
728
729        let fs = FakeFs::new(cx.background_executor().clone());
730        fs.insert_tree(
731            "/root",
732            json!({
733                "file1": buffer_1_text,
734            }),
735        )
736        .await;
737
738        let project = Project::test(fs, ["/root".as_ref()], cx).await;
739
740        let diff = indoc! {r#"
741            --- a/root/file1
742            +++ b/root/file1
743             one
744             two
745            -three
746            +3
747             four
748             five
749        "#};
750
751        apply_diff(diff, &project, &mut cx.to_async())
752            .await
753            .expect_err("Non-unique edits should fail");
754    }
755
756    #[gpui::test]
757    async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
758        let start = indoc! {r#"
759            one
760            two
761            three
762            four
763            five
764
765            four
766            five
767        "# };
768
769        let end = indoc! {r#"
770            one
771            two
772            3
773            four
774            5
775
776            four
777            five
778        "# };
779
780        cx.update(|cx| {
781            let settings_store = SettingsStore::test(cx);
782            cx.set_global(settings_store);
783            Project::init_settings(cx);
784            language::init(cx);
785        });
786
787        let fs = FakeFs::new(cx.background_executor().clone());
788        fs.insert_tree(
789            "/root",
790            json!({
791                "file1": start,
792            }),
793        )
794        .await;
795
796        let project = Project::test(fs, ["/root".as_ref()], cx).await;
797
798        let diff = indoc! {r#"
799            --- a/root/file1
800            +++ b/root/file1
801             one
802             two
803            -three
804            +3
805             four
806            -five
807            +5
808        "#};
809
810        let _buffers = apply_diff(diff, &project, &mut cx.to_async())
811            .await
812            .unwrap();
813
814        let buffer_1 = project
815            .update(cx, |project, cx| {
816                let project_path = project.find_project_path("/root/file1", cx).unwrap();
817                project.open_buffer(project_path, cx)
818            })
819            .await
820            .unwrap();
821
822        buffer_1.read_with(cx, |buffer, _cx| {
823            assert_eq!(buffer.text(), end);
824        });
825    }
826}