1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use crate::qa::QaResult;
4use anyhow::{Context as _, Result};
5use collections::HashMap;
6use edit_prediction::example_spec::ExampleSpec;
7use edit_prediction::udiff::OpenedBuffers;
8use gpui::Entity;
9use http_client::Url;
10use language::{Anchor, Buffer};
11use project::Project;
12use serde::{Deserialize, Serialize};
13use std::{
14 borrow::Cow,
15 collections::VecDeque,
16 io::Read,
17 path::{Path, PathBuf},
18};
19use zeta_prompt::ZetaPromptInput;
20
21#[derive(Clone, Debug, Serialize, Deserialize)]
22pub struct Example {
23 #[serde(flatten)]
24 pub spec: ExampleSpec,
25
26 /// The full content of the file where an edit is being predicted, and the
27 /// actual cursor offset.
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub prompt_inputs: Option<ZetaPromptInput>,
30
31 /// The input and expected output from the edit prediction model.
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub prompt: Option<ExamplePrompt>,
34
35 /// The actual predictions from the model.
36 #[serde(default, skip_serializing_if = "Vec::is_empty")]
37 pub predictions: Vec<ExamplePrediction>,
38
39 /// The scores, for how well the actual predictions match the expected
40 /// predictions.
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub score: Vec<ExampleScore>,
43
44 /// QA evaluation results for each prediction (indexed parallel to `predictions`).
45 #[serde(default, skip_serializing_if = "Vec::is_empty")]
46 pub qa: Vec<Option<QaResult>>,
47
48 /// The Zed version used to generate this example.
49 pub zed_version: Option<String>,
50
51 /// The application state used to process this example.
52 #[serde(skip)]
53 pub state: Option<ExampleState>,
54}
55
56#[derive(Clone, Debug)]
57pub struct ExampleState {
58 pub project: Entity<Project>,
59 pub buffer: Entity<Buffer>,
60 pub cursor_position: Anchor,
61 pub _open_buffers: OpenedBuffers,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct ExamplePrompt {
66 pub input: String,
67 pub expected_output: String,
68 pub rejected_output: Option<String>, // For DPO
69 #[serde(default)]
70 pub prefill: Option<String>,
71 pub provider: PredictionProvider,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct ExamplePrediction {
76 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub actual_patch: Option<String>,
78 #[serde(deserialize_with = "deserialize_null_as_empty_string")]
79 pub actual_output: String,
80 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub actual_cursor: Option<ActualCursor>,
82 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub error: Option<String>,
84 pub provider: PredictionProvider,
85 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub cumulative_logprob: Option<f64>,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub avg_logprob: Option<f64>,
89}
90
91#[derive(Clone, Debug, Serialize, Deserialize)]
92pub struct ActualCursor {
93 pub path: String,
94 pub row: u32,
95 pub column: u32,
96 pub offset: usize,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub editable_region_offset: Option<usize>,
99}
100
101impl ActualCursor {
102 /// Construct an `ActualCursor` from a cursor offset within the new editable region.
103 ///
104 /// - `path`: file path the cursor is in
105 /// - `editable_region_cursor_offset`: byte offset of the cursor within the new editable region text
106 /// - `new_editable_region`: the full new editable region text (after marker removal)
107 /// - `content`: the full file content (before the edit)
108 /// - `editable_region_byte_offset`: byte offset where the editable region starts in `content`
109 /// - `editable_region_start_line`: 0-based line number where the editable region starts in `content`
110 pub fn from_editable_region(
111 path: &std::path::Path,
112 editable_region_cursor_offset: usize,
113 new_editable_region: &str,
114 content: &str,
115 editable_region_byte_offset: usize,
116 editable_region_start_line: usize,
117 ) -> Self {
118 let global_offset = editable_region_byte_offset + editable_region_cursor_offset;
119 let new_region_prefix = &new_editable_region[..editable_region_cursor_offset];
120 let row = (editable_region_start_line + new_region_prefix.matches('\n').count()) as u32;
121 let column = match new_region_prefix.rfind('\n') {
122 Some(pos) => (editable_region_cursor_offset - pos - 1) as u32,
123 None => {
124 let content_prefix = &content[..editable_region_byte_offset];
125 let content_column = match content_prefix.rfind('\n') {
126 Some(pos) => editable_region_byte_offset - pos - 1,
127 None => editable_region_byte_offset,
128 };
129 (content_column + editable_region_cursor_offset) as u32
130 }
131 };
132 ActualCursor {
133 path: path.to_string_lossy().to_string(),
134 row,
135 column,
136 offset: global_offset,
137 editable_region_offset: Some(editable_region_cursor_offset),
138 }
139 }
140}
141
142fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
143where
144 D: serde::Deserializer<'de>,
145{
146 let opt = Option::<String>::deserialize(deserializer)?;
147 Ok(opt.unwrap_or_default())
148}
149
150#[derive(Clone, Debug, Serialize, Deserialize)]
151pub struct ExampleScore {
152 pub delta_chr_f: f32,
153 pub braces_disbalance: usize,
154 #[serde(default)]
155 pub exact_lines_tp: usize,
156 #[serde(default)]
157 pub exact_lines_fp: usize,
158 #[serde(default)]
159 pub exact_lines_fn: usize,
160 #[serde(default)]
161 pub token_match_tp: usize,
162 #[serde(default)]
163 pub token_match_fp: usize,
164 #[serde(default)]
165 pub token_match_fn: usize,
166 #[serde(default)]
167 pub token_match_precision: f64,
168 #[serde(default)]
169 pub token_match_recall: f64,
170 #[serde(default)]
171 pub reversal_ratio: f32,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
173 pub cursor_distance: Option<usize>,
174 #[serde(default, skip_serializing_if = "Option::is_none")]
175 pub cursor_exact_match: Option<bool>,
176 pub wrong_editable_region: Option<bool>,
177 #[serde(default)]
178 pub has_isolated_whitespace_changes: bool,
179 #[serde(default)]
180 pub inserted_tokens: usize,
181 #[serde(default)]
182 pub deleted_tokens: usize,
183 #[serde(default, skip_serializing_if = "Option::is_none")]
184 pub cumulative_logprob: Option<f64>,
185 #[serde(default, skip_serializing_if = "Option::is_none")]
186 pub avg_logprob: Option<f64>,
187}
188
189impl Example {
190 pub fn repo_name(&self) -> Result<RepoName<'_>> {
191 // git@github.com:owner/repo.git
192 if self.spec.repository_url.contains('@') {
193 let (owner, repo) = self
194 .spec
195 .repository_url
196 .split_once(':')
197 .context("expected : in git url")?
198 .1
199 .split_once('/')
200 .context("expected / in git url")?;
201 Ok(RepoName {
202 owner: Cow::Borrowed(owner),
203 name: Cow::Borrowed(repo.trim_end_matches(".git")),
204 })
205 // http://github.com/owner/repo.git
206 } else {
207 let url = Url::parse(&self.spec.repository_url)?;
208 let mut segments = url.path_segments().context("empty http url")?;
209 let owner = segments
210 .next()
211 .context("expected owner path segment")?
212 .to_string();
213 let repo = segments
214 .next()
215 .context("expected repo path segment")?
216 .trim_end_matches(".git")
217 .to_string();
218 assert!(segments.next().is_none());
219
220 Ok(RepoName {
221 owner: Cow::Owned(owner),
222 name: Cow::Owned(repo),
223 })
224 }
225 }
226}
227
228pub struct RepoName<'a> {
229 pub owner: Cow<'a, str>,
230 pub name: Cow<'a, str>,
231}
232
233impl RepoName<'_> {
234 pub fn worktree_path(&self) -> PathBuf {
235 WORKTREES_DIR
236 .join(self.owner.as_ref())
237 .join(self.name.as_ref())
238 }
239}
240
241pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
242 let mut examples = Vec::new();
243
244 for path in inputs {
245 let is_stdin = path.as_path() == Path::new("-");
246 let content = if is_stdin {
247 let mut buffer = String::new();
248 std::io::stdin()
249 .read_to_string(&mut buffer)
250 .expect("Failed to read from stdin");
251 buffer
252 } else {
253 std::fs::read_to_string(path)
254 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
255 };
256 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
257 let ext = if !is_stdin {
258 path.extension()
259 .map(|ext| ext.to_string_lossy().to_string())
260 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
261 } else {
262 "jsonl".to_string()
263 };
264
265 match ext.as_ref() {
266 "json" => {
267 let mut example =
268 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
269 panic!("Failed to parse example file: {}\n{error}", path.display())
270 });
271 if example.spec.name.is_empty() {
272 example.spec.name = filename;
273 }
274 examples.push(example);
275 }
276 "jsonl" => examples.extend(
277 content
278 .lines()
279 .enumerate()
280 .map(|(line_ix, line)| {
281 let mut example =
282 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
283 panic!(
284 "Failed to parse example on {}:{}\n{error}",
285 path.display(),
286 line_ix + 1
287 )
288 });
289 if example.spec.name.is_empty() {
290 example.spec.name = format!("{filename}-{line_ix}")
291 }
292 example
293 })
294 .collect::<Vec<Example>>(),
295 ),
296 "md" => {
297 let mut example = parse_markdown_example(&content).unwrap();
298 if example.spec.name.is_empty() {
299 example.spec.name = filename;
300 }
301 examples.push(example);
302 }
303 ext => {
304 panic!("{} has invalid example extension `{ext}`", path.display())
305 }
306 }
307 }
308
309 examples
310}
311
312pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
313 examples.sort_by(|a, b| {
314 a.spec
315 .repository_url
316 .cmp(&b.spec.repository_url)
317 .then(b.spec.revision.cmp(&a.spec.revision))
318 });
319}
320
321pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
322 let mut examples_by_repo: HashMap<String, Vec<Example>> = HashMap::default();
323 let mut ungrouped = Vec::new();
324 for example in examples {
325 if example.spec.repository_url.is_empty() {
326 ungrouped.push(example);
327 } else {
328 examples_by_repo
329 .entry(example.spec.repository_url.clone())
330 .or_insert_with(Vec::new)
331 .push(example);
332 }
333 }
334 let mut result: VecDeque<Vec<Example>> = examples_by_repo.into_values().collect();
335 for example in ungrouped {
336 result.push_back(vec![example]);
337 }
338 result
339}
340
341fn parse_markdown_example(input: &str) -> Result<Example> {
342 let spec = ExampleSpec::from_markdown(input)?;
343 Ok(Example {
344 spec,
345 prompt_inputs: None,
346 prompt: None,
347 predictions: Vec::new(),
348 score: Vec::new(),
349 qa: Vec::new(),
350 state: None,
351 zed_version: None,
352 })
353}