1use std::{
2 borrow::Cow,
3 env,
4 fmt::{self, Display},
5 fs,
6 io::Write,
7 mem,
8 ops::Range,
9 path::{Path, PathBuf},
10};
11
12use anyhow::{Context as _, Result};
13use clap::ValueEnum;
14use collections::HashSet;
15use futures::AsyncWriteExt as _;
16use gpui::{AsyncApp, Entity, http_client::Url};
17use language::Buffer;
18use project::{Project, ProjectPath};
19use pulldown_cmark::CowStr;
20use serde::{Deserialize, Serialize};
21
22const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
23const EDIT_HISTORY_HEADING: &str = "Edit History";
24const CURSOR_POSITION_HEADING: &str = "Cursor Position";
25const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
26const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
27const REPOSITORY_URL_FIELD: &str = "repository_url";
28const REVISION_FIELD: &str = "revision";
29
30#[derive(Debug)]
31pub struct NamedExample {
32 pub name: String,
33 pub example: Example,
34}
35
36#[derive(Debug, Serialize, Deserialize)]
37pub struct Example {
38 pub repository_url: String,
39 pub revision: String,
40 pub uncommitted_diff: String,
41 pub cursor_path: PathBuf,
42 pub cursor_position: String,
43 pub edit_history: String,
44 pub expected_patch: String,
45 pub expected_excerpts: Vec<ExpectedExcerpt>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49pub struct ExpectedExcerpt {
50 path: PathBuf,
51 text: String,
52}
53
54#[derive(ValueEnum, Debug, Clone)]
55pub enum ExampleFormat {
56 Json,
57 Toml,
58 Md,
59}
60
61impl NamedExample {
62 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
63 let path = path.as_ref();
64 let content = std::fs::read_to_string(path)?;
65 let ext = path.extension();
66
67 match ext.and_then(|s| s.to_str()) {
68 Some("json") => Ok(Self {
69 name: path.file_stem().unwrap_or_default().display().to_string(),
70 example: serde_json::from_str(&content)?,
71 }),
72 Some("toml") => Ok(Self {
73 name: path.file_stem().unwrap_or_default().display().to_string(),
74 example: toml::from_str(&content)?,
75 }),
76 Some("md") => Self::parse_md(&content),
77 Some(_) => {
78 anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
79 }
80 None => {
81 anyhow::bail!(
82 "Failed to determine example type since the file does not have an extension."
83 );
84 }
85 }
86 }
87
88 pub fn parse_md(input: &str) -> Result<Self> {
89 use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
90
91 let parser = Parser::new(input);
92
93 let mut named = NamedExample {
94 name: String::new(),
95 example: Example {
96 repository_url: String::new(),
97 revision: String::new(),
98 uncommitted_diff: String::new(),
99 cursor_path: PathBuf::new(),
100 cursor_position: String::new(),
101 edit_history: String::new(),
102 expected_patch: String::new(),
103 expected_excerpts: Vec::new(),
104 },
105 };
106
107 let mut text = String::new();
108 let mut current_section = String::new();
109 let mut block_info: CowStr = "".into();
110
111 for event in parser {
112 match event {
113 Event::Text(line) => {
114 text.push_str(&line);
115
116 if !named.name.is_empty()
117 && current_section.is_empty()
118 // in h1 section
119 && let Some((field, value)) = line.split_once('=')
120 {
121 match field.trim() {
122 REPOSITORY_URL_FIELD => {
123 named.example.repository_url = value.trim().to_string();
124 }
125 REVISION_FIELD => {
126 named.example.revision = value.trim().to_string();
127 }
128 _ => {
129 eprintln!("Warning: Unrecognized field `{field}`");
130 }
131 }
132 }
133 }
134 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
135 if !named.name.is_empty() {
136 anyhow::bail!(
137 "Found multiple H1 headings. There should only be one with the name of the example."
138 );
139 }
140 named.name = mem::take(&mut text);
141 }
142 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
143 current_section = mem::take(&mut text);
144 }
145 Event::End(TagEnd::Heading(level)) => {
146 anyhow::bail!("Unexpected heading level: {level}");
147 }
148 Event::Start(Tag::CodeBlock(kind)) => {
149 match kind {
150 CodeBlockKind::Fenced(info) => {
151 block_info = info;
152 }
153 CodeBlockKind::Indented => {
154 anyhow::bail!("Unexpected indented codeblock");
155 }
156 };
157 }
158 Event::Start(_) => {
159 text.clear();
160 block_info = "".into();
161 }
162 Event::End(TagEnd::CodeBlock) => {
163 let block_info = block_info.trim();
164 if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
165 named.example.uncommitted_diff = mem::take(&mut text);
166 } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
167 named.example.edit_history.push_str(&mem::take(&mut text));
168 } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
169 named.example.cursor_path = block_info.into();
170 named.example.cursor_position = mem::take(&mut text);
171 } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
172 named.example.expected_patch = mem::take(&mut text);
173 } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
174 named.example.expected_excerpts.push(ExpectedExcerpt {
175 path: block_info.into(),
176 text: mem::take(&mut text),
177 });
178 } else {
179 eprintln!("Warning: Unrecognized section `{current_section:?}`")
180 }
181 }
182 _ => {}
183 }
184 }
185
186 if named.example.cursor_path.as_path() == Path::new("")
187 || named.example.cursor_position.is_empty()
188 {
189 anyhow::bail!("Missing cursor position codeblock");
190 }
191
192 Ok(named)
193 }
194
195 pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
196 match format {
197 ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
198 ExampleFormat::Toml => {
199 Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
200 }
201 ExampleFormat::Md => Ok(write!(out, "{}", self)?),
202 }
203 }
204
205 #[allow(unused)]
206 pub async fn setup_worktree(&self) -> Result<PathBuf> {
207 let (repo_owner, repo_name) = self.repo_name()?;
208 let file_name = self.file_name();
209
210 let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
211 let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
212 fs::create_dir_all(&repos_dir)?;
213 fs::create_dir_all(&worktrees_dir)?;
214
215 let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
216 if !repo_dir.is_dir() {
217 fs::create_dir_all(&repo_dir)?;
218 run_git(&repo_dir, &["init"]).await?;
219 run_git(
220 &repo_dir,
221 &["remote", "add", "origin", &self.example.repository_url],
222 )
223 .await?;
224 }
225
226 // Resolve the example to a revision, fetching it if needed.
227 let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
228 let revision = if let Ok(revision) = revision {
229 revision
230 } else {
231 run_git(
232 &repo_dir,
233 &["fetch", "--depth", "1", "origin", &self.example.revision],
234 )
235 .await?;
236 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
237 if revision != self.example.revision {
238 run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
239 }
240 revision
241 };
242
243 // Create the worktree for this example if needed.
244 let worktree_path = worktrees_dir.join(&file_name);
245 if worktree_path.is_dir() {
246 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
247 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
248 run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
249 } else {
250 let worktree_path_string = worktree_path.to_string_lossy();
251 run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
252 run_git(
253 &repo_dir,
254 &["worktree", "add", "-f", &worktree_path_string, &file_name],
255 )
256 .await?;
257 }
258
259 // Apply the uncommitted diff for this example.
260 if !self.example.uncommitted_diff.is_empty() {
261 let mut apply_process = smol::process::Command::new("git")
262 .current_dir(&worktree_path)
263 .args(&["apply", "-"])
264 .stdin(std::process::Stdio::piped())
265 .spawn()?;
266
267 let mut stdin = apply_process.stdin.take().unwrap();
268 stdin
269 .write_all(self.example.uncommitted_diff.as_bytes())
270 .await?;
271 stdin.close().await?;
272 drop(stdin);
273
274 let apply_result = apply_process.output().await?;
275 if !apply_result.status.success() {
276 anyhow::bail!(
277 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
278 apply_result.status,
279 String::from_utf8_lossy(&apply_result.stderr),
280 String::from_utf8_lossy(&apply_result.stdout),
281 );
282 }
283 }
284
285 Ok(worktree_path)
286 }
287
288 fn file_name(&self) -> String {
289 self.name
290 .chars()
291 .map(|c| {
292 if c.is_whitespace() {
293 '-'
294 } else {
295 c.to_ascii_lowercase()
296 }
297 })
298 .collect()
299 }
300
301 #[allow(unused)]
302 fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
303 // git@github.com:owner/repo.git
304 if self.example.repository_url.contains('@') {
305 let (owner, repo) = self
306 .example
307 .repository_url
308 .split_once(':')
309 .context("expected : in git url")?
310 .1
311 .split_once('/')
312 .context("expected / in git url")?;
313 Ok((
314 Cow::Borrowed(owner),
315 Cow::Borrowed(repo.trim_end_matches(".git")),
316 ))
317 // http://github.com/owner/repo.git
318 } else {
319 let url = Url::parse(&self.example.repository_url)?;
320 let mut segments = url.path_segments().context("empty http url")?;
321 let owner = segments
322 .next()
323 .context("expected owner path segment")?
324 .to_string();
325 let repo = segments
326 .next()
327 .context("expected repo path segment")?
328 .trim_end_matches(".git")
329 .to_string();
330 assert!(segments.next().is_none());
331
332 Ok((owner.into(), repo.into()))
333 }
334 }
335
336 #[must_use]
337 pub async fn apply_edit_history(
338 &self,
339 project: &Entity<Project>,
340 cx: &mut AsyncApp,
341 ) -> Result<HashSet<Entity<Buffer>>> {
342 apply_diff(&self.example.edit_history, project, cx).await
343 }
344}
345
346async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
347 let output = smol::process::Command::new("git")
348 .current_dir(repo_path)
349 .args(args)
350 .output()
351 .await?;
352
353 anyhow::ensure!(
354 output.status.success(),
355 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
356 args.join(" "),
357 repo_path.display(),
358 output.status,
359 String::from_utf8_lossy(&output.stderr),
360 String::from_utf8_lossy(&output.stdout),
361 );
362 Ok(String::from_utf8(output.stdout)?.trim().to_string())
363}
364
365impl Display for NamedExample {
366 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367 write!(f, "# {}\n\n", self.name)?;
368 write!(
369 f,
370 "{REPOSITORY_URL_FIELD} = {}\n",
371 self.example.repository_url
372 )?;
373 write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
374
375 write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
376 write!(f, "`````diff\n")?;
377 write!(f, "{}", self.example.uncommitted_diff)?;
378 write!(f, "`````\n")?;
379
380 if !self.example.edit_history.is_empty() {
381 write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
382 }
383
384 write!(
385 f,
386 "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
387 self.example.cursor_path.display(),
388 self.example.cursor_position
389 )?;
390 write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
391
392 if !self.example.expected_patch.is_empty() {
393 write!(
394 f,
395 "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
396 self.example.expected_patch
397 )?;
398 }
399
400 if !self.example.expected_excerpts.is_empty() {
401 write!(f, "\n## {EXPECTED_EXCERPTS_HEADING}\n\n")?;
402
403 for excerpt in &self.example.expected_excerpts {
404 write!(
405 f,
406 "`````{}{}\n{}`````\n\n",
407 excerpt
408 .path
409 .extension()
410 .map(|ext| format!("{} ", ext.to_string_lossy()))
411 .unwrap_or_default(),
412 excerpt.path.display(),
413 excerpt.text
414 )?;
415 }
416 }
417
418 Ok(())
419 }
420}
421
422#[must_use]
423pub async fn apply_diff(
424 diff: &str,
425 project: &Entity<Project>,
426 cx: &mut AsyncApp,
427) -> Result<HashSet<Entity<Buffer>>> {
428 use cloud_llm_client::udiff::DiffLine;
429 use std::fmt::Write;
430
431 #[derive(Debug, Default)]
432 struct HunkState {
433 context: String,
434 edits: Vec<Edit>,
435 }
436
437 #[derive(Debug)]
438 struct Edit {
439 range: Range<usize>,
440 text: String,
441 }
442
443 let mut old_path = None;
444 let mut new_path = None;
445 let mut hunk = HunkState::default();
446 let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
447 let mut open_buffers = HashSet::default();
448
449 while let Some(diff_line) = diff_lines.next() {
450 match diff_line {
451 DiffLine::OldPath { path } => old_path = Some(path),
452 DiffLine::NewPath { path } => {
453 if old_path.is_none() {
454 anyhow::bail!(
455 "Found a new path header (`+++`) before an (`---`) old path header"
456 );
457 }
458 new_path = Some(path)
459 }
460 DiffLine::Context(ctx) => {
461 writeln!(&mut hunk.context, "{ctx}")?;
462 }
463 DiffLine::Deletion(del) => {
464 let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8();
465 if let Some(last_edit) = hunk.edits.last_mut()
466 && last_edit.range.end == range.start
467 {
468 last_edit.range.end = range.end;
469 } else {
470 hunk.edits.push(Edit {
471 range,
472 text: String::new(),
473 });
474 }
475 writeln!(&mut hunk.context, "{del}")?;
476 }
477 DiffLine::Addition(add) => {
478 let range = hunk.context.len()..hunk.context.len();
479 if let Some(last_edit) = hunk.edits.last_mut()
480 && last_edit.range.end == range.start
481 {
482 writeln!(&mut last_edit.text, "{add}").unwrap();
483 } else {
484 hunk.edits.push(Edit {
485 range,
486 text: format!("{add}\n"),
487 });
488 }
489 }
490 DiffLine::HunkHeader(_) | DiffLine::Garbage => {}
491 }
492
493 let at_hunk_end = match diff_lines.peek() {
494 Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true,
495 _ => false,
496 };
497
498 if at_hunk_end {
499 let hunk = mem::take(&mut hunk);
500
501 let Some(old_path) = old_path.as_deref() else {
502 anyhow::bail!("Missing old path (`---`) header")
503 };
504
505 let Some(new_path) = new_path.as_deref() else {
506 anyhow::bail!("Missing new path (`+++`) header")
507 };
508
509 let buffer = project
510 .update(cx, |project, cx| {
511 let project_path = project
512 .find_project_path(old_path, cx)
513 .context("Failed to find old_path in project")?;
514
515 anyhow::Ok(project.open_buffer(project_path, cx))
516 })??
517 .await?;
518 open_buffers.insert(buffer.clone());
519
520 if old_path != new_path {
521 project
522 .update(cx, |project, cx| {
523 let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
524 let new_path = ProjectPath {
525 worktree_id: project_file.worktree_id(cx),
526 path: project_file.path.clone(),
527 };
528 project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
529 })?
530 .await?;
531 }
532
533 // TODO is it worth using project search?
534 buffer.update(cx, |buffer, cx| {
535 let context_offset = if hunk.context.is_empty() {
536 0
537 } else {
538 let text = buffer.text();
539 if let Some(offset) = text.find(&hunk.context) {
540 if text[offset + 1..].contains(&hunk.context) {
541 anyhow::bail!("Context is not unique enough:\n{}", hunk.context);
542 }
543 offset
544 } else {
545 anyhow::bail!(
546 "Failed to match context:\n{}\n\nBuffer:\n{}",
547 hunk.context,
548 text
549 );
550 }
551 };
552
553 buffer.edit(
554 hunk.edits.into_iter().map(|edit| {
555 (
556 context_offset + edit.range.start..context_offset + edit.range.end,
557 edit.text,
558 )
559 }),
560 None,
561 cx,
562 );
563
564 anyhow::Ok(())
565 })??;
566 }
567 }
568
569 anyhow::Ok(open_buffers)
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use ::fs::FakeFs;
576 use gpui::TestAppContext;
577 use indoc::indoc;
578 use pretty_assertions::assert_eq;
579 use project::Project;
580 use serde_json::json;
581 use settings::SettingsStore;
582
583 #[gpui::test]
584 async fn test_apply_diff_successful(cx: &mut TestAppContext) {
585 let buffer_1_text = indoc! {r#"
586 one
587 two
588 three
589 four
590 five
591 "# };
592
593 let buffer_1_text_final = indoc! {r#"
594 3
595 4
596 5
597 "# };
598
599 let buffer_2_text = indoc! {r#"
600 six
601 seven
602 eight
603 nine
604 ten
605 "# };
606
607 let buffer_2_text_final = indoc! {r#"
608 5
609 six
610 seven
611 7.5
612 eight
613 nine
614 ten
615 11
616 "# };
617
618 cx.update(|cx| {
619 let settings_store = SettingsStore::test(cx);
620 cx.set_global(settings_store);
621 Project::init_settings(cx);
622 language::init(cx);
623 });
624
625 let fs = FakeFs::new(cx.background_executor.clone());
626 fs.insert_tree(
627 "/root",
628 json!({
629 "file1": buffer_1_text,
630 "file2": buffer_2_text,
631 }),
632 )
633 .await;
634
635 let project = Project::test(fs, ["/root".as_ref()], cx).await;
636
637 let diff = indoc! {r#"
638 --- a/root/file1
639 +++ b/root/file1
640 one
641 two
642 -three
643 +3
644 four
645 five
646 --- a/root/file1
647 +++ b/root/file1
648 3
649 -four
650 -five
651 +4
652 +5
653 --- a/root/file1
654 +++ b/root/file1
655 -one
656 -two
657 3
658 4
659 --- a/root/file2
660 +++ b/root/file2
661 +5
662 six
663 --- a/root/file2
664 +++ b/root/file2
665 seven
666 +7.5
667 eight
668 --- a/root/file2
669 +++ b/root/file2
670 ten
671 +11
672 "#};
673
674 let _buffers = apply_diff(diff, &project, &mut cx.to_async())
675 .await
676 .unwrap();
677 let buffer_1 = project
678 .update(cx, |project, cx| {
679 let project_path = project.find_project_path("/root/file1", cx).unwrap();
680 project.open_buffer(project_path, cx)
681 })
682 .await
683 .unwrap();
684
685 buffer_1.read_with(cx, |buffer, _cx| {
686 assert_eq!(buffer.text(), buffer_1_text_final);
687 });
688 let buffer_2 = project
689 .update(cx, |project, cx| {
690 let project_path = project.find_project_path("/root/file2", cx).unwrap();
691 project.open_buffer(project_path, cx)
692 })
693 .await
694 .unwrap();
695
696 buffer_2.read_with(cx, |buffer, _cx| {
697 assert_eq!(buffer.text(), buffer_2_text_final);
698 });
699 }
700
701 #[gpui::test]
702 async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
703 let buffer_1_text = indoc! {r#"
704 one
705 two
706 three
707 four
708 five
709 one
710 two
711 three
712 four
713 five
714 "# };
715
716 cx.update(|cx| {
717 let settings_store = SettingsStore::test(cx);
718 cx.set_global(settings_store);
719 Project::init_settings(cx);
720 language::init(cx);
721 });
722
723 let fs = FakeFs::new(cx.background_executor.clone());
724 fs.insert_tree(
725 "/root",
726 json!({
727 "file1": buffer_1_text,
728 }),
729 )
730 .await;
731
732 let project = Project::test(fs, ["/root".as_ref()], cx).await;
733
734 let diff = indoc! {r#"
735 --- a/root/file1
736 +++ b/root/file1
737 one
738 two
739 -three
740 +3
741 four
742 five
743 "#};
744
745 apply_diff(diff, &project, &mut cx.to_async())
746 .await
747 .expect_err("Non-unique edits should fail");
748 }
749
750 #[gpui::test]
751 async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
752 let start = indoc! {r#"
753 one
754 two
755 three
756 four
757 five
758
759 four
760 five
761 "# };
762
763 let end = indoc! {r#"
764 one
765 two
766 3
767 four
768 5
769
770 four
771 five
772 "# };
773
774 cx.update(|cx| {
775 let settings_store = SettingsStore::test(cx);
776 cx.set_global(settings_store);
777 Project::init_settings(cx);
778 language::init(cx);
779 });
780
781 let fs = FakeFs::new(cx.background_executor.clone());
782 fs.insert_tree(
783 "/root",
784 json!({
785 "file1": start,
786 }),
787 )
788 .await;
789
790 let project = Project::test(fs, ["/root".as_ref()], cx).await;
791
792 let diff = indoc! {r#"
793 --- a/root/file1
794 +++ b/root/file1
795 one
796 two
797 -three
798 +3
799 four
800 -five
801 +5
802 "#};
803
804 let _buffers = apply_diff(diff, &project, &mut cx.to_async())
805 .await
806 .unwrap();
807
808 let buffer_1 = project
809 .update(cx, |project, cx| {
810 let project_path = project.find_project_path("/root/file1", cx).unwrap();
811 project.open_buffer(project_path, cx)
812 })
813 .await
814 .unwrap();
815
816 buffer_1.read_with(cx, |buffer, _cx| {
817 assert_eq!(buffer.text(), end);
818 });
819 }
820}