1use crate::paths::WORKTREES_DIR;
2use crate::{PredictionProvider, PromptFormat};
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use edit_prediction::example_spec::ExampleSpec;
6use edit_prediction::udiff::OpenedBuffers;
7use gpui::Entity;
8use http_client::Url;
9use language::{Anchor, Buffer};
10use project::Project;
11use serde::{Deserialize, Serialize};
12use std::ops::Range;
13use std::sync::Arc;
14use std::{
15 borrow::Cow,
16 io::Read,
17 path::{Path, PathBuf},
18};
19use zeta_prompt::RelatedFile;
20
21#[derive(Clone, Debug, Serialize, Deserialize)]
22pub struct Example {
23 #[serde(flatten)]
24 pub spec: ExampleSpec,
25
26 /// The full content of the file where an edit is being predicted, and the
27 /// actual cursor offset.
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub buffer: Option<ExampleBuffer>,
30
31 /// The context retrieved for the prediction. This requires the worktree to
32 /// be loaded and the language server to be started.
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub context: Option<ExampleContext>,
35
36 /// The input and expected output from the edit prediction model.
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub prompt: Option<ExamplePrompt>,
39
40 /// The actual predictions from the model.
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub predictions: Vec<ExamplePrediction>,
43
44 /// The scores, for how well the actual predictions match the expected
45 /// predictions.
46 #[serde(default, skip_serializing_if = "Vec::is_empty")]
47 pub score: Vec<ExampleScore>,
48
49 /// The application state used to process this example.
50 #[serde(skip)]
51 pub state: Option<ExampleState>,
52}
53
54#[derive(Clone, Debug)]
55pub struct ExampleState {
56 pub project: Entity<Project>,
57 pub buffer: Entity<Buffer>,
58 pub cursor_position: Anchor,
59 pub _open_buffers: OpenedBuffers,
60}
61
62#[derive(Clone, Debug, Serialize, Deserialize)]
63pub struct ExampleContext {
64 pub files: Arc<[RelatedFile]>,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct ExampleBuffer {
69 pub content: String,
70 pub cursor_row: u32,
71 pub cursor_column: u32,
72 pub cursor_offset: usize,
73 pub context_range: Range<usize>,
74 pub editable_range: Range<usize>,
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize)]
78pub struct ExamplePrompt {
79 pub input: String,
80 pub expected_output: String,
81 pub format: PromptFormat,
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize)]
85pub struct ExamplePrediction {
86 pub actual_patch: String,
87 pub actual_output: String,
88 pub provider: PredictionProvider,
89}
90
91#[derive(Clone, Debug, Serialize, Deserialize)]
92pub struct ExampleScore {
93 pub delta_chr_f: f32,
94}
95
96impl Example {
97 pub fn repo_name(&self) -> Result<RepoName<'_>> {
98 // git@github.com:owner/repo.git
99 if self.spec.repository_url.contains('@') {
100 let (owner, repo) = self
101 .spec
102 .repository_url
103 .split_once(':')
104 .context("expected : in git url")?
105 .1
106 .split_once('/')
107 .context("expected / in git url")?;
108 Ok(RepoName {
109 owner: Cow::Borrowed(owner),
110 name: Cow::Borrowed(repo.trim_end_matches(".git")),
111 })
112 // http://github.com/owner/repo.git
113 } else {
114 let url = Url::parse(&self.spec.repository_url)?;
115 let mut segments = url.path_segments().context("empty http url")?;
116 let owner = segments
117 .next()
118 .context("expected owner path segment")?
119 .to_string();
120 let repo = segments
121 .next()
122 .context("expected repo path segment")?
123 .trim_end_matches(".git")
124 .to_string();
125 assert!(segments.next().is_none());
126
127 Ok(RepoName {
128 owner: Cow::Owned(owner),
129 name: Cow::Owned(repo),
130 })
131 }
132 }
133}
134
135pub struct RepoName<'a> {
136 pub owner: Cow<'a, str>,
137 pub name: Cow<'a, str>,
138}
139
140impl RepoName<'_> {
141 pub fn worktree_path(&self) -> PathBuf {
142 WORKTREES_DIR
143 .join(self.owner.as_ref())
144 .join(self.name.as_ref())
145 }
146}
147
148pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
149 let mut examples = Vec::new();
150
151 for path in inputs {
152 let is_stdin = path.as_path() == Path::new("-");
153 let content = if is_stdin {
154 let mut buffer = String::new();
155 std::io::stdin()
156 .read_to_string(&mut buffer)
157 .expect("Failed to read from stdin");
158 buffer
159 } else {
160 std::fs::read_to_string(path)
161 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
162 };
163 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
164 let ext = if !is_stdin {
165 path.extension()
166 .map(|ext| ext.to_string_lossy().to_string())
167 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
168 } else {
169 "jsonl".to_string()
170 };
171
172 match ext.as_ref() {
173 "json" => {
174 let mut example =
175 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
176 panic!("Failed to parse example file: {}\n{error}", path.display())
177 });
178 if example.spec.name.is_empty() {
179 example.spec.name = filename;
180 }
181 examples.push(example);
182 }
183 "jsonl" => examples.extend(
184 content
185 .lines()
186 .enumerate()
187 .map(|(line_ix, line)| {
188 let mut example =
189 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
190 panic!(
191 "Failed to parse example on {}:{}\n{error}",
192 path.display(),
193 line_ix + 1
194 )
195 });
196 if example.spec.name.is_empty() {
197 example.spec.name = format!("{filename}-{line_ix}")
198 }
199 example
200 })
201 .collect::<Vec<Example>>(),
202 ),
203 "md" => {
204 let mut example = parse_markdown_example(&content).unwrap();
205 if example.spec.name.is_empty() {
206 example.spec.name = filename;
207 }
208 examples.push(example);
209 }
210 ext => {
211 panic!("{} has invalid example extension `{ext}`", path.display())
212 }
213 }
214 }
215
216 examples
217}
218
219pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
220 examples.sort_by(|a, b| {
221 a.spec
222 .repository_url
223 .cmp(&b.spec.repository_url)
224 .then(b.spec.revision.cmp(&a.spec.revision))
225 });
226}
227
228pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
229 let mut examples_by_repo = HashMap::default();
230 for example in examples.iter_mut() {
231 examples_by_repo
232 .entry(example.spec.repository_url.clone())
233 .or_insert_with(Vec::new)
234 .push(example);
235 }
236 examples_by_repo.into_values().collect()
237}
238
239fn parse_markdown_example(input: &str) -> Result<Example> {
240 let spec = ExampleSpec::from_markdown(input)?;
241 Ok(Example {
242 spec,
243 buffer: None,
244 context: None,
245 prompt: None,
246 predictions: Vec::new(),
247 score: Vec::new(),
248 state: None,
249 })
250}