example.rs

  1use crate::PredictionProvider;
  2use crate::paths::WORKTREES_DIR;
  3use crate::qa::QaResult;
  4use anyhow::{Context as _, Result};
  5use collections::HashMap;
  6use edit_prediction::example_spec::ExampleSpec;
  7use edit_prediction::udiff::OpenedBuffers;
  8use gpui::Entity;
  9use http_client::Url;
 10use language::{Anchor, Buffer};
 11use project::Project;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    borrow::Cow,
 15    collections::VecDeque,
 16    io::Read,
 17    path::{Path, PathBuf},
 18};
 19use zeta_prompt::ZetaPromptInput;
 20
 21#[derive(Clone, Debug, Serialize, Deserialize)]
 22pub struct Example {
 23    #[serde(flatten)]
 24    pub spec: ExampleSpec,
 25
 26    /// The full content of the file where an edit is being predicted, and the
 27    /// actual cursor offset.
 28    #[serde(skip_serializing_if = "Option::is_none")]
 29    pub prompt_inputs: Option<ZetaPromptInput>,
 30
 31    /// The input and expected output from the edit prediction model.
 32    #[serde(skip_serializing_if = "Option::is_none")]
 33    pub prompt: Option<ExamplePrompt>,
 34
 35    /// The actual predictions from the model.
 36    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 37    pub predictions: Vec<ExamplePrediction>,
 38
 39    /// The scores, for how well the actual predictions match the expected
 40    /// predictions.
 41    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 42    pub score: Vec<ExampleScore>,
 43
 44    /// QA evaluation results for each prediction (indexed parallel to `predictions`).
 45    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 46    pub qa: Vec<Option<QaResult>>,
 47
 48    /// The Zed version used to generate this example.
 49    pub zed_version: Option<String>,
 50
 51    /// The application state used to process this example.
 52    #[serde(skip)]
 53    pub state: Option<ExampleState>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct ExampleState {
 58    pub project: Entity<Project>,
 59    pub buffer: Entity<Buffer>,
 60    pub cursor_position: Anchor,
 61    pub _open_buffers: OpenedBuffers,
 62}
 63
 64#[derive(Clone, Debug, Serialize, Deserialize)]
 65pub struct ExamplePrompt {
 66    pub input: String,
 67    pub expected_output: String,
 68    pub rejected_output: Option<String>, // For DPO
 69    #[serde(default)]
 70    pub prefill: Option<String>,
 71    pub provider: PredictionProvider,
 72}
 73
 74#[derive(Clone, Debug, Serialize, Deserialize)]
 75pub struct ExamplePrediction {
 76    #[serde(default, skip_serializing_if = "Option::is_none")]
 77    pub actual_patch: Option<String>,
 78    #[serde(deserialize_with = "deserialize_null_as_empty_string")]
 79    pub actual_output: String,
 80    #[serde(default, skip_serializing_if = "Option::is_none")]
 81    pub actual_cursor: Option<ActualCursor>,
 82    #[serde(default, skip_serializing_if = "Option::is_none")]
 83    pub error: Option<String>,
 84    pub provider: PredictionProvider,
 85    #[serde(default, skip_serializing_if = "Option::is_none")]
 86    pub cumulative_logprob: Option<f64>,
 87    #[serde(default, skip_serializing_if = "Option::is_none")]
 88    pub avg_logprob: Option<f64>,
 89}
 90
 91#[derive(Clone, Debug, Serialize, Deserialize)]
 92pub struct ActualCursor {
 93    pub path: String,
 94    pub row: u32,
 95    pub column: u32,
 96    pub offset: usize,
 97    #[serde(default, skip_serializing_if = "Option::is_none")]
 98    pub editable_region_offset: Option<usize>,
 99}
100
101impl ActualCursor {
102    /// Construct an `ActualCursor` from a cursor offset within the new editable region.
103    ///
104    /// - `path`: file path the cursor is in
105    /// - `editable_region_cursor_offset`: byte offset of the cursor within the new editable region text
106    /// - `new_editable_region`: the full new editable region text (after marker removal)
107    /// - `content`: the full file content (before the edit)
108    /// - `editable_region_byte_offset`: byte offset where the editable region starts in `content`
109    /// - `editable_region_start_line`: 0-based line number where the editable region starts in `content`
110    pub fn from_editable_region(
111        path: &std::path::Path,
112        editable_region_cursor_offset: usize,
113        new_editable_region: &str,
114        content: &str,
115        editable_region_byte_offset: usize,
116        editable_region_start_line: usize,
117    ) -> Self {
118        let global_offset = editable_region_byte_offset + editable_region_cursor_offset;
119        let new_region_prefix = &new_editable_region[..editable_region_cursor_offset];
120        let row = (editable_region_start_line + new_region_prefix.matches('\n').count()) as u32;
121        let column = match new_region_prefix.rfind('\n') {
122            Some(pos) => (editable_region_cursor_offset - pos - 1) as u32,
123            None => {
124                let content_prefix = &content[..editable_region_byte_offset];
125                let content_column = match content_prefix.rfind('\n') {
126                    Some(pos) => editable_region_byte_offset - pos - 1,
127                    None => editable_region_byte_offset,
128                };
129                (content_column + editable_region_cursor_offset) as u32
130            }
131        };
132        ActualCursor {
133            path: path.to_string_lossy().to_string(),
134            row,
135            column,
136            offset: global_offset,
137            editable_region_offset: Some(editable_region_cursor_offset),
138        }
139    }
140}
141
142fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
143where
144    D: serde::Deserializer<'de>,
145{
146    let opt = Option::<String>::deserialize(deserializer)?;
147    Ok(opt.unwrap_or_default())
148}
149
150#[derive(Clone, Debug, Serialize, Deserialize)]
151pub struct ExampleScore {
152    pub delta_chr_f: f32,
153    pub braces_disbalance: usize,
154    #[serde(default)]
155    pub exact_lines_tp: usize,
156    #[serde(default)]
157    pub exact_lines_fp: usize,
158    #[serde(default)]
159    pub exact_lines_fn: usize,
160    #[serde(default)]
161    pub token_match_tp: usize,
162    #[serde(default)]
163    pub token_match_fp: usize,
164    #[serde(default)]
165    pub token_match_fn: usize,
166    #[serde(default)]
167    pub token_match_precision: f64,
168    #[serde(default)]
169    pub token_match_recall: f64,
170    #[serde(default)]
171    pub reversal_ratio: f32,
172    #[serde(default, skip_serializing_if = "Option::is_none")]
173    pub cursor_distance: Option<usize>,
174    #[serde(default, skip_serializing_if = "Option::is_none")]
175    pub cursor_exact_match: Option<bool>,
176    pub wrong_editable_region: Option<bool>,
177    #[serde(default)]
178    pub has_isolated_whitespace_changes: bool,
179    #[serde(default)]
180    pub inserted_tokens: usize,
181    #[serde(default)]
182    pub deleted_tokens: usize,
183    #[serde(default, skip_serializing_if = "Option::is_none")]
184    pub cumulative_logprob: Option<f64>,
185    #[serde(default, skip_serializing_if = "Option::is_none")]
186    pub avg_logprob: Option<f64>,
187}
188
189impl Example {
190    pub fn repo_name(&self) -> Result<RepoName<'_>> {
191        // git@github.com:owner/repo.git
192        if self.spec.repository_url.contains('@') {
193            let (owner, repo) = self
194                .spec
195                .repository_url
196                .split_once(':')
197                .context("expected : in git url")?
198                .1
199                .split_once('/')
200                .context("expected / in git url")?;
201            Ok(RepoName {
202                owner: Cow::Borrowed(owner),
203                name: Cow::Borrowed(repo.trim_end_matches(".git")),
204            })
205        // http://github.com/owner/repo.git
206        } else {
207            let url = Url::parse(&self.spec.repository_url)?;
208            let mut segments = url.path_segments().context("empty http url")?;
209            let owner = segments
210                .next()
211                .context("expected owner path segment")?
212                .to_string();
213            let repo = segments
214                .next()
215                .context("expected repo path segment")?
216                .trim_end_matches(".git")
217                .to_string();
218            assert!(segments.next().is_none());
219
220            Ok(RepoName {
221                owner: Cow::Owned(owner),
222                name: Cow::Owned(repo),
223            })
224        }
225    }
226}
227
228pub struct RepoName<'a> {
229    pub owner: Cow<'a, str>,
230    pub name: Cow<'a, str>,
231}
232
233impl RepoName<'_> {
234    pub fn worktree_path(&self) -> PathBuf {
235        WORKTREES_DIR
236            .join(self.owner.as_ref())
237            .join(self.name.as_ref())
238    }
239}
240
241pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
242    let mut examples = Vec::new();
243
244    for path in inputs {
245        let is_stdin = path.as_path() == Path::new("-");
246        let content = if is_stdin {
247            let mut buffer = String::new();
248            std::io::stdin()
249                .read_to_string(&mut buffer)
250                .expect("Failed to read from stdin");
251            buffer
252        } else {
253            std::fs::read_to_string(path)
254                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
255        };
256        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
257        let ext = if !is_stdin {
258            path.extension()
259                .map(|ext| ext.to_string_lossy().to_string())
260                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
261        } else {
262            "jsonl".to_string()
263        };
264
265        match ext.as_ref() {
266            "json" => {
267                let mut example =
268                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
269                        panic!("Failed to parse example file: {}\n{error}", path.display())
270                    });
271                if example.spec.name.is_empty() {
272                    example.spec.name = filename;
273                }
274                examples.push(example);
275            }
276            "jsonl" => examples.extend(
277                content
278                    .lines()
279                    .enumerate()
280                    .map(|(line_ix, line)| {
281                        let mut example =
282                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
283                                panic!(
284                                    "Failed to parse example on {}:{}\n{error}",
285                                    path.display(),
286                                    line_ix + 1
287                                )
288                            });
289                        if example.spec.name.is_empty() {
290                            example.spec.name = format!("{filename}-{line_ix}")
291                        }
292                        example
293                    })
294                    .collect::<Vec<Example>>(),
295            ),
296            "md" => {
297                let mut example = parse_markdown_example(&content).unwrap();
298                if example.spec.name.is_empty() {
299                    example.spec.name = filename;
300                }
301                examples.push(example);
302            }
303            ext => {
304                panic!("{} has invalid example extension `{ext}`", path.display())
305            }
306        }
307    }
308
309    examples
310}
311
312pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
313    examples.sort_by(|a, b| {
314        a.spec
315            .repository_url
316            .cmp(&b.spec.repository_url)
317            .then(b.spec.revision.cmp(&a.spec.revision))
318    });
319}
320
321pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
322    let mut examples_by_repo: HashMap<String, Vec<Example>> = HashMap::default();
323    let mut ungrouped = Vec::new();
324    for example in examples {
325        if example.spec.repository_url.is_empty() {
326            ungrouped.push(example);
327        } else {
328            examples_by_repo
329                .entry(example.spec.repository_url.clone())
330                .or_insert_with(Vec::new)
331                .push(example);
332        }
333    }
334    let mut result: VecDeque<Vec<Example>> = examples_by_repo.into_values().collect();
335    for example in ungrouped {
336        result.push_back(vec![example]);
337    }
338    result
339}
340
341fn parse_markdown_example(input: &str) -> Result<Example> {
342    let spec = ExampleSpec::from_markdown(input)?;
343    Ok(Example {
344        spec,
345        prompt_inputs: None,
346        prompt: None,
347        predictions: Vec::new(),
348        score: Vec::new(),
349        qa: Vec::new(),
350        state: None,
351        zed_version: None,
352    })
353}