1use std::{
2 borrow::Cow,
3 cell::RefCell,
4 fmt::{self, Display},
5 fs,
6 io::Write,
7 mem,
8 path::{Path, PathBuf},
9 sync::{Arc, OnceLock},
10};
11
12use crate::headless::ZetaCliAppState;
13use anyhow::{Context as _, Result, anyhow};
14use clap::ValueEnum;
15use cloud_zeta2_prompt::CURSOR_MARKER;
16use collections::HashMap;
17use edit_prediction_context::Line;
18use futures::{
19 AsyncWriteExt as _,
20 lock::{Mutex, OwnedMutexGuard},
21};
22use futures::{FutureExt as _, future::Shared};
23use gpui::{AppContext as _, AsyncApp, Entity, Task, http_client::Url};
24use language::{Anchor, Buffer};
25use project::{Project, ProjectPath};
26use pulldown_cmark::CowStr;
27use serde::{Deserialize, Serialize};
28use util::{paths::PathStyle, rel_path::RelPath};
29use zeta2::{Zeta, udiff::OpenedBuffers};
30
31use crate::paths::{REPOS_DIR, WORKTREES_DIR};
32
33const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
34const EDIT_HISTORY_HEADING: &str = "Edit History";
35const CURSOR_POSITION_HEADING: &str = "Cursor Position";
36const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
37const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
38const REPOSITORY_URL_FIELD: &str = "repository_url";
39const REVISION_FIELD: &str = "revision";
40
41#[derive(Debug, Clone)]
42pub struct NamedExample {
43 pub name: String,
44 pub example: Example,
45}
46
47#[derive(Clone, Debug, Serialize, Deserialize)]
48pub struct Example {
49 pub repository_url: String,
50 pub revision: String,
51 pub uncommitted_diff: String,
52 pub cursor_path: PathBuf,
53 pub cursor_position: String,
54 pub edit_history: String,
55 pub expected_patch: String,
56 pub expected_context: Vec<ExpectedContextEntry>,
57}
58
59pub type ActualExcerpt = Excerpt;
60
61#[derive(Clone, Debug, Serialize, Deserialize)]
62pub struct Excerpt {
63 pub path: PathBuf,
64 pub text: String,
65}
66
67#[derive(Default, Clone, Debug, Serialize, Deserialize)]
68pub struct ExpectedContextEntry {
69 pub heading: String,
70 pub alternatives: Vec<ExpectedExcerptSet>,
71}
72
73#[derive(Default, Clone, Debug, Serialize, Deserialize)]
74pub struct ExpectedExcerptSet {
75 pub heading: String,
76 pub excerpts: Vec<ExpectedExcerpt>,
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize)]
80pub struct ExpectedExcerpt {
81 pub path: PathBuf,
82 pub text: String,
83 pub required_lines: Vec<Line>,
84}
85
86#[derive(ValueEnum, Debug, Clone)]
87pub enum ExampleFormat {
88 Json,
89 Toml,
90 Md,
91}
92
93impl NamedExample {
94 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
95 let path = path.as_ref();
96 let content = std::fs::read_to_string(path)?;
97 let ext = path.extension();
98
99 match ext.and_then(|s| s.to_str()) {
100 Some("json") => Ok(Self {
101 name: path.file_stem().unwrap_or_default().display().to_string(),
102 example: serde_json::from_str(&content)?,
103 }),
104 Some("toml") => Ok(Self {
105 name: path.file_stem().unwrap_or_default().display().to_string(),
106 example: toml::from_str(&content)?,
107 }),
108 Some("md") => Self::parse_md(&content),
109 Some(_) => {
110 anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
111 }
112 None => {
113 anyhow::bail!(
114 "Failed to determine example type since the file does not have an extension."
115 );
116 }
117 }
118 }
119
120 pub fn parse_md(input: &str) -> Result<Self> {
121 use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
122
123 let parser = Parser::new(input);
124
125 let mut named = NamedExample {
126 name: String::new(),
127 example: Example {
128 repository_url: String::new(),
129 revision: String::new(),
130 uncommitted_diff: String::new(),
131 cursor_path: PathBuf::new(),
132 cursor_position: String::new(),
133 edit_history: String::new(),
134 expected_patch: String::new(),
135 expected_context: Vec::new(),
136 },
137 };
138
139 let mut text = String::new();
140 let mut block_info: CowStr = "".into();
141
142 #[derive(PartialEq)]
143 enum Section {
144 UncommittedDiff,
145 EditHistory,
146 CursorPosition,
147 ExpectedExcerpts,
148 ExpectedPatch,
149 Other,
150 }
151
152 let mut current_section = Section::Other;
153
154 for event in parser {
155 match event {
156 Event::Text(line) => {
157 text.push_str(&line);
158
159 if !named.name.is_empty()
160 && current_section == Section::Other
161 // in h1 section
162 && let Some((field, value)) = line.split_once('=')
163 {
164 match field.trim() {
165 REPOSITORY_URL_FIELD => {
166 named.example.repository_url = value.trim().to_string();
167 }
168 REVISION_FIELD => {
169 named.example.revision = value.trim().to_string();
170 }
171 _ => {}
172 }
173 }
174 }
175 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
176 if !named.name.is_empty() {
177 anyhow::bail!(
178 "Found multiple H1 headings. There should only be one with the name of the example."
179 );
180 }
181 named.name = mem::take(&mut text);
182 }
183 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
184 let title = mem::take(&mut text);
185 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
186 Section::UncommittedDiff
187 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
188 Section::EditHistory
189 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
190 Section::CursorPosition
191 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
192 Section::ExpectedPatch
193 } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
194 Section::ExpectedExcerpts
195 } else {
196 Section::Other
197 };
198 }
199 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
200 let heading = mem::take(&mut text);
201 match current_section {
202 Section::ExpectedExcerpts => {
203 named.example.expected_context.push(ExpectedContextEntry {
204 heading,
205 alternatives: Vec::new(),
206 });
207 }
208 _ => {}
209 }
210 }
211 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
212 let heading = mem::take(&mut text);
213 match current_section {
214 Section::ExpectedExcerpts => {
215 let expected_context = &mut named.example.expected_context;
216 let last_entry = expected_context.last_mut().unwrap();
217 last_entry.alternatives.push(ExpectedExcerptSet {
218 heading,
219 excerpts: Vec::new(),
220 })
221 }
222 _ => {}
223 }
224 }
225 Event::End(TagEnd::Heading(level)) => {
226 anyhow::bail!("Unexpected heading level: {level}");
227 }
228 Event::Start(Tag::CodeBlock(kind)) => {
229 match kind {
230 CodeBlockKind::Fenced(info) => {
231 block_info = info;
232 }
233 CodeBlockKind::Indented => {
234 anyhow::bail!("Unexpected indented codeblock");
235 }
236 };
237 }
238 Event::Start(_) => {
239 text.clear();
240 block_info = "".into();
241 }
242 Event::End(TagEnd::CodeBlock) => {
243 let block_info = block_info.trim();
244 match current_section {
245 Section::UncommittedDiff => {
246 named.example.uncommitted_diff = mem::take(&mut text);
247 }
248 Section::EditHistory => {
249 named.example.edit_history.push_str(&mem::take(&mut text));
250 }
251 Section::CursorPosition => {
252 named.example.cursor_path = block_info.into();
253 named.example.cursor_position = mem::take(&mut text);
254 }
255 Section::ExpectedExcerpts => {
256 let text = mem::take(&mut text);
257 for excerpt in text.split("\n…\n") {
258 let (mut text, required_lines) = extract_required_lines(&excerpt);
259 if !text.ends_with('\n') {
260 text.push('\n');
261 }
262
263 if named.example.expected_context.is_empty() {
264 named.example.expected_context.push(Default::default());
265 }
266
267 let alternatives = &mut named
268 .example
269 .expected_context
270 .last_mut()
271 .unwrap()
272 .alternatives;
273
274 if alternatives.is_empty() {
275 alternatives.push(ExpectedExcerptSet {
276 heading: String::new(),
277 excerpts: vec![],
278 });
279 }
280
281 alternatives
282 .last_mut()
283 .unwrap()
284 .excerpts
285 .push(ExpectedExcerpt {
286 path: block_info.into(),
287 text,
288 required_lines,
289 });
290 }
291 }
292 Section::ExpectedPatch => {
293 named.example.expected_patch = mem::take(&mut text);
294 }
295 Section::Other => {}
296 }
297 }
298 _ => {}
299 }
300 }
301
302 if named.example.cursor_path.as_path() == Path::new("")
303 || named.example.cursor_position.is_empty()
304 {
305 anyhow::bail!("Missing cursor position codeblock");
306 }
307
308 Ok(named)
309 }
310
311 pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
312 match format {
313 ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
314 ExampleFormat::Toml => {
315 Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
316 }
317 ExampleFormat::Md => Ok(write!(out, "{}", self)?),
318 }
319 }
320
321 pub async fn setup_project<'a>(
322 &'a self,
323 app_state: &Arc<ZetaCliAppState>,
324 repetitions: u16,
325 cx: &mut AsyncApp,
326 ) -> Result<(Entity<Project>, Vec<Entity<Zeta>>, OpenedBuffers<'a>)> {
327 let worktree_path = self.setup_worktree().await?;
328
329 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
330
331 AUTHENTICATED
332 .get_or_init(|| {
333 let client = app_state.client.clone();
334 cx.spawn(async move |cx| {
335 client
336 .sign_in_with_optional_connect(true, cx)
337 .await
338 .unwrap();
339 })
340 .shared()
341 })
342 .clone()
343 .await;
344
345 let project = cx.update(|cx| {
346 Project::local(
347 app_state.client.clone(),
348 app_state.node_runtime.clone(),
349 app_state.user_store.clone(),
350 app_state.languages.clone(),
351 app_state.fs.clone(),
352 None,
353 cx,
354 )
355 })?;
356
357 let worktree = project
358 .update(cx, |project, cx| {
359 project.create_worktree(&worktree_path, true, cx)
360 })?
361 .await?;
362 worktree
363 .read_with(cx, |worktree, _cx| {
364 worktree.as_local().unwrap().scan_complete()
365 })?
366 .await;
367
368 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
369
370 let zetas = (0..repetitions)
371 .map(|_| {
372 let zeta = cx.new(|cx| {
373 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
374 })?;
375
376 cx.subscribe(&buffer_store, {
377 let project = project.clone();
378 let zeta = zeta.clone();
379 move |_, event, cx| match event {
380 project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
381 zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
382 }
383 _ => {}
384 }
385 })?
386 .detach();
387
388 anyhow::Ok(zeta)
389 })
390 .collect::<Result<Vec<_>>>()?;
391
392 let edited_buffers = self.apply_edit_history(&project, cx).await?;
393
394 anyhow::Ok((project, zetas, edited_buffers))
395 }
396
397 pub async fn setup_worktree(&self) -> Result<PathBuf> {
398 let (repo_owner, repo_name) = self.repo_name()?;
399 let file_name = self.file_name();
400
401 let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
402 let repo_lock = lock_repo(&repo_dir).await;
403
404 if !repo_dir.is_dir() {
405 fs::create_dir_all(&repo_dir)?;
406 run_git(&repo_dir, &["init"]).await?;
407 run_git(
408 &repo_dir,
409 &["remote", "add", "origin", &self.example.repository_url],
410 )
411 .await?;
412 }
413
414 // Resolve the example to a revision, fetching it if needed.
415 let revision = run_git(
416 &repo_dir,
417 &[
418 "rev-parse",
419 &format!("{}^{{commit}}", self.example.revision),
420 ],
421 )
422 .await;
423 let revision = if let Ok(revision) = revision {
424 revision
425 } else {
426 run_git(
427 &repo_dir,
428 &["fetch", "--depth", "1", "origin", &self.example.revision],
429 )
430 .await?;
431 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
432 if revision != self.example.revision {
433 run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
434 }
435 revision
436 };
437
438 // Create the worktree for this example if needed.
439 let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
440 if worktree_path.is_dir() {
441 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
442 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
443 run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
444 } else {
445 let worktree_path_string = worktree_path.to_string_lossy();
446 run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
447 run_git(
448 &repo_dir,
449 &["worktree", "add", "-f", &worktree_path_string, &file_name],
450 )
451 .await?;
452 }
453 drop(repo_lock);
454
455 // Apply the uncommitted diff for this example.
456 if !self.example.uncommitted_diff.is_empty() {
457 let mut apply_process = smol::process::Command::new("git")
458 .current_dir(&worktree_path)
459 .args(&["apply", "-"])
460 .stdin(std::process::Stdio::piped())
461 .spawn()?;
462
463 let mut stdin = apply_process.stdin.take().unwrap();
464 stdin
465 .write_all(self.example.uncommitted_diff.as_bytes())
466 .await?;
467 stdin.close().await?;
468 drop(stdin);
469
470 let apply_result = apply_process.output().await?;
471 if !apply_result.status.success() {
472 anyhow::bail!(
473 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
474 apply_result.status,
475 String::from_utf8_lossy(&apply_result.stderr),
476 String::from_utf8_lossy(&apply_result.stdout),
477 );
478 }
479 }
480
481 Ok(worktree_path)
482 }
483
484 pub fn file_name(&self) -> String {
485 self.name
486 .chars()
487 .map(|c| {
488 if c.is_whitespace() {
489 '-'
490 } else {
491 c.to_ascii_lowercase()
492 }
493 })
494 .collect()
495 }
496
497 fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
498 // git@github.com:owner/repo.git
499 if self.example.repository_url.contains('@') {
500 let (owner, repo) = self
501 .example
502 .repository_url
503 .split_once(':')
504 .context("expected : in git url")?
505 .1
506 .split_once('/')
507 .context("expected / in git url")?;
508 Ok((
509 Cow::Borrowed(owner),
510 Cow::Borrowed(repo.trim_end_matches(".git")),
511 ))
512 // http://github.com/owner/repo.git
513 } else {
514 let url = Url::parse(&self.example.repository_url)?;
515 let mut segments = url.path_segments().context("empty http url")?;
516 let owner = segments
517 .next()
518 .context("expected owner path segment")?
519 .to_string();
520 let repo = segments
521 .next()
522 .context("expected repo path segment")?
523 .trim_end_matches(".git")
524 .to_string();
525 assert!(segments.next().is_none());
526
527 Ok((owner.into(), repo.into()))
528 }
529 }
530
531 pub async fn cursor_position(
532 &self,
533 project: &Entity<Project>,
534 cx: &mut AsyncApp,
535 ) -> Result<(Entity<Buffer>, Anchor)> {
536 let worktree = project.read_with(cx, |project, cx| {
537 project.visible_worktrees(cx).next().unwrap()
538 })?;
539 let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
540 let cursor_buffer = project
541 .update(cx, |project, cx| {
542 project.open_buffer(
543 ProjectPath {
544 worktree_id: worktree.read(cx).id(),
545 path: cursor_path,
546 },
547 cx,
548 )
549 })?
550 .await?;
551 let cursor_offset_within_excerpt = self
552 .example
553 .cursor_position
554 .find(CURSOR_MARKER)
555 .ok_or_else(|| anyhow!("missing cursor marker"))?;
556 let mut cursor_excerpt = self.example.cursor_position.clone();
557 cursor_excerpt.replace_range(
558 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
559 "",
560 );
561 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
562 let text = buffer.text();
563
564 let mut matches = text.match_indices(&cursor_excerpt);
565 let Some((excerpt_offset, _)) = matches.next() else {
566 anyhow::bail!(
567 "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
568 );
569 };
570 assert!(matches.next().is_none());
571
572 Ok(excerpt_offset)
573 })??;
574
575 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
576 let cursor_anchor =
577 cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
578 Ok((cursor_buffer, cursor_anchor))
579 }
580
581 #[must_use]
582 pub async fn apply_edit_history(
583 &self,
584 project: &Entity<Project>,
585 cx: &mut AsyncApp,
586 ) -> Result<OpenedBuffers<'_>> {
587 zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
588 }
589}
590
591fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
592 const MARKER: &str = "[ZETA]";
593 let mut new_text = String::new();
594 let mut required_lines = Vec::new();
595 let mut skipped_lines = 0_u32;
596
597 for (row, mut line) in text.split('\n').enumerate() {
598 if let Some(marker_column) = line.find(MARKER) {
599 let mut strip_column = marker_column;
600
601 while strip_column > 0 {
602 let prev_char = line[strip_column - 1..].chars().next().unwrap();
603 if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
604 strip_column -= 1;
605 } else {
606 break;
607 }
608 }
609
610 let metadata = &line[marker_column + MARKER.len()..];
611 if metadata.contains("required") {
612 required_lines.push(Line(row as u32 - skipped_lines));
613 }
614
615 if strip_column == 0 {
616 skipped_lines += 1;
617 continue;
618 }
619
620 line = &line[..strip_column];
621 }
622
623 new_text.push_str(line);
624 new_text.push('\n');
625 }
626
627 new_text.pop();
628
629 (new_text, required_lines)
630}
631
632async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
633 let output = smol::process::Command::new("git")
634 .current_dir(repo_path)
635 .args(args)
636 .output()
637 .await?;
638
639 anyhow::ensure!(
640 output.status.success(),
641 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
642 args.join(" "),
643 repo_path.display(),
644 output.status,
645 String::from_utf8_lossy(&output.stderr),
646 String::from_utf8_lossy(&output.stdout),
647 );
648 Ok(String::from_utf8(output.stdout)?.trim().to_string())
649}
650
651impl Display for NamedExample {
652 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
653 write!(f, "# {}\n\n", self.name)?;
654 write!(
655 f,
656 "{REPOSITORY_URL_FIELD} = {}\n",
657 self.example.repository_url
658 )?;
659 write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
660
661 write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
662 write!(f, "`````diff\n")?;
663 write!(f, "{}", self.example.uncommitted_diff)?;
664 write!(f, "`````\n")?;
665
666 if !self.example.edit_history.is_empty() {
667 write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
668 }
669
670 write!(
671 f,
672 "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
673 self.example.cursor_path.display(),
674 self.example.cursor_position
675 )?;
676 write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
677
678 if !self.example.expected_patch.is_empty() {
679 write!(
680 f,
681 "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
682 self.example.expected_patch
683 )?;
684 }
685
686 if !self.example.expected_context.is_empty() {
687 write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
688
689 for entry in &self.example.expected_context {
690 write!(f, "\n### {}\n\n", entry.heading)?;
691
692 let skip_h4 =
693 entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
694
695 for excerpt_set in &entry.alternatives {
696 if !skip_h4 {
697 write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
698 }
699
700 for excerpt in &excerpt_set.excerpts {
701 write!(
702 f,
703 "`````{}{}\n{}`````\n\n",
704 excerpt
705 .path
706 .extension()
707 .map(|ext| format!("{} ", ext.to_string_lossy()))
708 .unwrap_or_default(),
709 excerpt.path.display(),
710 excerpt.text
711 )?;
712 }
713 }
714 }
715 }
716
717 Ok(())
718 }
719}
720
721thread_local! {
722 static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
723}
724
725#[must_use]
726pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
727 REPO_LOCKS
728 .with(|cell| {
729 cell.borrow_mut()
730 .entry(path.as_ref().to_path_buf())
731 .or_default()
732 .clone()
733 })
734 .lock_owned()
735 .await
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741 use indoc::indoc;
742 use pretty_assertions::assert_eq;
743
744 #[test]
745 fn test_extract_required_lines() {
746 let input = indoc! {"
747 zero
748 one // [ZETA] required
749 two
750 // [ZETA] something
751 three
752 four # [ZETA] required
753 five
754 "};
755
756 let expected_updated_input = indoc! {"
757 zero
758 one
759 two
760 three
761 four
762 five
763 "};
764
765 let expected_required_lines = vec![Line(1), Line(4)];
766
767 let (updated_input, required_lines) = extract_required_lines(input);
768 assert_eq!(updated_input, expected_updated_input);
769 assert_eq!(required_lines, expected_required_lines);
770 }
771}