1use crate::PredictionProvider;
2use crate::metrics::ClassificationMetrics;
3use crate::paths::WORKTREES_DIR;
4use crate::qa::QaResult;
5use anyhow::{Context as _, Result};
6use collections::HashMap;
7use edit_prediction::example_spec::ExampleSpec;
8use edit_prediction::udiff::OpenedBuffers;
9use gpui::Entity;
10use http_client::Url;
11use language::{Anchor, Buffer};
12use project::Project;
13use serde::{Deserialize, Serialize};
14use std::{
15 borrow::Cow,
16 collections::VecDeque,
17 io::Read,
18 path::{Path, PathBuf},
19};
20use zeta_prompt::ZetaPromptInput;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct Example {
24 #[serde(flatten)]
25 pub spec: ExampleSpec,
26
27 /// The full content of the file where an edit is being predicted, and the
28 /// actual cursor offset.
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub prompt_inputs: Option<ZetaPromptInput>,
31
32 /// The input and expected output from the edit prediction model.
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub prompt: Option<ExamplePrompt>,
35
36 /// The actual predictions from the model.
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
38 pub predictions: Vec<ExamplePrediction>,
39
40 /// The scores, for how well the actual predictions match the expected
41 /// predictions.
42 #[serde(default, skip_serializing_if = "Vec::is_empty")]
43 pub score: Vec<ExampleScore>,
44
45 /// QA evaluation results for each prediction (indexed parallel to `predictions`).
46 #[serde(default, skip_serializing_if = "Vec::is_empty")]
47 pub qa: Vec<Option<QaResult>>,
48
49 /// The Zed version used to generate this example.
50 pub zed_version: Option<String>,
51
52 /// The application state used to process this example.
53 #[serde(skip)]
54 pub state: Option<ExampleState>,
55}
56
57#[derive(Clone, Debug)]
58pub struct ExampleState {
59 pub project: Entity<Project>,
60 pub buffer: Entity<Buffer>,
61 pub cursor_position: Anchor,
62 pub _open_buffers: OpenedBuffers,
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ExamplePrompt {
67 pub input: String,
68 pub expected_output: String,
69 pub rejected_output: Option<String>, // For DPO
70 #[serde(default)]
71 pub prefill: Option<String>,
72 pub provider: PredictionProvider,
73}
74
75#[derive(Clone, Debug, Serialize, Deserialize)]
76pub struct ExamplePrediction {
77 #[serde(default, skip_serializing_if = "Option::is_none")]
78 pub actual_patch: Option<String>,
79 #[serde(deserialize_with = "deserialize_null_as_empty_string")]
80 pub actual_output: String,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub actual_cursor: Option<ActualCursor>,
83 #[serde(default, skip_serializing_if = "Option::is_none")]
84 pub error: Option<String>,
85 pub provider: PredictionProvider,
86 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub cumulative_logprob: Option<f64>,
88 #[serde(default, skip_serializing_if = "Option::is_none")]
89 pub avg_logprob: Option<f64>,
90}
91
92#[derive(Clone, Debug, Serialize, Deserialize)]
93pub struct ActualCursor {
94 pub path: String,
95 pub row: u32,
96 pub column: u32,
97 pub offset: usize,
98 #[serde(default, skip_serializing_if = "Option::is_none")]
99 pub editable_region_offset: Option<usize>,
100}
101
102impl ActualCursor {
103 /// Construct an `ActualCursor` from a cursor offset within the new editable region.
104 ///
105 /// - `path`: file path the cursor is in
106 /// - `editable_region_cursor_offset`: byte offset of the cursor within the new editable region text
107 /// - `new_editable_region`: the full new editable region text (after marker removal)
108 /// - `content`: the full file content (before the edit)
109 /// - `editable_region_byte_offset`: byte offset where the editable region starts in `content`
110 /// - `editable_region_start_line`: 0-based line number where the editable region starts in `content`
111 pub fn from_editable_region(
112 path: &std::path::Path,
113 editable_region_cursor_offset: usize,
114 new_editable_region: &str,
115 content: &str,
116 editable_region_byte_offset: usize,
117 editable_region_start_line: usize,
118 ) -> Self {
119 let global_offset = editable_region_byte_offset + editable_region_cursor_offset;
120 let new_region_prefix = &new_editable_region[..editable_region_cursor_offset];
121 let row = (editable_region_start_line + new_region_prefix.matches('\n').count()) as u32;
122 let column = match new_region_prefix.rfind('\n') {
123 Some(pos) => (editable_region_cursor_offset - pos - 1) as u32,
124 None => {
125 let content_prefix = &content[..editable_region_byte_offset];
126 let content_column = match content_prefix.rfind('\n') {
127 Some(pos) => editable_region_byte_offset - pos - 1,
128 None => editable_region_byte_offset,
129 };
130 (content_column + editable_region_cursor_offset) as u32
131 }
132 };
133 ActualCursor {
134 path: path.to_string_lossy().to_string(),
135 row,
136 column,
137 offset: global_offset,
138 editable_region_offset: Some(editable_region_cursor_offset),
139 }
140 }
141}
142
143fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
144where
145 D: serde::Deserializer<'de>,
146{
147 let opt = Option::<String>::deserialize(deserializer)?;
148 Ok(opt.unwrap_or_default())
149}
150
151#[derive(Clone, Debug, Serialize, Deserialize)]
152pub struct ExampleScore {
153 pub delta_chr_f: f32,
154 #[serde(default)]
155 pub delta_chr_f_true_positives: usize,
156 #[serde(default)]
157 pub delta_chr_f_false_positives: usize,
158 #[serde(default)]
159 pub delta_chr_f_false_negatives: usize,
160 #[serde(default)]
161 pub delta_chr_f_precision: f64,
162 #[serde(default)]
163 pub delta_chr_f_recall: f64,
164 #[serde(default)]
165 pub delta_chr_f_beta: f64,
166 pub braces_disbalance: usize,
167 #[serde(default)]
168 pub exact_lines_tp: usize,
169 #[serde(default)]
170 pub exact_lines_fp: usize,
171 #[serde(default)]
172 pub exact_lines_fn: usize,
173 #[serde(default)]
174 pub reversal_ratio: f32,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub cursor_distance: Option<usize>,
177 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub cursor_exact_match: Option<bool>,
179 pub wrong_editable_region: Option<bool>,
180 #[serde(default)]
181 pub has_isolated_whitespace_changes: bool,
182 #[serde(default)]
183 pub inserted_tokens: usize,
184 #[serde(default)]
185 pub deleted_tokens: usize,
186 #[serde(default, skip_serializing_if = "Option::is_none")]
187 pub cumulative_logprob: Option<f64>,
188 #[serde(default, skip_serializing_if = "Option::is_none")]
189 pub avg_logprob: Option<f64>,
190}
191
192impl ExampleScore {
193 pub fn delta_chr_f_counts(&self) -> ClassificationMetrics {
194 ClassificationMetrics {
195 true_positives: self.delta_chr_f_true_positives,
196 false_positives: self.delta_chr_f_false_positives,
197 false_negatives: self.delta_chr_f_false_negatives,
198 }
199 }
200
201 pub fn exact_lines_counts(&self) -> ClassificationMetrics {
202 ClassificationMetrics {
203 true_positives: self.exact_lines_tp,
204 false_positives: self.exact_lines_fp,
205 false_negatives: self.exact_lines_fn,
206 }
207 }
208}
209
210impl Example {
211 pub fn repo_name(&self) -> Result<RepoName<'_>> {
212 // git@github.com:owner/repo.git
213 if self.spec.repository_url.contains('@') {
214 let (owner, repo) = self
215 .spec
216 .repository_url
217 .split_once(':')
218 .context("expected : in git url")?
219 .1
220 .split_once('/')
221 .context("expected / in git url")?;
222 Ok(RepoName {
223 owner: Cow::Borrowed(owner),
224 name: Cow::Borrowed(repo.trim_end_matches(".git")),
225 })
226 // http://github.com/owner/repo.git
227 } else {
228 let url = Url::parse(&self.spec.repository_url)?;
229 let mut segments = url.path_segments().context("empty http url")?;
230 let owner = segments
231 .next()
232 .context("expected owner path segment")?
233 .to_string();
234 let repo = segments
235 .next()
236 .context("expected repo path segment")?
237 .trim_end_matches(".git")
238 .to_string();
239 assert!(segments.next().is_none());
240
241 Ok(RepoName {
242 owner: Cow::Owned(owner),
243 name: Cow::Owned(repo),
244 })
245 }
246 }
247}
248
249pub struct RepoName<'a> {
250 pub owner: Cow<'a, str>,
251 pub name: Cow<'a, str>,
252}
253
254impl RepoName<'_> {
255 pub fn worktree_path(&self) -> PathBuf {
256 WORKTREES_DIR
257 .join(self.owner.as_ref())
258 .join(self.name.as_ref())
259 }
260}
261
262pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
263 let mut examples = Vec::new();
264
265 for path in inputs {
266 let is_stdin = path.as_path() == Path::new("-");
267 let content = if is_stdin {
268 let mut buffer = String::new();
269 std::io::stdin()
270 .read_to_string(&mut buffer)
271 .expect("Failed to read from stdin");
272 buffer
273 } else {
274 std::fs::read_to_string(path)
275 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
276 };
277 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
278 let ext = if !is_stdin {
279 path.extension()
280 .map(|ext| ext.to_string_lossy().to_string())
281 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
282 } else {
283 "jsonl".to_string()
284 };
285
286 match ext.as_ref() {
287 "json" => {
288 let mut example =
289 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
290 panic!("Failed to parse example file: {}\n{error}", path.display())
291 });
292 if example.spec.name.is_empty() {
293 example.spec.name = filename;
294 }
295 examples.push(example);
296 }
297 "jsonl" => examples.extend(
298 content
299 .lines()
300 .enumerate()
301 .map(|(line_ix, line)| {
302 let mut example =
303 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
304 panic!(
305 "Failed to parse example on {}:{}\n{error}",
306 path.display(),
307 line_ix + 1
308 )
309 });
310 if example.spec.name.is_empty() {
311 example.spec.name = format!("{filename}-{line_ix}")
312 }
313 example
314 })
315 .collect::<Vec<Example>>(),
316 ),
317 "md" => {
318 let mut example = parse_markdown_example(&content).unwrap();
319 if example.spec.name.is_empty() {
320 example.spec.name = filename;
321 }
322 examples.push(example);
323 }
324 ext => {
325 panic!("{} has invalid example extension `{ext}`", path.display())
326 }
327 }
328 }
329
330 examples
331}
332
333pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
334 examples.sort_by(|a, b| {
335 a.spec
336 .repository_url
337 .cmp(&b.spec.repository_url)
338 .then(b.spec.revision.cmp(&a.spec.revision))
339 });
340}
341
342pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
343 let mut examples_by_repo: HashMap<String, Vec<Example>> = HashMap::default();
344 let mut ungrouped = Vec::new();
345 for example in examples {
346 if example.spec.repository_url.is_empty() {
347 ungrouped.push(example);
348 } else {
349 examples_by_repo
350 .entry(example.spec.repository_url.clone())
351 .or_insert_with(Vec::new)
352 .push(example);
353 }
354 }
355 let mut result: VecDeque<Vec<Example>> = examples_by_repo.into_values().collect();
356 for example in ungrouped {
357 result.push_back(vec![example]);
358 }
359 result
360}
361
362fn parse_markdown_example(input: &str) -> Result<Example> {
363 let spec = ExampleSpec::from_markdown(input)?;
364 Ok(Example {
365 spec,
366 prompt_inputs: None,
367 prompt: None,
368 predictions: Vec::new(),
369 score: Vec::new(),
370 qa: Vec::new(),
371 state: None,
372 zed_version: None,
373 })
374}