1use std::{
2 borrow::Cow,
3 cell::RefCell,
4 fmt::{self, Display},
5 fs,
6 io::Write,
7 mem,
8 path::{Path, PathBuf},
9 sync::Arc,
10};
11
12use anyhow::{Context as _, Result, anyhow};
13use clap::ValueEnum;
14use cloud_zeta2_prompt::CURSOR_MARKER;
15use collections::HashMap;
16use edit_prediction_context::Line;
17use futures::{
18 AsyncWriteExt as _,
19 lock::{Mutex, OwnedMutexGuard},
20};
21use gpui::{AsyncApp, Entity, http_client::Url};
22use language::{Anchor, Buffer};
23use project::{Project, ProjectPath};
24use pulldown_cmark::CowStr;
25use serde::{Deserialize, Serialize};
26use util::{paths::PathStyle, rel_path::RelPath};
27use zeta2::udiff::OpenedBuffers;
28
29use crate::paths::{REPOS_DIR, WORKTREES_DIR};
30
31const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
32const EDIT_HISTORY_HEADING: &str = "Edit History";
33const CURSOR_POSITION_HEADING: &str = "Cursor Position";
34const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
35const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
36const REPOSITORY_URL_FIELD: &str = "repository_url";
37const REVISION_FIELD: &str = "revision";
38
39#[derive(Debug, Clone)]
40pub struct NamedExample {
41 pub name: String,
42 pub example: Example,
43}
44
45#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct Example {
47 pub repository_url: String,
48 pub revision: String,
49 pub uncommitted_diff: String,
50 pub cursor_path: PathBuf,
51 pub cursor_position: String,
52 pub edit_history: String,
53 pub expected_patch: String,
54 pub expected_context: Vec<ExpectedContextEntry>,
55}
56
57pub type ActualExcerpt = Excerpt;
58
59#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct Excerpt {
61 pub path: PathBuf,
62 pub text: String,
63}
64
65#[derive(Default, Clone, Debug, Serialize, Deserialize)]
66pub struct ExpectedContextEntry {
67 pub heading: String,
68 pub alternatives: Vec<ExpectedExcerptSet>,
69}
70
71#[derive(Default, Clone, Debug, Serialize, Deserialize)]
72pub struct ExpectedExcerptSet {
73 pub heading: String,
74 pub excerpts: Vec<ExpectedExcerpt>,
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize)]
78pub struct ExpectedExcerpt {
79 pub path: PathBuf,
80 pub text: String,
81 pub required_lines: Vec<Line>,
82}
83
84#[derive(ValueEnum, Debug, Clone)]
85pub enum ExampleFormat {
86 Json,
87 Toml,
88 Md,
89}
90
91impl NamedExample {
92 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
93 let path = path.as_ref();
94 let content = std::fs::read_to_string(path)?;
95 let ext = path.extension();
96
97 match ext.and_then(|s| s.to_str()) {
98 Some("json") => Ok(Self {
99 name: path.file_stem().unwrap_or_default().display().to_string(),
100 example: serde_json::from_str(&content)?,
101 }),
102 Some("toml") => Ok(Self {
103 name: path.file_stem().unwrap_or_default().display().to_string(),
104 example: toml::from_str(&content)?,
105 }),
106 Some("md") => Self::parse_md(&content),
107 Some(_) => {
108 anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
109 }
110 None => {
111 anyhow::bail!(
112 "Failed to determine example type since the file does not have an extension."
113 );
114 }
115 }
116 }
117
118 pub fn parse_md(input: &str) -> Result<Self> {
119 use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
120
121 let parser = Parser::new(input);
122
123 let mut named = NamedExample {
124 name: String::new(),
125 example: Example {
126 repository_url: String::new(),
127 revision: String::new(),
128 uncommitted_diff: String::new(),
129 cursor_path: PathBuf::new(),
130 cursor_position: String::new(),
131 edit_history: String::new(),
132 expected_patch: String::new(),
133 expected_context: Vec::new(),
134 },
135 };
136
137 let mut text = String::new();
138 let mut block_info: CowStr = "".into();
139
140 #[derive(PartialEq)]
141 enum Section {
142 UncommittedDiff,
143 EditHistory,
144 CursorPosition,
145 ExpectedExcerpts,
146 ExpectedPatch,
147 Other,
148 }
149
150 let mut current_section = Section::Other;
151
152 for event in parser {
153 match event {
154 Event::Text(line) => {
155 text.push_str(&line);
156
157 if !named.name.is_empty()
158 && current_section == Section::Other
159 // in h1 section
160 && let Some((field, value)) = line.split_once('=')
161 {
162 match field.trim() {
163 REPOSITORY_URL_FIELD => {
164 named.example.repository_url = value.trim().to_string();
165 }
166 REVISION_FIELD => {
167 named.example.revision = value.trim().to_string();
168 }
169 _ => {
170 eprintln!("Warning: Unrecognized field `{field}`");
171 }
172 }
173 }
174 }
175 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
176 if !named.name.is_empty() {
177 anyhow::bail!(
178 "Found multiple H1 headings. There should only be one with the name of the example."
179 );
180 }
181 named.name = mem::take(&mut text);
182 }
183 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
184 let title = mem::take(&mut text);
185 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
186 Section::UncommittedDiff
187 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
188 Section::EditHistory
189 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
190 Section::CursorPosition
191 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
192 Section::ExpectedPatch
193 } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
194 Section::ExpectedExcerpts
195 } else {
196 eprintln!("Warning: Unrecognized section `{title:?}`");
197 Section::Other
198 };
199 }
200 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
201 let heading = mem::take(&mut text);
202 match current_section {
203 Section::ExpectedExcerpts => {
204 named.example.expected_context.push(ExpectedContextEntry {
205 heading,
206 alternatives: Vec::new(),
207 });
208 }
209 _ => {}
210 }
211 }
212 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
213 let heading = mem::take(&mut text);
214 match current_section {
215 Section::ExpectedExcerpts => {
216 let expected_context = &mut named.example.expected_context;
217 let last_entry = expected_context.last_mut().unwrap();
218 last_entry.alternatives.push(ExpectedExcerptSet {
219 heading,
220 excerpts: Vec::new(),
221 })
222 }
223 _ => {}
224 }
225 }
226 Event::End(TagEnd::Heading(level)) => {
227 anyhow::bail!("Unexpected heading level: {level}");
228 }
229 Event::Start(Tag::CodeBlock(kind)) => {
230 match kind {
231 CodeBlockKind::Fenced(info) => {
232 block_info = info;
233 }
234 CodeBlockKind::Indented => {
235 anyhow::bail!("Unexpected indented codeblock");
236 }
237 };
238 }
239 Event::Start(_) => {
240 text.clear();
241 block_info = "".into();
242 }
243 Event::End(TagEnd::CodeBlock) => {
244 let block_info = block_info.trim();
245 match current_section {
246 Section::UncommittedDiff => {
247 named.example.uncommitted_diff = mem::take(&mut text);
248 }
249 Section::EditHistory => {
250 named.example.edit_history.push_str(&mem::take(&mut text));
251 }
252 Section::CursorPosition => {
253 named.example.cursor_path = block_info.into();
254 named.example.cursor_position = mem::take(&mut text);
255 }
256 Section::ExpectedExcerpts => {
257 let text = mem::take(&mut text);
258 for excerpt in text.split("\n…\n") {
259 let (mut text, required_lines) = extract_required_lines(&excerpt);
260 if !text.ends_with('\n') {
261 text.push('\n');
262 }
263 let alternatives = &mut named
264 .example
265 .expected_context
266 .last_mut()
267 .unwrap()
268 .alternatives;
269
270 if alternatives.is_empty() {
271 alternatives.push(ExpectedExcerptSet {
272 heading: String::new(),
273 excerpts: vec![],
274 });
275 }
276
277 alternatives
278 .last_mut()
279 .unwrap()
280 .excerpts
281 .push(ExpectedExcerpt {
282 path: block_info.into(),
283 text,
284 required_lines,
285 });
286 }
287 }
288 Section::ExpectedPatch => {
289 named.example.expected_patch = mem::take(&mut text);
290 }
291 Section::Other => {}
292 }
293 }
294 _ => {}
295 }
296 }
297
298 if named.example.cursor_path.as_path() == Path::new("")
299 || named.example.cursor_position.is_empty()
300 {
301 anyhow::bail!("Missing cursor position codeblock");
302 }
303
304 Ok(named)
305 }
306
307 pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
308 match format {
309 ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
310 ExampleFormat::Toml => {
311 Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
312 }
313 ExampleFormat::Md => Ok(write!(out, "{}", self)?),
314 }
315 }
316
317 pub async fn setup_worktree(&self) -> Result<PathBuf> {
318 let (repo_owner, repo_name) = self.repo_name()?;
319 let file_name = self.file_name();
320
321 fs::create_dir_all(&*REPOS_DIR)?;
322 fs::create_dir_all(&*WORKTREES_DIR)?;
323
324 let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
325 let repo_lock = lock_repo(&repo_dir).await;
326
327 if !repo_dir.is_dir() {
328 fs::create_dir_all(&repo_dir)?;
329 run_git(&repo_dir, &["init"]).await?;
330 run_git(
331 &repo_dir,
332 &["remote", "add", "origin", &self.example.repository_url],
333 )
334 .await?;
335 }
336
337 // Resolve the example to a revision, fetching it if needed.
338 let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
339 let revision = if let Ok(revision) = revision {
340 revision
341 } else {
342 run_git(
343 &repo_dir,
344 &["fetch", "--depth", "1", "origin", &self.example.revision],
345 )
346 .await?;
347 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
348 if revision != self.example.revision {
349 run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
350 }
351 revision
352 };
353
354 // Create the worktree for this example if needed.
355 let worktree_path = WORKTREES_DIR.join(&file_name);
356 if worktree_path.is_dir() {
357 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
358 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
359 run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
360 } else {
361 let worktree_path_string = worktree_path.to_string_lossy();
362 run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
363 run_git(
364 &repo_dir,
365 &["worktree", "add", "-f", &worktree_path_string, &file_name],
366 )
367 .await?;
368 }
369 drop(repo_lock);
370
371 // Apply the uncommitted diff for this example.
372 if !self.example.uncommitted_diff.is_empty() {
373 let mut apply_process = smol::process::Command::new("git")
374 .current_dir(&worktree_path)
375 .args(&["apply", "-"])
376 .stdin(std::process::Stdio::piped())
377 .spawn()?;
378
379 let mut stdin = apply_process.stdin.take().unwrap();
380 stdin
381 .write_all(self.example.uncommitted_diff.as_bytes())
382 .await?;
383 stdin.close().await?;
384 drop(stdin);
385
386 let apply_result = apply_process.output().await?;
387 if !apply_result.status.success() {
388 anyhow::bail!(
389 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
390 apply_result.status,
391 String::from_utf8_lossy(&apply_result.stderr),
392 String::from_utf8_lossy(&apply_result.stdout),
393 );
394 }
395 }
396
397 Ok(worktree_path)
398 }
399
400 fn file_name(&self) -> String {
401 self.name
402 .chars()
403 .map(|c| {
404 if c.is_whitespace() {
405 '-'
406 } else {
407 c.to_ascii_lowercase()
408 }
409 })
410 .collect()
411 }
412
413 fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
414 // git@github.com:owner/repo.git
415 if self.example.repository_url.contains('@') {
416 let (owner, repo) = self
417 .example
418 .repository_url
419 .split_once(':')
420 .context("expected : in git url")?
421 .1
422 .split_once('/')
423 .context("expected / in git url")?;
424 Ok((
425 Cow::Borrowed(owner),
426 Cow::Borrowed(repo.trim_end_matches(".git")),
427 ))
428 // http://github.com/owner/repo.git
429 } else {
430 let url = Url::parse(&self.example.repository_url)?;
431 let mut segments = url.path_segments().context("empty http url")?;
432 let owner = segments
433 .next()
434 .context("expected owner path segment")?
435 .to_string();
436 let repo = segments
437 .next()
438 .context("expected repo path segment")?
439 .trim_end_matches(".git")
440 .to_string();
441 assert!(segments.next().is_none());
442
443 Ok((owner.into(), repo.into()))
444 }
445 }
446
447 pub async fn cursor_position(
448 &self,
449 project: &Entity<Project>,
450 cx: &mut AsyncApp,
451 ) -> Result<(Entity<Buffer>, Anchor)> {
452 let worktree = project.read_with(cx, |project, cx| {
453 project.visible_worktrees(cx).next().unwrap()
454 })?;
455 let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
456 let cursor_buffer = project
457 .update(cx, |project, cx| {
458 project.open_buffer(
459 ProjectPath {
460 worktree_id: worktree.read(cx).id(),
461 path: cursor_path,
462 },
463 cx,
464 )
465 })?
466 .await?;
467 let cursor_offset_within_excerpt = self
468 .example
469 .cursor_position
470 .find(CURSOR_MARKER)
471 .ok_or_else(|| anyhow!("missing cursor marker"))?;
472 let mut cursor_excerpt = self.example.cursor_position.clone();
473 cursor_excerpt.replace_range(
474 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
475 "",
476 );
477 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
478 let text = buffer.text();
479
480 let mut matches = text.match_indices(&cursor_excerpt);
481 let Some((excerpt_offset, _)) = matches.next() else {
482 anyhow::bail!(
483 "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
484 );
485 };
486 assert!(matches.next().is_none());
487
488 Ok(excerpt_offset)
489 })??;
490
491 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
492 let cursor_anchor =
493 cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
494 Ok((cursor_buffer, cursor_anchor))
495 }
496
497 #[must_use]
498 pub async fn apply_edit_history(
499 &self,
500 project: &Entity<Project>,
501 cx: &mut AsyncApp,
502 ) -> Result<OpenedBuffers<'_>> {
503 zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
504 }
505}
506
507fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
508 const MARKER: &str = "[ZETA]";
509 let mut new_text = String::new();
510 let mut required_lines = Vec::new();
511 let mut skipped_lines = 0_u32;
512
513 for (row, mut line) in text.split('\n').enumerate() {
514 if let Some(marker_column) = line.find(MARKER) {
515 let mut strip_column = marker_column;
516
517 while strip_column > 0 {
518 let prev_char = line[strip_column - 1..].chars().next().unwrap();
519 if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
520 strip_column -= 1;
521 } else {
522 break;
523 }
524 }
525
526 let metadata = &line[marker_column + MARKER.len()..];
527 if metadata.contains("required") {
528 required_lines.push(Line(row as u32 - skipped_lines));
529 }
530
531 if strip_column == 0 {
532 skipped_lines += 1;
533 continue;
534 }
535
536 line = &line[..strip_column];
537 }
538
539 new_text.push_str(line);
540 new_text.push('\n');
541 }
542
543 new_text.pop();
544
545 (new_text, required_lines)
546}
547
548async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
549 let output = smol::process::Command::new("git")
550 .current_dir(repo_path)
551 .args(args)
552 .output()
553 .await?;
554
555 anyhow::ensure!(
556 output.status.success(),
557 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
558 args.join(" "),
559 repo_path.display(),
560 output.status,
561 String::from_utf8_lossy(&output.stderr),
562 String::from_utf8_lossy(&output.stdout),
563 );
564 Ok(String::from_utf8(output.stdout)?.trim().to_string())
565}
566
567impl Display for NamedExample {
568 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
569 write!(f, "# {}\n\n", self.name)?;
570 write!(
571 f,
572 "{REPOSITORY_URL_FIELD} = {}\n",
573 self.example.repository_url
574 )?;
575 write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
576
577 write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
578 write!(f, "`````diff\n")?;
579 write!(f, "{}", self.example.uncommitted_diff)?;
580 write!(f, "`````\n")?;
581
582 if !self.example.edit_history.is_empty() {
583 write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
584 }
585
586 write!(
587 f,
588 "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
589 self.example.cursor_path.display(),
590 self.example.cursor_position
591 )?;
592 write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
593
594 if !self.example.expected_patch.is_empty() {
595 write!(
596 f,
597 "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
598 self.example.expected_patch
599 )?;
600 }
601
602 if !self.example.expected_context.is_empty() {
603 write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
604
605 for entry in &self.example.expected_context {
606 write!(f, "\n### {}\n\n", entry.heading)?;
607
608 let skip_h4 =
609 entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
610
611 for excerpt_set in &entry.alternatives {
612 if !skip_h4 {
613 write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
614 }
615
616 for excerpt in &excerpt_set.excerpts {
617 write!(
618 f,
619 "`````{}{}\n{}`````\n\n",
620 excerpt
621 .path
622 .extension()
623 .map(|ext| format!("{} ", ext.to_string_lossy()))
624 .unwrap_or_default(),
625 excerpt.path.display(),
626 excerpt.text
627 )?;
628 }
629 }
630 }
631 }
632
633 Ok(())
634 }
635}
636
637thread_local! {
638 static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
639}
640
641#[must_use]
642pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
643 REPO_LOCKS
644 .with(|cell| {
645 cell.borrow_mut()
646 .entry(path.as_ref().to_path_buf())
647 .or_default()
648 .clone()
649 })
650 .lock_owned()
651 .await
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657 use indoc::indoc;
658 use pretty_assertions::assert_eq;
659
660 #[test]
661 fn test_extract_required_lines() {
662 let input = indoc! {"
663 zero
664 one // [ZETA] required
665 two
666 // [ZETA] something
667 three
668 four # [ZETA] required
669 five
670 "};
671
672 let expected_updated_input = indoc! {"
673 zero
674 one
675 two
676 three
677 four
678 five
679 "};
680
681 let expected_required_lines = vec![Line(1), Line(4)];
682
683 let (updated_input, required_lines) = extract_required_lines(input);
684 assert_eq!(updated_input, expected_updated_input);
685 assert_eq!(required_lines, expected_required_lines);
686 }
687}