1use crate::{
2 PredictionProvider, PromptFormat,
3 metrics::ClassificationMetrics,
4 paths::{REPOS_DIR, WORKTREES_DIR},
5};
6use anyhow::{Context as _, Result};
7use edit_prediction::udiff::OpenedBuffers;
8use gpui::Entity;
9use http_client::Url;
10use language::{Anchor, Buffer};
11use project::Project;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::{
15 borrow::Cow,
16 io::{Read, Write},
17 mem,
18 path::{Path, PathBuf},
19};
20use zeta_prompt::RelatedFile;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct Example {
24 #[serde(default)]
25 pub name: String,
26 pub repository_url: String,
27 pub revision: String,
28 pub uncommitted_diff: String,
29 pub cursor_path: Arc<Path>,
30 pub cursor_position: String,
31 pub edit_history: String,
32 pub expected_patch: String,
33
34 /// The full content of the file where an edit is being predicted, and the
35 /// actual cursor offset.
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub buffer: Option<ExampleBuffer>,
38
39 /// The context retrieved for the prediction. This requires the worktree to
40 /// be loaded and the language server to be started.
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub context: Option<ExampleContext>,
43
44 /// The input and expected output from the edit prediction model.
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub prompt: Option<ExamplePrompt>,
47
48 /// The actual predictions from the model.
49 #[serde(default, skip_serializing_if = "Vec::is_empty")]
50 pub predictions: Vec<ExamplePrediction>,
51
52 /// The scores, for how well the actual predictions match the expected
53 /// predictions.
54 #[serde(default, skip_serializing_if = "Vec::is_empty")]
55 pub score: Vec<ExampleScore>,
56
57 /// The application state used to process this example.
58 #[serde(skip)]
59 pub state: Option<ExampleState>,
60}
61
62#[derive(Clone, Debug)]
63pub struct ExampleState {
64 pub project: Entity<Project>,
65 pub buffer: Entity<Buffer>,
66 pub cursor_position: Anchor,
67 pub _open_buffers: OpenedBuffers,
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct ExampleContext {
72 pub files: Arc<[RelatedFile]>,
73}
74
75#[derive(Clone, Debug, Serialize, Deserialize)]
76pub struct ExampleBuffer {
77 pub content: String,
78 pub cursor_row: u32,
79 pub cursor_column: u32,
80 pub cursor_offset: usize,
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize)]
84pub struct ExamplePrompt {
85 pub input: String,
86 pub expected_output: String,
87 pub format: PromptFormat,
88}
89
90#[derive(Clone, Debug, Serialize, Deserialize)]
91pub struct ExamplePrediction {
92 pub actual_patch: String,
93 pub actual_output: String,
94 pub provider: PredictionProvider,
95}
96
97#[derive(Clone, Debug, Serialize, Deserialize)]
98pub struct ExampleScore {
99 pub delta_chr_f: f32,
100 pub line_match: ClassificationMetrics,
101}
102
103impl Example {
104 fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
105 // git@github.com:owner/repo.git
106 if self.repository_url.contains('@') {
107 let (owner, repo) = self
108 .repository_url
109 .split_once(':')
110 .context("expected : in git url")?
111 .1
112 .split_once('/')
113 .context("expected / in git url")?;
114 Ok((
115 Cow::Borrowed(owner),
116 Cow::Borrowed(repo.trim_end_matches(".git")),
117 ))
118 // http://github.com/owner/repo.git
119 } else {
120 let url = Url::parse(&self.repository_url)?;
121 let mut segments = url.path_segments().context("empty http url")?;
122 let owner = segments
123 .next()
124 .context("expected owner path segment")?
125 .to_string();
126 let repo = segments
127 .next()
128 .context("expected repo path segment")?
129 .trim_end_matches(".git")
130 .to_string();
131 assert!(segments.next().is_none());
132
133 Ok((owner.into(), repo.into()))
134 }
135 }
136
137 pub fn worktree_path(&self) -> PathBuf {
138 WORKTREES_DIR
139 .join(&self.name)
140 .join(self.repo_name().unwrap().1.as_ref())
141 }
142
143 pub fn repo_path(&self) -> PathBuf {
144 let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
145 REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
146 }
147}
148
149pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
150 let mut examples = Vec::new();
151
152 let stdin_path: PathBuf = PathBuf::from("-");
153
154 let inputs = if inputs.is_empty() {
155 &[stdin_path]
156 } else {
157 inputs
158 };
159
160 for path in inputs {
161 let is_stdin = path.as_path() == Path::new("-");
162 let content = if is_stdin {
163 let mut buffer = String::new();
164 std::io::stdin()
165 .read_to_string(&mut buffer)
166 .expect("Failed to read from stdin");
167 buffer
168 } else {
169 std::fs::read_to_string(path)
170 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
171 };
172 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
173 let ext = if !is_stdin {
174 path.extension()
175 .map(|ext| ext.to_string_lossy().to_string())
176 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
177 } else {
178 "jsonl".to_string()
179 };
180
181 match ext.as_ref() {
182 "json" => {
183 let mut example =
184 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
185 panic!("Failed to parse example file: {}\n{error}", path.display())
186 });
187 if example.name.is_empty() {
188 example.name = filename;
189 }
190 examples.push(example);
191 }
192 "jsonl" => examples.extend(
193 content
194 .lines()
195 .enumerate()
196 .map(|(line_ix, line)| {
197 let mut example =
198 serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
199 panic!(
200 "Failed to parse example on {}:{}",
201 path.display(),
202 line_ix + 1
203 )
204 });
205 if example.name.is_empty() {
206 example.name = format!("{filename}-{line_ix}")
207 }
208 example
209 })
210 .collect::<Vec<Example>>(),
211 ),
212 "md" => {
213 examples.push(parse_markdown_example(filename, &content).unwrap());
214 }
215 ext => {
216 panic!("{} has invalid example extension `{ext}`", path.display())
217 }
218 }
219 }
220 examples
221}
222
223pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
224 let mut content = String::new();
225 for example in examples {
226 let line = serde_json::to_string(example).unwrap();
227 content.push_str(&line);
228 content.push('\n');
229 }
230 if let Some(output_path) = output_path {
231 std::fs::write(output_path, content).expect("Failed to write examples");
232 } else {
233 std::io::stdout().write_all(&content.as_bytes()).unwrap();
234 }
235}
236
237fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
238 use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
239
240 const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
241 const EDIT_HISTORY_HEADING: &str = "Edit History";
242 const CURSOR_POSITION_HEADING: &str = "Cursor Position";
243 const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
244 const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
245 const REPOSITORY_URL_FIELD: &str = "repository_url";
246 const REVISION_FIELD: &str = "revision";
247
248 let parser = Parser::new(input);
249
250 let mut example = Example {
251 name: id,
252 repository_url: String::new(),
253 revision: String::new(),
254 uncommitted_diff: String::new(),
255 cursor_path: PathBuf::new().into(),
256 cursor_position: String::new(),
257 edit_history: String::new(),
258 expected_patch: String::new(),
259 buffer: None,
260 context: None,
261 prompt: None,
262 predictions: Vec::new(),
263 score: Vec::new(),
264 state: None,
265 };
266
267 let mut name = String::new();
268 let mut text = String::new();
269 let mut block_info: CowStr = "".into();
270
271 #[derive(PartialEq)]
272 enum Section {
273 UncommittedDiff,
274 EditHistory,
275 CursorPosition,
276 ExpectedExcerpts,
277 ExpectedPatch,
278 Other,
279 }
280
281 let mut current_section = Section::Other;
282
283 for event in parser {
284 match event {
285 Event::Text(line) => {
286 text.push_str(&line);
287
288 if let Some((field, value)) = line.split_once('=') {
289 match field.trim() {
290 REPOSITORY_URL_FIELD => {
291 example.repository_url = value.trim().to_string();
292 }
293 REVISION_FIELD => {
294 example.revision = value.trim().to_string();
295 }
296 _ => {}
297 }
298 }
299 }
300 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
301 if !name.is_empty() {
302 anyhow::bail!(
303 "Found multiple H1 headings. There should only be one with the name of the example."
304 );
305 }
306 name = mem::take(&mut text);
307 }
308 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
309 let title = mem::take(&mut text);
310 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
311 Section::UncommittedDiff
312 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
313 Section::EditHistory
314 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
315 Section::CursorPosition
316 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
317 Section::ExpectedPatch
318 } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
319 Section::ExpectedExcerpts
320 } else {
321 Section::Other
322 };
323 }
324 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
325 mem::take(&mut text);
326 }
327 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
328 mem::take(&mut text);
329 }
330 Event::End(TagEnd::Heading(level)) => {
331 anyhow::bail!("Unexpected heading level: {level}");
332 }
333 Event::Start(Tag::CodeBlock(kind)) => {
334 match kind {
335 CodeBlockKind::Fenced(info) => {
336 block_info = info;
337 }
338 CodeBlockKind::Indented => {
339 anyhow::bail!("Unexpected indented codeblock");
340 }
341 };
342 }
343 Event::Start(_) => {
344 text.clear();
345 block_info = "".into();
346 }
347 Event::End(TagEnd::CodeBlock) => {
348 let block_info = block_info.trim();
349 match current_section {
350 Section::UncommittedDiff => {
351 example.uncommitted_diff = mem::take(&mut text);
352 }
353 Section::EditHistory => {
354 example.edit_history.push_str(&mem::take(&mut text));
355 }
356 Section::CursorPosition => {
357 example.cursor_path = Path::new(block_info).into();
358 example.cursor_position = mem::take(&mut text);
359 }
360 Section::ExpectedExcerpts => {
361 mem::take(&mut text);
362 }
363 Section::ExpectedPatch => {
364 example.expected_patch = mem::take(&mut text);
365 }
366 Section::Other => {}
367 }
368 }
369 _ => {}
370 }
371 }
372 if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
373 anyhow::bail!("Missing cursor position codeblock");
374 }
375
376 Ok(example)
377}