1use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use edit_prediction::udiff::OpenedBuffers;
5use gpui::Entity;
6use http_client::Url;
7use language::{Anchor, Buffer};
8use project::Project;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use std::{
12 borrow::Cow,
13 io::{Read, Write},
14 mem,
15 path::{Path, PathBuf},
16};
17use zeta_prompt::RelatedFile;
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct Example {
21 #[serde(default)]
22 pub name: String,
23 pub repository_url: String,
24 pub revision: String,
25 #[serde(default)]
26 pub uncommitted_diff: String,
27 pub cursor_path: Arc<Path>,
28 pub cursor_position: String,
29 pub edit_history: String,
30 pub expected_patch: String,
31
32 /// The full content of the file where an edit is being predicted, and the
33 /// actual cursor offset.
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub buffer: Option<ExampleBuffer>,
36
37 /// The context retrieved for the prediction. This requires the worktree to
38 /// be loaded and the language server to be started.
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub context: Option<ExampleContext>,
41
42 /// The input and expected output from the edit prediction model.
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub prompt: Option<ExamplePrompt>,
45
46 /// The actual predictions from the model.
47 #[serde(default, skip_serializing_if = "Vec::is_empty")]
48 pub predictions: Vec<ExamplePrediction>,
49
50 /// The scores, for how well the actual predictions match the expected
51 /// predictions.
52 #[serde(default, skip_serializing_if = "Vec::is_empty")]
53 pub score: Vec<ExampleScore>,
54
55 /// The application state used to process this example.
56 #[serde(skip)]
57 pub state: Option<ExampleState>,
58}
59
60#[derive(Clone, Debug)]
61pub struct ExampleState {
62 pub project: Entity<Project>,
63 pub buffer: Entity<Buffer>,
64 pub cursor_position: Anchor,
65 pub _open_buffers: OpenedBuffers,
66}
67
68#[derive(Clone, Debug, Serialize, Deserialize)]
69pub struct ExampleContext {
70 pub files: Arc<[RelatedFile]>,
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
74pub struct ExampleBuffer {
75 pub content: String,
76 pub cursor_row: u32,
77 pub cursor_column: u32,
78 pub cursor_offset: usize,
79}
80
81#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct ExamplePrompt {
83 pub input: String,
84 pub expected_output: String,
85 pub format: PromptFormat,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize)]
89pub struct ExamplePrediction {
90 pub actual_patch: String,
91 pub actual_output: String,
92 pub provider: PredictionProvider,
93}
94
95#[derive(Clone, Debug, Serialize, Deserialize)]
96pub struct ExampleScore {
97 pub delta_chr_f: f32,
98 pub line_match: ClassificationMetrics,
99}
100
101impl Example {
102 pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
103 // git@github.com:owner/repo.git
104 if self.repository_url.contains('@') {
105 let (owner, repo) = self
106 .repository_url
107 .split_once(':')
108 .context("expected : in git url")?
109 .1
110 .split_once('/')
111 .context("expected / in git url")?;
112 Ok((
113 Cow::Borrowed(owner),
114 Cow::Borrowed(repo.trim_end_matches(".git")),
115 ))
116 // http://github.com/owner/repo.git
117 } else {
118 let url = Url::parse(&self.repository_url)?;
119 let mut segments = url.path_segments().context("empty http url")?;
120 let owner = segments
121 .next()
122 .context("expected owner path segment")?
123 .to_string();
124 let repo = segments
125 .next()
126 .context("expected repo path segment")?
127 .trim_end_matches(".git")
128 .to_string();
129 assert!(segments.next().is_none());
130
131 Ok((owner.into(), repo.into()))
132 }
133 }
134}
135
136pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
137 let mut examples = Vec::new();
138
139 let stdin_path: PathBuf = PathBuf::from("-");
140
141 let inputs = if inputs.is_empty() {
142 &[stdin_path]
143 } else {
144 inputs
145 };
146
147 for path in inputs {
148 let is_stdin = path.as_path() == Path::new("-");
149 let content = if is_stdin {
150 let mut buffer = String::new();
151 std::io::stdin()
152 .read_to_string(&mut buffer)
153 .expect("Failed to read from stdin");
154 buffer
155 } else {
156 std::fs::read_to_string(path)
157 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
158 };
159 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
160 let ext = if !is_stdin {
161 path.extension()
162 .map(|ext| ext.to_string_lossy().to_string())
163 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
164 } else {
165 "jsonl".to_string()
166 };
167
168 match ext.as_ref() {
169 "json" => {
170 let mut example =
171 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
172 panic!("Failed to parse example file: {}\n{error}", path.display())
173 });
174 if example.name.is_empty() {
175 example.name = filename;
176 }
177 examples.push(example);
178 }
179 "jsonl" => examples.extend(
180 content
181 .lines()
182 .enumerate()
183 .map(|(line_ix, line)| {
184 let mut example =
185 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
186 panic!(
187 "Failed to parse example on {}:{}\n{error}",
188 path.display(),
189 line_ix + 1
190 )
191 });
192 if example.name.is_empty() {
193 example.name = format!("{filename}-{line_ix}")
194 }
195 example
196 })
197 .collect::<Vec<Example>>(),
198 ),
199 "md" => {
200 examples.push(parse_markdown_example(filename, &content).unwrap());
201 }
202 ext => {
203 panic!("{} has invalid example extension `{ext}`", path.display())
204 }
205 }
206 }
207
208 sort_examples_by_repo_and_rev(&mut examples);
209 examples
210}
211
212pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
213 let mut content = String::new();
214 for example in examples {
215 let line = serde_json::to_string(example).unwrap();
216 content.push_str(&line);
217 content.push('\n');
218 }
219 if let Some(output_path) = output_path {
220 std::fs::write(output_path, content).expect("Failed to write examples");
221 } else {
222 std::io::stdout().write_all(&content.as_bytes()).unwrap();
223 }
224}
225
226pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
227 examples.sort_by(|a, b| {
228 a.repository_url
229 .cmp(&b.repository_url)
230 .then(b.revision.cmp(&a.revision))
231 });
232}
233
234pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
235 let mut examples_by_repo = HashMap::default();
236 for example in examples.iter_mut() {
237 examples_by_repo
238 .entry(example.repository_url.clone())
239 .or_insert_with(Vec::new)
240 .push(example);
241 }
242 examples_by_repo.into_values().collect()
243}
244
245fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
246 use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
247
248 const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
249 const EDIT_HISTORY_HEADING: &str = "Edit History";
250 const CURSOR_POSITION_HEADING: &str = "Cursor Position";
251 const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
252 const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
253 const REPOSITORY_URL_FIELD: &str = "repository_url";
254 const REVISION_FIELD: &str = "revision";
255
256 let parser = Parser::new(input);
257
258 let mut example = Example {
259 name: id,
260 repository_url: String::new(),
261 revision: String::new(),
262 uncommitted_diff: String::new(),
263 cursor_path: PathBuf::new().into(),
264 cursor_position: String::new(),
265 edit_history: String::new(),
266 expected_patch: String::new(),
267 buffer: None,
268 context: None,
269 prompt: None,
270 predictions: Vec::new(),
271 score: Vec::new(),
272 state: None,
273 };
274
275 let mut text = String::new();
276 let mut block_info: CowStr = "".into();
277
278 #[derive(PartialEq)]
279 enum Section {
280 Start,
281 UncommittedDiff,
282 EditHistory,
283 CursorPosition,
284 ExpectedExcerpts,
285 ExpectedPatch,
286 Other,
287 }
288
289 let mut current_section = Section::Start;
290
291 for event in parser {
292 match event {
293 Event::Text(line) => {
294 text.push_str(&line);
295
296 if let Section::Start = current_section
297 && let Some((field, value)) = line.split_once('=')
298 {
299 match field.trim() {
300 REPOSITORY_URL_FIELD => {
301 example.repository_url = value.trim().to_string();
302 }
303 REVISION_FIELD => {
304 example.revision = value.trim().to_string();
305 }
306 _ => {}
307 }
308 }
309 }
310 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
311 let title = mem::take(&mut text);
312 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
313 Section::UncommittedDiff
314 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
315 Section::EditHistory
316 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
317 Section::CursorPosition
318 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
319 Section::ExpectedPatch
320 } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
321 Section::ExpectedExcerpts
322 } else {
323 Section::Other
324 };
325 }
326 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
327 mem::take(&mut text);
328 }
329 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
330 mem::take(&mut text);
331 }
332 Event::End(TagEnd::Heading(level)) => {
333 anyhow::bail!("Unexpected heading level: {level}");
334 }
335 Event::Start(Tag::CodeBlock(kind)) => {
336 match kind {
337 CodeBlockKind::Fenced(info) => {
338 block_info = info;
339 }
340 CodeBlockKind::Indented => {
341 anyhow::bail!("Unexpected indented codeblock");
342 }
343 };
344 }
345 Event::Start(_) => {
346 text.clear();
347 block_info = "".into();
348 }
349 Event::End(TagEnd::CodeBlock) => {
350 let block_info = block_info.trim();
351 match current_section {
352 Section::UncommittedDiff => {
353 example.uncommitted_diff = mem::take(&mut text);
354 }
355 Section::EditHistory => {
356 example.edit_history.push_str(&mem::take(&mut text));
357 }
358 Section::CursorPosition => {
359 example.cursor_path = Path::new(block_info).into();
360 example.cursor_position = mem::take(&mut text);
361 }
362 Section::ExpectedExcerpts => {
363 mem::take(&mut text);
364 }
365 Section::ExpectedPatch => {
366 example.expected_patch = mem::take(&mut text);
367 }
368 Section::Start | Section::Other => {}
369 }
370 }
371 _ => {}
372 }
373 }
374 if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
375 anyhow::bail!("Missing cursor position codeblock");
376 }
377
378 Ok(example)
379}