example.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    metrics::ClassificationMetrics,
  4    paths::{REPOS_DIR, WORKTREES_DIR},
  5};
  6use anyhow::{Context as _, Result};
  7use edit_prediction::udiff::OpenedBuffers;
  8use gpui::Entity;
  9use http_client::Url;
 10use language::{Anchor, Buffer};
 11use project::Project;
 12use serde::{Deserialize, Serialize};
 13use std::sync::Arc;
 14use std::{
 15    borrow::Cow,
 16    io::{Read, Write},
 17    mem,
 18    path::{Path, PathBuf},
 19};
 20use zeta_prompt::RelatedFile;
 21
 22#[derive(Clone, Debug, Serialize, Deserialize)]
 23pub struct Example {
 24    #[serde(default)]
 25    pub name: String,
 26    pub repository_url: String,
 27    pub revision: String,
 28    #[serde(default)]
 29    pub uncommitted_diff: String,
 30    pub cursor_path: Arc<Path>,
 31    pub cursor_position: String,
 32    pub edit_history: String,
 33    pub expected_patch: String,
 34
 35    /// The full content of the file where an edit is being predicted, and the
 36    /// actual cursor offset.
 37    #[serde(skip_serializing_if = "Option::is_none")]
 38    pub buffer: Option<ExampleBuffer>,
 39
 40    /// The context retrieved for the prediction. This requires the worktree to
 41    /// be loaded and the language server to be started.
 42    #[serde(skip_serializing_if = "Option::is_none")]
 43    pub context: Option<ExampleContext>,
 44
 45    /// The input and expected output from the edit prediction model.
 46    #[serde(skip_serializing_if = "Option::is_none")]
 47    pub prompt: Option<ExamplePrompt>,
 48
 49    /// The actual predictions from the model.
 50    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 51    pub predictions: Vec<ExamplePrediction>,
 52
 53    /// The scores, for how well the actual predictions match the expected
 54    /// predictions.
 55    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 56    pub score: Vec<ExampleScore>,
 57
 58    /// The application state used to process this example.
 59    #[serde(skip)]
 60    pub state: Option<ExampleState>,
 61}
 62
 63#[derive(Clone, Debug)]
 64pub struct ExampleState {
 65    pub project: Entity<Project>,
 66    pub buffer: Entity<Buffer>,
 67    pub cursor_position: Anchor,
 68    pub _open_buffers: OpenedBuffers,
 69}
 70
 71#[derive(Clone, Debug, Serialize, Deserialize)]
 72pub struct ExampleContext {
 73    pub files: Arc<[RelatedFile]>,
 74}
 75
 76#[derive(Clone, Debug, Serialize, Deserialize)]
 77pub struct ExampleBuffer {
 78    pub content: String,
 79    pub cursor_row: u32,
 80    pub cursor_column: u32,
 81    pub cursor_offset: usize,
 82}
 83
 84#[derive(Clone, Debug, Serialize, Deserialize)]
 85pub struct ExamplePrompt {
 86    pub input: String,
 87    pub expected_output: String,
 88    pub format: PromptFormat,
 89}
 90
 91#[derive(Clone, Debug, Serialize, Deserialize)]
 92pub struct ExamplePrediction {
 93    pub actual_patch: String,
 94    pub actual_output: String,
 95    pub provider: PredictionProvider,
 96}
 97
 98#[derive(Clone, Debug, Serialize, Deserialize)]
 99pub struct ExampleScore {
100    pub delta_chr_f: f32,
101    pub line_match: ClassificationMetrics,
102}
103
104impl Example {
105    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
106        // git@github.com:owner/repo.git
107        if self.repository_url.contains('@') {
108            let (owner, repo) = self
109                .repository_url
110                .split_once(':')
111                .context("expected : in git url")?
112                .1
113                .split_once('/')
114                .context("expected / in git url")?;
115            Ok((
116                Cow::Borrowed(owner),
117                Cow::Borrowed(repo.trim_end_matches(".git")),
118            ))
119        // http://github.com/owner/repo.git
120        } else {
121            let url = Url::parse(&self.repository_url)?;
122            let mut segments = url.path_segments().context("empty http url")?;
123            let owner = segments
124                .next()
125                .context("expected owner path segment")?
126                .to_string();
127            let repo = segments
128                .next()
129                .context("expected repo path segment")?
130                .trim_end_matches(".git")
131                .to_string();
132            assert!(segments.next().is_none());
133
134            Ok((owner.into(), repo.into()))
135        }
136    }
137
138    pub fn worktree_path(&self) -> PathBuf {
139        WORKTREES_DIR
140            .join(&self.name)
141            .join(self.repo_name().unwrap().1.as_ref())
142    }
143
144    pub fn repo_path(&self) -> PathBuf {
145        let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
146        REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
147    }
148}
149
150pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
151    let mut examples = Vec::new();
152
153    let stdin_path: PathBuf = PathBuf::from("-");
154
155    let inputs = if inputs.is_empty() {
156        &[stdin_path]
157    } else {
158        inputs
159    };
160
161    for path in inputs {
162        let is_stdin = path.as_path() == Path::new("-");
163        let content = if is_stdin {
164            let mut buffer = String::new();
165            std::io::stdin()
166                .read_to_string(&mut buffer)
167                .expect("Failed to read from stdin");
168            buffer
169        } else {
170            std::fs::read_to_string(path)
171                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
172        };
173        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
174        let ext = if !is_stdin {
175            path.extension()
176                .map(|ext| ext.to_string_lossy().to_string())
177                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
178        } else {
179            "jsonl".to_string()
180        };
181
182        match ext.as_ref() {
183            "json" => {
184                let mut example =
185                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
186                        panic!("Failed to parse example file: {}\n{error}", path.display())
187                    });
188                if example.name.is_empty() {
189                    example.name = filename;
190                }
191                examples.push(example);
192            }
193            "jsonl" => examples.extend(
194                content
195                    .lines()
196                    .enumerate()
197                    .map(|(line_ix, line)| {
198                        let mut example =
199                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
200                                panic!(
201                                    "Failed to parse example on {}:{}\n{error}",
202                                    path.display(),
203                                    line_ix + 1
204                                )
205                            });
206                        if example.name.is_empty() {
207                            example.name = format!("{filename}-{line_ix}")
208                        }
209                        example
210                    })
211                    .collect::<Vec<Example>>(),
212            ),
213            "md" => {
214                examples.push(parse_markdown_example(filename, &content).unwrap());
215            }
216            ext => {
217                panic!("{} has invalid example extension `{ext}`", path.display())
218            }
219        }
220    }
221    examples
222}
223
224pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
225    let mut content = String::new();
226    for example in examples {
227        let line = serde_json::to_string(example).unwrap();
228        content.push_str(&line);
229        content.push('\n');
230    }
231    if let Some(output_path) = output_path {
232        std::fs::write(output_path, content).expect("Failed to write examples");
233    } else {
234        std::io::stdout().write_all(&content.as_bytes()).unwrap();
235    }
236}
237
238fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
239    use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
240
241    const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
242    const EDIT_HISTORY_HEADING: &str = "Edit History";
243    const CURSOR_POSITION_HEADING: &str = "Cursor Position";
244    const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
245    const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
246    const REPOSITORY_URL_FIELD: &str = "repository_url";
247    const REVISION_FIELD: &str = "revision";
248
249    let parser = Parser::new(input);
250
251    let mut example = Example {
252        name: id,
253        repository_url: String::new(),
254        revision: String::new(),
255        uncommitted_diff: String::new(),
256        cursor_path: PathBuf::new().into(),
257        cursor_position: String::new(),
258        edit_history: String::new(),
259        expected_patch: String::new(),
260        buffer: None,
261        context: None,
262        prompt: None,
263        predictions: Vec::new(),
264        score: Vec::new(),
265        state: None,
266    };
267
268    let mut text = String::new();
269    let mut block_info: CowStr = "".into();
270
271    #[derive(PartialEq)]
272    enum Section {
273        Start,
274        UncommittedDiff,
275        EditHistory,
276        CursorPosition,
277        ExpectedExcerpts,
278        ExpectedPatch,
279        Other,
280    }
281
282    let mut current_section = Section::Start;
283
284    for event in parser {
285        match event {
286            Event::Text(line) => {
287                text.push_str(&line);
288
289                if let Section::Start = current_section
290                    && let Some((field, value)) = line.split_once('=')
291                {
292                    match field.trim() {
293                        REPOSITORY_URL_FIELD => {
294                            example.repository_url = value.trim().to_string();
295                        }
296                        REVISION_FIELD => {
297                            example.revision = value.trim().to_string();
298                        }
299                        _ => {}
300                    }
301                }
302            }
303            Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
304                let title = mem::take(&mut text);
305                current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
306                    Section::UncommittedDiff
307                } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
308                    Section::EditHistory
309                } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
310                    Section::CursorPosition
311                } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
312                    Section::ExpectedPatch
313                } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
314                    Section::ExpectedExcerpts
315                } else {
316                    Section::Other
317                };
318            }
319            Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
320                mem::take(&mut text);
321            }
322            Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
323                mem::take(&mut text);
324            }
325            Event::End(TagEnd::Heading(level)) => {
326                anyhow::bail!("Unexpected heading level: {level}");
327            }
328            Event::Start(Tag::CodeBlock(kind)) => {
329                match kind {
330                    CodeBlockKind::Fenced(info) => {
331                        block_info = info;
332                    }
333                    CodeBlockKind::Indented => {
334                        anyhow::bail!("Unexpected indented codeblock");
335                    }
336                };
337            }
338            Event::Start(_) => {
339                text.clear();
340                block_info = "".into();
341            }
342            Event::End(TagEnd::CodeBlock) => {
343                let block_info = block_info.trim();
344                match current_section {
345                    Section::UncommittedDiff => {
346                        example.uncommitted_diff = mem::take(&mut text);
347                    }
348                    Section::EditHistory => {
349                        example.edit_history.push_str(&mem::take(&mut text));
350                    }
351                    Section::CursorPosition => {
352                        example.cursor_path = Path::new(block_info).into();
353                        example.cursor_position = mem::take(&mut text);
354                    }
355                    Section::ExpectedExcerpts => {
356                        mem::take(&mut text);
357                    }
358                    Section::ExpectedPatch => {
359                        example.expected_patch = mem::take(&mut text);
360                    }
361                    Section::Start | Section::Other => {}
362                }
363            }
364            _ => {}
365        }
366    }
367    if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
368        anyhow::bail!("Missing cursor position codeblock");
369    }
370
371    Ok(example)
372}