example.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    metrics::ClassificationMetrics,
  4    paths::{REPOS_DIR, WORKTREES_DIR},
  5};
  6use anyhow::{Context as _, Result};
  7use edit_prediction::udiff::OpenedBuffers;
  8use gpui::Entity;
  9use http_client::Url;
 10use language::{Anchor, Buffer};
 11use project::Project;
 12use serde::{Deserialize, Serialize};
 13use std::sync::Arc;
 14use std::{
 15    borrow::Cow,
 16    io::{Read, Write},
 17    mem,
 18    path::{Path, PathBuf},
 19};
 20use zeta_prompt::RelatedFile;
 21
 22#[derive(Clone, Debug, Serialize, Deserialize)]
 23pub struct Example {
 24    #[serde(default)]
 25    pub name: String,
 26    pub repository_url: String,
 27    pub revision: String,
 28    pub uncommitted_diff: String,
 29    pub cursor_path: Arc<Path>,
 30    pub cursor_position: String,
 31    pub edit_history: String,
 32    pub expected_patch: String,
 33
 34    /// The full content of the file where an edit is being predicted, and the
 35    /// actual cursor offset.
 36    #[serde(skip_serializing_if = "Option::is_none")]
 37    pub buffer: Option<ExampleBuffer>,
 38
 39    /// The context retrieved for the prediction. This requires the worktree to
 40    /// be loaded and the language server to be started.
 41    #[serde(skip_serializing_if = "Option::is_none")]
 42    pub context: Option<ExampleContext>,
 43
 44    /// The input and expected output from the edit prediction model.
 45    #[serde(skip_serializing_if = "Option::is_none")]
 46    pub prompt: Option<ExamplePrompt>,
 47
 48    /// The actual predictions from the model.
 49    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 50    pub predictions: Vec<ExamplePrediction>,
 51
 52    /// The scores, for how well the actual predictions match the expected
 53    /// predictions.
 54    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 55    pub score: Vec<ExampleScore>,
 56
 57    /// The application state used to process this example.
 58    #[serde(skip)]
 59    pub state: Option<ExampleState>,
 60}
 61
 62#[derive(Clone, Debug)]
 63pub struct ExampleState {
 64    pub project: Entity<Project>,
 65    pub buffer: Entity<Buffer>,
 66    pub cursor_position: Anchor,
 67    pub _open_buffers: OpenedBuffers,
 68}
 69
 70#[derive(Clone, Debug, Serialize, Deserialize)]
 71pub struct ExampleContext {
 72    pub files: Arc<[RelatedFile]>,
 73}
 74
 75#[derive(Clone, Debug, Serialize, Deserialize)]
 76pub struct ExampleBuffer {
 77    pub content: String,
 78    pub cursor_row: u32,
 79    pub cursor_column: u32,
 80    pub cursor_offset: usize,
 81}
 82
 83#[derive(Clone, Debug, Serialize, Deserialize)]
 84pub struct ExamplePrompt {
 85    pub input: String,
 86    pub expected_output: String,
 87    pub format: PromptFormat,
 88}
 89
 90#[derive(Clone, Debug, Serialize, Deserialize)]
 91pub struct ExamplePrediction {
 92    pub actual_patch: String,
 93    pub actual_output: String,
 94    pub provider: PredictionProvider,
 95}
 96
 97#[derive(Clone, Debug, Serialize, Deserialize)]
 98pub struct ExampleScore {
 99    pub delta_chr_f: f32,
100    pub line_match: ClassificationMetrics,
101}
102
103impl Example {
104    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
105        // git@github.com:owner/repo.git
106        if self.repository_url.contains('@') {
107            let (owner, repo) = self
108                .repository_url
109                .split_once(':')
110                .context("expected : in git url")?
111                .1
112                .split_once('/')
113                .context("expected / in git url")?;
114            Ok((
115                Cow::Borrowed(owner),
116                Cow::Borrowed(repo.trim_end_matches(".git")),
117            ))
118        // http://github.com/owner/repo.git
119        } else {
120            let url = Url::parse(&self.repository_url)?;
121            let mut segments = url.path_segments().context("empty http url")?;
122            let owner = segments
123                .next()
124                .context("expected owner path segment")?
125                .to_string();
126            let repo = segments
127                .next()
128                .context("expected repo path segment")?
129                .trim_end_matches(".git")
130                .to_string();
131            assert!(segments.next().is_none());
132
133            Ok((owner.into(), repo.into()))
134        }
135    }
136
137    pub fn worktree_path(&self) -> PathBuf {
138        WORKTREES_DIR
139            .join(&self.name)
140            .join(self.repo_name().unwrap().1.as_ref())
141    }
142
143    pub fn repo_path(&self) -> PathBuf {
144        let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
145        REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
146    }
147}
148
149pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
150    let mut examples = Vec::new();
151
152    let stdin_path: PathBuf = PathBuf::from("-");
153
154    let inputs = if inputs.is_empty() {
155        &[stdin_path]
156    } else {
157        inputs
158    };
159
160    for path in inputs {
161        let is_stdin = path.as_path() == Path::new("-");
162        let content = if is_stdin {
163            let mut buffer = String::new();
164            std::io::stdin()
165                .read_to_string(&mut buffer)
166                .expect("Failed to read from stdin");
167            buffer
168        } else {
169            std::fs::read_to_string(path)
170                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
171        };
172        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
173        let ext = if !is_stdin {
174            path.extension()
175                .map(|ext| ext.to_string_lossy().to_string())
176                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
177        } else {
178            "jsonl".to_string()
179        };
180
181        match ext.as_ref() {
182            "json" => {
183                let mut example =
184                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
185                        panic!("Failed to parse example file: {}\n{error}", path.display())
186                    });
187                if example.name.is_empty() {
188                    example.name = filename;
189                }
190                examples.push(example);
191            }
192            "jsonl" => examples.extend(
193                content
194                    .lines()
195                    .enumerate()
196                    .map(|(line_ix, line)| {
197                        let mut example =
198                            serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
199                                panic!(
200                                    "Failed to parse example on {}:{}",
201                                    path.display(),
202                                    line_ix + 1
203                                )
204                            });
205                        if example.name.is_empty() {
206                            example.name = format!("{filename}-{line_ix}")
207                        }
208                        example
209                    })
210                    .collect::<Vec<Example>>(),
211            ),
212            "md" => {
213                examples.push(parse_markdown_example(filename, &content).unwrap());
214            }
215            ext => {
216                panic!("{} has invalid example extension `{ext}`", path.display())
217            }
218        }
219    }
220    examples
221}
222
223pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
224    let mut content = String::new();
225    for example in examples {
226        let line = serde_json::to_string(example).unwrap();
227        content.push_str(&line);
228        content.push('\n');
229    }
230    if let Some(output_path) = output_path {
231        std::fs::write(output_path, content).expect("Failed to write examples");
232    } else {
233        std::io::stdout().write_all(&content.as_bytes()).unwrap();
234    }
235}
236
237fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
238    use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
239
240    const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
241    const EDIT_HISTORY_HEADING: &str = "Edit History";
242    const CURSOR_POSITION_HEADING: &str = "Cursor Position";
243    const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
244    const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
245    const REPOSITORY_URL_FIELD: &str = "repository_url";
246    const REVISION_FIELD: &str = "revision";
247
248    let parser = Parser::new(input);
249
250    let mut example = Example {
251        name: id,
252        repository_url: String::new(),
253        revision: String::new(),
254        uncommitted_diff: String::new(),
255        cursor_path: PathBuf::new().into(),
256        cursor_position: String::new(),
257        edit_history: String::new(),
258        expected_patch: String::new(),
259        buffer: None,
260        context: None,
261        prompt: None,
262        predictions: Vec::new(),
263        score: Vec::new(),
264        state: None,
265    };
266
267    let mut name = String::new();
268    let mut text = String::new();
269    let mut block_info: CowStr = "".into();
270
271    #[derive(PartialEq)]
272    enum Section {
273        UncommittedDiff,
274        EditHistory,
275        CursorPosition,
276        ExpectedExcerpts,
277        ExpectedPatch,
278        Other,
279    }
280
281    let mut current_section = Section::Other;
282
283    for event in parser {
284        match event {
285            Event::Text(line) => {
286                text.push_str(&line);
287
288                if let Some((field, value)) = line.split_once('=') {
289                    match field.trim() {
290                        REPOSITORY_URL_FIELD => {
291                            example.repository_url = value.trim().to_string();
292                        }
293                        REVISION_FIELD => {
294                            example.revision = value.trim().to_string();
295                        }
296                        _ => {}
297                    }
298                }
299            }
300            Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
301                if !name.is_empty() {
302                    anyhow::bail!(
303                        "Found multiple H1 headings. There should only be one with the name of the example."
304                    );
305                }
306                name = mem::take(&mut text);
307            }
308            Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
309                let title = mem::take(&mut text);
310                current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
311                    Section::UncommittedDiff
312                } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
313                    Section::EditHistory
314                } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
315                    Section::CursorPosition
316                } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
317                    Section::ExpectedPatch
318                } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
319                    Section::ExpectedExcerpts
320                } else {
321                    Section::Other
322                };
323            }
324            Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
325                mem::take(&mut text);
326            }
327            Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
328                mem::take(&mut text);
329            }
330            Event::End(TagEnd::Heading(level)) => {
331                anyhow::bail!("Unexpected heading level: {level}");
332            }
333            Event::Start(Tag::CodeBlock(kind)) => {
334                match kind {
335                    CodeBlockKind::Fenced(info) => {
336                        block_info = info;
337                    }
338                    CodeBlockKind::Indented => {
339                        anyhow::bail!("Unexpected indented codeblock");
340                    }
341                };
342            }
343            Event::Start(_) => {
344                text.clear();
345                block_info = "".into();
346            }
347            Event::End(TagEnd::CodeBlock) => {
348                let block_info = block_info.trim();
349                match current_section {
350                    Section::UncommittedDiff => {
351                        example.uncommitted_diff = mem::take(&mut text);
352                    }
353                    Section::EditHistory => {
354                        example.edit_history.push_str(&mem::take(&mut text));
355                    }
356                    Section::CursorPosition => {
357                        example.cursor_path = Path::new(block_info).into();
358                        example.cursor_position = mem::take(&mut text);
359                    }
360                    Section::ExpectedExcerpts => {
361                        mem::take(&mut text);
362                    }
363                    Section::ExpectedPatch => {
364                        example.expected_patch = mem::take(&mut text);
365                    }
366                    Section::Other => {}
367                }
368            }
369            _ => {}
370        }
371    }
372    if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
373        anyhow::bail!("Missing cursor position codeblock");
374    }
375
376    Ok(example)
377}