example.rs

  1use crate::PredictionProvider;
  2use crate::paths::WORKTREES_DIR;
  3use anyhow::{Context as _, Result};
  4use collections::HashMap;
  5use edit_prediction::example_spec::ExampleSpec;
  6use edit_prediction::udiff::OpenedBuffers;
  7use gpui::Entity;
  8use http_client::Url;
  9use language::{Anchor, Buffer};
 10use project::Project;
 11use serde::{Deserialize, Serialize};
 12use std::{
 13    borrow::Cow,
 14    io::Read,
 15    path::{Path, PathBuf},
 16    sync::Arc,
 17};
 18use zeta_prompt::RelatedFile;
 19
 20#[derive(Clone, Debug, Serialize, Deserialize)]
 21pub struct Example {
 22    #[serde(flatten)]
 23    pub spec: ExampleSpec,
 24
 25    /// The full content of the file where an edit is being predicted, and the
 26    /// actual cursor offset.
 27    #[serde(skip_serializing_if = "Option::is_none")]
 28    pub prompt_inputs: Option<ExamplePromptInputs>,
 29
 30    /// The input and expected output from the edit prediction model.
 31    #[serde(skip_serializing_if = "Option::is_none")]
 32    pub prompt: Option<ExamplePrompt>,
 33
 34    /// The actual predictions from the model.
 35    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 36    pub predictions: Vec<ExamplePrediction>,
 37
 38    /// The scores, for how well the actual predictions match the expected
 39    /// predictions.
 40    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 41    pub score: Vec<ExampleScore>,
 42
 43    /// The application state used to process this example.
 44    #[serde(skip)]
 45    pub state: Option<ExampleState>,
 46}
 47
 48#[derive(Clone, Debug)]
 49pub struct ExampleState {
 50    pub project: Entity<Project>,
 51    pub buffer: Entity<Buffer>,
 52    pub cursor_position: Anchor,
 53    pub _open_buffers: OpenedBuffers,
 54}
 55
 56#[derive(Clone, Debug, Serialize, Deserialize)]
 57pub struct ExamplePromptInputs {
 58    pub content: String,
 59    pub cursor_row: u32,
 60    pub cursor_column: u32,
 61    pub cursor_offset: usize,
 62    pub edit_history: Vec<Arc<zeta_prompt::Event>>,
 63    pub related_files: Option<Vec<RelatedFile>>,
 64}
 65
 66#[derive(Clone, Debug, Serialize, Deserialize)]
 67pub struct ExamplePrompt {
 68    pub input: String,
 69    pub expected_output: String,
 70    pub provider: PredictionProvider,
 71}
 72
 73#[derive(Clone, Debug, Serialize, Deserialize)]
 74pub struct ExamplePrediction {
 75    pub actual_patch: String,
 76    pub actual_output: String,
 77    pub provider: PredictionProvider,
 78}
 79
 80#[derive(Clone, Debug, Serialize, Deserialize)]
 81pub struct ExampleScore {
 82    pub delta_chr_f: f32,
 83}
 84
 85impl Example {
 86    pub fn repo_name(&self) -> Result<RepoName<'_>> {
 87        // git@github.com:owner/repo.git
 88        if self.spec.repository_url.contains('@') {
 89            let (owner, repo) = self
 90                .spec
 91                .repository_url
 92                .split_once(':')
 93                .context("expected : in git url")?
 94                .1
 95                .split_once('/')
 96                .context("expected / in git url")?;
 97            Ok(RepoName {
 98                owner: Cow::Borrowed(owner),
 99                name: Cow::Borrowed(repo.trim_end_matches(".git")),
100            })
101        // http://github.com/owner/repo.git
102        } else {
103            let url = Url::parse(&self.spec.repository_url)?;
104            let mut segments = url.path_segments().context("empty http url")?;
105            let owner = segments
106                .next()
107                .context("expected owner path segment")?
108                .to_string();
109            let repo = segments
110                .next()
111                .context("expected repo path segment")?
112                .trim_end_matches(".git")
113                .to_string();
114            assert!(segments.next().is_none());
115
116            Ok(RepoName {
117                owner: Cow::Owned(owner),
118                name: Cow::Owned(repo),
119            })
120        }
121    }
122}
123
124pub struct RepoName<'a> {
125    pub owner: Cow<'a, str>,
126    pub name: Cow<'a, str>,
127}
128
129impl RepoName<'_> {
130    pub fn worktree_path(&self) -> PathBuf {
131        WORKTREES_DIR
132            .join(self.owner.as_ref())
133            .join(self.name.as_ref())
134    }
135}
136
137pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
138    let mut examples = Vec::new();
139
140    for path in inputs {
141        let is_stdin = path.as_path() == Path::new("-");
142        let content = if is_stdin {
143            let mut buffer = String::new();
144            std::io::stdin()
145                .read_to_string(&mut buffer)
146                .expect("Failed to read from stdin");
147            buffer
148        } else {
149            std::fs::read_to_string(path)
150                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
151        };
152        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
153        let ext = if !is_stdin {
154            path.extension()
155                .map(|ext| ext.to_string_lossy().to_string())
156                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
157        } else {
158            "jsonl".to_string()
159        };
160
161        match ext.as_ref() {
162            "json" => {
163                let mut example =
164                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
165                        panic!("Failed to parse example file: {}\n{error}", path.display())
166                    });
167                if example.spec.name.is_empty() {
168                    example.spec.name = filename;
169                }
170                examples.push(example);
171            }
172            "jsonl" => examples.extend(
173                content
174                    .lines()
175                    .enumerate()
176                    .map(|(line_ix, line)| {
177                        let mut example =
178                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
179                                panic!(
180                                    "Failed to parse example on {}:{}\n{error}",
181                                    path.display(),
182                                    line_ix + 1
183                                )
184                            });
185                        if example.spec.name.is_empty() {
186                            example.spec.name = format!("{filename}-{line_ix}")
187                        }
188                        example
189                    })
190                    .collect::<Vec<Example>>(),
191            ),
192            "md" => {
193                let mut example = parse_markdown_example(&content).unwrap();
194                if example.spec.name.is_empty() {
195                    example.spec.name = filename;
196                }
197                examples.push(example);
198            }
199            ext => {
200                panic!("{} has invalid example extension `{ext}`", path.display())
201            }
202        }
203    }
204
205    examples
206}
207
208pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
209    examples.sort_by(|a, b| {
210        a.spec
211            .repository_url
212            .cmp(&b.spec.repository_url)
213            .then(b.spec.revision.cmp(&a.spec.revision))
214    });
215}
216
217pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
218    let mut examples_by_repo = HashMap::default();
219    for example in examples.iter_mut() {
220        examples_by_repo
221            .entry(example.spec.repository_url.clone())
222            .or_insert_with(Vec::new)
223            .push(example);
224    }
225    examples_by_repo.into_values().collect()
226}
227
228fn parse_markdown_example(input: &str) -> Result<Example> {
229    let spec = ExampleSpec::from_markdown(input)?;
230    Ok(Example {
231        spec,
232        prompt_inputs: None,
233        prompt: None,
234        predictions: Vec::new(),
235        score: Vec::new(),
236        state: None,
237    })
238}