example.rs

  1use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
  2use anyhow::{Context as _, Result};
  3use collections::HashMap;
  4use edit_prediction::udiff::OpenedBuffers;
  5use gpui::Entity;
  6use http_client::Url;
  7use language::{Anchor, Buffer};
  8use project::Project;
  9use serde::{Deserialize, Serialize};
 10use std::sync::Arc;
 11use std::{
 12    borrow::Cow,
 13    io::{Read, Write},
 14    mem,
 15    path::{Path, PathBuf},
 16};
 17use zeta_prompt::RelatedFile;
 18
 19#[derive(Clone, Debug, Serialize, Deserialize)]
 20pub struct Example {
 21    #[serde(default)]
 22    pub name: String,
 23    pub repository_url: String,
 24    pub revision: String,
 25    #[serde(default)]
 26    pub uncommitted_diff: String,
 27    pub cursor_path: Arc<Path>,
 28    pub cursor_position: String,
 29    pub edit_history: String,
 30    pub expected_patch: String,
 31
 32    /// The full content of the file where an edit is being predicted, and the
 33    /// actual cursor offset.
 34    #[serde(skip_serializing_if = "Option::is_none")]
 35    pub buffer: Option<ExampleBuffer>,
 36
 37    /// The context retrieved for the prediction. This requires the worktree to
 38    /// be loaded and the language server to be started.
 39    #[serde(skip_serializing_if = "Option::is_none")]
 40    pub context: Option<ExampleContext>,
 41
 42    /// The input and expected output from the edit prediction model.
 43    #[serde(skip_serializing_if = "Option::is_none")]
 44    pub prompt: Option<ExamplePrompt>,
 45
 46    /// The actual predictions from the model.
 47    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 48    pub predictions: Vec<ExamplePrediction>,
 49
 50    /// The scores, for how well the actual predictions match the expected
 51    /// predictions.
 52    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 53    pub score: Vec<ExampleScore>,
 54
 55    /// The application state used to process this example.
 56    #[serde(skip)]
 57    pub state: Option<ExampleState>,
 58}
 59
 60#[derive(Clone, Debug)]
 61pub struct ExampleState {
 62    pub project: Entity<Project>,
 63    pub buffer: Entity<Buffer>,
 64    pub cursor_position: Anchor,
 65    pub _open_buffers: OpenedBuffers,
 66}
 67
 68#[derive(Clone, Debug, Serialize, Deserialize)]
 69pub struct ExampleContext {
 70    pub files: Arc<[RelatedFile]>,
 71}
 72
 73#[derive(Clone, Debug, Serialize, Deserialize)]
 74pub struct ExampleBuffer {
 75    pub content: String,
 76    pub cursor_row: u32,
 77    pub cursor_column: u32,
 78    pub cursor_offset: usize,
 79}
 80
 81#[derive(Clone, Debug, Serialize, Deserialize)]
 82pub struct ExamplePrompt {
 83    pub input: String,
 84    pub expected_output: String,
 85    pub format: PromptFormat,
 86}
 87
 88#[derive(Clone, Debug, Serialize, Deserialize)]
 89pub struct ExamplePrediction {
 90    pub actual_patch: String,
 91    pub actual_output: String,
 92    pub provider: PredictionProvider,
 93}
 94
 95#[derive(Clone, Debug, Serialize, Deserialize)]
 96pub struct ExampleScore {
 97    pub delta_chr_f: f32,
 98    pub line_match: ClassificationMetrics,
 99}
