1use crate::{
2 PredictionProvider, PromptFormat,
3 metrics::ClassificationMetrics,
4 paths::{REPOS_DIR, WORKTREES_DIR},
5};
6use anyhow::{Context as _, Result};
7use edit_prediction::udiff::OpenedBuffers;
8use gpui::Entity;
9use http_client::Url;
10use language::{Anchor, Buffer};
11use project::Project;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::{
15 borrow::Cow,
16 io::{Read, Write},
17 mem,
18 path::{Path, PathBuf},
19};
20use zeta_prompt::RelatedFile;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct Example {
24 #[serde(default)]
25 pub name: String,
26 pub repository_url: String,
27 pub revision: String,
28 #[serde(default)]
29 pub uncommitted_diff: String,
30 pub cursor_path: Arc<Path>,
31 pub cursor_position: String,
32 pub edit_history: String,
33 pub expected_patch: String,
34
35 /// The full content of the file where an edit is being predicted, and the
36 /// actual cursor offset.
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub buffer: Option<ExampleBuffer>,
39
40 /// The context retrieved for the prediction. This requires the worktree to
41 /// be loaded and the language server to be started.
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub context: Option<ExampleContext>,
44
45 /// The input and expected output from the edit prediction model.
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub prompt: Option<ExamplePrompt>,
48
49 /// The actual predictions from the model.
50 #[serde(default, skip_serializing_if = "Vec::is_empty")]
51 pub predictions: Vec<ExamplePrediction>,
52
53 /// The scores, for how well the actual predictions match the expected
54 /// predictions.
55 #[serde(default, skip_serializing_if = "Vec::is_empty")]
56 pub score: Vec<ExampleScore>,
57
58 /// The application state used to process this example.
59 #[serde(skip)]
60 pub state: Option<ExampleState>,
61}
62
63#[derive(Clone, Debug)]
64pub struct ExampleState {
65 pub project: Entity<Project>,
66 pub buffer: Entity<Buffer>,
67 pub cursor_position: Anchor,
68 pub _open_buffers: OpenedBuffers,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
72pub struct ExampleContext {
73 pub files: Arc<[RelatedFile]>,
74}
75
76#[derive(Clone, Debug, Serialize, Deserialize)]
77pub struct ExampleBuffer {
78 pub content: String,
79 pub cursor_row: u32,
80 pub cursor_column: u32,
81 pub cursor_offset: usize,
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize)]
85pub struct ExamplePrompt {
86 pub input: String,
87 pub expected_output: String,
88 pub format: PromptFormat,
89}
90
91#[derive(Clone, Debug, Serialize, Deserialize)]
92pub struct ExamplePrediction {
93 pub actual_patch: String,
94 pub actual_output: String,
95 pub provider: PredictionProvider,
96}
97
98#[derive(Clone, Debug, Serialize, Deserialize)]
99pub struct ExampleScore {
100 pub delta_chr_f: f32,
101 pub line_match: ClassificationMetrics,
102}
103
104impl Example {
105 fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
106 // git@github.com:owner/repo.git
107 if self.repository_url.contains('@') {
108 let (owner, repo) = self
109 .repository_url
110 .split_once(':')
111 .context("expected : in git url")?
112 .1
113 .split_once('/')
114 .context("expected / in git url")?;
115 Ok((
116 Cow::Borrowed(owner),
117 Cow::Borrowed(repo.trim_end_matches(".git")),
118 ))
119 // http://github.com/owner/repo.git
120 } else {
121 let url = Url::parse(&self.repository_url)?;
122 let mut segments = url.path_segments().context("empty http url")?;
123 let owner = segments
124 .next()
125 .context("expected owner path segment")?
126 .to_string();
127 let repo = segments
128 .next()
129 .context("expected repo path segment")?
130 .trim_end_matches(".git")
131 .to_string();
132 assert!(segments.next().is_none());
133
134 Ok((owner.into(), repo.into()))
135 }
136 }
137
138 pub fn worktree_path(&self) -> PathBuf {
139 WORKTREES_DIR
140 .join(&self.name)
141 .join(self.repo_name().unwrap().1.as_ref())
142 }
143
144 pub fn repo_path(&self) -> PathBuf {
145 let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
146 REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
147 }
148}
149
150pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
151 let mut examples = Vec::new();
152
153 let stdin_path: PathBuf = PathBuf::from("-");
154
155 let inputs = if inputs.is_empty() {
156 &[stdin_path]
157 } else {
158 inputs
159 };
160
161 for path in inputs {
162 let is_stdin = path.as_path() == Path::new("-");
163 let content = if is_stdin {
164 let mut buffer = String::new();
165 std::io::stdin()
166 .read_to_string(&mut buffer)
167 .expect("Failed to read from stdin");
168 buffer
169 } else {
170 std::fs::read_to_string(path)
171 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
172 };
173 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
174 let ext = if !is_stdin {
175 path.extension()
176 .map(|ext| ext.to_string_lossy().to_string())
177 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
178 } else {
179 "jsonl".to_string()
180 };
181
182 match ext.as_ref() {
183 "json" => {
184 let mut example =
185 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
186 panic!("Failed to parse example file: {}\n{error}", path.display())
187 });
188 if example.name.is_empty() {
189 example.name = filename;
190 }
191 examples.push(example);
192 }
193 "jsonl" => examples.extend(
194 content
195 .lines()
196 .enumerate()
197 .map(|(line_ix, line)| {
198 let mut example =
199 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
200 panic!(
201 "Failed to parse example on {}:{}\n{error}",
202 path.display(),
203 line_ix + 1
204 )
205 });
206 if example.name.is_empty() {
207 example.name = format!("{filename}-{line_ix}")
208 }
209 example
210 })
211 .collect::<Vec<Example>>(),
212 ),
213 "md" => {
214 examples.push(parse_markdown_example(filename, &content).unwrap());
215 }
216 ext => {
217 panic!("{} has invalid example extension `{ext}`", path.display())
218 }
219 }
220 }
221 examples
222}
223
224pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
225 let mut content = String::new();
226 for example in examples {
227 let line = serde_json::to_string(example).unwrap();
228 content.push_str(&line);
229 content.push('\n');
230 }
231 if let Some(output_path) = output_path {
232 std::fs::write(output_path, content).expect("Failed to write examples");
233 } else {
234 std::io::stdout().write_all(&content.as_bytes()).unwrap();
235 }
236}
237
238fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
239 use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
240
241 const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
242 const EDIT_HISTORY_HEADING: &str = "Edit History";
243 const CURSOR_POSITION_HEADING: &str = "Cursor Position";
244 const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
245 const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
246 const REPOSITORY_URL_FIELD: &str = "repository_url";
247 const REVISION_FIELD: &str = "revision";
248
249 let parser = Parser::new(input);
250
251 let mut example = Example {
252 name: id,
253 repository_url: String::new(),
254 revision: String::new(),
255 uncommitted_diff: String::new(),
256 cursor_path: PathBuf::new().into(),
257 cursor_position: String::new(),
258 edit_history: String::new(),
259 expected_patch: String::new(),
260 buffer: None,
261 context: None,
262 prompt: None,
263 predictions: Vec::new(),
264 score: Vec::new(),
265 state: None,
266 };
267
268 let mut text = String::new();
269 let mut block_info: CowStr = "".into();
270
271 #[derive(PartialEq)]
272 enum Section {
273 Start,
274 UncommittedDiff,
275 EditHistory,
276 CursorPosition,
277 ExpectedExcerpts,
278 ExpectedPatch,
279 Other,
280 }
281
282 let mut current_section = Section::Start;
283
284 for event in parser {
285 match event {
286 Event::Text(line) => {
287 text.push_str(&line);
288
289 if let Section::Start = current_section
290 && let Some((field, value)) = line.split_once('=')
291 {
292 match field.trim() {
293 REPOSITORY_URL_FIELD => {
294 example.repository_url = value.trim().to_string();
295 }
296 REVISION_FIELD => {
297 example.revision = value.trim().to_string();
298 }
299 _ => {}
300 }
301 }
302 }
303 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
304 let title = mem::take(&mut text);
305 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
306 Section::UncommittedDiff
307 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
308 Section::EditHistory
309 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
310 Section::CursorPosition
311 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
312 Section::ExpectedPatch
313 } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
314 Section::ExpectedExcerpts
315 } else {
316 Section::Other
317 };
318 }
319 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
320 mem::take(&mut text);
321 }
322 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
323 mem::take(&mut text);
324 }
325 Event::End(TagEnd::Heading(level)) => {
326 anyhow::bail!("Unexpected heading level: {level}");
327 }
328 Event::Start(Tag::CodeBlock(kind)) => {
329 match kind {
330 CodeBlockKind::Fenced(info) => {
331 block_info = info;
332 }
333 CodeBlockKind::Indented => {
334 anyhow::bail!("Unexpected indented codeblock");
335 }
336 };
337 }
338 Event::Start(_) => {
339 text.clear();
340 block_info = "".into();
341 }
342 Event::End(TagEnd::CodeBlock) => {
343 let block_info = block_info.trim();
344 match current_section {
345 Section::UncommittedDiff => {
346 example.uncommitted_diff = mem::take(&mut text);
347 }
348 Section::EditHistory => {
349 example.edit_history.push_str(&mem::take(&mut text));
350 }
351 Section::CursorPosition => {
352 example.cursor_path = Path::new(block_info).into();
353 example.cursor_position = mem::take(&mut text);
354 }
355 Section::ExpectedExcerpts => {
356 mem::take(&mut text);
357 }
358 Section::ExpectedPatch => {
359 example.expected_patch = mem::take(&mut text);
360 }
361 Section::Start | Section::Other => {}
362 }
363 }
364 _ => {}
365 }
366 }
367 if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
368 anyhow::bail!("Missing cursor position codeblock");
369 }
370
371 Ok(example)
372}