example.rs

  1use crate::paths::WORKTREES_DIR;
  2use crate::{PredictionProvider, PromptFormat};
  3use anyhow::{Context as _, Result};
  4use collections::HashMap;
  5use edit_prediction::example_spec::ExampleSpec;
  6use edit_prediction::udiff::OpenedBuffers;
  7use gpui::Entity;
  8use http_client::Url;
  9use language::{Anchor, Buffer};
 10use project::Project;
 11use serde::{Deserialize, Serialize};
 12use std::ops::Range;
 13use std::{
 14    borrow::Cow,
 15    io::Read,
 16    path::{Path, PathBuf},
 17};
 18use zeta_prompt::RelatedFile;
 19
 20#[derive(Clone, Debug, Serialize, Deserialize)]
 21pub struct Example {
 22    #[serde(flatten)]
 23    pub spec: ExampleSpec,
 24
 25    /// The full content of the file where an edit is being predicted, and the
 26    /// actual cursor offset.
 27    #[serde(skip_serializing_if = "Option::is_none")]
 28    pub buffer: Option<ExampleBuffer>,
 29
 30    /// The context retrieved for the prediction. This requires the worktree to
 31    /// be loaded and the language server to be started.
 32    #[serde(skip_serializing_if = "Option::is_none")]
 33    pub context: Option<ExampleContext>,
 34
 35    /// The input and expected output from the edit prediction model.
 36    #[serde(skip_serializing_if = "Option::is_none")]
 37    pub prompt: Option<ExamplePrompt>,
 38
 39    /// The actual predictions from the model.
 40    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 41    pub predictions: Vec<ExamplePrediction>,
 42
 43    /// The scores, for how well the actual predictions match the expected
 44    /// predictions.
 45    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 46    pub score: Vec<ExampleScore>,
 47
 48    /// The application state used to process this example.
 49    #[serde(skip)]
 50    pub state: Option<ExampleState>,
 51}
 52
 53#[derive(Clone, Debug)]
 54pub struct ExampleState {
 55    pub project: Entity<Project>,
 56    pub buffer: Entity<Buffer>,
 57    pub cursor_position: Anchor,
 58    pub _open_buffers: OpenedBuffers,
 59}
 60
 61#[derive(Clone, Debug, Serialize, Deserialize)]
 62pub struct ExampleContext {
 63    pub files: Vec<RelatedFile>,
 64}
 65
 66#[derive(Clone, Debug, Serialize, Deserialize)]
 67pub struct ExampleBuffer {
 68    pub content: String,
 69    pub cursor_row: u32,
 70    pub cursor_column: u32,
 71    pub cursor_offset: usize,
 72    pub context_range: Range<usize>,
 73    pub editable_range: Range<usize>,
 74}
 75
 76#[derive(Clone, Debug, Serialize, Deserialize)]
 77pub struct ExamplePrompt {
 78    pub input: String,
 79    pub expected_output: String,
 80    pub format: PromptFormat,
 81}
 82
 83#[derive(Clone, Debug, Serialize, Deserialize)]
 84pub struct ExamplePrediction {
 85    pub actual_patch: String,
 86    pub actual_output: String,
 87    pub provider: PredictionProvider,
 88}
 89
 90#[derive(Clone, Debug, Serialize, Deserialize)]
 91pub struct ExampleScore {
 92    pub delta_chr_f: f32,
 93}
 94
 95impl Example {
 96    pub fn repo_name(&self) -> Result<RepoName<'_>> {
 97        // git@github.com:owner/repo.git
 98        if self.spec.repository_url.contains('@') {
 99            let (owner, repo) = self
100                .spec
101                .repository_url
102                .split_once(':')
103                .context("expected : in git url")?
104                .1
105                .split_once('/')
106                .context("expected / in git url")?;
107            Ok(RepoName {
108                owner: Cow::Borrowed(owner),
109                name: Cow::Borrowed(repo.trim_end_matches(".git")),
110            })
111        // http://github.com/owner/repo.git
112        } else {
113            let url = Url::parse(&self.spec.repository_url)?;
114            let mut segments = url.path_segments().context("empty http url")?;
115            let owner = segments
116                .next()
117                .context("expected owner path segment")?
118                .to_string();
119            let repo = segments
120                .next()
121                .context("expected repo path segment")?
122                .trim_end_matches(".git")
123                .to_string();
124            assert!(segments.next().is_none());
125
126            Ok(RepoName {
127                owner: Cow::Owned(owner),
128                name: Cow::Owned(repo),
129            })
130        }
131    }
132}
133
134pub struct RepoName<'a> {
135    pub owner: Cow<'a, str>,
136    pub name: Cow<'a, str>,
137}
138
139impl RepoName<'_> {
140    pub fn worktree_path(&self) -> PathBuf {
141        WORKTREES_DIR
142            .join(self.owner.as_ref())
143            .join(self.name.as_ref())
144    }
145}
146
147pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
148    let mut examples = Vec::new();
149
150    for path in inputs {
151        let is_stdin = path.as_path() == Path::new("-");
152        let content = if is_stdin {
153            let mut buffer = String::new();
154            std::io::stdin()
155                .read_to_string(&mut buffer)
156                .expect("Failed to read from stdin");
157            buffer
158        } else {
159            std::fs::read_to_string(path)
160                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
161        };
162        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
163        let ext = if !is_stdin {
164            path.extension()
165                .map(|ext| ext.to_string_lossy().to_string())
166                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
167        } else {
168            "jsonl".to_string()
169        };
170
171        match ext.as_ref() {
172            "json" => {
173                let mut example =
174                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
175                        panic!("Failed to parse example file: {}\n{error}", path.display())
176                    });
177                if example.spec.name.is_empty() {
178                    example.spec.name = filename;
179                }
180                examples.push(example);
181            }
182            "jsonl" => examples.extend(
183                content
184                    .lines()
185                    .enumerate()
186                    .map(|(line_ix, line)| {
187                        let mut example =
188                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
189                                panic!(
190                                    "Failed to parse example on {}:{}\n{error}",
191                                    path.display(),
192                                    line_ix + 1
193                                )
194                            });
195                        if example.spec.name.is_empty() {
196                            example.spec.name = format!("{filename}-{line_ix}")
197                        }
198                        example
199                    })
200                    .collect::<Vec<Example>>(),
201            ),
202            "md" => {
203                let mut example = parse_markdown_example(&content).unwrap();
204                if example.spec.name.is_empty() {
205                    example.spec.name = filename;
206                }
207                examples.push(example);
208            }
209            ext => {
210                panic!("{} has invalid example extension `{ext}`", path.display())
211            }
212        }
213    }
214
215    examples
216}
217
218pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
219    examples.sort_by(|a, b| {
220        a.spec
221            .repository_url
222            .cmp(&b.spec.repository_url)
223            .then(b.spec.revision.cmp(&a.spec.revision))
224    });
225}
226
227pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
228    let mut examples_by_repo = HashMap::default();
229    for example in examples.iter_mut() {
230        examples_by_repo
231            .entry(example.spec.repository_url.clone())
232            .or_insert_with(Vec::new)
233            .push(example);
234    }
235    examples_by_repo.into_values().collect()
236}
237
238fn parse_markdown_example(input: &str) -> Result<Example> {
239    let spec = ExampleSpec::from_markdown(input)?;
240    Ok(Example {
241        spec,
242        buffer: None,
243        context: None,
244        prompt: None,
245        predictions: Vec::new(),
246        score: Vec::new(),
247        state: None,
248    })
249}