100
101impl Example {
102    pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
103        // git@github.com:owner/repo.git
104        if self.repository_url.contains('@') {
105            let (owner, repo) = self
106                .repository_url
107                .split_once(':')
108                .context("expected : in git url")?
109                .1
110                .split_once('/')
111                .context("expected / in git url")?;
112            Ok((
113                Cow::Borrowed(owner),
114                Cow::Borrowed(repo.trim_end_matches(".git")),
115            ))
116        // http://github.com/owner/repo.git
117        } else {
118            let url = Url::parse(&self.repository_url)?;
119            let mut segments = url.path_segments().context("empty http url")?;
120            let owner = segments
121                .next()
122                .context("expected owner path segment")?
123                .to_string();
124            let repo = segments
125                .next()
126                .context("expected repo path segment")?
127                .trim_end_matches(".git")
128                .to_string();
129            assert!(segments.next().is_none());
130
131            Ok((owner.into(), repo.into()))
132        }
133    }
134}
135
136pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
137    let mut examples = Vec::new();
138
139    let stdin_path: PathBuf = PathBuf::from("-");
140
141    let inputs = if inputs.is_empty() {
142        &[stdin_path]
143    } else {
144        inputs
145    };
146
147    for path in inputs {
148        let is_stdin = path.as_path() == Path::new("-");
149        let content = if is_stdin {
150            let mut buffer = String::new();
151            std::io::stdin()
152                .read_to_string(&mut buffer)
153                .expect("Failed to read from stdin");
154            buffer
155        } else {
156            std::fs::read_to_string(path)
157                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
158        };
159        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
160        let ext = if !is_stdin {
161            path.extension()
162                .map(|ext| ext.to_string_lossy().to_string())
163                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
164        } else {
165            "jsonl".to_string()
166        };
167
168        match ext.as_ref() {
169            "json" => {
170                let mut example =
171                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
172                        panic!("Failed to parse example file: {}\n{error}", path.display())
173                    });
174                if example.name.is_empty() {
175                    example.name = filename;
176                }
177                examples.push(example);
178            }
179            "jsonl" => examples.extend(
180                content
181                    .lines()
182                    .enumerate()
183                    .map(|(line_ix, line)| {
184                        let mut example =
185                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
186                                panic!(
187                                    "Failed to parse example on {}:{}\n{error}",
188                                    path.display(),
189                                    line_ix + 1
190                                )
191                            });
192                        if example.name.is_empty() {
193                            example.name = format!("{filename}-{line_ix}")
194                        }
195                        example
196                    })
197                    .collect::<Vec<Example>>(),
198            ),
199            "md" => {
200                examples.push(parse_markdown_example(filename, &content).unwrap());
201            }
202            ext => {
203                panic!("{} has invalid example extension `{ext}`", path.display())
204            }
205        }
206    }
207
208    sort_examples_by_repo_and_rev(&mut examples);
209    examples
210}
211
212pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
213    let mut content = String::new();
214    for example in examples {
215        let line = serde_json::to_string(example).unwrap();
216        content.push_str(&line);
217        content.push('\n');
218    }
219    if let Some(output_path) = output_path {
220        std::fs::write(output_path, content).expect("Failed to write examples");
221    } else {
222        std::io::stdout().write_all(&content.as_bytes()).unwrap();
223    }
224}
225
226pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
227    examples.sort_by(|a, b| {
228        a.repository_url
229            .cmp(&b.repository_url)
230            .then(b.revision.cmp(&a.revision))
231    });
232}
233
234pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
235    let mut examples_by_repo = HashMap::default();
236    for example in examples.iter_mut() {
237        examples_by_repo
238            .entry(example.repository_url.clone())
239            .or_insert_with(Vec::new)
240            .push(example);
241    }
242    examples_by_repo.into_values().collect()
243}
244
245fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
246    use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
247
248    const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
249    const EDIT_HISTORY_HEADING: &str = "Edit History";
250    const CURSOR_POSITION_HEADING: &str = "Cursor Position";
251    const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
252    const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
253    const REPOSITORY_URL_FIELD: &str = "repository_url";
254    const REVISION_FIELD: &str = "revision";
255
256    let parser = Parser::new(input);
257
258    let mut example = Example {
259        name: id,
260        repository_url: String::new(),
261        revision: String::new(),
262        uncommitted_diff: String::new(),
263        cursor_path: PathBuf::new().into(),
264        cursor_position: String::new(),
265        edit_history: String::new(),
266        expected_patch: String::new(),
267        buffer: None,
268        context: None,
269        prompt: None,
270        predictions: Vec::new(),
271        score: Vec::new(),
272        state: None,
273    };
274
275    let mut text = String::new();
276    let mut block_info: CowStr = "".into();
277
278    #[derive(PartialEq)]
279    enum Section {
280        Start,
281        UncommittedDiff,
282        EditHistory,
283        CursorPosition,
284        ExpectedExcerpts,
285        ExpectedPatch,
286        Other,
287    }
288
289    let mut current_section = Section::Start;
290
291    for event in parser {
292        match event {
293            Event::Text(line) => {
294                text.push_str(&line);
295
296                if let Section::Start = current_section
297                    && let Some((field, value)) = line.split_once('=')
298                {
299                    match field.trim() {
300                        REPOSITORY_URL_FIELD => {
301                            example.repository_url = value.trim().to_string();
302                        }
303                        REVISION_FIELD => {
304                            example.revision = value.trim().to_string();
305                        }
306                        _ => {}
307                    }
308                }
309            }
310            Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
311                let title = mem::take(&mut text);
312                current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
313                    Section::UncommittedDiff
314                } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
315                    Section::EditHistory
316                } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
317                    Section::CursorPosition
318                } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
319                    Section::ExpectedPatch
320                } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
321                    Section::ExpectedExcerpts
322                } else {
323                    Section::Other
324                };
325            }
326            Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
327                mem::take(&mut text);
328            }
329            Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
330                mem::take(&mut text);
331            }
332            Event::End(TagEnd::Heading(level)) => {
333                anyhow::bail!("Unexpected heading level: {level}");
334            }
335            Event::Start(Tag::CodeBlock(kind)) => {
336                match kind {
337                    CodeBlockKind::Fenced(info) => {
338                        block_info = info;
339                    }
340                    CodeBlockKind::Indented => {
341                        anyhow::bail!("Unexpected indented codeblock");
342                    }
343                };
344            }
345            Event::Start(_) => {
346                text.clear();
347                block_info = "".into();
348            }
349            Event::End(TagEnd::CodeBlock) => {
350                let block_info = block_info.trim();
351                match current_section {
352                    Section::UncommittedDiff => {
353                        example.uncommitted_diff = mem::take(&mut text);
354                    }
355                    Section::EditHistory => {
356                        example.edit_history.push_str(&mem::take(&mut text));
357                    }
358                    Section::CursorPosition => {
359                        example.cursor_path = Path::new(block_info).into();
360                        example.cursor_position = mem::take(&mut text);
361                    }
362                    Section::ExpectedExcerpts => {
363                        mem::take(&mut text);
364                    }
365                    Section::ExpectedPatch => {
366                        example.expected_patch = mem::take(&mut text);
367                    }
368                    Section::Start | Section::Other => {}
369                }
370            }
371            _ => {}
372        }
373    }
374    if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
375        anyhow::bail!("Missing cursor position codeblock");
376    }
377
378    Ok(example)
379}