example.rs

  1use crate::PredictionProvider;
  2use crate::paths::WORKTREES_DIR;
  3use crate::qa::QaResult;
  4use anyhow::{Context as _, Result};
  5use collections::HashMap;
  6use edit_prediction::example_spec::ExampleSpec;
  7use edit_prediction::udiff::OpenedBuffers;
  8use gpui::Entity;
  9use http_client::Url;
 10use language::{Anchor, Buffer};
 11use project::Project;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    borrow::Cow,
 15    collections::VecDeque,
 16    io::Read,
 17    path::{Path, PathBuf},
 18    sync::Arc,
 19};
 20use zeta_prompt::RelatedFile;
 21
 22#[derive(Clone, Debug, Serialize, Deserialize)]
 23pub struct Example {
 24    #[serde(flatten)]
 25    pub spec: ExampleSpec,
 26
 27    /// The full content of the file where an edit is being predicted, and the
 28    /// actual cursor offset.
 29    #[serde(skip_serializing_if = "Option::is_none")]
 30    pub prompt_inputs: Option<ExamplePromptInputs>,
 31
 32    /// The input and expected output from the edit prediction model.
 33    #[serde(skip_serializing_if = "Option::is_none")]
 34    pub prompt: Option<ExamplePrompt>,
 35
 36    /// The actual predictions from the model.
 37    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 38    pub predictions: Vec<ExamplePrediction>,
 39
 40    /// The scores, for how well the actual predictions match the expected
 41    /// predictions.
 42    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 43    pub score: Vec<ExampleScore>,
 44
 45    /// QA evaluation results for each prediction (indexed parallel to `predictions`).
 46    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 47    pub qa: Vec<Option<QaResult>>,
 48
 49    /// The application state used to process this example.
 50    #[serde(skip)]
 51    pub state: Option<ExampleState>,
 52}
 53
 54#[derive(Clone, Debug)]
 55pub struct ExampleState {
 56    pub project: Entity<Project>,
 57    pub buffer: Entity<Buffer>,
 58    pub cursor_position: Anchor,
 59    pub _open_buffers: OpenedBuffers,
 60}
 61
 62#[derive(Clone, Debug, Serialize, Deserialize)]
 63pub struct ExamplePromptInputs {
 64    pub content: String,
 65    pub cursor_row: u32,
 66    pub cursor_column: u32,
 67    pub cursor_offset: usize,
 68    #[serde(default, skip_serializing_if = "Option::is_none")]
 69    pub excerpt_start_row: Option<u32>,
 70    pub edit_history: Vec<Arc<zeta_prompt::Event>>,
 71    pub related_files: Option<Vec<RelatedFile>>,
 72}
 73
 74#[derive(Clone, Debug, Serialize, Deserialize)]
 75pub struct ExamplePrompt {
 76    pub input: String,
 77    pub expected_output: String,
 78    pub rejected_output: Option<String>, // For DPO
 79    #[serde(default)]
 80    pub prefill: Option<String>,
 81    pub provider: PredictionProvider,
 82}
 83
 84#[derive(Clone, Debug, Serialize, Deserialize)]
 85pub struct ExamplePrediction {
 86    #[serde(default, skip_serializing_if = "Option::is_none")]
 87    pub actual_patch: Option<String>,
 88    #[serde(deserialize_with = "deserialize_null_as_empty_string")]
 89    pub actual_output: String,
 90    #[serde(default, skip_serializing_if = "Option::is_none")]
 91    pub actual_cursor: Option<ActualCursor>,
 92    #[serde(default, skip_serializing_if = "Option::is_none")]
 93    pub error: Option<String>,
 94    pub provider: PredictionProvider,
 95}
 96
 97#[derive(Clone, Debug, Serialize, Deserialize)]
 98pub struct ActualCursor {
 99    pub path: String,
100    pub row: u32,
101    pub column: u32,
102    pub offset: usize,
103    #[serde(default, skip_serializing_if = "Option::is_none")]
104    pub editable_region_offset: Option<usize>,
105}
106
107impl ActualCursor {
108    /// Construct an `ActualCursor` from a cursor offset within the new editable region.
109    ///
110    /// - `path`: file path the cursor is in
111    /// - `editable_region_cursor_offset`: byte offset of the cursor within the new editable region text
112    /// - `new_editable_region`: the full new editable region text (after marker removal)
113    /// - `content`: the full file content (before the edit)
114    /// - `editable_region_byte_offset`: byte offset where the editable region starts in `content`
115    /// - `editable_region_start_line`: 0-based line number where the editable region starts in `content`
116    pub fn from_editable_region(
117        path: &std::path::Path,
118        editable_region_cursor_offset: usize,
119        new_editable_region: &str,
120        content: &str,
121        editable_region_byte_offset: usize,
122        editable_region_start_line: usize,
123    ) -> Self {
124        let global_offset = editable_region_byte_offset + editable_region_cursor_offset;
125        let new_region_prefix = &new_editable_region[..editable_region_cursor_offset];
126        let row = (editable_region_start_line + new_region_prefix.matches('\n').count()) as u32;
127        let column = match new_region_prefix.rfind('\n') {
128            Some(pos) => (editable_region_cursor_offset - pos - 1) as u32,
129            None => {
130                let content_prefix = &content[..editable_region_byte_offset];
131                let content_column = match content_prefix.rfind('\n') {
132                    Some(pos) => editable_region_byte_offset - pos - 1,
133                    None => editable_region_byte_offset,
134                };
135                (content_column + editable_region_cursor_offset) as u32
136            }
137        };
138        ActualCursor {
139            path: path.to_string_lossy().to_string(),
140            row,
141            column,
142            offset: global_offset,
143            editable_region_offset: Some(editable_region_cursor_offset),
144        }
145    }
146}
147
148fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
149where
150    D: serde::Deserializer<'de>,
151{
152    let opt = Option::<String>::deserialize(deserializer)?;
153    Ok(opt.unwrap_or_default())
154}
155
156#[derive(Clone, Debug, Serialize, Deserialize)]
157pub struct ExampleScore {
158    pub delta_chr_f: f32,
159    pub braces_disbalance: usize,
160    #[serde(default)]
161    pub exact_lines_tp: usize,
162    #[serde(default)]
163    pub exact_lines_fp: usize,
164    #[serde(default)]
165    pub exact_lines_fn: usize,
166    #[serde(default)]
167    pub reversal_ratio: f32,
168    #[serde(default, skip_serializing_if = "Option::is_none")]
169    pub cursor_distance: Option<usize>,
170    #[serde(default, skip_serializing_if = "Option::is_none")]
171    pub cursor_exact_match: Option<bool>,
172    pub wrong_editable_region: Option<bool>,
173    #[serde(default)]
174    pub has_isolated_whitespace_changes: bool,
175    #[serde(default)]
176    pub inserted_tokens: usize,
177    #[serde(default)]
178    pub deleted_tokens: usize,
179}
180
181impl Example {
182    pub fn repo_name(&self) -> Result<RepoName<'_>> {
183        // git@github.com:owner/repo.git
184        if self.spec.repository_url.contains('@') {
185            let (owner, repo) = self
186                .spec
187                .repository_url
188                .split_once(':')
189                .context("expected : in git url")?
190                .1
191                .split_once('/')
192                .context("expected / in git url")?;
193            Ok(RepoName {
194                owner: Cow::Borrowed(owner),
195                name: Cow::Borrowed(repo.trim_end_matches(".git")),
196            })
197        // http://github.com/owner/repo.git
198        } else {
199            let url = Url::parse(&self.spec.repository_url)?;
200            let mut segments = url.path_segments().context("empty http url")?;
201            let owner = segments
202                .next()
203                .context("expected owner path segment")?
204                .to_string();
205            let repo = segments
206                .next()
207                .context("expected repo path segment")?
208                .trim_end_matches(".git")
209                .to_string();
210            assert!(segments.next().is_none());
211
212            Ok(RepoName {
213                owner: Cow::Owned(owner),
214                name: Cow::Owned(repo),
215            })
216        }
217    }
218}
219
220pub struct RepoName<'a> {
221    pub owner: Cow<'a, str>,
222    pub name: Cow<'a, str>,
223}
224
225impl RepoName<'_> {
226    pub fn worktree_path(&self) -> PathBuf {
227        WORKTREES_DIR
228            .join(self.owner.as_ref())
229            .join(self.name.as_ref())
230    }
231}
232
233pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
234    let mut examples = Vec::new();
235
236    for path in inputs {
237        let is_stdin = path.as_path() == Path::new("-");
238        let content = if is_stdin {
239            let mut buffer = String::new();
240            std::io::stdin()
241                .read_to_string(&mut buffer)
242                .expect("Failed to read from stdin");
243            buffer
244        } else {
245            std::fs::read_to_string(path)
246                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
247        };
248        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
249        let ext = if !is_stdin {
250            path.extension()
251                .map(|ext| ext.to_string_lossy().to_string())
252                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
253        } else {
254            "jsonl".to_string()
255        };
256
257        match ext.as_ref() {
258            "json" => {
259                let mut example =
260                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
261                        panic!("Failed to parse example file: {}\n{error}", path.display())
262                    });
263                if example.spec.name.is_empty() {
264                    example.spec.name = filename;
265                }
266                examples.push(example);
267            }
268            "jsonl" => examples.extend(
269                content
270                    .lines()
271                    .enumerate()
272                    .map(|(line_ix, line)| {
273                        let mut example =
274                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
275                                panic!(
276                                    "Failed to parse example on {}:{}\n{error}",
277                                    path.display(),
278                                    line_ix + 1
279                                )
280                            });
281                        if example.spec.name.is_empty() {
282                            example.spec.name = format!("{filename}-{line_ix}")
283                        }
284                        example
285                    })
286                    .collect::<Vec<Example>>(),
287            ),
288            "md" => {
289                let mut example = parse_markdown_example(&content).unwrap();
290                if example.spec.name.is_empty() {
291                    example.spec.name = filename;
292                }
293                examples.push(example);
294            }
295            ext => {
296                panic!("{} has invalid example extension `{ext}`", path.display())
297            }
298        }
299    }
300
301    examples
302}
303
304pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
305    examples.sort_by(|a, b| {
306        a.spec
307            .repository_url
308            .cmp(&b.spec.repository_url)
309            .then(b.spec.revision.cmp(&a.spec.revision))
310    });
311}
312
313pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
314    let mut examples_by_repo: HashMap<String, Vec<Example>> = HashMap::default();
315    let mut ungrouped = Vec::new();
316    for example in examples {
317        if example.spec.repository_url.is_empty() {
318            ungrouped.push(example);
319        } else {
320            examples_by_repo
321                .entry(example.spec.repository_url.clone())
322                .or_insert_with(Vec::new)
323                .push(example);
324        }
325    }
326    let mut result: VecDeque<Vec<Example>> = examples_by_repo.into_values().collect();
327    for example in ungrouped {
328        result.push_back(vec![example]);
329    }
330    result
331}
332
333fn parse_markdown_example(input: &str) -> Result<Example> {
334    let spec = ExampleSpec::from_markdown(input)?;
335    Ok(Example {
336        spec,
337        prompt_inputs: None,
338        prompt: None,
339        predictions: Vec::new(),
340        score: Vec::new(),
341        qa: Vec::new(),
342        state: None,
343    })
344}