1use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use edit_prediction::example_spec::ExampleSpec;
5use edit_prediction::udiff::OpenedBuffers;
6use gpui::Entity;
7use http_client::Url;
8use language::{Anchor, Buffer};
9use project::Project;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use std::{
13 borrow::Cow,
14 io::{Read, Write},
15 path::{Path, PathBuf},
16};
17use zeta_prompt::RelatedFile;
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct Example {
21 #[serde(flatten)]
22 pub spec: ExampleSpec,
23
24 /// The full content of the file where an edit is being predicted, and the
25 /// actual cursor offset.
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub buffer: Option<ExampleBuffer>,
28
29 /// The context retrieved for the prediction. This requires the worktree to
30 /// be loaded and the language server to be started.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub context: Option<ExampleContext>,
33
34 /// The input and expected output from the edit prediction model.
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub prompt: Option<ExamplePrompt>,
37
38 /// The actual predictions from the model.
39 #[serde(default, skip_serializing_if = "Vec::is_empty")]
40 pub predictions: Vec<ExamplePrediction>,
41
42 /// The scores, for how well the actual predictions match the expected
43 /// predictions.
44 #[serde(default, skip_serializing_if = "Vec::is_empty")]
45 pub score: Vec<ExampleScore>,
46
47 /// The application state used to process this example.
48 #[serde(skip)]
49 pub state: Option<ExampleState>,
50}
51
52#[derive(Clone, Debug)]
53pub struct ExampleState {
54 pub project: Entity<Project>,
55 pub buffer: Entity<Buffer>,
56 pub cursor_position: Anchor,
57 pub _open_buffers: OpenedBuffers,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct ExampleContext {
62 pub files: Arc<[RelatedFile]>,
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ExampleBuffer {
67 pub content: String,
68 pub cursor_row: u32,
69 pub cursor_column: u32,
70 pub cursor_offset: usize,
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
74pub struct ExamplePrompt {
75 pub input: String,
76 pub expected_output: String,
77 pub format: PromptFormat,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct ExamplePrediction {
82 pub actual_patch: String,
83 pub actual_output: String,
84 pub provider: PredictionProvider,
85}
86
87#[derive(Clone, Debug, Serialize, Deserialize)]
88pub struct ExampleScore {
89 pub delta_chr_f: f32,
90 pub line_match: ClassificationMetrics,
91}
92
93impl Example {
94 pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
95 // git@github.com:owner/repo.git
96 if self.spec.repository_url.contains('@') {
97 let (owner, repo) = self
98 .spec
99 .repository_url
100 .split_once(':')
101 .context("expected : in git url")?
102 .1
103 .split_once('/')
104 .context("expected / in git url")?;
105 Ok((
106 Cow::Borrowed(owner),
107 Cow::Borrowed(repo.trim_end_matches(".git")),
108 ))
109 // http://github.com/owner/repo.git
110 } else {
111 let url = Url::parse(&self.spec.repository_url)?;
112 let mut segments = url.path_segments().context("empty http url")?;
113 let owner = segments
114 .next()
115 .context("expected owner path segment")?
116 .to_string();
117 let repo = segments
118 .next()
119 .context("expected repo path segment")?
120 .trim_end_matches(".git")
121 .to_string();
122 assert!(segments.next().is_none());
123
124 Ok((owner.into(), repo.into()))
125 }
126 }
127}
128
129pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
130 let mut examples = Vec::new();
131
132 let stdin_path: PathBuf = PathBuf::from("-");
133
134 let inputs = if inputs.is_empty() {
135 &[stdin_path]
136 } else {
137 inputs
138 };
139
140 for path in inputs {
141 let is_stdin = path.as_path() == Path::new("-");
142 let content = if is_stdin {
143 let mut buffer = String::new();
144 std::io::stdin()
145 .read_to_string(&mut buffer)
146 .expect("Failed to read from stdin");
147 buffer
148 } else {
149 std::fs::read_to_string(path)
150 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
151 };
152 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
153 let ext = if !is_stdin {
154 path.extension()
155 .map(|ext| ext.to_string_lossy().to_string())
156 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
157 } else {
158 "jsonl".to_string()
159 };
160
161 match ext.as_ref() {
162 "json" => {
163 let mut example =
164 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
165 panic!("Failed to parse example file: {}\n{error}", path.display())
166 });
167 if example.spec.name.is_empty() {
168 example.spec.name = filename;
169 }
170 examples.push(example);
171 }
172 "jsonl" => examples.extend(
173 content
174 .lines()
175 .enumerate()
176 .map(|(line_ix, line)| {
177 let mut example =
178 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
179 panic!(
180 "Failed to parse example on {}:{}\n{error}",
181 path.display(),
182 line_ix + 1
183 )
184 });
185 if example.spec.name.is_empty() {
186 example.spec.name = format!("{filename}-{line_ix}")
187 }
188 example
189 })
190 .collect::<Vec<Example>>(),
191 ),
192 "md" => {
193 let mut example = parse_markdown_example(&content).unwrap();
194 if example.spec.name.is_empty() {
195 example.spec.name = filename;
196 }
197 examples.push(example);
198 }
199 ext => {
200 panic!("{} has invalid example extension `{ext}`", path.display())
201 }
202 }
203 }
204
205 sort_examples_by_repo_and_rev(&mut examples);
206 examples
207}
208
209pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
210 let mut content = String::new();
211 for example in examples {
212 let line = serde_json::to_string(example).unwrap();
213 content.push_str(&line);
214 content.push('\n');
215 }
216 if let Some(output_path) = output_path {
217 std::fs::write(output_path, content).expect("Failed to write examples");
218 } else {
219 std::io::stdout().write_all(&content.as_bytes()).unwrap();
220 }
221}
222
223pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
224 examples.sort_by(|a, b| {
225 a.spec
226 .repository_url
227 .cmp(&b.spec.repository_url)
228 .then(b.spec.revision.cmp(&a.spec.revision))
229 });
230}
231
232pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
233 let mut examples_by_repo = HashMap::default();
234 for example in examples.iter_mut() {
235 examples_by_repo
236 .entry(example.spec.repository_url.clone())
237 .or_insert_with(Vec::new)
238 .push(example);
239 }
240 examples_by_repo.into_values().collect()
241}
242
243fn parse_markdown_example(input: &str) -> Result<Example> {
244 let spec = ExampleSpec::from_markdown(input)?;
245 Ok(Example {
246 spec,
247 buffer: None,
248 context: None,
249 prompt: None,
250 predictions: Vec::new(),
251 score: Vec::new(),
252 state: None,
253 })
254}