1use std::{
2 borrow::Cow,
3 env,
4 fmt::{self, Display},
5 fs,
6 io::Write,
7 mem,
8 path::{Path, PathBuf},
9};
10
11use anyhow::{Context as _, Result};
12use clap::ValueEnum;
13use collections::HashSet;
14use futures::AsyncWriteExt as _;
15use gpui::{AsyncApp, Entity, http_client::Url};
16use language::Buffer;
17use project::{Project, ProjectPath};
18use pulldown_cmark::CowStr;
19use serde::{Deserialize, Serialize};
20
21const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
22const EDIT_HISTORY_HEADING: &str = "Edit History";
23const CURSOR_POSITION_HEADING: &str = "Cursor Position";
24const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
25const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
26const REPOSITORY_URL_FIELD: &str = "repository_url";
27const REVISION_FIELD: &str = "revision";
28
29#[derive(Debug)]
30pub struct NamedExample {
31 pub name: String,
32 pub example: Example,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36pub struct Example {
37 pub repository_url: String,
38 pub revision: String,
39 pub uncommitted_diff: String,
40 pub cursor_path: PathBuf,
41 pub cursor_position: String,
42 pub edit_history: String,
43 pub expected_patch: String,
44 pub expected_excerpts: Vec<ExpectedExcerpt>,
45}
46
47#[derive(Debug, Serialize, Deserialize)]
48pub struct ExpectedExcerpt {
49 path: PathBuf,
50 text: String,
51}
52
53#[derive(ValueEnum, Debug, Clone)]
54pub enum ExampleFormat {
55 Json,
56 Toml,
57 Md,
58}
59
60impl NamedExample {
61 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
62 let path = path.as_ref();
63 let content = std::fs::read_to_string(path)?;
64 let ext = path.extension();
65
66 match ext.and_then(|s| s.to_str()) {
67 Some("json") => Ok(Self {
68 name: path.file_stem().unwrap_or_default().display().to_string(),
69 example: serde_json::from_str(&content)?,
70 }),
71 Some("toml") => Ok(Self {
72 name: path.file_stem().unwrap_or_default().display().to_string(),
73 example: toml::from_str(&content)?,
74 }),
75 Some("md") => Self::parse_md(&content),
76 Some(_) => {
77 anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
78 }
79 None => {
80 anyhow::bail!(
81 "Failed to determine example type since the file does not have an extension."
82 );
83 }
84 }
85 }
86
87 pub fn parse_md(input: &str) -> Result<Self> {
88 use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
89
90 let parser = Parser::new(input);
91
92 let mut named = NamedExample {
93 name: String::new(),
94 example: Example {
95 repository_url: String::new(),
96 revision: String::new(),
97 uncommitted_diff: String::new(),
98 cursor_path: PathBuf::new(),
99 cursor_position: String::new(),
100 edit_history: String::new(),
101 expected_patch: String::new(),
102 expected_excerpts: Vec::new(),
103 },
104 };
105
106 let mut text = String::new();
107 let mut current_section = String::new();
108 let mut block_info: CowStr = "".into();
109
110 for event in parser {
111 match event {
112 Event::Text(line) => {
113 text.push_str(&line);
114
115 if !named.name.is_empty()
116 && current_section.is_empty()
117 // in h1 section
118 && let Some((field, value)) = line.split_once('=')
119 {
120 match field.trim() {
121 REPOSITORY_URL_FIELD => {
122 named.example.repository_url = value.trim().to_string();
123 }
124 REVISION_FIELD => {
125 named.example.revision = value.trim().to_string();
126 }
127 _ => {
128 eprintln!("Warning: Unrecognized field `{field}`");
129 }
130 }
131 }
132 }
133 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
134 if !named.name.is_empty() {
135 anyhow::bail!(
136 "Found multiple H1 headings. There should only be one with the name of the example."
137 );
138 }
139 named.name = mem::take(&mut text);
140 }
141 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
142 current_section = mem::take(&mut text);
143 }
144 Event::End(TagEnd::Heading(level)) => {
145 anyhow::bail!("Unexpected heading level: {level}");
146 }
147 Event::Start(Tag::CodeBlock(kind)) => {
148 match kind {
149 CodeBlockKind::Fenced(info) => {
150 block_info = info;
151 }
152 CodeBlockKind::Indented => {
153 anyhow::bail!("Unexpected indented codeblock");
154 }
155 };
156 }
157 Event::Start(_) => {
158 text.clear();
159 block_info = "".into();
160 }
161 Event::End(TagEnd::CodeBlock) => {
162 let block_info = block_info.trim();
163 if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
164 named.example.uncommitted_diff = mem::take(&mut text);
165 } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
166 named.example.edit_history.push_str(&mem::take(&mut text));
167 } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
168 named.example.cursor_path = block_info.into();
169 named.example.cursor_position = mem::take(&mut text);
170 } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
171 named.example.expected_patch = mem::take(&mut text);
172 } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
173 named.example.expected_excerpts.push(ExpectedExcerpt {
174 path: block_info.into(),
175 text: mem::take(&mut text),
176 });
177 } else {
178 eprintln!("Warning: Unrecognized section `{current_section:?}`")
179 }
180 }
181 _ => {}
182 }
183 }
184
185 if named.example.cursor_path.as_path() == Path::new("")
186 || named.example.cursor_position.is_empty()
187 {
188 anyhow::bail!("Missing cursor position codeblock");
189 }
190
191 Ok(named)
192 }
193
194 pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
195 match format {
196 ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
197 ExampleFormat::Toml => {
198 Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
199 }
200 ExampleFormat::Md => Ok(write!(out, "{}", self)?),
201 }
202 }
203
204 #[allow(unused)]
205 pub async fn setup_worktree(&self) -> Result<PathBuf> {
206 let (repo_owner, repo_name) = self.repo_name()?;
207 let file_name = self.file_name();
208
209 let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
210 let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
211 fs::create_dir_all(&repos_dir)?;
212 fs::create_dir_all(&worktrees_dir)?;
213
214 let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
215 if !repo_dir.is_dir() {
216 fs::create_dir_all(&repo_dir)?;
217 run_git(&repo_dir, &["init"]).await?;
218 run_git(
219 &repo_dir,
220 &["remote", "add", "origin", &self.example.repository_url],
221 )
222 .await?;
223 }
224
225 // Resolve the example to a revision, fetching it if needed.
226 let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
227 let revision = if let Ok(revision) = revision {
228 revision
229 } else {
230 run_git(
231 &repo_dir,
232 &["fetch", "--depth", "1", "origin", &self.example.revision],
233 )
234 .await?;
235 let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
236 if revision != self.example.revision {
237 run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
238 }
239 revision
240 };
241
242 // Create the worktree for this example if needed.
243 let worktree_path = worktrees_dir.join(&file_name);
244 if worktree_path.is_dir() {
245 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
246 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
247 run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
248 } else {
249 let worktree_path_string = worktree_path.to_string_lossy();
250 run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
251 run_git(
252 &repo_dir,
253 &["worktree", "add", "-f", &worktree_path_string, &file_name],
254 )
255 .await?;
256 }
257
258 // Apply the uncommitted diff for this example.
259 if !self.example.uncommitted_diff.is_empty() {
260 let mut apply_process = smol::process::Command::new("git")
261 .current_dir(&worktree_path)
262 .args(&["apply", "-"])
263 .stdin(std::process::Stdio::piped())
264 .spawn()?;
265
266 let mut stdin = apply_process.stdin.take().unwrap();
267 stdin
268 .write_all(self.example.uncommitted_diff.as_bytes())
269 .await?;
270 stdin.close().await?;
271 drop(stdin);
272
273 let apply_result = apply_process.output().await?;
274 if !apply_result.status.success() {
275 anyhow::bail!(
276 "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
277 apply_result.status,
278 String::from_utf8_lossy(&apply_result.stderr),
279 String::from_utf8_lossy(&apply_result.stdout),
280 );
281 }
282 }
283
284 Ok(worktree_path)
285 }
286
287 fn file_name(&self) -> String {
288 self.name
289 .chars()
290 .map(|c| {
291 if c.is_whitespace() {
292 '-'
293 } else {
294 c.to_ascii_lowercase()
295 }
296 })
297 .collect()
298 }
299
300 #[allow(unused)]
301 fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
302 // git@github.com:owner/repo.git
303 if self.example.repository_url.contains('@') {
304 let (owner, repo) = self
305 .example
306 .repository_url
307 .split_once(':')
308 .context("expected : in git url")?
309 .1
310 .split_once('/')
311 .context("expected / in git url")?;
312 Ok((
313 Cow::Borrowed(owner),
314 Cow::Borrowed(repo.trim_end_matches(".git")),
315 ))
316 // http://github.com/owner/repo.git
317 } else {
318 let url = Url::parse(&self.example.repository_url)?;
319 let mut segments = url.path_segments().context("empty http url")?;
320 let owner = segments
321 .next()
322 .context("expected owner path segment")?
323 .to_string();
324 let repo = segments
325 .next()
326 .context("expected repo path segment")?
327 .trim_end_matches(".git")
328 .to_string();
329 assert!(segments.next().is_none());
330
331 Ok((owner.into(), repo.into()))
332 }
333 }
334
335 #[must_use]
336 pub async fn apply_edit_history(
337 &self,
338 project: &Entity<Project>,
339 cx: &mut AsyncApp,
340 ) -> Result<HashSet<Entity<Buffer>>> {
341 apply_diff(&self.example.edit_history, project, cx).await
342 }
343}
344
345async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
346 let output = smol::process::Command::new("git")
347 .current_dir(repo_path)
348 .args(args)
349 .output()
350 .await?;
351
352 anyhow::ensure!(
353 output.status.success(),
354 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
355 args.join(" "),
356 repo_path.display(),
357 output.status,
358 String::from_utf8_lossy(&output.stderr),
359 String::from_utf8_lossy(&output.stdout),
360 );
361 Ok(String::from_utf8(output.stdout)?.trim().to_string())
362}
363
364impl Display for NamedExample {
365 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366 write!(f, "# {}\n\n", self.name)?;
367 write!(
368 f,
369 "{REPOSITORY_URL_FIELD} = {}\n",
370 self.example.repository_url
371 )?;
372 write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
373
374 write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
375 write!(f, "`````diff\n")?;
376 write!(f, "{}", self.example.uncommitted_diff)?;
377 write!(f, "`````\n")?;
378
379 if !self.example.edit_history.is_empty() {
380 write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
381 }
382
383 write!(
384 f,
385 "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
386 self.example.cursor_path.display(),
387 self.example.cursor_position
388 )?;
389 write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
390
391 if !self.example.expected_patch.is_empty() {
392 write!(
393 f,
394 "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
395 self.example.expected_patch
396 )?;
397 }
398
399 if !self.example.expected_excerpts.is_empty() {
400 write!(f, "\n## {EXPECTED_EXCERPTS_HEADING}\n\n")?;
401
402 for excerpt in &self.example.expected_excerpts {
403 write!(
404 f,
405 "`````{}{}\n{}`````\n\n",
406 excerpt
407 .path
408 .extension()
409 .map(|ext| format!("{} ", ext.to_string_lossy()))
410 .unwrap_or_default(),
411 excerpt.path.display(),
412 excerpt.text
413 )?;
414 }
415 }
416
417 Ok(())
418 }
419}
420
421#[must_use]
422pub async fn apply_diff(
423 diff: &str,
424 project: &Entity<Project>,
425 cx: &mut AsyncApp,
426) -> Result<HashSet<Entity<Buffer>>> {
427 use cloud_llm_client::udiff::DiffLine;
428 use std::fmt::Write;
429
430 #[derive(Debug, Default)]
431 struct Edit {
432 context: String,
433 deletion_start: Option<usize>,
434 addition: String,
435 }
436
437 let mut old_path = None;
438 let mut new_path = None;
439 let mut pending = Edit::default();
440 let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
441 let mut open_buffers = HashSet::default();
442
443 while let Some(diff_line) = diff_lines.next() {
444 match diff_line {
445 DiffLine::OldPath { path } => {
446 mem::take(&mut pending);
447 old_path = Some(path)
448 }
449 DiffLine::HunkHeader(_) => {
450 mem::take(&mut pending);
451 }
452 DiffLine::NewPath { path } => {
453 if old_path.is_none() {
454 anyhow::bail!(
455 "Found a new path header (`+++`) before an (`---`) old path header"
456 );
457 }
458 new_path = Some(path)
459 }
460 DiffLine::Context(ctx) => {
461 writeln!(&mut pending.context, "{ctx}")?;
462 }
463 DiffLine::Deletion(del) => {
464 pending.deletion_start.get_or_insert(pending.context.len());
465 writeln!(&mut pending.context, "{del}")?;
466 }
467 DiffLine::Addition(add) => {
468 if pending.context.is_empty() {
469 anyhow::bail!("Found an addition before any context or deletion lines");
470 }
471
472 writeln!(&mut pending.addition, "{add}")?;
473 }
474 DiffLine::Garbage => {}
475 }
476
477 let commit_pending = match diff_lines.peek() {
478 Some(DiffLine::OldPath { .. })
479 | Some(DiffLine::HunkHeader(_))
480 | Some(DiffLine::Context(_))
481 | None => {
482 // commit pending edit cluster
483 !pending.addition.is_empty() || pending.deletion_start.is_some()
484 }
485 Some(DiffLine::Deletion(_)) => {
486 // start a new cluster if we have any additions specifically
487 // if we only have deletions, we continue to aggregate them
488 !pending.addition.is_empty()
489 }
490 _ => false,
491 };
492
493 if commit_pending {
494 let edit = mem::take(&mut pending);
495
496 let Some(old_path) = old_path.as_deref() else {
497 anyhow::bail!("Missing old path (`---`) header")
498 };
499
500 let Some(new_path) = new_path.as_deref() else {
501 anyhow::bail!("Missing new path (`+++`) header")
502 };
503
504 let buffer = project
505 .update(cx, |project, cx| {
506 let project_path = project
507 .find_project_path(old_path, cx)
508 .context("Failed to find old_path in project")?;
509
510 anyhow::Ok(project.open_buffer(project_path, cx))
511 })??
512 .await?;
513 open_buffers.insert(buffer.clone());
514
515 if old_path != new_path {
516 project
517 .update(cx, |project, cx| {
518 let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
519 let new_path = ProjectPath {
520 worktree_id: project_file.worktree_id(cx),
521 path: project_file.path.clone(),
522 };
523 project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
524 })?
525 .await?;
526 }
527
528 // TODO is it worth using project search?
529 buffer.update(cx, |buffer, cx| {
530 let text = buffer.text();
531 // todo! check there's only one
532 if let Some(context_offset) = text.find(&edit.context) {
533 let end = context_offset + edit.context.len();
534 let start = if let Some(deletion_start) = edit.deletion_start {
535 context_offset + deletion_start
536 } else {
537 end
538 };
539
540 buffer.edit([(start..end, edit.addition)], None, cx);
541
542 anyhow::Ok(())
543 } else {
544 anyhow::bail!("Failed to match context:\n{}", edit.context);
545 }
546 })??;
547 }
548 }
549
550 anyhow::Ok(open_buffers)
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556 use fs::FakeFs;
557 use gpui::TestAppContext;
558 use indoc::indoc;
559 use project::Project;
560 use serde_json::json;
561
562 #[gpui::test]
563 async fn test_apply_diff(cx: &mut TestAppContext) {
564 let buffer_1_text = indoc! {r#"
565 one
566 two
567 three
568 four
569 five
570 "# };
571
572 let buffer_2_text = indoc! {r#"
573 six
574 seven
575 eight
576 nine
577 ten
578 "# };
579
580 let fs = FakeFs::new(cx.background_executor().clone());
581 fs.insert_tree(
582 "/root",
583 json!({
584 "file1": buffer_1_text,
585 "file2": buffer_2_text,
586 }),
587 )
588 .await;
589
590 let project = Project::test(fs, ["/root".as_ref()], cx).await;
591
592 let diff = indoc! {r#"
593 --- a/root/file1
594 +++ b/root/file1
595 one
596 two
597 -three
598 +3
599 four
600 five
601 "#};
602
603 let _buffers = apply_diff(diff, &project, &mut cx.to_async())
604 .await
605 .unwrap();
606 let buffer_1 = project
607 .update(cx, |project, cx| {
608 let project_path = project.find_project_path("/root/file1", cx).unwrap();
609 project.open_buffer(project_path, cx)
610 })?
611 .await?;
612
613 buffer_1.read_with(cx, |buffer, cx| {
614 pretty_assertions::assert_eq!(buffer.text())
615 })
616 }
617}