1use crate::{PredictionProvider, PromptFormat};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use edit_prediction::example_spec::ExampleSpec;
5use edit_prediction::udiff::OpenedBuffers;
6use gpui::Entity;
7use http_client::Url;
8use language::{Anchor, Buffer};
9use project::Project;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use std::{
13 borrow::Cow,
14 io::{Read, Write},
15 path::{Path, PathBuf},
16};
17use zeta_prompt::RelatedFile;
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct Example {
21 #[serde(flatten)]
22 pub spec: ExampleSpec,
23
24 /// The full content of the file where an edit is being predicted, and the
25 /// actual cursor offset.
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub buffer: Option<ExampleBuffer>,
28
29 /// The context retrieved for the prediction. This requires the worktree to
30 /// be loaded and the language server to be started.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub context: Option<ExampleContext>,
33
34 /// The input and expected output from the edit prediction model.
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub prompt: Option<ExamplePrompt>,
37
38 /// The actual predictions from the model.
39 #[serde(default, skip_serializing_if = "Vec::is_empty")]
40 pub predictions: Vec<ExamplePrediction>,
41
42 /// The scores, for how well the actual predictions match the expected
43 /// predictions.
44 #[serde(default, skip_serializing_if = "Vec::is_empty")]
45 pub score: Vec<ExampleScore>,
46
47 /// The application state used to process this example.
48 #[serde(skip)]
49 pub state: Option<ExampleState>,
50}
51
52#[derive(Clone, Debug)]
53pub struct ExampleState {
54 pub project: Entity<Project>,
55 pub buffer: Entity<Buffer>,
56 pub cursor_position: Anchor,
57 pub _open_buffers: OpenedBuffers,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct ExampleContext {
62 pub files: Arc<[RelatedFile]>,
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ExampleBuffer {
67 pub content: String,
68 pub cursor_row: u32,
69 pub cursor_column: u32,
70 pub cursor_offset: usize,
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
74pub struct ExamplePrompt {
75 pub input: String,
76 pub expected_output: String,
77 pub format: PromptFormat,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct ExamplePrediction {
82 pub actual_patch: String,
83 pub actual_output: String,
84 pub provider: PredictionProvider,
85}
86
87#[derive(Clone, Debug, Serialize, Deserialize)]
88pub struct ExampleScore {
89 pub delta_chr_f: f32,
90}
91
92impl Example {
93 pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
94 // git@github.com:owner/repo.git
95 if self.spec.repository_url.contains('@') {
96 let (owner, repo) = self
97 .spec
98 .repository_url
99 .split_once(':')
100 .context("expected : in git url")?
101 .1
102 .split_once('/')
103 .context("expected / in git url")?;
104 Ok((
105 Cow::Borrowed(owner),
106 Cow::Borrowed(repo.trim_end_matches(".git")),
107 ))
108 // http://github.com/owner/repo.git
109 } else {
110 let url = Url::parse(&self.spec.repository_url)?;
111 let mut segments = url.path_segments().context("empty http url")?;
112 let owner = segments
113 .next()
114 .context("expected owner path segment")?
115 .to_string();
116 let repo = segments
117 .next()
118 .context("expected repo path segment")?
119 .trim_end_matches(".git")
120 .to_string();
121 assert!(segments.next().is_none());
122
123 Ok((owner.into(), repo.into()))
124 }
125 }
126}
127
128pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
129 let mut examples = Vec::new();
130
131 let stdin_path: PathBuf = PathBuf::from("-");
132
133 let inputs = if inputs.is_empty() {
134 &[stdin_path]
135 } else {
136 inputs
137 };
138
139 for path in inputs {
140 let is_stdin = path.as_path() == Path::new("-");
141 let content = if is_stdin {
142 let mut buffer = String::new();
143 std::io::stdin()
144 .read_to_string(&mut buffer)
145 .expect("Failed to read from stdin");
146 buffer
147 } else {
148 std::fs::read_to_string(path)
149 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
150 };
151 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
152 let ext = if !is_stdin {
153 path.extension()
154 .map(|ext| ext.to_string_lossy().to_string())
155 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
156 } else {
157 "jsonl".to_string()
158 };
159
160 match ext.as_ref() {
161 "json" => {
162 let mut example =
163 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
164 panic!("Failed to parse example file: {}\n{error}", path.display())
165 });
166 if example.spec.name.is_empty() {
167 example.spec.name = filename;
168 }
169 examples.push(example);
170 }
171 "jsonl" => examples.extend(
172 content
173 .lines()
174 .enumerate()
175 .map(|(line_ix, line)| {
176 let mut example =
177 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
178 panic!(
179 "Failed to parse example on {}:{}\n{error}",
180 path.display(),
181 line_ix + 1
182 )
183 });
184 if example.spec.name.is_empty() {
185 example.spec.name = format!("{filename}-{line_ix}")
186 }
187 example
188 })
189 .collect::<Vec<Example>>(),
190 ),
191 "md" => {
192 let mut example = parse_markdown_example(&content).unwrap();
193 if example.spec.name.is_empty() {
194 example.spec.name = filename;
195 }
196 examples.push(example);
197 }
198 ext => {
199 panic!("{} has invalid example extension `{ext}`", path.display())
200 }
201 }
202 }
203
204 sort_examples_by_repo_and_rev(&mut examples);
205 examples
206}
207
208pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
209 let mut content = String::new();
210 for example in examples {
211 let line = serde_json::to_string(example).unwrap();
212 content.push_str(&line);
213 content.push('\n');
214 }
215 if let Some(output_path) = output_path {
216 std::fs::write(output_path, content).expect("Failed to write examples");
217 } else {
218 std::io::stdout().write_all(&content.as_bytes()).unwrap();
219 }
220}
221
222pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
223 examples.sort_by(|a, b| {
224 a.spec
225 .repository_url
226 .cmp(&b.spec.repository_url)
227 .then(b.spec.revision.cmp(&a.spec.revision))
228 });
229}
230
231pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
232 let mut examples_by_repo = HashMap::default();
233 for example in examples.iter_mut() {
234 examples_by_repo
235 .entry(example.spec.repository_url.clone())
236 .or_insert_with(Vec::new)
237 .push(example);
238 }
239 examples_by_repo.into_values().collect()
240}
241
242fn parse_markdown_example(input: &str) -> Result<Example> {
243 let spec = ExampleSpec::from_markdown(input)?;
244 Ok(Example {
245 spec,
246 buffer: None,
247 context: None,
248 prompt: None,
249 predictions: Vec::new(),
250 score: Vec::new(),
251 state: None,
252 })
253}