1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use crate::qa::QaResult;
4use anyhow::{Context as _, Result};
5use collections::HashMap;
6use edit_prediction::example_spec::ExampleSpec;
7use edit_prediction::udiff::OpenedBuffers;
8use gpui::Entity;
9use http_client::Url;
10use language::{Anchor, Buffer};
11use project::Project;
12use serde::{Deserialize, Serialize};
13use std::{
14 borrow::Cow,
15 collections::VecDeque,
16 io::Read,
17 path::{Path, PathBuf},
18 sync::Arc,
19};
20use zeta_prompt::RelatedFile;
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct Example {
24 #[serde(flatten)]
25 pub spec: ExampleSpec,
26
27 /// The full content of the file where an edit is being predicted, and the
28 /// actual cursor offset.
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub prompt_inputs: Option<ExamplePromptInputs>,
31
32 /// The input and expected output from the edit prediction model.
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub prompt: Option<ExamplePrompt>,
35
36 /// The actual predictions from the model.
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
38 pub predictions: Vec<ExamplePrediction>,
39
40 /// The scores, for how well the actual predictions match the expected
41 /// predictions.
42 #[serde(default, skip_serializing_if = "Vec::is_empty")]
43 pub score: Vec<ExampleScore>,
44
45 /// QA evaluation results for each prediction (indexed parallel to `predictions`).
46 #[serde(default, skip_serializing_if = "Vec::is_empty")]
47 pub qa: Vec<Option<QaResult>>,
48
49 /// The application state used to process this example.
50 #[serde(skip)]
51 pub state: Option<ExampleState>,
52}
53
54#[derive(Clone, Debug)]
55pub struct ExampleState {
56 pub project: Entity<Project>,
57 pub buffer: Entity<Buffer>,
58 pub cursor_position: Anchor,
59 pub _open_buffers: OpenedBuffers,
60}
61
62#[derive(Clone, Debug, Serialize, Deserialize)]
63pub struct ExamplePromptInputs {
64 pub content: String,
65 pub cursor_row: u32,
66 pub cursor_column: u32,
67 pub cursor_offset: usize,
68 pub edit_history: Vec<Arc<zeta_prompt::Event>>,
69 pub related_files: Option<Vec<RelatedFile>>,
70}
71
72#[derive(Clone, Debug, Serialize, Deserialize)]
73pub struct ExamplePrompt {
74 pub input: String,
75 pub expected_output: String,
76 pub provider: PredictionProvider,
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize)]
80pub struct ExamplePrediction {
81 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub actual_patch: Option<String>,
83 #[serde(deserialize_with = "deserialize_null_as_empty_string")]
84 pub actual_output: String,
85 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub error: Option<String>,
87 pub provider: PredictionProvider,
88}
89
90fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
91where
92 D: serde::Deserializer<'de>,
93{
94 let opt = Option::<String>::deserialize(deserializer)?;
95 Ok(opt.unwrap_or_default())
96}
97
98#[derive(Clone, Debug, Serialize, Deserialize)]
99pub struct ExampleScore {
100 pub delta_chr_f: f32,
101 pub braces_disbalance: usize,
102 #[serde(default)]
103 pub exact_lines_tp: usize,
104 #[serde(default)]
105 pub exact_lines_fp: usize,
106 #[serde(default)]
107 pub exact_lines_fn: usize,
108 #[serde(default)]
109 pub reversal_ratio: f32,
110}
111
112impl Example {
113 pub fn repo_name(&self) -> Result<RepoName<'_>> {
114 // git@github.com:owner/repo.git
115 if self.spec.repository_url.contains('@') {
116 let (owner, repo) = self
117 .spec
118 .repository_url
119 .split_once(':')
120 .context("expected : in git url")?
121 .1
122 .split_once('/')
123 .context("expected / in git url")?;
124 Ok(RepoName {
125 owner: Cow::Borrowed(owner),
126 name: Cow::Borrowed(repo.trim_end_matches(".git")),
127 })
128 // http://github.com/owner/repo.git
129 } else {
130 let url = Url::parse(&self.spec.repository_url)?;
131 let mut segments = url.path_segments().context("empty http url")?;
132 let owner = segments
133 .next()
134 .context("expected owner path segment")?
135 .to_string();
136 let repo = segments
137 .next()
138 .context("expected repo path segment")?
139 .trim_end_matches(".git")
140 .to_string();
141 assert!(segments.next().is_none());
142
143 Ok(RepoName {
144 owner: Cow::Owned(owner),
145 name: Cow::Owned(repo),
146 })
147 }
148 }
149}
150
151pub struct RepoName<'a> {
152 pub owner: Cow<'a, str>,
153 pub name: Cow<'a, str>,
154}
155
156impl RepoName<'_> {
157 pub fn worktree_path(&self) -> PathBuf {
158 WORKTREES_DIR
159 .join(self.owner.as_ref())
160 .join(self.name.as_ref())
161 }
162}
163
164pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
165 let mut examples = Vec::new();
166
167 for path in inputs {
168 let is_stdin = path.as_path() == Path::new("-");
169 let content = if is_stdin {
170 let mut buffer = String::new();
171 std::io::stdin()
172 .read_to_string(&mut buffer)
173 .expect("Failed to read from stdin");
174 buffer
175 } else {
176 std::fs::read_to_string(path)
177 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
178 };
179 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
180 let ext = if !is_stdin {
181 path.extension()
182 .map(|ext| ext.to_string_lossy().to_string())
183 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
184 } else {
185 "jsonl".to_string()
186 };
187
188 match ext.as_ref() {
189 "json" => {
190 let mut example =
191 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
192 panic!("Failed to parse example file: {}\n{error}", path.display())
193 });
194 if example.spec.name.is_empty() {
195 example.spec.name = filename;
196 }
197 examples.push(example);
198 }
199 "jsonl" => examples.extend(
200 content
201 .lines()
202 .enumerate()
203 .map(|(line_ix, line)| {
204 let mut example =
205 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
206 panic!(
207 "Failed to parse example on {}:{}\n{error}",
208 path.display(),
209 line_ix + 1
210 )
211 });
212 if example.spec.name.is_empty() {
213 example.spec.name = format!("{filename}-{line_ix}")
214 }
215 example
216 })
217 .collect::<Vec<Example>>(),
218 ),
219 "md" => {
220 let mut example = parse_markdown_example(&content).unwrap();
221 if example.spec.name.is_empty() {
222 example.spec.name = filename;
223 }
224 examples.push(example);
225 }
226 ext => {
227 panic!("{} has invalid example extension `{ext}`", path.display())
228 }
229 }
230 }
231
232 examples
233}
234
235pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
236 examples.sort_by(|a, b| {
237 a.spec
238 .repository_url
239 .cmp(&b.spec.repository_url)
240 .then(b.spec.revision.cmp(&a.spec.revision))
241 });
242}
243
244pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
245 let mut examples_by_repo = HashMap::default();
246 for example in examples {
247 examples_by_repo
248 .entry(example.spec.repository_url.clone())
249 .or_insert_with(Vec::new)
250 .push(example);
251 }
252 examples_by_repo.into_values().collect()
253}
254
255fn parse_markdown_example(input: &str) -> Result<Example> {
256 let spec = ExampleSpec::from_markdown(input)?;
257 Ok(Example {
258 spec,
259 prompt_inputs: None,
260 prompt: None,
261 predictions: Vec::new(),
262 score: Vec::new(),
263 qa: Vec::new(),
264 state: None,
265 })
266}