1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use edit_prediction::example_spec::ExampleSpec;
6use edit_prediction::udiff::OpenedBuffers;
7use gpui::Entity;
8use http_client::Url;
9use language::{Anchor, Buffer};
10use project::Project;
11use serde::{Deserialize, Serialize};
12use std::{
13 borrow::Cow,
14 collections::VecDeque,
15 io::Read,
16 path::{Path, PathBuf},
17 sync::Arc,
18};
19use zeta_prompt::RelatedFile;
20
21#[derive(Clone, Debug, Serialize, Deserialize)]
22pub struct Example {
23 #[serde(flatten)]
24 pub spec: ExampleSpec,
25
26 /// The full content of the file where an edit is being predicted, and the
27 /// actual cursor offset.
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub prompt_inputs: Option<ExamplePromptInputs>,
30
31 /// The input and expected output from the edit prediction model.
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub prompt: Option<ExamplePrompt>,
34
35 /// The actual predictions from the model.
36 #[serde(default, skip_serializing_if = "Vec::is_empty")]
37 pub predictions: Vec<ExamplePrediction>,
38
39 /// The scores, for how well the actual predictions match the expected
40 /// predictions.
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub score: Vec<ExampleScore>,
43
44 /// The application state used to process this example.
45 #[serde(skip)]
46 pub state: Option<ExampleState>,
47}
48
49#[derive(Clone, Debug)]
50pub struct ExampleState {
51 pub project: Entity<Project>,
52 pub buffer: Entity<Buffer>,
53 pub cursor_position: Anchor,
54 pub _open_buffers: OpenedBuffers,
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct ExamplePromptInputs {
59 pub content: String,
60 pub cursor_row: u32,
61 pub cursor_column: u32,
62 pub cursor_offset: usize,
63 pub edit_history: Vec<Arc<zeta_prompt::Event>>,
64 pub related_files: Option<Vec<RelatedFile>>,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct ExamplePrompt {
69 pub input: String,
70 pub expected_output: String,
71 pub provider: PredictionProvider,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct ExamplePrediction {
76 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub actual_patch: Option<String>,
78 pub actual_output: String,
79 pub provider: PredictionProvider,
80}
81
82#[derive(Clone, Debug, Serialize, Deserialize)]
83pub struct ExampleScore {
84 pub delta_chr_f: f32,
85 pub braces_disbalance: usize,
86 #[serde(default)]
87 pub exact_lines_tp: usize,
88 #[serde(default)]
89 pub exact_lines_fp: usize,
90 #[serde(default)]
91 pub exact_lines_fn: usize,
92}
93
94impl Example {
95 pub fn repo_name(&self) -> Result<RepoName<'_>> {
96 // git@github.com:owner/repo.git
97 if self.spec.repository_url.contains('@') {
98 let (owner, repo) = self
99 .spec
100 .repository_url
101 .split_once(':')
102 .context("expected : in git url")?
103 .1
104 .split_once('/')
105 .context("expected / in git url")?;
106 Ok(RepoName {
107 owner: Cow::Borrowed(owner),
108 name: Cow::Borrowed(repo.trim_end_matches(".git")),
109 })
110 // http://github.com/owner/repo.git
111 } else {
112 let url = Url::parse(&self.spec.repository_url)?;
113 let mut segments = url.path_segments().context("empty http url")?;
114 let owner = segments
115 .next()
116 .context("expected owner path segment")?
117 .to_string();
118 let repo = segments
119 .next()
120 .context("expected repo path segment")?
121 .trim_end_matches(".git")
122 .to_string();
123 assert!(segments.next().is_none());
124
125 Ok(RepoName {
126 owner: Cow::Owned(owner),
127 name: Cow::Owned(repo),
128 })
129 }
130 }
131}
132
133pub struct RepoName<'a> {
134 pub owner: Cow<'a, str>,
135 pub name: Cow<'a, str>,
136}
137
138impl RepoName<'_> {
139 pub fn worktree_path(&self) -> PathBuf {
140 WORKTREES_DIR
141 .join(self.owner.as_ref())
142 .join(self.name.as_ref())
143 }
144}
145
146pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
147 let mut examples = Vec::new();
148
149 for path in inputs {
150 let is_stdin = path.as_path() == Path::new("-");
151 let content = if is_stdin {
152 let mut buffer = String::new();
153 std::io::stdin()
154 .read_to_string(&mut buffer)
155 .expect("Failed to read from stdin");
156 buffer
157 } else {
158 std::fs::read_to_string(path)
159 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
160 };
161 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
162 let ext = if !is_stdin {
163 path.extension()
164 .map(|ext| ext.to_string_lossy().to_string())
165 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
166 } else {
167 "jsonl".to_string()
168 };
169
170 match ext.as_ref() {
171 "json" => {
172 let mut example =
173 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
174 panic!("Failed to parse example file: {}\n{error}", path.display())
175 });
176 if example.spec.name.is_empty() {
177 example.spec.name = filename;
178 }
179 examples.push(example);
180 }
181 "jsonl" => examples.extend(
182 content
183 .lines()
184 .enumerate()
185 .map(|(line_ix, line)| {
186 let mut example =
187 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
188 panic!(
189 "Failed to parse example on {}:{}\n{error}",
190 path.display(),
191 line_ix + 1
192 )
193 });
194 if example.spec.name.is_empty() {
195 example.spec.name = format!("{filename}-{line_ix}")
196 }
197 example
198 })
199 .collect::<Vec<Example>>(),
200 ),
201 "md" => {
202 let mut example = parse_markdown_example(&content).unwrap();
203 if example.spec.name.is_empty() {
204 example.spec.name = filename;
205 }
206 examples.push(example);
207 }
208 ext => {
209 panic!("{} has invalid example extension `{ext}`", path.display())
210 }
211 }
212 }
213
214 examples
215}
216
217pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
218 examples.sort_by(|a, b| {
219 a.spec
220 .repository_url
221 .cmp(&b.spec.repository_url)
222 .then(b.spec.revision.cmp(&a.spec.revision))
223 });
224}
225
226pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
227 let mut examples_by_repo = HashMap::default();
228 for example in examples {
229 examples_by_repo
230 .entry(example.spec.repository_url.clone())
231 .or_insert_with(Vec::new)
232 .push(example);
233 }
234 examples_by_repo.into_values().collect()
235}
236
237fn parse_markdown_example(input: &str) -> Result<Example> {
238 let spec = ExampleSpec::from_markdown(input)?;
239 Ok(Example {
240 spec,
241 prompt_inputs: None,
242 prompt: None,
243 predictions: Vec::new(),
244 score: Vec::new(),
245 state: None,
246 })
247}