1use crate::paths::WORKTREES_DIR;
2use crate::{PredictionProvider, PromptFormat};
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use edit_prediction::example_spec::ExampleSpec;
6use edit_prediction::udiff::OpenedBuffers;
7use gpui::Entity;
8use http_client::Url;
9use language::{Anchor, Buffer};
10use project::Project;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::{
14 borrow::Cow,
15 io::{Read, Write},
16 path::{Path, PathBuf},
17};
18use zeta_prompt::RelatedFile;
19
20#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct Example {
22 #[serde(flatten)]
23 pub spec: ExampleSpec,
24
25 /// The full content of the file where an edit is being predicted, and the
26 /// actual cursor offset.
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub buffer: Option<ExampleBuffer>,
29
30 /// The context retrieved for the prediction. This requires the worktree to
31 /// be loaded and the language server to be started.
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub context: Option<ExampleContext>,
34
35 /// The input and expected output from the edit prediction model.
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub prompt: Option<ExamplePrompt>,
38
39 /// The actual predictions from the model.
40 #[serde(default, skip_serializing_if = "Vec::is_empty")]
41 pub predictions: Vec<ExamplePrediction>,
42
43 /// The scores, for how well the actual predictions match the expected
44 /// predictions.
45 #[serde(default, skip_serializing_if = "Vec::is_empty")]
46 pub score: Vec<ExampleScore>,
47
48 /// The application state used to process this example.
49 #[serde(skip)]
50 pub state: Option<ExampleState>,
51}
52
53#[derive(Clone, Debug)]
54pub struct ExampleState {
55 pub project: Entity<Project>,
56 pub buffer: Entity<Buffer>,
57 pub cursor_position: Anchor,
58 pub _open_buffers: OpenedBuffers,
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize)]
62pub struct ExampleContext {
63 pub files: Arc<[RelatedFile]>,
64}
65
66#[derive(Clone, Debug, Serialize, Deserialize)]
67pub struct ExampleBuffer {
68 pub content: String,
69 pub cursor_row: u32,
70 pub cursor_column: u32,
71 pub cursor_offset: usize,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct ExamplePrompt {
76 pub input: String,
77 pub expected_output: String,
78 pub format: PromptFormat,
79}
80
81#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct ExamplePrediction {
83 pub actual_patch: String,
84 pub actual_output: String,
85 pub provider: PredictionProvider,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize)]
89pub struct ExampleScore {
90 pub delta_chr_f: f32,
91}
92
93impl Example {
94 pub fn repo_name(&self) -> Result<RepoName<'_>> {
95 // git@github.com:owner/repo.git
96 if self.spec.repository_url.contains('@') {
97 let (owner, repo) = self
98 .spec
99 .repository_url
100 .split_once(':')
101 .context("expected : in git url")?
102 .1
103 .split_once('/')
104 .context("expected / in git url")?;
105 Ok(RepoName {
106 owner: Cow::Borrowed(owner),
107 name: Cow::Borrowed(repo.trim_end_matches(".git")),
108 })
109 // http://github.com/owner/repo.git
110 } else {
111 let url = Url::parse(&self.spec.repository_url)?;
112 let mut segments = url.path_segments().context("empty http url")?;
113 let owner = segments
114 .next()
115 .context("expected owner path segment")?
116 .to_string();
117 let repo = segments
118 .next()
119 .context("expected repo path segment")?
120 .trim_end_matches(".git")
121 .to_string();
122 assert!(segments.next().is_none());
123
124 Ok(RepoName {
125 owner: Cow::Owned(owner),
126 name: Cow::Owned(repo),
127 })
128 }
129 }
130}
131
132pub struct RepoName<'a> {
133 pub owner: Cow<'a, str>,
134 pub name: Cow<'a, str>,
135}
136
137impl RepoName<'_> {
138 pub fn worktree_path(&self) -> PathBuf {
139 WORKTREES_DIR
140 .join(self.owner.as_ref())
141 .join(self.name.as_ref())
142 }
143}
144
145pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
146 let mut examples = Vec::new();
147
148 for path in inputs {
149 let is_stdin = path.as_path() == Path::new("-");
150 let content = if is_stdin {
151 let mut buffer = String::new();
152 std::io::stdin()
153 .read_to_string(&mut buffer)
154 .expect("Failed to read from stdin");
155 buffer
156 } else {
157 std::fs::read_to_string(path)
158 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
159 };
160 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
161 let ext = if !is_stdin {
162 path.extension()
163 .map(|ext| ext.to_string_lossy().to_string())
164 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
165 } else {
166 "jsonl".to_string()
167 };
168
169 match ext.as_ref() {
170 "json" => {
171 let mut example =
172 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
173 panic!("Failed to parse example file: {}\n{error}", path.display())
174 });
175 if example.spec.name.is_empty() {
176 example.spec.name = filename;
177 }
178 examples.push(example);
179 }
180 "jsonl" => examples.extend(
181 content
182 .lines()
183 .enumerate()
184 .map(|(line_ix, line)| {
185 let mut example =
186 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
187 panic!(
188 "Failed to parse example on {}:{}\n{error}",
189 path.display(),
190 line_ix + 1
191 )
192 });
193 if example.spec.name.is_empty() {
194 example.spec.name = format!("{filename}-{line_ix}")
195 }
196 example
197 })
198 .collect::<Vec<Example>>(),
199 ),
200 "md" => {
201 let mut example = parse_markdown_example(&content).unwrap();
202 if example.spec.name.is_empty() {
203 example.spec.name = filename;
204 }
205 examples.push(example);
206 }
207 ext => {
208 panic!("{} has invalid example extension `{ext}`", path.display())
209 }
210 }
211 }
212
213 examples
214}
215
216pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
217 let mut content = String::new();
218 for example in examples {
219 let line = serde_json::to_string(example).unwrap();
220 content.push_str(&line);
221 content.push('\n');
222 }
223 if let Some(output_path) = output_path {
224 std::fs::write(output_path, content).expect("Failed to write examples");
225 } else {
226 std::io::stdout().write_all(&content.as_bytes()).unwrap();
227 }
228}
229
230pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
231 examples.sort_by(|a, b| {
232 a.spec
233 .repository_url
234 .cmp(&b.spec.repository_url)
235 .then(b.spec.revision.cmp(&a.spec.revision))
236 });
237}
238
239pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
240 let mut examples_by_repo = HashMap::default();
241 for example in examples.iter_mut() {
242 examples_by_repo
243 .entry(example.spec.repository_url.clone())
244 .or_insert_with(Vec::new)
245 .push(example);
246 }
247 examples_by_repo.into_values().collect()
248}
249
250fn parse_markdown_example(input: &str) -> Result<Example> {
251 let spec = ExampleSpec::from_markdown(input)?;
252 Ok(Example {
253 spec,
254 buffer: None,
255 context: None,
256 prompt: None,
257 predictions: Vec::new(),
258 score: Vec::new(),
259 state: None,
260 })
261}