1use std::{
2 borrow::Cow,
3 cell::RefCell,
4 fmt::{self, Display},
5 fs,
6 hash::Hash,
7 hash::Hasher,
8 io::Write,
9 mem,
10 path::{Path, PathBuf},
11 sync::{Arc, OnceLock},
12};
13
14use crate::headless::ZetaCliAppState;
15use anyhow::{Context as _, Result, anyhow};
16use clap::ValueEnum;
17use cloud_zeta2_prompt::CURSOR_MARKER;
18use collections::HashMap;
19use edit_prediction::udiff::OpenedBuffers;
20use futures::{
21 AsyncWriteExt as _,
22 lock::{Mutex, OwnedMutexGuard},
23};
24use futures::{FutureExt as _, future::Shared};
25use gpui::{AsyncApp, Entity, Task, http_client::Url};
26use language::{Anchor, Buffer};
27use project::{Project, ProjectPath};
28use pulldown_cmark::CowStr;
29use serde::{Deserialize, Serialize};
30use util::{paths::PathStyle, rel_path::RelPath};
31
32use crate::paths::{REPOS_DIR, WORKTREES_DIR};
33
34const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
35const EDIT_HISTORY_HEADING: &str = "Edit History";
36const CURSOR_POSITION_HEADING: &str = "Cursor Position";
37const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
38const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
39const REPOSITORY_URL_FIELD: &str = "repository_url";
40const REVISION_FIELD: &str = "revision";
41
42#[derive(Debug, Clone)]
43pub struct NamedExample {
44 pub name: String,
45 pub example: Example,
46}
47
48#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
49pub struct Example {
50 pub repository_url: String,
51 pub revision: String,
52 pub uncommitted_diff: String,
53 pub cursor_path: PathBuf,
54 pub cursor_position: String,
55 pub edit_history: String,
56 pub expected_patch: String,
57}
58
59impl Example {
60 fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
61 // git@github.com:owner/repo.git
62 if self.repository_url.contains('@') {
63 let (owner, repo) = self
64 .repository_url
65 .split_once(':')
66 .context("expected : in git url")?
67 .1
68 .split_once('/')
69 .context("expected / in git url")?;
70 Ok((
71 Cow::Borrowed(owner),
72 Cow::Borrowed(repo.trim_end_matches(".git")),
73 ))
74 // http://github.com/owner/repo.git
75 } else {
76 let url = Url::parse(&self.repository_url)?;
77 let mut segments = url.path_segments().context("empty http url")?;
78 let owner = segments
79 .next()
80 .context("expected owner path segment")?
81 .to_string();
82 let repo = segments
83 .next()
84 .context("expected repo path segment")?
85 .trim_end_matches(".git")
86 .to_string();
87 assert!(segments.next().is_none());
88
89 Ok((owner.into(), repo.into()))
90 }
91 }
92
93 pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
94 let (repo_owner, repo_name) = self.repo_name()?;
95
96 let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
97 let repo_lock = lock_repo(&repo_dir).await;
98
99 if !repo_dir.is_dir() {
100 fs::create_dir_all(&repo_dir)?;
101 run_git(&repo_dir, &["init"]).await?;
102 run_git(
103 &repo_dir,
104 &["remote", "add", "origin", &self.repository_url],
105 )
106 .await?;
107 }
108
109 // Resolve the example to a revision, fetching it if needed.
110 let revision = run_git(
111 &repo_dir,
112 &["rev-parse", &format!("{}^{{commit}}", self.revision)],
113 )
114 .await;
115 let revision = if let Ok(revision) = revision {
116 revision
117 } else {
118 if run_git(
119 &repo_dir,
120 &["fetch", "--depth", "1", "origin", &self.revision],
121 )
122 .await
123 .is_err()
124 {
125 run_git(&repo_dir, &["fetch", "origin"]).await?;
126 }
127 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
128 if revision != self.revision {
129 run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
130 }
131 revision
132 };
133
134 // Create the worktree for this example if needed.
135 let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
136 if worktree_path.is_dir() {
137 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
138 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
139 run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
140 } else {
141 let worktree_path_string = worktree_path.to_string_lossy();
142 run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
143 run_git(
144 &repo_dir,
145 &["worktree", "add", "-f", &worktree_path_string, &file_name],
146 )
147 .await?;
148 }
149 drop(repo_lock);
150
151 // Apply the uncommitted diff for this example.
152 if !self.uncommitted_diff.is_empty() {
153 let mut apply_process = smol::process::Command::new("git")
154 .current_dir(&worktree_path)
155 .args(&["apply", "-"])
156 .stdin(std::process::Stdio::piped())
157 .spawn()?;
158
159 let mut stdin = apply_process.stdin.take().unwrap();
160 stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
161 stdin.close().await?;
162 drop(stdin);
163
164 let apply_result = apply_process.output().await?;
165 if !apply_result.status.success() {
166 anyhow::bail!(
167 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
168 apply_result.status,
169 String::from_utf8_lossy(&apply_result.stderr),
170 String::from_utf8_lossy(&apply_result.stdout),
171 );
172 }
173 }
174
175 Ok(worktree_path)
176 }
177
178 pub fn unique_name(&self) -> String {
179 let mut hasher = std::hash::DefaultHasher::new();
180 self.hash(&mut hasher);
181 let disambiguator = hasher.finish();
182 let hash = format!("{:04x}", disambiguator);
183 format!("{}_{}", &self.revision[..8], &hash[..4])
184 }
185}
186
187pub type ActualExcerpt = Excerpt;
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190pub struct Excerpt {
191 pub path: PathBuf,
192 pub text: String,
193}
194
195#[derive(ValueEnum, Debug, Clone)]
196pub enum ExampleFormat {
197 Json,
198 Toml,
199 Md,
200}
201
202impl NamedExample {
203 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
204 let path = path.as_ref();
205 let content = std::fs::read_to_string(path)?;
206 let ext = path.extension();
207
208 match ext.and_then(|s| s.to_str()) {
209 Some("json") => Ok(Self {
210 name: path.file_stem().unwrap_or_default().display().to_string(),
211 example: serde_json::from_str(&content)?,
212 }),
213 Some("toml") => Ok(Self {
214 name: path.file_stem().unwrap_or_default().display().to_string(),
215 example: toml::from_str(&content)?,
216 }),
217 Some("md") => Self::parse_md(&content),
218 Some(_) => {
219 anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
220 }
221 None => {
222 anyhow::bail!(
223 "Failed to determine example type since the file does not have an extension."
224 );
225 }
226 }
227 }
228
229 pub fn parse_md(input: &str) -> Result<Self> {
230 use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
231
232 let parser = Parser::new(input);
233
234 let mut named = NamedExample {
235 name: String::new(),
236 example: Example {
237 repository_url: String::new(),
238 revision: String::new(),
239 uncommitted_diff: String::new(),
240 cursor_path: PathBuf::new(),
241 cursor_position: String::new(),
242 edit_history: String::new(),
243 expected_patch: String::new(),
244 },
245 };
246
247 let mut text = String::new();
248 let mut block_info: CowStr = "".into();
249
250 #[derive(PartialEq)]
251 enum Section {
252 UncommittedDiff,
253 EditHistory,
254 CursorPosition,
255 ExpectedExcerpts,
256 ExpectedPatch,
257 Other,
258 }
259
260 let mut current_section = Section::Other;
261
262 for event in parser {
263 match event {
264 Event::Text(line) => {
265 text.push_str(&line);
266
267 if !named.name.is_empty()
268 && current_section == Section::Other
269 // in h1 section
270 && let Some((field, value)) = line.split_once('=')
271 {
272 match field.trim() {
273 REPOSITORY_URL_FIELD => {
274 named.example.repository_url = value.trim().to_string();
275 }
276 REVISION_FIELD => {
277 named.example.revision = value.trim().to_string();
278 }
279 _ => {}
280 }
281 }
282 }
283 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
284 if !named.name.is_empty() {
285 anyhow::bail!(
286 "Found multiple H1 headings. There should only be one with the name of the example."
287 );
288 }
289 named.name = mem::take(&mut text);
290 }
291 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
292 let title = mem::take(&mut text);
293 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
294 Section::UncommittedDiff
295 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
296 Section::EditHistory
297 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
298 Section::CursorPosition
299 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
300 Section::ExpectedPatch
301 } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
302 Section::ExpectedExcerpts
303 } else {
304 Section::Other
305 };
306 }
307 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
308 mem::take(&mut text);
309 }
310 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
311 mem::take(&mut text);
312 }
313 Event::End(TagEnd::Heading(level)) => {
314 anyhow::bail!("Unexpected heading level: {level}");
315 }
316 Event::Start(Tag::CodeBlock(kind)) => {
317 match kind {
318 CodeBlockKind::Fenced(info) => {
319 block_info = info;
320 }
321 CodeBlockKind::Indented => {
322 anyhow::bail!("Unexpected indented codeblock");
323 }
324 };
325 }
326 Event::Start(_) => {
327 text.clear();
328 block_info = "".into();
329 }
330 Event::End(TagEnd::CodeBlock) => {
331 let block_info = block_info.trim();
332 match current_section {
333 Section::UncommittedDiff => {
334 named.example.uncommitted_diff = mem::take(&mut text);
335 }
336 Section::EditHistory => {
337 named.example.edit_history.push_str(&mem::take(&mut text));
338 }
339 Section::CursorPosition => {
340 named.example.cursor_path = block_info.into();
341 named.example.cursor_position = mem::take(&mut text);
342 }
343 Section::ExpectedExcerpts => {
344 mem::take(&mut text);
345 }
346 Section::ExpectedPatch => {
347 named.example.expected_patch = mem::take(&mut text);
348 }
349 Section::Other => {}
350 }
351 }
352 _ => {}
353 }
354 }
355
356 if named.example.cursor_path.as_path() == Path::new("")
357 || named.example.cursor_position.is_empty()
358 {
359 anyhow::bail!("Missing cursor position codeblock");
360 }
361
362 Ok(named)
363 }
364
365 pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
366 match format {
367 ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
368 ExampleFormat::Toml => {
369 Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
370 }
371 ExampleFormat::Md => Ok(write!(out, "{}", self)?),
372 }
373 }
374
375 pub async fn setup_project(
376 &self,
377 app_state: &Arc<ZetaCliAppState>,
378 cx: &mut AsyncApp,
379 ) -> Result<Entity<Project>> {
380 let worktree_path = self.setup_worktree().await?;
381
382 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
383
384 AUTHENTICATED
385 .get_or_init(|| {
386 let client = app_state.client.clone();
387 cx.spawn(async move |cx| {
388 client
389 .sign_in_with_optional_connect(true, cx)
390 .await
391 .unwrap();
392 })
393 .shared()
394 })
395 .clone()
396 .await;
397
398 let project = cx.update(|cx| {
399 Project::local(
400 app_state.client.clone(),
401 app_state.node_runtime.clone(),
402 app_state.user_store.clone(),
403 app_state.languages.clone(),
404 app_state.fs.clone(),
405 None,
406 cx,
407 )
408 })?;
409
410 let worktree = project
411 .update(cx, |project, cx| {
412 project.create_worktree(&worktree_path, true, cx)
413 })?
414 .await?;
415 worktree
416 .read_with(cx, |worktree, _cx| {
417 worktree.as_local().unwrap().scan_complete()
418 })?
419 .await;
420
421 anyhow::Ok(project)
422 }
423
424 pub async fn setup_worktree(&self) -> Result<PathBuf> {
425 self.example.setup_worktree(self.file_name()).await
426 }
427
428 pub fn file_name(&self) -> String {
429 self.name
430 .chars()
431 .map(|c| {
432 if c.is_whitespace() {
433 '-'
434 } else {
435 c.to_ascii_lowercase()
436 }
437 })
438 .collect()
439 }
440
441 pub async fn cursor_position(
442 &self,
443 project: &Entity<Project>,
444 cx: &mut AsyncApp,
445 ) -> Result<(Entity<Buffer>, Anchor)> {
446 let worktree = project.read_with(cx, |project, cx| {
447 project.visible_worktrees(cx).next().unwrap()
448 })?;
449 let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
450 let cursor_buffer = project
451 .update(cx, |project, cx| {
452 project.open_buffer(
453 ProjectPath {
454 worktree_id: worktree.read(cx).id(),
455 path: cursor_path,
456 },
457 cx,
458 )
459 })?
460 .await?;
461 let cursor_offset_within_excerpt = self
462 .example
463 .cursor_position
464 .find(CURSOR_MARKER)
465 .ok_or_else(|| anyhow!("missing cursor marker"))?;
466 let mut cursor_excerpt = self.example.cursor_position.clone();
467 cursor_excerpt.replace_range(
468 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
469 "",
470 );
471 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
472 let text = buffer.text();
473
474 let mut matches = text.match_indices(&cursor_excerpt);
475 let Some((excerpt_offset, _)) = matches.next() else {
476 anyhow::bail!(
477 "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
478 );
479 };
480 assert!(matches.next().is_none());
481
482 Ok(excerpt_offset)
483 })??;
484
485 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
486 let cursor_anchor =
487 cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
488 Ok((cursor_buffer, cursor_anchor))
489 }
490
491 #[must_use]
492 pub async fn apply_edit_history(
493 &self,
494 project: &Entity<Project>,
495 cx: &mut AsyncApp,
496 ) -> Result<OpenedBuffers<'_>> {
497 edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
498 }
499}
500
501async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
502 let output = smol::process::Command::new("git")
503 .current_dir(repo_path)
504 .args(args)
505 .output()
506 .await?;
507
508 anyhow::ensure!(
509 output.status.success(),
510 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
511 args.join(" "),
512 repo_path.display(),
513 output.status,
514 String::from_utf8_lossy(&output.stderr),
515 String::from_utf8_lossy(&output.stdout),
516 );
517 Ok(String::from_utf8(output.stdout)?.trim().to_string())
518}
519
520impl Display for NamedExample {
521 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522 write!(f, "# {}\n\n", self.name)?;
523 write!(
524 f,
525 "{REPOSITORY_URL_FIELD} = {}\n",
526 self.example.repository_url
527 )?;
528 write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
529
530 write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
531 write!(f, "`````diff\n")?;
532 write!(f, "{}", self.example.uncommitted_diff)?;
533 write!(f, "`````\n")?;
534
535 if !self.example.edit_history.is_empty() {
536 write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
537 }
538
539 write!(
540 f,
541 "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
542 self.example.cursor_path.display(),
543 self.example.cursor_position
544 )?;
545 write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
546
547 if !self.example.expected_patch.is_empty() {
548 write!(
549 f,
550 "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
551 self.example.expected_patch
552 )?;
553 }
554
555 Ok(())
556 }
557}
558
559thread_local! {
560 static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
561}
562
563#[must_use]
564pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
565 REPO_LOCKS
566 .with(|cell| {
567 cell.borrow_mut()
568 .entry(path.as_ref().to_path_buf())
569 .or_default()
570 .clone()
571 })
572 .lock_owned()
573 .await
574}