example.rs

  1use crate::PredictionProvider;
  2use crate::paths::WORKTREES_DIR;
  3use crate::qa::QaResult;
  4use anyhow::{Context as _, Result};
  5use collections::HashMap;
  6use edit_prediction::example_spec::ExampleSpec;
  7use edit_prediction::udiff::OpenedBuffers;
  8use gpui::Entity;
  9use http_client::Url;
 10use language::{Anchor, Buffer};
 11use project::Project;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    borrow::Cow,
 15    collections::VecDeque,
 16    io::Read,
 17    path::{Path, PathBuf},
 18    sync::Arc,
 19};
 20use zeta_prompt::RelatedFile;
 21
 22#[derive(Clone, Debug, Serialize, Deserialize)]
 23pub struct Example {
 24    #[serde(flatten)]
 25    pub spec: ExampleSpec,
 26
 27    /// The full content of the file where an edit is being predicted, and the
 28    /// actual cursor offset.
 29    #[serde(skip_serializing_if = "Option::is_none")]
 30    pub prompt_inputs: Option<ExamplePromptInputs>,
 31
 32    /// The input and expected output from the edit prediction model.
 33    #[serde(skip_serializing_if = "Option::is_none")]
 34    pub prompt: Option<ExamplePrompt>,
 35
 36    /// The actual predictions from the model.
 37    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 38    pub predictions: Vec<ExamplePrediction>,
 39
 40    /// The scores, for how well the actual predictions match the expected
 41    /// predictions.
 42    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 43    pub score: Vec<ExampleScore>,
 44
 45    /// QA evaluation results for each prediction (indexed parallel to `predictions`).
 46    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 47    pub qa: Vec<Option<QaResult>>,
 48
 49    /// The application state used to process this example.
 50    #[serde(skip)]
 51    pub state: Option<ExampleState>,
 52}
 53
 54#[derive(Clone, Debug)]
 55pub struct ExampleState {
 56    pub project: Entity<Project>,
 57    pub buffer: Entity<Buffer>,
 58    pub cursor_position: Anchor,
 59    pub _open_buffers: OpenedBuffers,
 60}
 61
 62#[derive(Clone, Debug, Serialize, Deserialize)]
 63pub struct ExamplePromptInputs {
 64    pub content: String,
 65    pub cursor_row: u32,
 66    pub cursor_column: u32,
 67    pub cursor_offset: usize,
 68    pub edit_history: Vec<Arc<zeta_prompt::Event>>,
 69    pub related_files: Option<Vec<RelatedFile>>,
 70}
 71
 72#[derive(Clone, Debug, Serialize, Deserialize)]
 73pub struct ExamplePrompt {
 74    pub input: String,
 75    pub expected_output: String,
 76    pub rejected_output: Option<String>, // For DPO
 77    pub provider: PredictionProvider,
 78}
 79
 80#[derive(Clone, Debug, Serialize, Deserialize)]
 81pub struct ExamplePrediction {
 82    #[serde(default, skip_serializing_if = "Option::is_none")]
 83    pub actual_patch: Option<String>,
 84    #[serde(deserialize_with = "deserialize_null_as_empty_string")]
 85    pub actual_output: String,
 86    #[serde(default, skip_serializing_if = "Option::is_none")]
 87    pub error: Option<String>,
 88    pub provider: PredictionProvider,
 89}
 90
 91fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
 92where
 93    D: serde::Deserializer<'de>,
 94{
 95    let opt = Option::<String>::deserialize(deserializer)?;
 96    Ok(opt.unwrap_or_default())
 97}
 98
 99#[derive(Clone, Debug, Serialize, Deserialize)]
