1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use crate::qa::QaResult;
4use anyhow::{Context as _, Result};
5use collections::HashMap;
6use edit_prediction::example_spec::ExampleSpec;
7use edit_prediction::udiff::OpenedBuffers;
8use gpui::Entity;
9use http_client::Url;
10use language::{Anchor, Buffer};
11use project::Project;
12use serde::{Deserialize, Serialize};
13use std::{
14 borrow::Cow,
15 collections::VecDeque,
16 io::Read,
17 path::{Path, PathBuf},
18 sync::Arc,
19};
20use zeta_prompt::RelatedFile;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct Example {
24 #[serde(flatten)]
25 pub spec: ExampleSpec,
26
27 /// The full content of the file where an edit is being predicted, and the
28 /// actual cursor offset.
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub prompt_inputs: Option<ExamplePromptInputs>,
31
32 /// The input and expected output from the edit prediction model.
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub prompt: Option<ExamplePrompt>,
35
36 /// The actual predictions from the model.
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
38 pub predictions: Vec<ExamplePrediction>,
39
40 /// The scores, for how well the actual predictions match the expected
41 /// predictions.
42 #[serde(default, skip_serializing_if = "Vec::is_empty")]
43 pub score: Vec<ExampleScore>,
44
45 /// QA evaluation results for each prediction (indexed parallel to `predictions`).
46 #[serde(default, skip_serializing_if = "Vec::is_empty")]
47 pub qa: Vec<Option<QaResult>>,
48
49 /// The application state used to process this example.
50 #[serde(skip)]
51 pub state: Option<ExampleState>,
52}
53
54#[derive(Clone, Debug)]
55pub struct ExampleState {
56 pub project: Entity<Project>,
57 pub buffer: Entity<Buffer>,
58 pub cursor_position: Anchor,
59 pub _open_buffers: OpenedBuffers,
60}
61
62#[derive(Clone, Debug, Serialize, Deserialize)]
63pub struct ExamplePromptInputs {
64 pub content: String,
65 pub cursor_row: u32,
66 pub cursor_column: u32,
67 pub cursor_offset: usize,
68 #[serde(default, skip_serializing_if = "Option::is_none")]
69 pub excerpt_start_row: Option<u32>,
70 pub edit_history: Vec<Arc<zeta_prompt::Event>>,
71 pub related_files: Option<Vec<RelatedFile>>,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct ExamplePrompt {
76 pub input: String,
77 pub expected_output: String,
78 pub rejected_output: Option<String>, // For DPO
79 #[serde(default)]
80 pub prefill: Option<String>,
81 pub provider: PredictionProvider,
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize)]
85pub struct ExamplePrediction {
86 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub actual_patch: Option<String>,
88 #[serde(deserialize_with = "deserialize_null_as_empty_string")]
89 pub actual_output: String,
90 #[serde(default, skip_serializing_if = "Option::is_none")]
91 pub actual_cursor: Option<ActualCursor>,
92 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub error: Option<String>,
94 pub provider: PredictionProvider,
95}
96
97#[derive(Clone, Debug, Serialize, Deserialize)]
98pub struct ActualCursor {
99 pub path: String,
100 pub row: u32,
101 pub column: u32,
102 pub offset: usize,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub editable_region_offset: Option<usize>,
105}
106
107impl ActualCursor {
108 /// Construct an `ActualCursor` from a cursor offset within the new editable region.
109 ///
110 /// - `path`: file path the cursor is in
111 /// - `editable_region_cursor_offset`: byte offset of the cursor within the new editable region text
112 /// - `new_editable_region`: the full new editable region text (after marker removal)
113 /// - `content`: the full file content (before the edit)
114 /// - `editable_region_byte_offset`: byte offset where the editable region starts in `content`
115 /// - `editable_region_start_line`: 0-based line number where the editable region starts in `content`
116 pub fn from_editable_region(
117 path: &std::path::Path,
118 editable_region_cursor_offset: usize,
119 new_editable_region: &str,
120 content: &str,
121 editable_region_byte_offset: usize,
122 editable_region_start_line: usize,
123 ) -> Self {
124 let global_offset = editable_region_byte_offset + editable_region_cursor_offset;
125 let new_region_prefix = &new_editable_region[..editable_region_cursor_offset];
126 let row = (editable_region_start_line + new_region_prefix.matches('\n').count()) as u32;
127 let column = match new_region_prefix.rfind('\n') {
128 Some(pos) => (editable_region_cursor_offset - pos - 1) as u32,
129 None => {
130 let content_prefix = &content[..editable_region_byte_offset];
131 let content_column = match content_prefix.rfind('\n') {
132 Some(pos) => editable_region_byte_offset - pos - 1,
133 None => editable_region_byte_offset,
134 };
135 (content_column + editable_region_cursor_offset) as u32
136 }
137 };
138 ActualCursor {
139 path: path.to_string_lossy().to_string(),
140 row,
141 column,
142 offset: global_offset,
143 editable_region_offset: Some(editable_region_cursor_offset),
144 }
145 }
146}
147
148fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
149where
150 D: serde::Deserializer<'de>,
151{
152 let opt = Option::<String>::deserialize(deserializer)?;
153 Ok(opt.unwrap_or_default())
154}
155
156#[derive(Clone, Debug, Serialize, Deserialize)]
157pub struct ExampleScore {
158 pub delta_chr_f: f32,
159 pub braces_disbalance: usize,
160 #[serde(default)]
161 pub exact_lines_tp: usize,
162 #[serde(default)]
163 pub exact_lines_fp: usize,
164 #[serde(default)]
165 pub exact_lines_fn: usize,
166 #[serde(default)]
167 pub reversal_ratio: f32,
168 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub cursor_distance: Option<usize>,
170 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub cursor_exact_match: Option<bool>,
172 pub wrong_editable_region: Option<bool>,
173 #[serde(default)]
174 pub has_isolated_whitespace_changes: bool,
175 #[serde(default)]
176 pub inserted_tokens: usize,
177 #[serde(default)]
178 pub deleted_tokens: usize,
179}
180
181impl Example {
182 pub fn repo_name(&self) -> Result<RepoName<'_>> {
183 // git@github.com:owner/repo.git
184 if self.spec.repository_url.contains('@') {
185 let (owner, repo) = self
186 .spec
187 .repository_url
188 .split_once(':')
189 .context("expected : in git url")?
190 .1
191 .split_once('/')
192 .context("expected / in git url")?;
193 Ok(RepoName {
194 owner: Cow::Borrowed(owner),
195 name: Cow::Borrowed(repo.trim_end_matches(".git")),
196 })
197 // http://github.com/owner/repo.git
198 } else {
199 let url = Url::parse(&self.spec.repository_url)?;
200 let mut segments = url.path_segments().context("empty http url")?;
201 let owner = segments
202 .next()
203 .context("expected owner path segment")?
204 .to_string();
205 let repo = segments
206 .next()
207 .context("expected repo path segment")?
208 .trim_end_matches(".git")
209 .to_string();
210 assert!(segments.next().is_none());
211
212 Ok(RepoName {
213 owner: Cow::Owned(owner),
214 name: Cow::Owned(repo),
215 })
216 }
217 }
218}
219
220pub struct RepoName<'a> {
221 pub owner: Cow<'a, str>,
222 pub name: Cow<'a, str>,
223}
224
225impl RepoName<'_> {
226 pub fn worktree_path(&self) -> PathBuf {
227 WORKTREES_DIR
228 .join(self.owner.as_ref())
229 .join(self.name.as_ref())
230 }
231}
232
233pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
234 let mut examples = Vec::new();
235
236 for path in inputs {
237 let is_stdin = path.as_path() == Path::new("-");
238 let content = if is_stdin {
239 let mut buffer = String::new();
240 std::io::stdin()
241 .read_to_string(&mut buffer)
242 .expect("Failed to read from stdin");
243 buffer
244 } else {
245 std::fs::read_to_string(path)
246 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
247 };
248 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
249 let ext = if !is_stdin {
250 path.extension()
251 .map(|ext| ext.to_string_lossy().to_string())
252 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
253 } else {
254 "jsonl".to_string()
255 };
256
257 match ext.as_ref() {
258 "json" => {
259 let mut example =
260 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
261 panic!("Failed to parse example file: {}\n{error}", path.display())
262 });
263 if example.spec.name.is_empty() {
264 example.spec.name = filename;
265 }
266 examples.push(example);
267 }
268 "jsonl" => examples.extend(
269 content
270 .lines()
271 .enumerate()
272 .map(|(line_ix, line)| {
273 let mut example =
274 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
275 panic!(
276 "Failed to parse example on {}:{}\n{error}",
277 path.display(),
278 line_ix + 1
279 )
280 });
281 if example.spec.name.is_empty() {
282 example.spec.name = format!("{filename}-{line_ix}")
283 }
284 example
285 })
286 .collect::<Vec<Example>>(),
287 ),
288 "md" => {
289 let mut example = parse_markdown_example(&content).unwrap();
290 if example.spec.name.is_empty() {
291 example.spec.name = filename;
292 }
293 examples.push(example);
294 }
295 ext => {
296 panic!("{} has invalid example extension `{ext}`", path.display())
297 }
298 }
299 }
300
301 examples
302}
303
304pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
305 examples.sort_by(|a, b| {
306 a.spec
307 .repository_url
308 .cmp(&b.spec.repository_url)
309 .then(b.spec.revision.cmp(&a.spec.revision))
310 });
311}
312
313pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
314 let mut examples_by_repo: HashMap<String, Vec<Example>> = HashMap::default();
315 let mut ungrouped = Vec::new();
316 for example in examples {
317 if example.spec.repository_url.is_empty() {
318 ungrouped.push(example);
319 } else {
320 examples_by_repo
321 .entry(example.spec.repository_url.clone())
322 .or_insert_with(Vec::new)
323 .push(example);
324 }
325 }
326 let mut result: VecDeque<Vec<Example>> = examples_by_repo.into_values().collect();
327 for example in ungrouped {
328 result.push_back(vec![example]);
329 }
330 result
331}
332
333fn parse_markdown_example(input: &str) -> Result<Example> {
334 let spec = ExampleSpec::from_markdown(input)?;
335 Ok(Example {
336 spec,
337 prompt_inputs: None,
338 prompt: None,
339 predictions: Vec::new(),
340 score: Vec::new(),
341 qa: Vec::new(),
342 state: None,
343 })
344}