1use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use edit_prediction::example_spec::ExampleSpec;
5use edit_prediction::udiff::OpenedBuffers;
6use gpui::Entity;
7use http_client::Url;
8use language::{Anchor, Buffer};
9use project::Project;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use std::{
13 borrow::Cow,
14 io::{Read, Write},
15 path::{Path, PathBuf},
16};
17use zeta_prompt::RelatedFile;
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct Example {
21 #[serde(flatten)]
22 pub spec: ExampleSpec,
23
24 /// The full content of the file where an edit is being predicted, and the
25 /// actual cursor offset.
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub buffer: Option<ExampleBuffer>,
28
29 /// The context retrieved for the prediction. This requires the worktree to
30 /// be loaded and the language server to be started.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub context: Option<ExampleContext>,
33
34 /// The input and expected output from the edit prediction model.
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub prompt: Option<ExamplePrompt>,
37
38 /// The actual predictions from the model.
39 #[serde(default, skip_serializing_if = "Vec::is_empty")]
40 pub predictions: Vec<ExamplePrediction>,
41
42 /// The scores, for how well the actual predictions match the expected
43 /// predictions.
44 #[serde(default, skip_serializing_if = "Vec::is_empty")]
45 pub score: Vec<ExampleScore>,
46
47 /// The application state used to process this example.
48 #[serde(skip)]
49 pub state: Option<ExampleState>,
50}
51
52#[derive(Clone, Debug)]
53pub struct ExampleState {
54 pub project: Entity<Project>,
55 pub buffer: Entity<Buffer>,
56 pub cursor_position: Anchor,
57 pub _open_buffers: OpenedBuffers,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct ExampleContext {
62 pub files: Arc<[RelatedFile]>,
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ExampleBuffer {
67 pub content: String,
68 pub cursor_row: u32,
69 pub cursor_column: u32,
70 pub cursor_offset: usize,
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
74pub struct ExamplePrompt {
75 pub input: String,
76 pub expected_output: String,
77 pub format: PromptFormat,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct ExamplePrediction {
82 pub actual_patch: String,
83 pub actual_output: String,
84 pub provider: PredictionProvider,
85}
86
87#[derive(Clone, Debug, Serialize, Deserialize)]
88pub struct ExampleScore {
89 pub delta_chr_f: f32,
90 pub line_match: ClassificationMetrics,
91}
92
93impl Example {
94 pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
95 // git@github.com:owner/repo.git
96 if self.spec.repository_url.contains('@') {
97 let (owner, repo) = self
98 .spec
99 .repository_url
100 .split_once(':')
101 .context("expected : in git url")?
102 .1
103 .split_once('/')
104 .context("expected / in git url")?;
105 Ok((
106 Cow::Borrowed(owner),
107 Cow::Borrowed(repo.trim_end_matches(".git")),
108 ))
109 // http://github.com/owner/repo.git
110 } else {
111 let url = Url::parse(&self.spec.repository_url)?;
112 let mut segments = url.path_segments().context("empty http url")?;
113 let owner = segments
114 .next()
115 .context("expected owner path segment")?
116 .to_string();
117 let repo = segments
118 .next()
119 .context("expected repo path segment")?
120 .trim_end_matches(".git")
121 .to_string();
122 assert!(segments.next().is_none());
123
124 Ok((owner.into(), repo.into()))
125 }
126 }
127}
128
129pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
130 let mut examples = Vec::new();
131
132 let stdin_path: PathBuf = PathBuf::from("-");
133
134 let inputs = if inputs.is_empty() {
135 &[stdin_path]
136 } else {
137 inputs
138 };
139
140 for path in inputs {
141 let is_stdin = path.as_path() == Path::new("-");
142 let content = if is_stdin {
143 let mut buffer = String::new();
144 std::io::stdin()
145 .read_to_string(&mut buffer)
146 .expect("Failed to read from stdin");
147 buffer
148 } else {
149 std::fs::read_to_string(path)
150 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
151 };
152 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
153 let ext = if !is_stdin {
154 path.extension()
155 .map(|ext| ext.to_string_lossy().to_string())
156 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
157 } else {
158 "jsonl".to_string()
159 };
160
161 match ext.as_ref() {
162 "json" => {
163 let mut example =
164 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
165 panic!("Failed to parse example file: {}\n{error}", path.display())
166 });
167 if example.spec.name.is_empty() {
168 example.spec.name = filename;
169 }
170 examples.push(example);
171 }
172 "jsonl" => examples.extend(
173 content
174 .lines()
175 .enumerate()
176 .map(|(line_ix, line)| {
177 let mut example =
178 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
179 panic!(
180 "Failed to parse example on {}:{}\n{error}",
181 path.display(),
182 line_ix + 1
183 )
184 });
185 if example.spec.name.is_empty() {
186 example.spec.name = format!("{filename}-{line_ix}")
187 }
188 example
189 })
190 .collect::<Vec<Example>>(),
191 ),
192 "md" => {
193 examples.push(parse_markdown_example(filename, &content).unwrap());
194 }
195 ext => {
196 panic!("{} has invalid example extension `{ext}`", path.display())
197 }
198 }
199 }
200
201 sort_examples_by_repo_and_rev(&mut examples);
202 examples
203}
204
205pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
206 let mut content = String::new();
207 for example in examples {
208 let line = serde_json::to_string(example).unwrap();
209 content.push_str(&line);
210 content.push('\n');
211 }
212 if let Some(output_path) = output_path {
213 std::fs::write(output_path, content).expect("Failed to write examples");
214 } else {
215 std::io::stdout().write_all(&content.as_bytes()).unwrap();
216 }
217}
218
219pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
220 examples.sort_by(|a, b| {
221 a.spec
222 .repository_url
223 .cmp(&b.spec.repository_url)
224 .then(b.spec.revision.cmp(&a.spec.revision))
225 });
226}
227
228pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
229 let mut examples_by_repo = HashMap::default();
230 for example in examples.iter_mut() {
231 examples_by_repo
232 .entry(example.spec.repository_url.clone())
233 .or_insert_with(Vec::new)
234 .push(example);
235 }
236 examples_by_repo.into_values().collect()
237}
238
239fn parse_markdown_example(name: String, input: &str) -> Result<Example> {
240 let spec = ExampleSpec::from_markdown(name, input)?;
241 Ok(Example {
242 spec,
243 buffer: None,
244 context: None,
245 prompt: None,
246 predictions: Vec::new(),
247 score: Vec::new(),
248 state: None,
249 })
250}