100pub struct ExampleScore {
101    pub delta_chr_f: f32,
102    pub braces_disbalance: usize,
103    #[serde(default)]
104    pub exact_lines_tp: usize,
105    #[serde(default)]
106    pub exact_lines_fp: usize,
107    #[serde(default)]
108    pub exact_lines_fn: usize,
109    #[serde(default)]
110    pub reversal_ratio: f32,
111}
112
113impl Example {
114    pub fn repo_name(&self) -> Result<RepoName<'_>> {
115        // git@github.com:owner/repo.git
116        if self.spec.repository_url.contains('@') {
117            let (owner, repo) = self
118                .spec
119                .repository_url
120                .split_once(':')
121                .context("expected : in git url")?
122                .1
123                .split_once('/')
124                .context("expected / in git url")?;
125            Ok(RepoName {
126                owner: Cow::Borrowed(owner),
127                name: Cow::Borrowed(repo.trim_end_matches(".git")),
128            })
129        // http://github.com/owner/repo.git
130        } else {
131            let url = Url::parse(&self.spec.repository_url)?;
132            let mut segments = url.path_segments().context("empty http url")?;
133            let owner = segments
134                .next()
135                .context("expected owner path segment")?
136                .to_string();
137            let repo = segments
138                .next()
139                .context("expected repo path segment")?
140                .trim_end_matches(".git")
141                .to_string();
142            assert!(segments.next().is_none());
143
144            Ok(RepoName {
145                owner: Cow::Owned(owner),
146                name: Cow::Owned(repo),
147            })
148        }
149    }
150}
151
152pub struct RepoName<'a> {
153    pub owner: Cow<'a, str>,
154    pub name: Cow<'a, str>,
155}
156
157impl RepoName<'_> {
158    pub fn worktree_path(&self) -> PathBuf {
159        WORKTREES_DIR
160            .join(self.owner.as_ref())
161            .join(self.name.as_ref())
162    }
163}
164
165pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
166    let mut examples = Vec::new();
167
168    for path in inputs {
169        let is_stdin = path.as_path() == Path::new("-");
170        let content = if is_stdin {
171            let mut buffer = String::new();
172            std::io::stdin()
173                .read_to_string(&mut buffer)
174                .expect("Failed to read from stdin");
175            buffer
176        } else {
177            std::fs::read_to_string(path)
178                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
179        };
180        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
181        let ext = if !is_stdin {
182            path.extension()
183                .map(|ext| ext.to_string_lossy().to_string())
184                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
185        } else {
186            "jsonl".to_string()
187        };
188
189        match ext.as_ref() {
190            "json" => {
191                let mut example =
192                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
193                        panic!("Failed to parse example file: {}\n{error}", path.display())
194                    });
195                if example.spec.name.is_empty() {
196                    example.spec.name = filename;
197                }
198                examples.push(example);
199            }
200            "jsonl" => examples.extend(
201                content
202                    .lines()
203                    .enumerate()
204                    .map(|(line_ix, line)| {
205                        let mut example =
206                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
207                                panic!(
208                                    "Failed to parse example on {}:{}\n{error}",
209                                    path.display(),
210                                    line_ix + 1
211                                )
212                            });
213                        if example.spec.name.is_empty() {
214                            example.spec.name = format!("{filename}-{line_ix}")
215                        }
216                        example
217                    })
218                    .collect::<Vec<Example>>(),
219            ),
220            "md" => {
221                let mut example = parse_markdown_example(&content).unwrap();
222                if example.spec.name.is_empty() {
223                    example.spec.name = filename;
224                }
225                examples.push(example);
226            }
227            ext => {
228                panic!("{} has invalid example extension `{ext}`", path.display())
229            }
230        }
231    }
232
233    examples
234}
235
236pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
237    examples.sort_by(|a, b| {
238        a.spec
239            .repository_url
240            .cmp(&b.spec.repository_url)
241            .then(b.spec.revision.cmp(&a.spec.revision))
242    });
243}
244
245pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
246    let mut examples_by_repo = HashMap::default();
247    for example in examples {
248        examples_by_repo
249            .entry(example.spec.repository_url.clone())
250            .or_insert_with(Vec::new)
251            .push(example);
252    }
253    examples_by_repo.into_values().collect()
254}
255
256fn parse_markdown_example(input: &str) -> Result<Example> {
257    let spec = ExampleSpec::from_markdown(input)?;
258    Ok(Example {
259        spec,
260        prompt_inputs: None,
261        prompt: None,
262        predictions: Vec::new(),
263        score: Vec::new(),
264        qa: Vec::new(),
265        state: None,
266    })
267}