1use std::{
2 borrow::Cow,
3 cell::RefCell,
4 fmt::{self, Display},
5 fs,
6 io::Write,
7 mem,
8 path::{Path, PathBuf},
9 sync::{Arc, OnceLock},
10};
11
12use crate::headless::ZetaCliAppState;
13use anyhow::{Context as _, Result, anyhow};
14use clap::ValueEnum;
15use cloud_zeta2_prompt::CURSOR_MARKER;
16use collections::HashMap;
17use futures::{
18 AsyncWriteExt as _,
19 lock::{Mutex, OwnedMutexGuard},
20};
21use futures::{FutureExt as _, future::Shared};
22use gpui::{AsyncApp, Entity, Task, http_client::Url};
23use language::{Anchor, Buffer};
24use project::{Project, ProjectPath};
25use pulldown_cmark::CowStr;
26use serde::{Deserialize, Serialize};
27use util::{paths::PathStyle, rel_path::RelPath};
28use zeta::udiff::OpenedBuffers;
29
30use crate::paths::{REPOS_DIR, WORKTREES_DIR};
31
32const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
33const EDIT_HISTORY_HEADING: &str = "Edit History";
34const CURSOR_POSITION_HEADING: &str = "Cursor Position";
35const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
36const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
37const REPOSITORY_URL_FIELD: &str = "repository_url";
38const REVISION_FIELD: &str = "revision";
39
40#[derive(Debug, Clone)]
41pub struct NamedExample {
42 pub name: String,
43 pub example: Example,
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize)]
47pub struct Example {
48 pub repository_url: String,
49 pub revision: String,
50 pub uncommitted_diff: String,
51 pub cursor_path: PathBuf,
52 pub cursor_position: String,
53 pub edit_history: String,
54 pub expected_patch: String,
55}
56
57pub type ActualExcerpt = Excerpt;
58
59#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct Excerpt {
61 pub path: PathBuf,
62 pub text: String,
63}
64
65#[derive(ValueEnum, Debug, Clone)]
66pub enum ExampleFormat {
67 Json,
68 Toml,
69 Md,
70}
71
72impl NamedExample {
73 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
74 let path = path.as_ref();
75 let content = std::fs::read_to_string(path)?;
76 let ext = path.extension();
77
78 match ext.and_then(|s| s.to_str()) {
79 Some("json") => Ok(Self {
80 name: path.file_stem().unwrap_or_default().display().to_string(),
81 example: serde_json::from_str(&content)?,
82 }),
83 Some("toml") => Ok(Self {
84 name: path.file_stem().unwrap_or_default().display().to_string(),
85 example: toml::from_str(&content)?,
86 }),
87 Some("md") => Self::parse_md(&content),
88 Some(_) => {
89 anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
90 }
91 None => {
92 anyhow::bail!(
93 "Failed to determine example type since the file does not have an extension."
94 );
95 }
96 }
97 }
98
99 pub fn parse_md(input: &str) -> Result<Self> {
100 use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
101
102 let parser = Parser::new(input);
103
104 let mut named = NamedExample {
105 name: String::new(),
106 example: Example {
107 repository_url: String::new(),
108 revision: String::new(),
109 uncommitted_diff: String::new(),
110 cursor_path: PathBuf::new(),
111 cursor_position: String::new(),
112 edit_history: String::new(),
113 expected_patch: String::new(),
114 },
115 };
116
117 let mut text = String::new();
118 let mut block_info: CowStr = "".into();
119
120 #[derive(PartialEq)]
121 enum Section {
122 UncommittedDiff,
123 EditHistory,
124 CursorPosition,
125 ExpectedExcerpts,
126 ExpectedPatch,
127 Other,
128 }
129
130 let mut current_section = Section::Other;
131
132 for event in parser {
133 match event {
134 Event::Text(line) => {
135 text.push_str(&line);
136
137 if !named.name.is_empty()
138 && current_section == Section::Other
139 // in h1 section
140 && let Some((field, value)) = line.split_once('=')
141 {
142 match field.trim() {
143 REPOSITORY_URL_FIELD => {
144 named.example.repository_url = value.trim().to_string();
145 }
146 REVISION_FIELD => {
147 named.example.revision = value.trim().to_string();
148 }
149 _ => {}
150 }
151 }
152 }
153 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
154 if !named.name.is_empty() {
155 anyhow::bail!(
156 "Found multiple H1 headings. There should only be one with the name of the example."
157 );
158 }
159 named.name = mem::take(&mut text);
160 }
161 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
162 let title = mem::take(&mut text);
163 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
164 Section::UncommittedDiff
165 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
166 Section::EditHistory
167 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
168 Section::CursorPosition
169 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
170 Section::ExpectedPatch
171 } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
172 Section::ExpectedExcerpts
173 } else {
174 Section::Other
175 };
176 }
177 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
178 mem::take(&mut text);
179 }
180 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
181 mem::take(&mut text);
182 }
183 Event::End(TagEnd::Heading(level)) => {
184 anyhow::bail!("Unexpected heading level: {level}");
185 }
186 Event::Start(Tag::CodeBlock(kind)) => {
187 match kind {
188 CodeBlockKind::Fenced(info) => {
189 block_info = info;
190 }
191 CodeBlockKind::Indented => {
192 anyhow::bail!("Unexpected indented codeblock");
193 }
194 };
195 }
196 Event::Start(_) => {
197 text.clear();
198 block_info = "".into();
199 }
200 Event::End(TagEnd::CodeBlock) => {
201 let block_info = block_info.trim();
202 match current_section {
203 Section::UncommittedDiff => {
204 named.example.uncommitted_diff = mem::take(&mut text);
205 }
206 Section::EditHistory => {
207 named.example.edit_history.push_str(&mem::take(&mut text));
208 }
209 Section::CursorPosition => {
210 named.example.cursor_path = block_info.into();
211 named.example.cursor_position = mem::take(&mut text);
212 }
213 Section::ExpectedExcerpts => {
214 mem::take(&mut text);
215 }
216 Section::ExpectedPatch => {
217 named.example.expected_patch = mem::take(&mut text);
218 }
219 Section::Other => {}
220 }
221 }
222 _ => {}
223 }
224 }
225
226 if named.example.cursor_path.as_path() == Path::new("")
227 || named.example.cursor_position.is_empty()
228 {
229 anyhow::bail!("Missing cursor position codeblock");
230 }
231
232 Ok(named)
233 }
234
235 pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
236 match format {
237 ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
238 ExampleFormat::Toml => {
239 Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
240 }
241 ExampleFormat::Md => Ok(write!(out, "{}", self)?),
242 }
243 }
244
245 pub async fn setup_project(
246 &self,
247 app_state: &Arc<ZetaCliAppState>,
248 cx: &mut AsyncApp,
249 ) -> Result<Entity<Project>> {
250 let worktree_path = self.setup_worktree().await?;
251
252 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
253
254 AUTHENTICATED
255 .get_or_init(|| {
256 let client = app_state.client.clone();
257 cx.spawn(async move |cx| {
258 client
259 .sign_in_with_optional_connect(true, cx)
260 .await
261 .unwrap();
262 })
263 .shared()
264 })
265 .clone()
266 .await;
267
268 let project = cx.update(|cx| {
269 Project::local(
270 app_state.client.clone(),
271 app_state.node_runtime.clone(),
272 app_state.user_store.clone(),
273 app_state.languages.clone(),
274 app_state.fs.clone(),
275 None,
276 cx,
277 )
278 })?;
279
280 let worktree = project
281 .update(cx, |project, cx| {
282 project.create_worktree(&worktree_path, true, cx)
283 })?
284 .await?;
285 worktree
286 .read_with(cx, |worktree, _cx| {
287 worktree.as_local().unwrap().scan_complete()
288 })?
289 .await;
290
291 anyhow::Ok(project)
292 }
293
294 pub async fn setup_worktree(&self) -> Result<PathBuf> {
295 let (repo_owner, repo_name) = self.repo_name()?;
296 let file_name = self.file_name();
297
298 let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
299 let repo_lock = lock_repo(&repo_dir).await;
300
301 if !repo_dir.is_dir() {
302 fs::create_dir_all(&repo_dir)?;
303 run_git(&repo_dir, &["init"]).await?;
304 run_git(
305 &repo_dir,
306 &["remote", "add", "origin", &self.example.repository_url],
307 )
308 .await?;
309 }
310
311 // Resolve the example to a revision, fetching it if needed.
312 let revision = run_git(
313 &repo_dir,
314 &[
315 "rev-parse",
316 &format!("{}^{{commit}}", self.example.revision),
317 ],
318 )
319 .await;
320 let revision = if let Ok(revision) = revision {
321 revision
322 } else {
323 run_git(
324 &repo_dir,
325 &["fetch", "--depth", "1", "origin", &self.example.revision],
326 )
327 .await?;
328 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
329 if revision != self.example.revision {
330 run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
331 }
332 revision
333 };
334
335 // Create the worktree for this example if needed.
336 let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
337 if worktree_path.is_dir() {
338 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
339 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
340 run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
341 } else {
342 let worktree_path_string = worktree_path.to_string_lossy();
343 run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
344 run_git(
345 &repo_dir,
346 &["worktree", "add", "-f", &worktree_path_string, &file_name],
347 )
348 .await?;
349 }
350 drop(repo_lock);
351
352 // Apply the uncommitted diff for this example.
353 if !self.example.uncommitted_diff.is_empty() {
354 let mut apply_process = smol::process::Command::new("git")
355 .current_dir(&worktree_path)
356 .args(&["apply", "-"])
357 .stdin(std::process::Stdio::piped())
358 .spawn()?;
359
360 let mut stdin = apply_process.stdin.take().unwrap();
361 stdin
362 .write_all(self.example.uncommitted_diff.as_bytes())
363 .await?;
364 stdin.close().await?;
365 drop(stdin);
366
367 let apply_result = apply_process.output().await?;
368 if !apply_result.status.success() {
369 anyhow::bail!(
370 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
371 apply_result.status,
372 String::from_utf8_lossy(&apply_result.stderr),
373 String::from_utf8_lossy(&apply_result.stdout),
374 );
375 }
376 }
377
378 Ok(worktree_path)
379 }
380
381 pub fn file_name(&self) -> String {
382 self.name
383 .chars()
384 .map(|c| {
385 if c.is_whitespace() {
386 '-'
387 } else {
388 c.to_ascii_lowercase()
389 }
390 })
391 .collect()
392 }
393
394 fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
395 // git@github.com:owner/repo.git
396 if self.example.repository_url.contains('@') {
397 let (owner, repo) = self
398 .example
399 .repository_url
400 .split_once(':')
401 .context("expected : in git url")?
402 .1
403 .split_once('/')
404 .context("expected / in git url")?;
405 Ok((
406 Cow::Borrowed(owner),
407 Cow::Borrowed(repo.trim_end_matches(".git")),
408 ))
409 // http://github.com/owner/repo.git
410 } else {
411 let url = Url::parse(&self.example.repository_url)?;
412 let mut segments = url.path_segments().context("empty http url")?;
413 let owner = segments
414 .next()
415 .context("expected owner path segment")?
416 .to_string();
417 let repo = segments
418 .next()
419 .context("expected repo path segment")?
420 .trim_end_matches(".git")
421 .to_string();
422 assert!(segments.next().is_none());
423
424 Ok((owner.into(), repo.into()))
425 }
426 }
427
428 pub async fn cursor_position(
429 &self,
430 project: &Entity<Project>,
431 cx: &mut AsyncApp,
432 ) -> Result<(Entity<Buffer>, Anchor)> {
433 let worktree = project.read_with(cx, |project, cx| {
434 project.visible_worktrees(cx).next().unwrap()
435 })?;
436 let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
437 let cursor_buffer = project
438 .update(cx, |project, cx| {
439 project.open_buffer(
440 ProjectPath {
441 worktree_id: worktree.read(cx).id(),
442 path: cursor_path,
443 },
444 cx,
445 )
446 })?
447 .await?;
448 let cursor_offset_within_excerpt = self
449 .example
450 .cursor_position
451 .find(CURSOR_MARKER)
452 .ok_or_else(|| anyhow!("missing cursor marker"))?;
453 let mut cursor_excerpt = self.example.cursor_position.clone();
454 cursor_excerpt.replace_range(
455 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
456 "",
457 );
458 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
459 let text = buffer.text();
460
461 let mut matches = text.match_indices(&cursor_excerpt);
462 let Some((excerpt_offset, _)) = matches.next() else {
463 anyhow::bail!(
464 "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
465 );
466 };
467 assert!(matches.next().is_none());
468
469 Ok(excerpt_offset)
470 })??;
471
472 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
473 let cursor_anchor =
474 cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
475 Ok((cursor_buffer, cursor_anchor))
476 }
477
478 #[must_use]
479 pub async fn apply_edit_history(
480 &self,
481 project: &Entity<Project>,
482 cx: &mut AsyncApp,
483 ) -> Result<OpenedBuffers<'_>> {
484 zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await
485 }
486}
487
488async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
489 let output = smol::process::Command::new("git")
490 .current_dir(repo_path)
491 .args(args)
492 .output()
493 .await?;
494
495 anyhow::ensure!(
496 output.status.success(),
497 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
498 args.join(" "),
499 repo_path.display(),
500 output.status,
501 String::from_utf8_lossy(&output.stderr),
502 String::from_utf8_lossy(&output.stdout),
503 );
504 Ok(String::from_utf8(output.stdout)?.trim().to_string())
505}
506
507impl Display for NamedExample {
508 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
509 write!(f, "# {}\n\n", self.name)?;
510 write!(
511 f,
512 "{REPOSITORY_URL_FIELD} = {}\n",
513 self.example.repository_url
514 )?;
515 write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
516
517 write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
518 write!(f, "`````diff\n")?;
519 write!(f, "{}", self.example.uncommitted_diff)?;
520 write!(f, "`````\n")?;
521
522 if !self.example.edit_history.is_empty() {
523 write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
524 }
525
526 write!(
527 f,
528 "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
529 self.example.cursor_path.display(),
530 self.example.cursor_position
531 )?;
532 write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
533
534 if !self.example.expected_patch.is_empty() {
535 write!(
536 f,
537 "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
538 self.example.expected_patch
539 )?;
540 }
541
542 Ok(())
543 }
544}
545
546thread_local! {
547 static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
548}
549
550#[must_use]
551pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
552 REPO_LOCKS
553 .with(|cell| {
554 cell.borrow_mut()
555 .entry(path.as_ref().to_path_buf())
556 .or_default()
557 .clone()
558 })
559 .lock_owned()
560 .await
561}