example.rs

  1use crate::paths::WORKTREES_DIR;
  2use crate::{PredictionProvider, PromptFormat};
  3use anyhow::{Context as _, Result};
  4use collections::HashMap;
  5use edit_prediction::example_spec::ExampleSpec;
  6use edit_prediction::udiff::OpenedBuffers;
  7use gpui::Entity;
  8use http_client::Url;
  9use language::{Anchor, Buffer};
 10use project::Project;
 11use serde::{Deserialize, Serialize};
 12use std::sync::Arc;
 13use std::{
 14    borrow::Cow,
 15    io::{Read, Write},
 16    path::{Path, PathBuf},
 17};
 18use zeta_prompt::RelatedFile;
 19
 20#[derive(Clone, Debug, Serialize, Deserialize)]
 21pub struct Example {
 22    #[serde(flatten)]
 23    pub spec: ExampleSpec,
 24
 25    /// The full content of the file where an edit is being predicted, and the
 26    /// actual cursor offset.
 27    #[serde(skip_serializing_if = "Option::is_none")]
 28    pub buffer: Option<ExampleBuffer>,
 29
 30    /// The context retrieved for the prediction. This requires the worktree to
 31    /// be loaded and the language server to be started.
 32    #[serde(skip_serializing_if = "Option::is_none")]
 33    pub context: Option<ExampleContext>,
 34
 35    /// The input and expected output from the edit prediction model.
 36    #[serde(skip_serializing_if = "Option::is_none")]
 37    pub prompt: Option<ExamplePrompt>,
 38
 39    /// The actual predictions from the model.
 40    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 41    pub predictions: Vec<ExamplePrediction>,
 42
 43    /// The scores, for how well the actual predictions match the expected
 44    /// predictions.
 45    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 46    pub score: Vec<ExampleScore>,
 47
 48    /// The application state used to process this example.
 49    #[serde(skip)]
 50    pub state: Option<ExampleState>,
 51}
 52
 53#[derive(Clone, Debug)]
 54pub struct ExampleState {
 55    pub project: Entity<Project>,
 56    pub buffer: Entity<Buffer>,
 57    pub cursor_position: Anchor,
 58    pub _open_buffers: OpenedBuffers,
 59}
 60
 61#[derive(Clone, Debug, Serialize, Deserialize)]
 62pub struct ExampleContext {
 63    pub files: Arc<[RelatedFile]>,
 64}
 65
 66#[derive(Clone, Debug, Serialize, Deserialize)]
 67pub struct ExampleBuffer {
 68    pub content: String,
 69    pub cursor_row: u32,
 70    pub cursor_column: u32,
 71    pub cursor_offset: usize,
 72}
 73
 74#[derive(Clone, Debug, Serialize, Deserialize)]
 75pub struct ExamplePrompt {
 76    pub input: String,
 77    pub expected_output: String,
 78    pub format: PromptFormat,
 79}
 80
 81#[derive(Clone, Debug, Serialize, Deserialize)]
 82pub struct ExamplePrediction {
 83    pub actual_patch: String,
 84    pub actual_output: String,
 85    pub provider: PredictionProvider,
 86}
 87
 88#[derive(Clone, Debug, Serialize, Deserialize)]
 89pub struct ExampleScore {
 90    pub delta_chr_f: f32,
 91}
 92
 93impl Example {
 94    pub fn repo_name(&self) -> Result<RepoName<'_>> {
 95        // git@github.com:owner/repo.git
 96        if self.spec.repository_url.contains('@') {
 97            let (owner, repo) = self
 98                .spec
 99                .repository_url
100                .split_once(':')
101                .context("expected : in git url")?
102                .1
103                .split_once('/')
104                .context("expected / in git url")?;
105            Ok(RepoName {
106                owner: Cow::Borrowed(owner),
107                name: Cow::Borrowed(repo.trim_end_matches(".git")),
108            })
109        // http://github.com/owner/repo.git
110        } else {
111            let url = Url::parse(&self.spec.repository_url)?;
112            let mut segments = url.path_segments().context("empty http url")?;
113            let owner = segments
114                .next()
115                .context("expected owner path segment")?
116                .to_string();
117            let repo = segments
118                .next()
119                .context("expected repo path segment")?
120                .trim_end_matches(".git")
121                .to_string();
122            assert!(segments.next().is_none());
123
124            Ok(RepoName {
125                owner: Cow::Owned(owner),
126                name: Cow::Owned(repo),
127            })
128        }
129    }
130}
131
132pub struct RepoName<'a> {
133    pub owner: Cow<'a, str>,
134    pub name: Cow<'a, str>,
135}
136
137impl RepoName<'_> {
138    pub fn worktree_path(&self) -> PathBuf {
139        WORKTREES_DIR
140            .join(self.owner.as_ref())
141            .join(self.name.as_ref())
142    }
143}
144
145pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
146    let mut examples = Vec::new();
147
148    for path in inputs {
149        let is_stdin = path.as_path() == Path::new("-");
150        let content = if is_stdin {
151            let mut buffer = String::new();
152            std::io::stdin()
153                .read_to_string(&mut buffer)
154                .expect("Failed to read from stdin");
155            buffer
156        } else {
157            std::fs::read_to_string(path)
158                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
159        };
160        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
161        let ext = if !is_stdin {
162            path.extension()
163                .map(|ext| ext.to_string_lossy().to_string())
164                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
165        } else {
166            "jsonl".to_string()
167        };
168
169        match ext.as_ref() {
170            "json" => {
171                let mut example =
172                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
173                        panic!("Failed to parse example file: {}\n{error}", path.display())
174                    });
175                if example.spec.name.is_empty() {
176                    example.spec.name = filename;
177                }
178                examples.push(example);
179            }
180            "jsonl" => examples.extend(
181                content
182                    .lines()
183                    .enumerate()
184                    .map(|(line_ix, line)| {
185                        let mut example =
186                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
187                                panic!(
188                                    "Failed to parse example on {}:{}\n{error}",
189                                    path.display(),
190                                    line_ix + 1
191                                )
192                            });
193                        if example.spec.name.is_empty() {
194                            example.spec.name = format!("{filename}-{line_ix}")
195                        }
196                        example
197                    })
198                    .collect::<Vec<Example>>(),
199            ),
200            "md" => {
201                let mut example = parse_markdown_example(&content).unwrap();
202                if example.spec.name.is_empty() {
203                    example.spec.name = filename;
204                }
205                examples.push(example);
206            }
207            ext => {
208                panic!("{} has invalid example extension `{ext}`", path.display())
209            }
210        }
211    }
212
213    examples
214}
215
216pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
217    let mut content = String::new();
218    for example in examples {
219        let line = serde_json::to_string(example).unwrap();
220        content.push_str(&line);
221        content.push('\n');
222    }
223    if let Some(output_path) = output_path {
224        std::fs::write(output_path, content).expect("Failed to write examples");
225    } else {
226        std::io::stdout().write_all(&content.as_bytes()).unwrap();
227    }
228}
229
230pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
231    examples.sort_by(|a, b| {
232        a.spec
233            .repository_url
234            .cmp(&b.spec.repository_url)
235            .then(b.spec.revision.cmp(&a.spec.revision))
236    });
237}
238
239pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
240    let mut examples_by_repo = HashMap::default();
241    for example in examples.iter_mut() {
242        examples_by_repo
243            .entry(example.spec.repository_url.clone())
244            .or_insert_with(Vec::new)
245            .push(example);
246    }
247    examples_by_repo.into_values().collect()
248}
249
250fn parse_markdown_example(input: &str) -> Result<Example> {
251    let spec = ExampleSpec::from_markdown(input)?;
252    Ok(Example {
253        spec,
254        buffer: None,
255        context: None,
256        prompt: None,
257        predictions: Vec::new(),
258        score: Vec::new(),
259        state: None,
260    })
261}