example.rs

  1use crate::PredictionProvider;
  2use crate::paths::WORKTREES_DIR;
  3use anyhow::{Context as _, Result};
  4use collections::HashMap;
  5use edit_prediction::example_spec::ExampleSpec;
  6use edit_prediction::udiff::OpenedBuffers;
  7use gpui::Entity;
  8use http_client::Url;
  9use language::{Anchor, Buffer};
 10use project::Project;
 11use serde::{Deserialize, Serialize};
 12use std::{
 13    borrow::Cow,
 14    collections::VecDeque,
 15    io::Read,
 16    path::{Path, PathBuf},
 17    sync::Arc,
 18};
 19use zeta_prompt::RelatedFile;
 20
 21#[derive(Clone, Debug, Serialize, Deserialize)]
 22pub struct Example {
 23    #[serde(flatten)]
 24    pub spec: ExampleSpec,
 25
 26    /// The full content of the file where an edit is being predicted, and the
 27    /// actual cursor offset.
 28    #[serde(skip_serializing_if = "Option::is_none")]
 29    pub prompt_inputs: Option<ExamplePromptInputs>,
 30
 31    /// The input and expected output from the edit prediction model.
 32    #[serde(skip_serializing_if = "Option::is_none")]
 33    pub prompt: Option<ExamplePrompt>,
 34
 35    /// The actual predictions from the model.
 36    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 37    pub predictions: Vec<ExamplePrediction>,
 38
 39    /// The scores, for how well the actual predictions match the expected
 40    /// predictions.
 41    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 42    pub score: Vec<ExampleScore>,
 43
 44    /// The application state used to process this example.
 45    #[serde(skip)]
 46    pub state: Option<ExampleState>,
 47}
 48
 49#[derive(Clone, Debug)]
 50pub struct ExampleState {
 51    pub project: Entity<Project>,
 52    pub buffer: Entity<Buffer>,
 53    pub cursor_position: Anchor,
 54    pub _open_buffers: OpenedBuffers,
 55}
 56
 57#[derive(Clone, Debug, Serialize, Deserialize)]
 58pub struct ExamplePromptInputs {
 59    pub content: String,
 60    pub cursor_row: u32,
 61    pub cursor_column: u32,
 62    pub cursor_offset: usize,
 63    pub edit_history: Vec<Arc<zeta_prompt::Event>>,
 64    pub related_files: Option<Vec<RelatedFile>>,
 65}
 66
 67#[derive(Clone, Debug, Serialize, Deserialize)]
 68pub struct ExamplePrompt {
 69    pub input: String,
 70    pub expected_output: String,
 71    pub provider: PredictionProvider,
 72}
 73
 74#[derive(Clone, Debug, Serialize, Deserialize)]
 75pub struct ExamplePrediction {
 76    #[serde(default, skip_serializing_if = "Option::is_none")]
 77    pub actual_patch: Option<String>,
 78    #[serde(deserialize_with = "deserialize_null_as_empty_string")]
 79    pub actual_output: String,
 80    #[serde(default, skip_serializing_if = "Option::is_none")]
 81    pub error: Option<String>,
 82    pub provider: PredictionProvider,
 83}
 84
 85fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
 86where
 87    D: serde::Deserializer<'de>,
 88{
 89    let opt = Option::<String>::deserialize(deserializer)?;
 90    Ok(opt.unwrap_or_default())
 91}
 92
 93#[derive(Clone, Debug, Serialize, Deserialize)]
 94pub struct ExampleScore {
 95    pub delta_chr_f: f32,
 96    pub braces_disbalance: usize,
 97    #[serde(default)]
 98    pub exact_lines_tp: usize,
 99    #[serde(default)]
100    pub exact_lines_fp: usize,
101    #[serde(default)]
102    pub exact_lines_fn: usize,
103}
104
105impl Example {
106    pub fn repo_name(&self) -> Result<RepoName<'_>> {
107        // git@github.com:owner/repo.git
108        if self.spec.repository_url.contains('@') {
109            let (owner, repo) = self
110                .spec
111                .repository_url
112                .split_once(':')
113                .context("expected : in git url")?
114                .1
115                .split_once('/')
116                .context("expected / in git url")?;
117            Ok(RepoName {
118                owner: Cow::Borrowed(owner),
119                name: Cow::Borrowed(repo.trim_end_matches(".git")),
120            })
121        // http://github.com/owner/repo.git
122        } else {
123            let url = Url::parse(&self.spec.repository_url)?;
124            let mut segments = url.path_segments().context("empty http url")?;
125            let owner = segments
126                .next()
127                .context("expected owner path segment")?
128                .to_string();
129            let repo = segments
130                .next()
131                .context("expected repo path segment")?
132                .trim_end_matches(".git")
133                .to_string();
134            assert!(segments.next().is_none());
135
136            Ok(RepoName {
137                owner: Cow::Owned(owner),
138                name: Cow::Owned(repo),
139            })
140        }
141    }
142}
143
144pub struct RepoName<'a> {
145    pub owner: Cow<'a, str>,
146    pub name: Cow<'a, str>,
147}
148
149impl RepoName<'_> {
150    pub fn worktree_path(&self) -> PathBuf {
151        WORKTREES_DIR
152            .join(self.owner.as_ref())
153            .join(self.name.as_ref())
154    }
155}
156
157pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
158    let mut examples = Vec::new();
159
160    for path in inputs {
161        let is_stdin = path.as_path() == Path::new("-");
162        let content = if is_stdin {
163            let mut buffer = String::new();
164            std::io::stdin()
165                .read_to_string(&mut buffer)
166                .expect("Failed to read from stdin");
167            buffer
168        } else {
169            std::fs::read_to_string(path)
170                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
171        };
172        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
173        let ext = if !is_stdin {
174            path.extension()
175                .map(|ext| ext.to_string_lossy().to_string())
176                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
177        } else {
178            "jsonl".to_string()
179        };
180
181        match ext.as_ref() {
182            "json" => {
183                let mut example =
184                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
185                        panic!("Failed to parse example file: {}\n{error}", path.display())
186                    });
187                if example.spec.name.is_empty() {
188                    example.spec.name = filename;
189                }
190                examples.push(example);
191            }
192            "jsonl" => examples.extend(
193                content
194                    .lines()
195                    .enumerate()
196                    .map(|(line_ix, line)| {
197                        let mut example =
198                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
199                                panic!(
200                                    "Failed to parse example on {}:{}\n{error}",
201                                    path.display(),
202                                    line_ix + 1
203                                )
204                            });
205                        if example.spec.name.is_empty() {
206                            example.spec.name = format!("{filename}-{line_ix}")
207                        }
208                        example
209                    })
210                    .collect::<Vec<Example>>(),
211            ),
212            "md" => {
213                let mut example = parse_markdown_example(&content).unwrap();
214                if example.spec.name.is_empty() {
215                    example.spec.name = filename;
216                }
217                examples.push(example);
218            }
219            ext => {
220                panic!("{} has invalid example extension `{ext}`", path.display())
221            }
222        }
223    }
224
225    examples
226}
227
228pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
229    examples.sort_by(|a, b| {
230        a.spec
231            .repository_url
232            .cmp(&b.spec.repository_url)
233            .then(b.spec.revision.cmp(&a.spec.revision))
234    });
235}
236
237pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
238    let mut examples_by_repo = HashMap::default();
239    for example in examples {
240        examples_by_repo
241            .entry(example.spec.repository_url.clone())
242            .or_insert_with(Vec::new)
243            .push(example);
244    }
245    examples_by_repo.into_values().collect()
246}
247
248fn parse_markdown_example(input: &str) -> Result<Example> {
249    let spec = ExampleSpec::from_markdown(input)?;
250    Ok(Example {
251        spec,
252        prompt_inputs: None,
253        prompt: None,
254        predictions: Vec::new(),
255        score: Vec::new(),
256        state: None,
257    })
258}