1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use crate::qa::QaResult;
4use anyhow::{Context as _, Result};
5use collections::HashMap;
6use edit_prediction::example_spec::ExampleSpec;
7use edit_prediction::udiff::OpenedBuffers;
8use gpui::Entity;
9use http_client::Url;
10use language::{Anchor, Buffer};
11use project::Project;
12use serde::{Deserialize, Serialize};
13use std::{
14 borrow::Cow,
15 collections::VecDeque,
16 io::Read,
17 path::{Path, PathBuf},
18 sync::Arc,
19};
20use zeta_prompt::RelatedFile;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct Example {
24 #[serde(flatten)]
25 pub spec: ExampleSpec,
26
27 /// The full content of the file where an edit is being predicted, and the
28 /// actual cursor offset.
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub prompt_inputs: Option<ExamplePromptInputs>,
31
32 /// The input and expected output from the edit prediction model.
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub prompt: Option<ExamplePrompt>,
35
36 /// The actual predictions from the model.
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
38 pub predictions: Vec<ExamplePrediction>,
39
40 /// The scores, for how well the actual predictions match the expected
41 /// predictions.
42 #[serde(default, skip_serializing_if = "Vec::is_empty")]
43 pub score: Vec<ExampleScore>,
44
45 /// QA evaluation results for each prediction (indexed parallel to `predictions`).
46 #[serde(default, skip_serializing_if = "Vec::is_empty")]
47 pub qa: Vec<Option<QaResult>>,
48
49 /// The application state used to process this example.
50 #[serde(skip)]
51 pub state: Option<ExampleState>,
52}
53
54#[derive(Clone, Debug)]
55pub struct ExampleState {
56 pub project: Entity<Project>,
57 pub buffer: Entity<Buffer>,
58 pub cursor_position: Anchor,
59 pub _open_buffers: OpenedBuffers,
60}
61
62#[derive(Clone, Debug, Serialize, Deserialize)]
63pub struct ExamplePromptInputs {
64 pub content: String,
65 pub cursor_row: u32,
66 pub cursor_column: u32,
67 pub cursor_offset: usize,
68 #[serde(default, skip_serializing_if = "Option::is_none")]
69 pub excerpt_start_row: Option<u32>,
70 pub edit_history: Vec<Arc<zeta_prompt::Event>>,
71 pub related_files: Option<Vec<RelatedFile>>,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct ExamplePrompt {
76 pub input: String,
77 pub expected_output: String,
78 pub rejected_output: Option<String>, // For DPO
79 pub provider: PredictionProvider,
80}
81
82#[derive(Clone, Debug, Serialize, Deserialize)]
83pub struct ExamplePrediction {
84 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub actual_patch: Option<String>,
86 #[serde(deserialize_with = "deserialize_null_as_empty_string")]
87 pub actual_output: String,
88 #[serde(default, skip_serializing_if = "Option::is_none")]
89 pub actual_cursor_offset: Option<usize>,
90 #[serde(default, skip_serializing_if = "Option::is_none")]
91 pub error: Option<String>,
92 pub provider: PredictionProvider,
93}
94
95fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
96where
97 D: serde::Deserializer<'de>,
98{
99 let opt = Option::<String>::deserialize(deserializer)?;
100 Ok(opt.unwrap_or_default())
101}
102
103#[derive(Clone, Debug, Serialize, Deserialize)]
104pub struct ExampleScore {
105 pub delta_chr_f: f32,
106 pub braces_disbalance: usize,
107 #[serde(default)]
108 pub exact_lines_tp: usize,
109 #[serde(default)]
110 pub exact_lines_fp: usize,
111 #[serde(default)]
112 pub exact_lines_fn: usize,
113 #[serde(default)]
114 pub reversal_ratio: f32,
115 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub cursor_distance: Option<usize>,
117 #[serde(default, skip_serializing_if = "Option::is_none")]
118 pub cursor_exact_match: Option<bool>,
119 pub wrong_editable_region: Option<bool>,
120 #[serde(default)]
121 pub has_isolated_whitespace_changes: bool,
122}
123
124impl Example {
125 pub fn repo_name(&self) -> Result<RepoName<'_>> {
126 // git@github.com:owner/repo.git
127 if self.spec.repository_url.contains('@') {
128 let (owner, repo) = self
129 .spec
130 .repository_url
131 .split_once(':')
132 .context("expected : in git url")?
133 .1
134 .split_once('/')
135 .context("expected / in git url")?;
136 Ok(RepoName {
137 owner: Cow::Borrowed(owner),
138 name: Cow::Borrowed(repo.trim_end_matches(".git")),
139 })
140 // http://github.com/owner/repo.git
141 } else {
142 let url = Url::parse(&self.spec.repository_url)?;
143 let mut segments = url.path_segments().context("empty http url")?;
144 let owner = segments
145 .next()
146 .context("expected owner path segment")?
147 .to_string();
148 let repo = segments
149 .next()
150 .context("expected repo path segment")?
151 .trim_end_matches(".git")
152 .to_string();
153 assert!(segments.next().is_none());
154
155 Ok(RepoName {
156 owner: Cow::Owned(owner),
157 name: Cow::Owned(repo),
158 })
159 }
160 }
161}
162
163pub struct RepoName<'a> {
164 pub owner: Cow<'a, str>,
165 pub name: Cow<'a, str>,
166}
167
168impl RepoName<'_> {
169 pub fn worktree_path(&self) -> PathBuf {
170 WORKTREES_DIR
171 .join(self.owner.as_ref())
172 .join(self.name.as_ref())
173 }
174}
175
176pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
177 let mut examples = Vec::new();
178
179 for path in inputs {
180 let is_stdin = path.as_path() == Path::new("-");
181 let content = if is_stdin {
182 let mut buffer = String::new();
183 std::io::stdin()
184 .read_to_string(&mut buffer)
185 .expect("Failed to read from stdin");
186 buffer
187 } else {
188 std::fs::read_to_string(path)
189 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
190 };
191 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
192 let ext = if !is_stdin {
193 path.extension()
194 .map(|ext| ext.to_string_lossy().to_string())
195 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
196 } else {
197 "jsonl".to_string()
198 };
199
200 match ext.as_ref() {
201 "json" => {
202 let mut example =
203 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
204 panic!("Failed to parse example file: {}\n{error}", path.display())
205 });
206 if example.spec.name.is_empty() {
207 example.spec.name = filename;
208 }
209 examples.push(example);
210 }
211 "jsonl" => examples.extend(
212 content
213 .lines()
214 .enumerate()
215 .map(|(line_ix, line)| {
216 let mut example =
217 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
218 panic!(
219 "Failed to parse example on {}:{}\n{error}",
220 path.display(),
221 line_ix + 1
222 )
223 });
224 if example.spec.name.is_empty() {
225 example.spec.name = format!("{filename}-{line_ix}")
226 }
227 example
228 })
229 .collect::<Vec<Example>>(),
230 ),
231 "md" => {
232 let mut example = parse_markdown_example(&content).unwrap();
233 if example.spec.name.is_empty() {
234 example.spec.name = filename;
235 }
236 examples.push(example);
237 }
238 ext => {
239 panic!("{} has invalid example extension `{ext}`", path.display())
240 }
241 }
242 }
243
244 examples
245}
246
247pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
248 examples.sort_by(|a, b| {
249 a.spec
250 .repository_url
251 .cmp(&b.spec.repository_url)
252 .then(b.spec.revision.cmp(&a.spec.revision))
253 });
254}
255
256pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
257 let mut examples_by_repo: HashMap<String, Vec<Example>> = HashMap::default();
258 let mut ungrouped = Vec::new();
259 for example in examples {
260 if example.spec.repository_url.is_empty() {
261 ungrouped.push(example);
262 } else {
263 examples_by_repo
264 .entry(example.spec.repository_url.clone())
265 .or_insert_with(Vec::new)
266 .push(example);
267 }
268 }
269 let mut result: VecDeque<Vec<Example>> = examples_by_repo.into_values().collect();
270 for example in ungrouped {
271 result.push_back(vec![example]);
272 }
273 result
274}
275
276fn parse_markdown_example(input: &str) -> Result<Example> {
277 let spec = ExampleSpec::from_markdown(input)?;
278 Ok(Example {
279 spec,
280 prompt_inputs: None,
281 prompt: None,
282 predictions: Vec::new(),
283 score: Vec::new(),
284 qa: Vec::new(),
285 state: None,
286 })
287}