1use crate::paths::WORKTREES_DIR;
2use crate::{PredictionProvider, PromptFormat};
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use edit_prediction::example_spec::ExampleSpec;
6use edit_prediction::udiff::OpenedBuffers;
7use gpui::Entity;
8use http_client::Url;
9use language::{Anchor, Buffer};
10use project::Project;
11use serde::{Deserialize, Serialize};
12use std::ops::Range;
13use std::{
14 borrow::Cow,
15 io::Read,
16 path::{Path, PathBuf},
17};
18use zeta_prompt::RelatedFile;
19
20#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct Example {
22 #[serde(flatten)]
23 pub spec: ExampleSpec,
24
25 /// The full content of the file where an edit is being predicted, and the
26 /// actual cursor offset.
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub buffer: Option<ExampleBuffer>,
29
30 /// The context retrieved for the prediction. This requires the worktree to
31 /// be loaded and the language server to be started.
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub context: Option<ExampleContext>,
34
35 /// The input and expected output from the edit prediction model.
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub prompt: Option<ExamplePrompt>,
38
39 /// The actual predictions from the model.
40 #[serde(default, skip_serializing_if = "Vec::is_empty")]
41 pub predictions: Vec<ExamplePrediction>,
42
43 /// The scores, for how well the actual predictions match the expected
44 /// predictions.
45 #[serde(default, skip_serializing_if = "Vec::is_empty")]
46 pub score: Vec<ExampleScore>,
47
48 /// The application state used to process this example.
49 #[serde(skip)]
50 pub state: Option<ExampleState>,
51}
52
53#[derive(Clone, Debug)]
54pub struct ExampleState {
55 pub project: Entity<Project>,
56 pub buffer: Entity<Buffer>,
57 pub cursor_position: Anchor,
58 pub _open_buffers: OpenedBuffers,
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize)]
62pub struct ExampleContext {
63 pub files: Vec<RelatedFile>,
64}
65
66#[derive(Clone, Debug, Serialize, Deserialize)]
67pub struct ExampleBuffer {
68 pub content: String,
69 pub cursor_row: u32,
70 pub cursor_column: u32,
71 pub cursor_offset: usize,
72 pub context_range: Range<usize>,
73 pub editable_range: Range<usize>,
74}
75
76#[derive(Clone, Debug, Serialize, Deserialize)]
77pub struct ExamplePrompt {
78 pub input: String,
79 pub expected_output: String,
80 pub format: PromptFormat,
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize)]
84pub struct ExamplePrediction {
85 pub actual_patch: String,
86 pub actual_output: String,
87 pub provider: PredictionProvider,
88}
89
90#[derive(Clone, Debug, Serialize, Deserialize)]
91pub struct ExampleScore {
92 pub delta_chr_f: f32,
93}
94
95impl Example {
96 pub fn repo_name(&self) -> Result<RepoName<'_>> {
97 // git@github.com:owner/repo.git
98 if self.spec.repository_url.contains('@') {
99 let (owner, repo) = self
100 .spec
101 .repository_url
102 .split_once(':')
103 .context("expected : in git url")?
104 .1
105 .split_once('/')
106 .context("expected / in git url")?;
107 Ok(RepoName {
108 owner: Cow::Borrowed(owner),
109 name: Cow::Borrowed(repo.trim_end_matches(".git")),
110 })
111 // http://github.com/owner/repo.git
112 } else {
113 let url = Url::parse(&self.spec.repository_url)?;
114 let mut segments = url.path_segments().context("empty http url")?;
115 let owner = segments
116 .next()
117 .context("expected owner path segment")?
118 .to_string();
119 let repo = segments
120 .next()
121 .context("expected repo path segment")?
122 .trim_end_matches(".git")
123 .to_string();
124 assert!(segments.next().is_none());
125
126 Ok(RepoName {
127 owner: Cow::Owned(owner),
128 name: Cow::Owned(repo),
129 })
130 }
131 }
132}
133
134pub struct RepoName<'a> {
135 pub owner: Cow<'a, str>,
136 pub name: Cow<'a, str>,
137}
138
139impl RepoName<'_> {
140 pub fn worktree_path(&self) -> PathBuf {
141 WORKTREES_DIR
142 .join(self.owner.as_ref())
143 .join(self.name.as_ref())
144 }
145}
146
147pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
148 let mut examples = Vec::new();
149
150 for path in inputs {
151 let is_stdin = path.as_path() == Path::new("-");
152 let content = if is_stdin {
153 let mut buffer = String::new();
154 std::io::stdin()
155 .read_to_string(&mut buffer)
156 .expect("Failed to read from stdin");
157 buffer
158 } else {
159 std::fs::read_to_string(path)
160 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
161 };
162 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
163 let ext = if !is_stdin {
164 path.extension()
165 .map(|ext| ext.to_string_lossy().to_string())
166 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
167 } else {
168 "jsonl".to_string()
169 };
170
171 match ext.as_ref() {
172 "json" => {
173 let mut example =
174 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
175 panic!("Failed to parse example file: {}\n{error}", path.display())
176 });
177 if example.spec.name.is_empty() {
178 example.spec.name = filename;
179 }
180 examples.push(example);
181 }
182 "jsonl" => examples.extend(
183 content
184 .lines()
185 .enumerate()
186 .map(|(line_ix, line)| {
187 let mut example =
188 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
189 panic!(
190 "Failed to parse example on {}:{}\n{error}",
191 path.display(),
192 line_ix + 1
193 )
194 });
195 if example.spec.name.is_empty() {
196 example.spec.name = format!("{filename}-{line_ix}")
197 }
198 example
199 })
200 .collect::<Vec<Example>>(),
201 ),
202 "md" => {
203 let mut example = parse_markdown_example(&content).unwrap();
204 if example.spec.name.is_empty() {
205 example.spec.name = filename;
206 }
207 examples.push(example);
208 }
209 ext => {
210 panic!("{} has invalid example extension `{ext}`", path.display())
211 }
212 }
213 }
214
215 examples
216}
217
218pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
219 examples.sort_by(|a, b| {
220 a.spec
221 .repository_url
222 .cmp(&b.spec.repository_url)
223 .then(b.spec.revision.cmp(&a.spec.revision))
224 });
225}
226
227pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
228 let mut examples_by_repo = HashMap::default();
229 for example in examples.iter_mut() {
230 examples_by_repo
231 .entry(example.spec.repository_url.clone())
232 .or_insert_with(Vec::new)
233 .push(example);
234 }
235 examples_by_repo.into_values().collect()
236}
237
238fn parse_markdown_example(input: &str) -> Result<Example> {
239 let spec = ExampleSpec::from_markdown(input)?;
240 Ok(Example {
241 spec,
242 buffer: None,
243 context: None,
244 prompt: None,
245 predictions: Vec::new(),
246 score: Vec::new(),
247 state: None,
248 })
249}