1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use edit_prediction::example_spec::ExampleSpec;
6use edit_prediction::udiff::OpenedBuffers;
7use gpui::Entity;
8use http_client::Url;
9use language::{Anchor, Buffer};
10use project::Project;
11use serde::{Deserialize, Serialize};
12use std::{
13 borrow::Cow,
14 collections::VecDeque,
15 io::Read,
16 path::{Path, PathBuf},
17 sync::Arc,
18};
19use zeta_prompt::RelatedFile;
20
21#[derive(Clone, Debug, Serialize, Deserialize)]
22pub struct Example {
23 #[serde(flatten)]
24 pub spec: ExampleSpec,
25
26 /// The full content of the file where an edit is being predicted, and the
27 /// actual cursor offset.
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub prompt_inputs: Option<ExamplePromptInputs>,
30
31 /// The input and expected output from the edit prediction model.
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub prompt: Option<ExamplePrompt>,
34
35 /// The actual predictions from the model.
36 #[serde(default, skip_serializing_if = "Vec::is_empty")]
37 pub predictions: Vec<ExamplePrediction>,
38
39 /// The scores, for how well the actual predictions match the expected
40 /// predictions.
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub score: Vec<ExampleScore>,
43
44 /// The application state used to process this example.
45 #[serde(skip)]
46 pub state: Option<ExampleState>,
47}
48
49#[derive(Clone, Debug)]
50pub struct ExampleState {
51 pub project: Entity<Project>,
52 pub buffer: Entity<Buffer>,
53 pub cursor_position: Anchor,
54 pub _open_buffers: OpenedBuffers,
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct ExamplePromptInputs {
59 pub content: String,
60 pub cursor_row: u32,
61 pub cursor_column: u32,
62 pub cursor_offset: usize,
63 pub edit_history: Vec<Arc<zeta_prompt::Event>>,
64 pub related_files: Option<Vec<RelatedFile>>,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct ExamplePrompt {
69 pub input: String,
70 pub expected_output: String,
71 pub provider: PredictionProvider,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct ExamplePrediction {
76 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub actual_patch: Option<String>,
78 #[serde(deserialize_with = "deserialize_null_as_empty_string")]
79 pub actual_output: String,
80 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub error: Option<String>,
82 pub provider: PredictionProvider,
83}
84
85fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
86where
87 D: serde::Deserializer<'de>,
88{
89 let opt = Option::<String>::deserialize(deserializer)?;
90 Ok(opt.unwrap_or_default())
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub struct ExampleScore {
95 pub delta_chr_f: f32,
96 pub braces_disbalance: usize,
97 #[serde(default)]
98 pub exact_lines_tp: usize,
99 #[serde(default)]
100 pub exact_lines_fp: usize,
101 #[serde(default)]
102 pub exact_lines_fn: usize,
103}
104
105impl Example {
106 pub fn repo_name(&self) -> Result<RepoName<'_>> {
107 // git@github.com:owner/repo.git
108 if self.spec.repository_url.contains('@') {
109 let (owner, repo) = self
110 .spec
111 .repository_url
112 .split_once(':')
113 .context("expected : in git url")?
114 .1
115 .split_once('/')
116 .context("expected / in git url")?;
117 Ok(RepoName {
118 owner: Cow::Borrowed(owner),
119 name: Cow::Borrowed(repo.trim_end_matches(".git")),
120 })
121 // http://github.com/owner/repo.git
122 } else {
123 let url = Url::parse(&self.spec.repository_url)?;
124 let mut segments = url.path_segments().context("empty http url")?;
125 let owner = segments
126 .next()
127 .context("expected owner path segment")?
128 .to_string();
129 let repo = segments
130 .next()
131 .context("expected repo path segment")?
132 .trim_end_matches(".git")
133 .to_string();
134 assert!(segments.next().is_none());
135
136 Ok(RepoName {
137 owner: Cow::Owned(owner),
138 name: Cow::Owned(repo),
139 })
140 }
141 }
142}
143
144pub struct RepoName<'a> {
145 pub owner: Cow<'a, str>,
146 pub name: Cow<'a, str>,
147}
148
149impl RepoName<'_> {
150 pub fn worktree_path(&self) -> PathBuf {
151 WORKTREES_DIR
152 .join(self.owner.as_ref())
153 .join(self.name.as_ref())
154 }
155}
156
157pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
158 let mut examples = Vec::new();
159
160 for path in inputs {
161 let is_stdin = path.as_path() == Path::new("-");
162 let content = if is_stdin {
163 let mut buffer = String::new();
164 std::io::stdin()
165 .read_to_string(&mut buffer)
166 .expect("Failed to read from stdin");
167 buffer
168 } else {
169 std::fs::read_to_string(path)
170 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
171 };
172 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
173 let ext = if !is_stdin {
174 path.extension()
175 .map(|ext| ext.to_string_lossy().to_string())
176 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
177 } else {
178 "jsonl".to_string()
179 };
180
181 match ext.as_ref() {
182 "json" => {
183 let mut example =
184 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
185 panic!("Failed to parse example file: {}\n{error}", path.display())
186 });
187 if example.spec.name.is_empty() {
188 example.spec.name = filename;
189 }
190 examples.push(example);
191 }
192 "jsonl" => examples.extend(
193 content
194 .lines()
195 .enumerate()
196 .map(|(line_ix, line)| {
197 let mut example =
198 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
199 panic!(
200 "Failed to parse example on {}:{}\n{error}",
201 path.display(),
202 line_ix + 1
203 )
204 });
205 if example.spec.name.is_empty() {
206 example.spec.name = format!("{filename}-{line_ix}")
207 }
208 example
209 })
210 .collect::<Vec<Example>>(),
211 ),
212 "md" => {
213 let mut example = parse_markdown_example(&content).unwrap();
214 if example.spec.name.is_empty() {
215 example.spec.name = filename;
216 }
217 examples.push(example);
218 }
219 ext => {
220 panic!("{} has invalid example extension `{ext}`", path.display())
221 }
222 }
223 }
224
225 examples
226}
227
228pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
229 examples.sort_by(|a, b| {
230 a.spec
231 .repository_url
232 .cmp(&b.spec.repository_url)
233 .then(b.spec.revision.cmp(&a.spec.revision))
234 });
235}
236
237pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
238 let mut examples_by_repo = HashMap::default();
239 for example in examples {
240 examples_by_repo
241 .entry(example.spec.repository_url.clone())
242 .or_insert_with(Vec::new)
243 .push(example);
244 }
245 examples_by_repo.into_values().collect()
246}
247
248fn parse_markdown_example(input: &str) -> Result<Example> {
249 let spec = ExampleSpec::from_markdown(input)?;
250 Ok(Example {
251 spec,
252 prompt_inputs: None,
253 prompt: None,
254 predictions: Vec::new(),
255 score: Vec::new(),
256 state: None,
257 })
258}