1use crate::PredictionProvider;
2use crate::metrics::ClassificationMetrics;
3use crate::paths::WORKTREES_DIR;
4use crate::qa::QaResult;
5use anyhow::{Context as _, Result};
6use collections::HashMap;
7use edit_prediction::example_spec::ExampleSpec;
8use edit_prediction::udiff::OpenedBuffers;
9use gpui::Entity;
10use http_client::Url;
11use language::{Anchor, Buffer};
12use project::Project;
13use serde::{Deserialize, Serialize};
14use std::{
15 borrow::Cow,
16 collections::VecDeque,
17 io::Read,
18 path::{Path, PathBuf},
19};
20use zeta_prompt::ZetaPromptInput;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct Example {
24 #[serde(flatten)]
25 pub spec: ExampleSpec,
26
27 /// The full content of the file where an edit is being predicted, and the
28 /// actual cursor offset.
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub prompt_inputs: Option<ZetaPromptInput>,
31
32 /// The input and expected output from the edit prediction model.
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub prompt: Option<ExamplePrompt>,
35
36 /// The actual predictions from the model.
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
38 pub predictions: Vec<ExamplePrediction>,
39
40 /// The scores, for how well the actual predictions match the expected
41 /// predictions.
42 #[serde(default, skip_serializing_if = "Vec::is_empty")]
43 pub score: Vec<ExampleScore>,
44
45 /// QA evaluation results for each prediction (indexed parallel to `predictions`).
46 #[serde(default, skip_serializing_if = "Vec::is_empty")]
47 pub qa: Vec<Option<QaResult>>,
48
49 /// The Zed version used to generate this example.
50 pub zed_version: Option<String>,
51
52 /// The application state used to process this example.
53 #[serde(skip)]
54 pub state: Option<ExampleState>,
55}
56
57#[derive(Clone, Debug)]
58pub struct ExampleState {
59 pub project: Entity<Project>,
60 pub buffer: Entity<Buffer>,
61 pub cursor_position: Anchor,
62 pub _open_buffers: OpenedBuffers,
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ExamplePrompt {
67 pub input: String,
68 #[serde(default)]
69 pub expected_output: Option<String>,
70 pub rejected_output: Option<String>, // For DPO
71 #[serde(default)]
72 pub prefill: Option<String>,
73 pub provider: PredictionProvider,
74}
75
76#[derive(Clone, Debug, Serialize, Deserialize)]
77pub struct ExamplePrediction {
78 #[serde(default, skip_serializing_if = "Option::is_none")]
79 pub actual_patch: Option<String>,
80 #[serde(deserialize_with = "deserialize_null_as_empty_string")]
81 pub actual_output: String,
82 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub actual_cursor: Option<ActualCursor>,
84 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub error: Option<String>,
86 pub provider: PredictionProvider,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub cumulative_logprob: Option<f64>,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub avg_logprob: Option<f64>,
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub struct ActualCursor {
95 pub path: String,
96 pub row: u32,
97 pub column: u32,
98 pub offset: usize,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
100 pub editable_region_offset: Option<usize>,
101}
102
103impl ActualCursor {
104 /// Construct an `ActualCursor` from a cursor offset within the new editable region.
105 ///
106 /// - `path`: file path the cursor is in
107 /// - `editable_region_cursor_offset`: byte offset of the cursor within the new editable region text
108 /// - `new_editable_region`: the full new editable region text (after marker removal)
109 /// - `content`: the full file content (before the edit)
110 /// - `editable_region_byte_offset`: byte offset where the editable region starts in `content`
111 /// - `editable_region_start_line`: 0-based line number where the editable region starts in `content`
112 pub fn from_editable_region(
113 path: &std::path::Path,
114 editable_region_cursor_offset: usize,
115 new_editable_region: &str,
116 content: &str,
117 editable_region_byte_offset: usize,
118 editable_region_start_line: usize,
119 ) -> Self {
120 let global_offset = editable_region_byte_offset + editable_region_cursor_offset;
121 let new_region_prefix = &new_editable_region[..editable_region_cursor_offset];
122 let row = (editable_region_start_line + new_region_prefix.matches('\n').count()) as u32;
123 let column = match new_region_prefix.rfind('\n') {
124 Some(pos) => (editable_region_cursor_offset - pos - 1) as u32,
125 None => {
126 let content_prefix = &content[..editable_region_byte_offset];
127 let content_column = match content_prefix.rfind('\n') {
128 Some(pos) => editable_region_byte_offset - pos - 1,
129 None => editable_region_byte_offset,
130 };
131 (content_column + editable_region_cursor_offset) as u32
132 }
133 };
134 ActualCursor {
135 path: path.to_string_lossy().to_string(),
136 row,
137 column,
138 offset: global_offset,
139 editable_region_offset: Some(editable_region_cursor_offset),
140 }
141 }
142}
143
144fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
145where
146 D: serde::Deserializer<'de>,
147{
148 let opt = Option::<String>::deserialize(deserializer)?;
149 Ok(opt.unwrap_or_default())
150}
151
152#[derive(Clone, Debug, Serialize, Deserialize)]
153pub struct ExampleScore {
154 pub delta_chr_f: f32,
155 #[serde(default)]
156 pub delta_chr_f_true_positives: usize,
157 #[serde(default)]
158 pub delta_chr_f_false_positives: usize,
159 #[serde(default)]
160 pub delta_chr_f_false_negatives: usize,
161 #[serde(default)]
162 pub delta_chr_f_precision: f64,
163 #[serde(default)]
164 pub delta_chr_f_recall: f64,
165 #[serde(default)]
166 pub delta_chr_f_beta: f64,
167 pub braces_disbalance: usize,
168 #[serde(default)]
169 pub exact_lines_tp: usize,
170 #[serde(default)]
171 pub exact_lines_fp: usize,
172 #[serde(default)]
173 pub exact_lines_fn: usize,
174 #[serde(default)]
175 pub reversal_ratio: f32,
176 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub cursor_distance: Option<usize>,
178 #[serde(default, skip_serializing_if = "Option::is_none")]
179 pub cursor_exact_match: Option<bool>,
180 pub wrong_editable_region: Option<bool>,
181 #[serde(default)]
182 pub has_isolated_whitespace_changes: bool,
183 #[serde(default)]
184 pub inserted_tokens: usize,
185 #[serde(default)]
186 pub deleted_tokens: usize,
187 #[serde(default, skip_serializing_if = "Option::is_none")]
188 pub kept_rate: Option<f64>,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub recall_rate: Option<f64>,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub kept_chars: Option<usize>,
193 #[serde(default, skip_serializing_if = "Option::is_none")]
194 pub correctly_deleted_chars: Option<usize>,
195 #[serde(default, skip_serializing_if = "Option::is_none")]
196 pub discarded_chars: Option<usize>,
197 #[serde(default, skip_serializing_if = "Option::is_none")]
198 pub cumulative_logprob: Option<f64>,
199 #[serde(default, skip_serializing_if = "Option::is_none")]
200 pub avg_logprob: Option<f64>,
201}
202
203impl ExampleScore {
204 pub fn delta_chr_f_counts(&self) -> ClassificationMetrics {
205 ClassificationMetrics {
206 true_positives: self.delta_chr_f_true_positives,
207 false_positives: self.delta_chr_f_false_positives,
208 false_negatives: self.delta_chr_f_false_negatives,
209 }
210 }
211
212 pub fn exact_lines_counts(&self) -> ClassificationMetrics {
213 ClassificationMetrics {
214 true_positives: self.exact_lines_tp,
215 false_positives: self.exact_lines_fp,
216 false_negatives: self.exact_lines_fn,
217 }
218 }
219}
220
221impl Example {
222 pub fn repo_name(&self) -> Result<RepoName<'_>> {
223 // git@github.com:owner/repo.git
224 if self.spec.repository_url.contains('@') {
225 let (owner, repo) = self
226 .spec
227 .repository_url
228 .split_once(':')
229 .context("expected : in git url")?
230 .1
231 .split_once('/')
232 .context("expected / in git url")?;
233 Ok(RepoName {
234 owner: Cow::Borrowed(owner),
235 name: Cow::Borrowed(repo.trim_end_matches(".git")),
236 })
237 // http://github.com/owner/repo.git
238 } else {
239 let url = Url::parse(&self.spec.repository_url)?;
240 let mut segments = url.path_segments().context("empty http url")?;
241 let owner = segments
242 .next()
243 .context("expected owner path segment")?
244 .to_string();
245 let repo = segments
246 .next()
247 .context("expected repo path segment")?
248 .trim_end_matches(".git")
249 .to_string();
250 assert!(segments.next().is_none());
251
252 Ok(RepoName {
253 owner: Cow::Owned(owner),
254 name: Cow::Owned(repo),
255 })
256 }
257 }
258}
259
260pub struct RepoName<'a> {
261 pub owner: Cow<'a, str>,
262 pub name: Cow<'a, str>,
263}
264
265impl RepoName<'_> {
266 pub fn worktree_path(&self) -> PathBuf {
267 WORKTREES_DIR
268 .join(self.owner.as_ref())
269 .join(self.name.as_ref())
270 }
271}
272
273pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
274 let mut examples = Vec::new();
275
276 for path in inputs {
277 let is_stdin = path.as_path() == Path::new("-");
278 let content = if is_stdin {
279 let mut buffer = String::new();
280 std::io::stdin()
281 .read_to_string(&mut buffer)
282 .expect("Failed to read from stdin");
283 buffer
284 } else {
285 std::fs::read_to_string(path)
286 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
287 };
288 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
289 let ext = if !is_stdin {
290 path.extension()
291 .map(|ext| ext.to_string_lossy().to_string())
292 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
293 } else {
294 "jsonl".to_string()
295 };
296
297 match ext.as_ref() {
298 "json" => {
299 let mut example =
300 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
301 panic!("Failed to parse example file: {}\n{error}", path.display())
302 });
303 if example.spec.name.is_empty() {
304 example.spec.name = filename;
305 }
306 examples.push(example);
307 }
308 "jsonl" => examples.extend(
309 content
310 .lines()
311 .enumerate()
312 .map(|(line_ix, line)| {
313 let mut example =
314 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
315 panic!(
316 "Failed to parse example on {}:{}\n{error}",
317 path.display(),
318 line_ix + 1
319 )
320 });
321 if example.spec.name.is_empty() {
322 example.spec.name = format!("{filename}-{line_ix}")
323 }
324 example
325 })
326 .collect::<Vec<Example>>(),
327 ),
328 "md" => {
329 let mut example = parse_markdown_example(&content).unwrap();
330 if example.spec.name.is_empty() {
331 example.spec.name = filename;
332 }
333 examples.push(example);
334 }
335 ext => {
336 panic!("{} has invalid example extension `{ext}`", path.display())
337 }
338 }
339 }
340
341 examples
342}
343
344pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
345 examples.sort_by(|a, b| {
346 a.spec
347 .repository_url
348 .cmp(&b.spec.repository_url)
349 .then(b.spec.revision.cmp(&a.spec.revision))
350 });
351}
352
353pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
354 let mut examples_by_repo: HashMap<String, Vec<Example>> = HashMap::default();
355 let mut ungrouped = Vec::new();
356 for example in examples {
357 if example.spec.repository_url.is_empty() {
358 ungrouped.push(example);
359 } else {
360 examples_by_repo
361 .entry(example.spec.repository_url.clone())
362 .or_insert_with(Vec::new)
363 .push(example);
364 }
365 }
366 let mut result: VecDeque<Vec<Example>> = examples_by_repo.into_values().collect();
367 for example in ungrouped {
368 result.push_back(vec![example]);
369 }
370 result
371}
372
373fn parse_markdown_example(input: &str) -> Result<Example> {
374 let spec = ExampleSpec::from_markdown(input)?;
375 Ok(Example {
376 spec,
377 prompt_inputs: None,
378 prompt: None,
379 predictions: Vec::new(),
380 score: Vec::new(),
381 qa: Vec::new(),
382 state: None,
383 zed_version: None,
384 })
385}