example.rs

  1use crate::{PredictionProvider, PromptFormat};
  2use anyhow::{Context as _, Result};
  3use collections::HashMap;
  4use edit_prediction::example_spec::ExampleSpec;
  5use edit_prediction::udiff::OpenedBuffers;
  6use gpui::Entity;
  7use http_client::Url;
  8use language::{Anchor, Buffer};
  9use project::Project;
 10use serde::{Deserialize, Serialize};
 11use std::sync::Arc;
 12use std::{
 13    borrow::Cow,
 14    io::{Read, Write},
 15    path::{Path, PathBuf},
 16};
 17use zeta_prompt::RelatedFile;
 18
 19#[derive(Clone, Debug, Serialize, Deserialize)]
 20pub struct Example {
 21    #[serde(flatten)]
 22    pub spec: ExampleSpec,
 23
 24    /// The full content of the file where an edit is being predicted, and the
 25    /// actual cursor offset.
 26    #[serde(skip_serializing_if = "Option::is_none")]
 27    pub buffer: Option<ExampleBuffer>,
 28
 29    /// The context retrieved for the prediction. This requires the worktree to
 30    /// be loaded and the language server to be started.
 31    #[serde(skip_serializing_if = "Option::is_none")]
 32    pub context: Option<ExampleContext>,
 33
 34    /// The input and expected output from the edit prediction model.
 35    #[serde(skip_serializing_if = "Option::is_none")]
 36    pub prompt: Option<ExamplePrompt>,
 37
 38    /// The actual predictions from the model.
 39    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 40    pub predictions: Vec<ExamplePrediction>,
 41
 42    /// The scores, for how well the actual predictions match the expected
 43    /// predictions.
 44    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 45    pub score: Vec<ExampleScore>,
 46
 47    /// The application state used to process this example.
 48    #[serde(skip)]
 49    pub state: Option<ExampleState>,
 50}
 51
 52#[derive(Clone, Debug)]
 53pub struct ExampleState {
 54    pub project: Entity<Project>,
 55    pub buffer: Entity<Buffer>,
 56    pub cursor_position: Anchor,
 57    pub _open_buffers: OpenedBuffers,
 58}
 59
 60#[derive(Clone, Debug, Serialize, Deserialize)]
 61pub struct ExampleContext {
 62    pub files: Arc<[RelatedFile]>,
 63}
 64
 65#[derive(Clone, Debug, Serialize, Deserialize)]
 66pub struct ExampleBuffer {
 67    pub content: String,
 68    pub cursor_row: u32,
 69    pub cursor_column: u32,
 70    pub cursor_offset: usize,
 71}
 72
 73#[derive(Clone, Debug, Serialize, Deserialize)]
 74pub struct ExamplePrompt {
 75    pub input: String,
 76    pub expected_output: String,
 77    pub format: PromptFormat,
 78}
 79
 80#[derive(Clone, Debug, Serialize, Deserialize)]
 81pub struct ExamplePrediction {
 82    pub actual_patch: String,
 83    pub actual_output: String,
 84    pub provider: PredictionProvider,
 85}
 86
 87#[derive(Clone, Debug, Serialize, Deserialize)]
 88pub struct ExampleScore {
 89    pub delta_chr_f: f32,
 90}
 91
 92impl Example {
 93    pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
 94        // git@github.com:owner/repo.git
 95        if self.spec.repository_url.contains('@') {
 96            let (owner, repo) = self
 97                .spec
 98                .repository_url
 99                .split_once(':')
100                .context("expected : in git url")?
101                .1
102                .split_once('/')
103                .context("expected / in git url")?;
104            Ok((
105                Cow::Borrowed(owner),
106                Cow::Borrowed(repo.trim_end_matches(".git")),
107            ))
108        // http://github.com/owner/repo.git
109        } else {
110            let url = Url::parse(&self.spec.repository_url)?;
111            let mut segments = url.path_segments().context("empty http url")?;
112            let owner = segments
113                .next()
114                .context("expected owner path segment")?
115                .to_string();
116            let repo = segments
117                .next()
118                .context("expected repo path segment")?
119                .trim_end_matches(".git")
120                .to_string();
121            assert!(segments.next().is_none());
122
123            Ok((owner.into(), repo.into()))
124        }
125    }
126}
127
128pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
129    let mut examples = Vec::new();
130
131    let stdin_path: PathBuf = PathBuf::from("-");
132
133    let inputs = if inputs.is_empty() {
134        &[stdin_path]
135    } else {
136        inputs
137    };
138
139    for path in inputs {
140        let is_stdin = path.as_path() == Path::new("-");
141        let content = if is_stdin {
142            let mut buffer = String::new();
143            std::io::stdin()
144                .read_to_string(&mut buffer)
145                .expect("Failed to read from stdin");
146            buffer
147        } else {
148            std::fs::read_to_string(path)
149                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
150        };
151        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
152        let ext = if !is_stdin {
153            path.extension()
154                .map(|ext| ext.to_string_lossy().to_string())
155                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
156        } else {
157            "jsonl".to_string()
158        };
159
160        match ext.as_ref() {
161            "json" => {
162                let mut example =
163                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
164                        panic!("Failed to parse example file: {}\n{error}", path.display())
165                    });
166                if example.spec.name.is_empty() {
167                    example.spec.name = filename;
168                }
169                examples.push(example);
170            }
171            "jsonl" => examples.extend(
172                content
173                    .lines()
174                    .enumerate()
175                    .map(|(line_ix, line)| {
176                        let mut example =
177                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
178                                panic!(
179                                    "Failed to parse example on {}:{}\n{error}",
180                                    path.display(),
181                                    line_ix + 1
182                                )
183                            });
184                        if example.spec.name.is_empty() {
185                            example.spec.name = format!("{filename}-{line_ix}")
186                        }
187                        example
188                    })
189                    .collect::<Vec<Example>>(),
190            ),
191            "md" => {
192                let mut example = parse_markdown_example(&content).unwrap();
193                if example.spec.name.is_empty() {
194                    example.spec.name = filename;
195                }
196                examples.push(example);
197            }
198            ext => {
199                panic!("{} has invalid example extension `{ext}`", path.display())
200            }
201        }
202    }
203
204    sort_examples_by_repo_and_rev(&mut examples);
205    examples
206}
207
208pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
209    let mut content = String::new();
210    for example in examples {
211        let line = serde_json::to_string(example).unwrap();
212        content.push_str(&line);
213        content.push('\n');
214    }
215    if let Some(output_path) = output_path {
216        std::fs::write(output_path, content).expect("Failed to write examples");
217    } else {
218        std::io::stdout().write_all(&content.as_bytes()).unwrap();
219    }
220}
221
222pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
223    examples.sort_by(|a, b| {
224        a.spec
225            .repository_url
226            .cmp(&b.spec.repository_url)
227            .then(b.spec.revision.cmp(&a.spec.revision))
228    });
229}
230
231pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
232    let mut examples_by_repo = HashMap::default();
233    for example in examples.iter_mut() {
234        examples_by_repo
235            .entry(example.spec.repository_url.clone())
236            .or_insert_with(Vec::new)
237            .push(example);
238    }
239    examples_by_repo.into_values().collect()
240}
241
242fn parse_markdown_example(input: &str) -> Result<Example> {
243    let spec = ExampleSpec::from_markdown(input)?;
244    Ok(Example {
245        spec,
246        buffer: None,
247        context: None,
248        prompt: None,
249        predictions: Vec::new(),
250        score: Vec::new(),
251        state: None,
252    })
253}