example.rs

  1use crate::PredictionProvider;
  2use crate::metrics::ClassificationMetrics;
  3use crate::paths::WORKTREES_DIR;
  4use crate::qa::QaResult;
  5use anyhow::{Context as _, Result};
  6use collections::HashMap;
  7use edit_prediction::example_spec::ExampleSpec;
  8use edit_prediction::udiff::OpenedBuffers;
  9use gpui::Entity;
 10use http_client::Url;
 11use language::{Anchor, Buffer};
 12use project::Project;
 13use serde::{Deserialize, Serialize};
 14use std::{
 15    borrow::Cow,
 16    collections::VecDeque,
 17    io::Read,
 18    path::{Path, PathBuf},
 19};
 20use zeta_prompt::ZetaPromptInput;
 21
 22#[derive(Clone, Debug, Serialize, Deserialize)]
 23pub struct Example {
 24    #[serde(flatten)]
 25    pub spec: ExampleSpec,
 26
 27    /// The full content of the file where an edit is being predicted, and the
 28    /// actual cursor offset.
 29    #[serde(skip_serializing_if = "Option::is_none")]
 30    pub prompt_inputs: Option<ZetaPromptInput>,
 31
 32    /// The input and expected output from the edit prediction model.
 33    #[serde(skip_serializing_if = "Option::is_none")]
 34    pub prompt: Option<ExamplePrompt>,
 35
 36    /// The actual predictions from the model.
 37    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 38    pub predictions: Vec<ExamplePrediction>,
 39
 40    /// The scores, for how well the actual predictions match the expected
 41    /// predictions.
 42    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 43    pub score: Vec<ExampleScore>,
 44
 45    /// QA evaluation results for each prediction (indexed parallel to `predictions`).
 46    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 47    pub qa: Vec<Option<QaResult>>,
 48
 49    /// The Zed version used to generate this example.
 50    pub zed_version: Option<String>,
 51
 52    /// The application state used to process this example.
 53    #[serde(skip)]
 54    pub state: Option<ExampleState>,
 55}
 56
 57#[derive(Clone, Debug)]
 58pub struct ExampleState {
 59    pub project: Entity<Project>,
 60    pub buffer: Entity<Buffer>,
 61    pub cursor_position: Anchor,
 62    pub _open_buffers: OpenedBuffers,
 63}
 64
 65#[derive(Clone, Debug, Serialize, Deserialize)]
 66pub struct ExamplePrompt {
 67    pub input: String,
 68    #[serde(default)]
 69    pub expected_output: Option<String>,
 70    pub rejected_output: Option<String>, // For DPO
 71    #[serde(default)]
 72    pub prefill: Option<String>,
 73    pub provider: PredictionProvider,
 74}
 75
 76#[derive(Clone, Debug, Serialize, Deserialize)]
 77pub struct ExamplePrediction {
 78    #[serde(default, skip_serializing_if = "Option::is_none")]
 79    pub actual_patch: Option<String>,
 80    #[serde(deserialize_with = "deserialize_null_as_empty_string")]
 81    pub actual_output: String,
 82    #[serde(default, skip_serializing_if = "Option::is_none")]
 83    pub actual_cursor: Option<ActualCursor>,
 84    #[serde(default, skip_serializing_if = "Option::is_none")]
 85    pub error: Option<String>,
 86    pub provider: PredictionProvider,
 87    #[serde(default, skip_serializing_if = "Option::is_none")]
 88    pub cumulative_logprob: Option<f64>,
 89    #[serde(default, skip_serializing_if = "Option::is_none")]
 90    pub avg_logprob: Option<f64>,
 91}
 92
 93#[derive(Clone, Debug, Serialize, Deserialize)]
 94pub struct ActualCursor {
 95    pub path: String,
 96    pub row: u32,
 97    pub column: u32,
 98    pub offset: usize,
 99    #[serde(default, skip_serializing_if = "Option::is_none")]
100    pub editable_region_offset: Option<usize>,
101}
102
103impl ActualCursor {
104    /// Construct an `ActualCursor` from a cursor offset within the new editable region.
105    ///
106    /// - `path`: file path the cursor is in
107    /// - `editable_region_cursor_offset`: byte offset of the cursor within the new editable region text
108    /// - `new_editable_region`: the full new editable region text (after marker removal)
109    /// - `content`: the full file content (before the edit)
110    /// - `editable_region_byte_offset`: byte offset where the editable region starts in `content`
111    /// - `editable_region_start_line`: 0-based line number where the editable region starts in `content`
112    pub fn from_editable_region(
113        path: &std::path::Path,
114        editable_region_cursor_offset: usize,
115        new_editable_region: &str,
116        content: &str,
117        editable_region_byte_offset: usize,
118        editable_region_start_line: usize,
119    ) -> Self {
120        let global_offset = editable_region_byte_offset + editable_region_cursor_offset;
121        let new_region_prefix = &new_editable_region[..editable_region_cursor_offset];
122        let row = (editable_region_start_line + new_region_prefix.matches('\n').count()) as u32;
123        let column = match new_region_prefix.rfind('\n') {
124            Some(pos) => (editable_region_cursor_offset - pos - 1) as u32,
125            None => {
126                let content_prefix = &content[..editable_region_byte_offset];
127                let content_column = match content_prefix.rfind('\n') {
128                    Some(pos) => editable_region_byte_offset - pos - 1,
129                    None => editable_region_byte_offset,
130                };
131                (content_column + editable_region_cursor_offset) as u32
132            }
133        };
134        ActualCursor {
135            path: path.to_string_lossy().to_string(),
136            row,
137            column,
138            offset: global_offset,
139            editable_region_offset: Some(editable_region_cursor_offset),
140        }
141    }
142}
143
144fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
145where
146    D: serde::Deserializer<'de>,
147{
148    let opt = Option::<String>::deserialize(deserializer)?;
149    Ok(opt.unwrap_or_default())
150}
151
152#[derive(Clone, Debug, Serialize, Deserialize)]
153pub struct ExampleScore {
154    pub delta_chr_f: f32,
155    #[serde(default)]
156    pub delta_chr_f_true_positives: usize,
157    #[serde(default)]
158    pub delta_chr_f_false_positives: usize,
159    #[serde(default)]
160    pub delta_chr_f_false_negatives: usize,
161    #[serde(default)]
162    pub delta_chr_f_precision: f64,
163    #[serde(default)]
164    pub delta_chr_f_recall: f64,
165    #[serde(default)]
166    pub delta_chr_f_beta: f64,
167    pub braces_disbalance: usize,
168    #[serde(default)]
169    pub exact_lines_tp: usize,
170    #[serde(default)]
171    pub exact_lines_fp: usize,
172    #[serde(default)]
173    pub exact_lines_fn: usize,
174    #[serde(default)]
175    pub reversal_ratio: f32,
176    #[serde(default, skip_serializing_if = "Option::is_none")]
177    pub cursor_distance: Option<usize>,
178    #[serde(default, skip_serializing_if = "Option::is_none")]
179    pub cursor_exact_match: Option<bool>,
180    pub wrong_editable_region: Option<bool>,
181    #[serde(default)]
182    pub has_isolated_whitespace_changes: bool,
183    #[serde(default)]
184    pub inserted_tokens: usize,
185    #[serde(default)]
186    pub deleted_tokens: usize,
187    #[serde(default, skip_serializing_if = "Option::is_none")]
188    pub kept_rate: Option<f64>,
189    #[serde(default, skip_serializing_if = "Option::is_none")]
190    pub recall_rate: Option<f64>,
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub kept_chars: Option<usize>,
193    #[serde(default, skip_serializing_if = "Option::is_none")]
194    pub correctly_deleted_chars: Option<usize>,
195    #[serde(default, skip_serializing_if = "Option::is_none")]
196    pub discarded_chars: Option<usize>,
197    #[serde(default, skip_serializing_if = "Option::is_none")]
198    pub cumulative_logprob: Option<f64>,
199    #[serde(default, skip_serializing_if = "Option::is_none")]
200    pub avg_logprob: Option<f64>,
201}
202
203impl ExampleScore {
204    pub fn delta_chr_f_counts(&self) -> ClassificationMetrics {
205        ClassificationMetrics {
206            true_positives: self.delta_chr_f_true_positives,
207            false_positives: self.delta_chr_f_false_positives,
208            false_negatives: self.delta_chr_f_false_negatives,
209        }
210    }
211
212    pub fn exact_lines_counts(&self) -> ClassificationMetrics {
213        ClassificationMetrics {
214            true_positives: self.exact_lines_tp,
215            false_positives: self.exact_lines_fp,
216            false_negatives: self.exact_lines_fn,
217        }
218    }
219}
220
221impl Example {
222    pub fn repo_name(&self) -> Result<RepoName<'_>> {
223        // git@github.com:owner/repo.git
224        if self.spec.repository_url.contains('@') {
225            let (owner, repo) = self
226                .spec
227                .repository_url
228                .split_once(':')
229                .context("expected : in git url")?
230                .1
231                .split_once('/')
232                .context("expected / in git url")?;
233            Ok(RepoName {
234                owner: Cow::Borrowed(owner),
235                name: Cow::Borrowed(repo.trim_end_matches(".git")),
236            })
237        // http://github.com/owner/repo.git
238        } else {
239            let url = Url::parse(&self.spec.repository_url)?;
240            let mut segments = url.path_segments().context("empty http url")?;
241            let owner = segments
242                .next()
243                .context("expected owner path segment")?
244                .to_string();
245            let repo = segments
246                .next()
247                .context("expected repo path segment")?
248                .trim_end_matches(".git")
249                .to_string();
250            assert!(segments.next().is_none());
251
252            Ok(RepoName {
253                owner: Cow::Owned(owner),
254                name: Cow::Owned(repo),
255            })
256        }
257    }
258}
259
260pub struct RepoName<'a> {
261    pub owner: Cow<'a, str>,
262    pub name: Cow<'a, str>,
263}
264
265impl RepoName<'_> {
266    pub fn worktree_path(&self) -> PathBuf {
267        WORKTREES_DIR
268            .join(self.owner.as_ref())
269            .join(self.name.as_ref())
270    }
271}
272
273pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
274    let mut examples = Vec::new();
275
276    for path in inputs {
277        let is_stdin = path.as_path() == Path::new("-");
278        let content = if is_stdin {
279            let mut buffer = String::new();
280            std::io::stdin()
281                .read_to_string(&mut buffer)
282                .expect("Failed to read from stdin");
283            buffer
284        } else {
285            std::fs::read_to_string(path)
286                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
287        };
288        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
289        let ext = if !is_stdin {
290            path.extension()
291                .map(|ext| ext.to_string_lossy().to_string())
292                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
293        } else {
294            "jsonl".to_string()
295        };
296
297        match ext.as_ref() {
298            "json" => {
299                let mut example =
300                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
301                        panic!("Failed to parse example file: {}\n{error}", path.display())
302                    });
303                if example.spec.name.is_empty() {
304                    example.spec.name = filename;
305                }
306                examples.push(example);
307            }
308            "jsonl" => examples.extend(
309                content
310                    .lines()
311                    .enumerate()
312                    .map(|(line_ix, line)| {
313                        let mut example =
314                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
315                                panic!(
316                                    "Failed to parse example on {}:{}\n{error}",
317                                    path.display(),
318                                    line_ix + 1
319                                )
320                            });
321                        if example.spec.name.is_empty() {
322                            example.spec.name = format!("{filename}-{line_ix}")
323                        }
324                        example
325                    })
326                    .collect::<Vec<Example>>(),
327            ),
328            "md" => {
329                let mut example = parse_markdown_example(&content).unwrap();
330                if example.spec.name.is_empty() {
331                    example.spec.name = filename;
332                }
333                examples.push(example);
334            }
335            ext => {
336                panic!("{} has invalid example extension `{ext}`", path.display())
337            }
338        }
339    }
340
341    examples
342}
343
344pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
345    examples.sort_by(|a, b| {
346        a.spec
347            .repository_url
348            .cmp(&b.spec.repository_url)
349            .then(b.spec.revision.cmp(&a.spec.revision))
350    });
351}
352
353pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
354    let mut examples_by_repo: HashMap<String, Vec<Example>> = HashMap::default();
355    let mut ungrouped = Vec::new();
356    for example in examples {
357        if example.spec.repository_url.is_empty() {
358            ungrouped.push(example);
359        } else {
360            examples_by_repo
361                .entry(example.spec.repository_url.clone())
362                .or_insert_with(Vec::new)
363                .push(example);
364        }
365    }
366    let mut result: VecDeque<Vec<Example>> = examples_by_repo.into_values().collect();
367    for example in ungrouped {
368        result.push_back(vec![example]);
369    }
370    result
371}
372
373fn parse_markdown_example(input: &str) -> Result<Example> {
374    let spec = ExampleSpec::from_markdown(input)?;
375    Ok(Example {
376        spec,
377        prompt_inputs: None,
378        prompt: None,
379        predictions: Vec::new(),
380        score: Vec::new(),
381        qa: Vec::new(),
382        state: None,
383        zed_version: None,
384    })
385}