1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use edit_prediction::example_spec::ExampleSpec;
6use edit_prediction::udiff::OpenedBuffers;
7use gpui::Entity;
8use http_client::Url;
9use language::{Anchor, Buffer};
10use project::Project;
11use serde::{Deserialize, Serialize};
12use std::{
13 borrow::Cow,
14 collections::VecDeque,
15 io::Read,
16 path::{Path, PathBuf},
17 sync::Arc,
18};
19use zeta_prompt::RelatedFile;
20
21#[derive(Clone, Debug, Serialize, Deserialize)]
22pub struct Example {
23 #[serde(flatten)]
24 pub spec: ExampleSpec,
25
26 /// The full content of the file where an edit is being predicted, and the
27 /// actual cursor offset.
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub prompt_inputs: Option<ExamplePromptInputs>,
30
31 /// The input and expected output from the edit prediction model.
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub prompt: Option<ExamplePrompt>,
34
35 /// The actual predictions from the model.
36 #[serde(default, skip_serializing_if = "Vec::is_empty")]
37 pub predictions: Vec<ExamplePrediction>,
38
39 /// The scores, for how well the actual predictions match the expected
40 /// predictions.
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub score: Vec<ExampleScore>,
43
44 /// The application state used to process this example.
45 #[serde(skip)]
46 pub state: Option<ExampleState>,
47}
48
49#[derive(Clone, Debug)]
50pub struct ExampleState {
51 pub project: Entity<Project>,
52 pub buffer: Entity<Buffer>,
53 pub cursor_position: Anchor,
54 pub _open_buffers: OpenedBuffers,
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct ExamplePromptInputs {
59 pub content: String,
60 pub cursor_row: u32,
61 pub cursor_column: u32,
62 pub cursor_offset: usize,
63 pub edit_history: Vec<Arc<zeta_prompt::Event>>,
64 pub related_files: Option<Vec<RelatedFile>>,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct ExamplePrompt {
69 pub input: String,
70 pub expected_output: String,
71 pub provider: PredictionProvider,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct ExamplePrediction {
76 pub actual_patch: String,
77 pub actual_output: String,
78 pub provider: PredictionProvider,
79}
80
81#[derive(Clone, Debug, Serialize, Deserialize)]
82pub struct ExampleScore {
83 pub delta_chr_f: f32,
84}
85
86impl Example {
87 pub fn repo_name(&self) -> Result<RepoName<'_>> {
88 // git@github.com:owner/repo.git
89 if self.spec.repository_url.contains('@') {
90 let (owner, repo) = self
91 .spec
92 .repository_url
93 .split_once(':')
94 .context("expected : in git url")?
95 .1
96 .split_once('/')
97 .context("expected / in git url")?;
98 Ok(RepoName {
99 owner: Cow::Borrowed(owner),
100 name: Cow::Borrowed(repo.trim_end_matches(".git")),
101 })
102 // http://github.com/owner/repo.git
103 } else {
104 let url = Url::parse(&self.spec.repository_url)?;
105 let mut segments = url.path_segments().context("empty http url")?;
106 let owner = segments
107 .next()
108 .context("expected owner path segment")?
109 .to_string();
110 let repo = segments
111 .next()
112 .context("expected repo path segment")?
113 .trim_end_matches(".git")
114 .to_string();
115 assert!(segments.next().is_none());
116
117 Ok(RepoName {
118 owner: Cow::Owned(owner),
119 name: Cow::Owned(repo),
120 })
121 }
122 }
123}
124
125pub struct RepoName<'a> {
126 pub owner: Cow<'a, str>,
127 pub name: Cow<'a, str>,
128}
129
130impl RepoName<'_> {
131 pub fn worktree_path(&self) -> PathBuf {
132 WORKTREES_DIR
133 .join(self.owner.as_ref())
134 .join(self.name.as_ref())
135 }
136}
137
138pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
139 let mut examples = Vec::new();
140
141 for path in inputs {
142 let is_stdin = path.as_path() == Path::new("-");
143 let content = if is_stdin {
144 let mut buffer = String::new();
145 std::io::stdin()
146 .read_to_string(&mut buffer)
147 .expect("Failed to read from stdin");
148 buffer
149 } else {
150 std::fs::read_to_string(path)
151 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
152 };
153 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
154 let ext = if !is_stdin {
155 path.extension()
156 .map(|ext| ext.to_string_lossy().to_string())
157 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
158 } else {
159 "jsonl".to_string()
160 };
161
162 match ext.as_ref() {
163 "json" => {
164 let mut example =
165 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
166 panic!("Failed to parse example file: {}\n{error}", path.display())
167 });
168 if example.spec.name.is_empty() {
169 example.spec.name = filename;
170 }
171 examples.push(example);
172 }
173 "jsonl" => examples.extend(
174 content
175 .lines()
176 .enumerate()
177 .map(|(line_ix, line)| {
178 let mut example =
179 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
180 panic!(
181 "Failed to parse example on {}:{}\n{error}",
182 path.display(),
183 line_ix + 1
184 )
185 });
186 if example.spec.name.is_empty() {
187 example.spec.name = format!("{filename}-{line_ix}")
188 }
189 example
190 })
191 .collect::<Vec<Example>>(),
192 ),
193 "md" => {
194 let mut example = parse_markdown_example(&content).unwrap();
195 if example.spec.name.is_empty() {
196 example.spec.name = filename;
197 }
198 examples.push(example);
199 }
200 ext => {
201 panic!("{} has invalid example extension `{ext}`", path.display())
202 }
203 }
204 }
205
206 examples
207}
208
209pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
210 examples.sort_by(|a, b| {
211 a.spec
212 .repository_url
213 .cmp(&b.spec.repository_url)
214 .then(b.spec.revision.cmp(&a.spec.revision))
215 });
216}
217
218pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
219 let mut examples_by_repo = HashMap::default();
220 for example in examples {
221 examples_by_repo
222 .entry(example.spec.repository_url.clone())
223 .or_insert_with(Vec::new)
224 .push(example);
225 }
226 examples_by_repo.into_values().collect()
227}
228
229fn parse_markdown_example(input: &str) -> Result<Example> {
230 let spec = ExampleSpec::from_markdown(input)?;
231 Ok(Example {
232 spec,
233 prompt_inputs: None,
234 prompt: None,
235 predictions: Vec::new(),
236 score: Vec::new(),
237 state: None,
238 })
239}