1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use edit_prediction::example_spec::ExampleSpec;
6use edit_prediction::udiff::OpenedBuffers;
7use gpui::Entity;
8use http_client::Url;
9use language::{Anchor, Buffer};
10use project::Project;
11use serde::{Deserialize, Serialize};
12use std::{
13 borrow::Cow,
14 io::Read,
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use zeta_prompt::RelatedFile;
19
20#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct Example {
22 #[serde(flatten)]
23 pub spec: ExampleSpec,
24
25 /// The full content of the file where an edit is being predicted, and the
26 /// actual cursor offset.
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub prompt_inputs: Option<ExamplePromptInputs>,
29
30 /// The input and expected output from the edit prediction model.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub prompt: Option<ExamplePrompt>,
33
34 /// The actual predictions from the model.
35 #[serde(default, skip_serializing_if = "Vec::is_empty")]
36 pub predictions: Vec<ExamplePrediction>,
37
38 /// The scores, for how well the actual predictions match the expected
39 /// predictions.
40 #[serde(default, skip_serializing_if = "Vec::is_empty")]
41 pub score: Vec<ExampleScore>,
42
43 /// The application state used to process this example.
44 #[serde(skip)]
45 pub state: Option<ExampleState>,
46}
47
48#[derive(Clone, Debug)]
49pub struct ExampleState {
50 pub project: Entity<Project>,
51 pub buffer: Entity<Buffer>,
52 pub cursor_position: Anchor,
53 pub _open_buffers: OpenedBuffers,
54}
55
56#[derive(Clone, Debug, Serialize, Deserialize)]
57pub struct ExamplePromptInputs {
58 pub content: String,
59 pub cursor_row: u32,
60 pub cursor_column: u32,
61 pub cursor_offset: usize,
62 pub edit_history: Vec<Arc<zeta_prompt::Event>>,
63 pub related_files: Option<Vec<RelatedFile>>,
64}
65
66#[derive(Clone, Debug, Serialize, Deserialize)]
67pub struct ExamplePrompt {
68 pub input: String,
69 pub expected_output: String,
70 pub provider: PredictionProvider,
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
74pub struct ExamplePrediction {
75 pub actual_patch: String,
76 pub actual_output: String,
77 pub provider: PredictionProvider,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize)]
81pub struct ExampleScore {
82 pub delta_chr_f: f32,
83}
84
85impl Example {
86 pub fn repo_name(&self) -> Result<RepoName<'_>> {
87 // git@github.com:owner/repo.git
88 if self.spec.repository_url.contains('@') {
89 let (owner, repo) = self
90 .spec
91 .repository_url
92 .split_once(':')
93 .context("expected : in git url")?
94 .1
95 .split_once('/')
96 .context("expected / in git url")?;
97 Ok(RepoName {
98 owner: Cow::Borrowed(owner),
99 name: Cow::Borrowed(repo.trim_end_matches(".git")),
100 })
101 // http://github.com/owner/repo.git
102 } else {
103 let url = Url::parse(&self.spec.repository_url)?;
104 let mut segments = url.path_segments().context("empty http url")?;
105 let owner = segments
106 .next()
107 .context("expected owner path segment")?
108 .to_string();
109 let repo = segments
110 .next()
111 .context("expected repo path segment")?
112 .trim_end_matches(".git")
113 .to_string();
114 assert!(segments.next().is_none());
115
116 Ok(RepoName {
117 owner: Cow::Owned(owner),
118 name: Cow::Owned(repo),
119 })
120 }
121 }
122}
123
124pub struct RepoName<'a> {
125 pub owner: Cow<'a, str>,
126 pub name: Cow<'a, str>,
127}
128
129impl RepoName<'_> {
130 pub fn worktree_path(&self) -> PathBuf {
131 WORKTREES_DIR
132 .join(self.owner.as_ref())
133 .join(self.name.as_ref())
134 }
135}
136
137pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
138 let mut examples = Vec::new();
139
140 for path in inputs {
141 let is_stdin = path.as_path() == Path::new("-");
142 let content = if is_stdin {
143 let mut buffer = String::new();
144 std::io::stdin()
145 .read_to_string(&mut buffer)
146 .expect("Failed to read from stdin");
147 buffer
148 } else {
149 std::fs::read_to_string(path)
150 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
151 };
152 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
153 let ext = if !is_stdin {
154 path.extension()
155 .map(|ext| ext.to_string_lossy().to_string())
156 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
157 } else {
158 "jsonl".to_string()
159 };
160
161 match ext.as_ref() {
162 "json" => {
163 let mut example =
164 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
165 panic!("Failed to parse example file: {}\n{error}", path.display())
166 });
167 if example.spec.name.is_empty() {
168 example.spec.name = filename;
169 }
170 examples.push(example);
171 }
172 "jsonl" => examples.extend(
173 content
174 .lines()
175 .enumerate()
176 .map(|(line_ix, line)| {
177 let mut example =
178 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
179 panic!(
180 "Failed to parse example on {}:{}\n{error}",
181 path.display(),
182 line_ix + 1
183 )
184 });
185 if example.spec.name.is_empty() {
186 example.spec.name = format!("{filename}-{line_ix}")
187 }
188 example
189 })
190 .collect::<Vec<Example>>(),
191 ),
192 "md" => {
193 let mut example = parse_markdown_example(&content).unwrap();
194 if example.spec.name.is_empty() {
195 example.spec.name = filename;
196 }
197 examples.push(example);
198 }
199 ext => {
200 panic!("{} has invalid example extension `{ext}`", path.display())
201 }
202 }
203 }
204
205 examples
206}
207
208pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
209 examples.sort_by(|a, b| {
210 a.spec
211 .repository_url
212 .cmp(&b.spec.repository_url)
213 .then(b.spec.revision.cmp(&a.spec.revision))
214 });
215}
216
217pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
218 let mut examples_by_repo = HashMap::default();
219 for example in examples.iter_mut() {
220 examples_by_repo
221 .entry(example.spec.repository_url.clone())
222 .or_insert_with(Vec::new)
223 .push(example);
224 }
225 examples_by_repo.into_values().collect()
226}
227
228fn parse_markdown_example(input: &str) -> Result<Example> {
229 let spec = ExampleSpec::from_markdown(input)?;
230 Ok(Example {
231 spec,
232 prompt_inputs: None,
233 prompt: None,
234 predictions: Vec::new(),
235 score: Vec::new(),
236 state: None,
237 })
238}