example.rs

  1use crate::paths::WORKTREES_DIR;
  2use crate::{PredictionProvider, PromptFormat};
  3use anyhow::{Context as _, Result};
  4use collections::HashMap;
  5use edit_prediction::example_spec::ExampleSpec;
  6use edit_prediction::udiff::OpenedBuffers;
  7use gpui::Entity;
  8use http_client::Url;
  9use language::{Anchor, Buffer};
 10use project::Project;
 11use serde::{Deserialize, Serialize};
 12use std::ops::Range;
 13use std::sync::Arc;
 14use std::{
 15    borrow::Cow,
 16    io::{Read, Write},
 17    path::{Path, PathBuf},
 18};
 19use zeta_prompt::RelatedFile;
 20
 21#[derive(Clone, Debug, Serialize, Deserialize)]
 22pub struct Example {
 23    #[serde(flatten)]
 24    pub spec: ExampleSpec,
 25
 26    /// The full content of the file where an edit is being predicted, and the
 27    /// actual cursor offset.
 28    #[serde(skip_serializing_if = "Option::is_none")]
 29    pub buffer: Option<ExampleBuffer>,
 30
 31    /// The context retrieved for the prediction. This requires the worktree to
 32    /// be loaded and the language server to be started.
 33    #[serde(skip_serializing_if = "Option::is_none")]
 34    pub context: Option<ExampleContext>,
 35
 36    /// The input and expected output from the edit prediction model.
 37    #[serde(skip_serializing_if = "Option::is_none")]
 38    pub prompt: Option<ExamplePrompt>,
 39
 40    /// The actual predictions from the model.
 41    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 42    pub predictions: Vec<ExamplePrediction>,
 43
 44    /// The scores, for how well the actual predictions match the expected
 45    /// predictions.
 46    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 47    pub score: Vec<ExampleScore>,
 48
 49    /// The application state used to process this example.
 50    #[serde(skip)]
 51    pub state: Option<ExampleState>,
 52}
 53
 54#[derive(Clone, Debug)]
 55pub struct ExampleState {
 56    pub project: Entity<Project>,
 57    pub buffer: Entity<Buffer>,
 58    pub cursor_position: Anchor,
 59    pub _open_buffers: OpenedBuffers,
 60}
 61
 62#[derive(Clone, Debug, Serialize, Deserialize)]
 63pub struct ExampleContext {
 64    pub files: Arc<[RelatedFile]>,
 65}
 66
 67#[derive(Clone, Debug, Serialize, Deserialize)]
 68pub struct ExampleBuffer {
 69    pub content: String,
 70    pub cursor_row: u32,
 71    pub cursor_column: u32,
 72    pub cursor_offset: usize,
 73    pub context_range: Range<usize>,
 74    pub editable_range: Range<usize>,
 75}
 76
 77#[derive(Clone, Debug, Serialize, Deserialize)]
 78pub struct ExamplePrompt {
 79    pub input: String,
 80    pub expected_output: String,
 81    pub format: PromptFormat,
 82}
 83
 84#[derive(Clone, Debug, Serialize, Deserialize)]
 85pub struct ExamplePrediction {
 86    pub actual_patch: String,
 87    pub actual_output: String,
 88    pub provider: PredictionProvider,
 89}
 90
 91#[derive(Clone, Debug, Serialize, Deserialize)]
 92pub struct ExampleScore {
 93    pub delta_chr_f: f32,
 94}
 95
 96impl Example {
 97    pub fn repo_name(&self) -> Result<RepoName<'_>> {
 98        // git@github.com:owner/repo.git
 99        if self.spec.repository_url.contains('@') {
100            let (owner, repo) = self
101                .spec
102                .repository_url
103                .split_once(':')
104                .context("expected : in git url")?
105                .1
106                .split_once('/')
107                .context("expected / in git url")?;
108            Ok(RepoName {
109                owner: Cow::Borrowed(owner),
110                name: Cow::Borrowed(repo.trim_end_matches(".git")),
111            })
112        // http://github.com/owner/repo.git
113        } else {
114            let url = Url::parse(&self.spec.repository_url)?;
115            let mut segments = url.path_segments().context("empty http url")?;
116            let owner = segments
117                .next()
118                .context("expected owner path segment")?
119                .to_string();
120            let repo = segments
121                .next()
122                .context("expected repo path segment")?
123                .trim_end_matches(".git")
124                .to_string();
125            assert!(segments.next().is_none());
126
127            Ok(RepoName {
128                owner: Cow::Owned(owner),
129                name: Cow::Owned(repo),
130            })
131        }
132    }
133}
134
135pub struct RepoName<'a> {
136    pub owner: Cow<'a, str>,
137    pub name: Cow<'a, str>,
138}
139
140impl RepoName<'_> {
141    pub fn worktree_path(&self) -> PathBuf {
142        WORKTREES_DIR
143            .join(self.owner.as_ref())
144            .join(self.name.as_ref())
145    }
146}
147
148pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
149    let mut examples = Vec::new();
150
151    for path in inputs {
152        let is_stdin = path.as_path() == Path::new("-");
153        let content = if is_stdin {
154            let mut buffer = String::new();
155            std::io::stdin()
156                .read_to_string(&mut buffer)
157                .expect("Failed to read from stdin");
158            buffer
159        } else {
160            std::fs::read_to_string(path)
161                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
162        };
163        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
164        let ext = if !is_stdin {
165            path.extension()
166                .map(|ext| ext.to_string_lossy().to_string())
167                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
168        } else {
169            "jsonl".to_string()
170        };
171
172        match ext.as_ref() {
173            "json" => {
174                let mut example =
175                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
176                        panic!("Failed to parse example file: {}\n{error}", path.display())
177                    });
178                if example.spec.name.is_empty() {
179                    example.spec.name = filename;
180                }
181                examples.push(example);
182            }
183            "jsonl" => examples.extend(
184                content
185                    .lines()
186                    .enumerate()
187                    .map(|(line_ix, line)| {
188                        let mut example =
189                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
190                                panic!(
191                                    "Failed to parse example on {}:{}\n{error}",
192                                    path.display(),
193                                    line_ix + 1
194                                )
195                            });
196                        if example.spec.name.is_empty() {
197                            example.spec.name = format!("{filename}-{line_ix}")
198                        }
199                        example
200                    })
201                    .collect::<Vec<Example>>(),
202            ),
203            "md" => {
204                let mut example = parse_markdown_example(&content).unwrap();
205                if example.spec.name.is_empty() {
206                    example.spec.name = filename;
207                }
208                examples.push(example);
209            }
210            ext => {
211                panic!("{} has invalid example extension `{ext}`", path.display())
212            }
213        }
214    }
215
216    examples
217}
218
219pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
220    let mut content = String::new();
221    for example in examples {
222        let line = serde_json::to_string(example).unwrap();
223        content.push_str(&line);
224        content.push('\n');
225    }
226    if let Some(output_path) = output_path {
227        std::fs::write(output_path, content).expect("Failed to write examples");
228    } else {
229        std::io::stdout().write_all(&content.as_bytes()).unwrap();
230    }
231}
232
233pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
234    examples.sort_by(|a, b| {
235        a.spec
236            .repository_url
237            .cmp(&b.spec.repository_url)
238            .then(b.spec.revision.cmp(&a.spec.revision))
239    });
240}
241
242pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
243    let mut examples_by_repo = HashMap::default();
244    for example in examples.iter_mut() {
245        examples_by_repo
246            .entry(example.spec.repository_url.clone())
247            .or_insert_with(Vec::new)
248            .push(example);
249    }
250    examples_by_repo.into_values().collect()
251}
252
253fn parse_markdown_example(input: &str) -> Result<Example> {
254    let spec = ExampleSpec::from_markdown(input)?;
255    Ok(Example {
256        spec,
257        buffer: None,
258        context: None,
259        prompt: None,
260        predictions: Vec::new(),
261        score: Vec::new(),
262        state: None,
263    })
264}