1use crate::PredictionProvider;
2use crate::paths::WORKTREES_DIR;
3use anyhow::{Context as _, Result};
4use collections::HashMap;
5use edit_prediction::example_spec::ExampleSpec;
6use edit_prediction::udiff::OpenedBuffers;
7use gpui::Entity;
8use http_client::Url;
9use language::{Anchor, Buffer};
10use project::Project;
11use serde::{Deserialize, Serialize};
12use std::{
13 borrow::Cow,
14 io::Read,
15 ops::Range,
16 path::{Path, PathBuf},
17 sync::Arc,
18};
19use zeta_prompt::RelatedFile;
20
21#[derive(Clone, Debug, Serialize, Deserialize)]
22pub struct Example {
23 #[serde(flatten)]
24 pub spec: ExampleSpec,
25
26 /// The full content of the file where an edit is being predicted, and the
27 /// actual cursor offset.
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub prompt_inputs: Option<ExamplePromptInputs>,
30
31 /// The input and expected output from the edit prediction model.
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub prompt: Option<ExamplePrompt>,
34
35 /// The actual predictions from the model.
36 #[serde(default, skip_serializing_if = "Vec::is_empty")]
37 pub predictions: Vec<ExamplePrediction>,
38
39 /// The scores, for how well the actual predictions match the expected
40 /// predictions.
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub score: Vec<ExampleScore>,
43
44 /// The application state used to process this example.
45 #[serde(skip)]
46 pub state: Option<ExampleState>,
47}
48
49#[derive(Clone, Debug)]
50pub struct ExampleState {
51 pub project: Entity<Project>,
52 pub buffer: Entity<Buffer>,
53 pub cursor_position: Anchor,
54 pub _open_buffers: OpenedBuffers,
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct ExamplePromptInputs {
59 pub content: String,
60 pub cursor_row: u32,
61 pub cursor_column: u32,
62 pub cursor_offset: usize,
63 pub context_range: Range<usize>,
64 pub editable_range: Range<usize>,
65 pub edit_history: Vec<Arc<zeta_prompt::Event>>,
66 pub related_files: Option<Vec<RelatedFile>>,
67}
68
69#[derive(Clone, Debug, Serialize, Deserialize)]
70pub struct ExamplePrompt {
71 pub input: String,
72 pub expected_output: String,
73 pub provider: PredictionProvider,
74}
75
76#[derive(Clone, Debug, Serialize, Deserialize)]
77pub struct ExamplePrediction {
78 pub actual_patch: String,
79 pub actual_output: String,
80 pub provider: PredictionProvider,
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize)]
84pub struct ExampleScore {
85 pub delta_chr_f: f32,
86}
87
88impl Example {
89 pub fn repo_name(&self) -> Result<RepoName<'_>> {
90 // git@github.com:owner/repo.git
91 if self.spec.repository_url.contains('@') {
92 let (owner, repo) = self
93 .spec
94 .repository_url
95 .split_once(':')
96 .context("expected : in git url")?
97 .1
98 .split_once('/')
99 .context("expected / in git url")?;
100 Ok(RepoName {
101 owner: Cow::Borrowed(owner),
102 name: Cow::Borrowed(repo.trim_end_matches(".git")),
103 })
104 // http://github.com/owner/repo.git
105 } else {
106 let url = Url::parse(&self.spec.repository_url)?;
107 let mut segments = url.path_segments().context("empty http url")?;
108 let owner = segments
109 .next()
110 .context("expected owner path segment")?
111 .to_string();
112 let repo = segments
113 .next()
114 .context("expected repo path segment")?
115 .trim_end_matches(".git")
116 .to_string();
117 assert!(segments.next().is_none());
118
119 Ok(RepoName {
120 owner: Cow::Owned(owner),
121 name: Cow::Owned(repo),
122 })
123 }
124 }
125}
126
127pub struct RepoName<'a> {
128 pub owner: Cow<'a, str>,
129 pub name: Cow<'a, str>,
130}
131
132impl RepoName<'_> {
133 pub fn worktree_path(&self) -> PathBuf {
134 WORKTREES_DIR
135 .join(self.owner.as_ref())
136 .join(self.name.as_ref())
137 }
138}
139
140pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
141 let mut examples = Vec::new();
142
143 for path in inputs {
144 let is_stdin = path.as_path() == Path::new("-");
145 let content = if is_stdin {
146 let mut buffer = String::new();
147 std::io::stdin()
148 .read_to_string(&mut buffer)
149 .expect("Failed to read from stdin");
150 buffer
151 } else {
152 std::fs::read_to_string(path)
153 .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
154 };
155 let filename = path.file_stem().unwrap().to_string_lossy().to_string();
156 let ext = if !is_stdin {
157 path.extension()
158 .map(|ext| ext.to_string_lossy().to_string())
159 .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
160 } else {
161 "jsonl".to_string()
162 };
163
164 match ext.as_ref() {
165 "json" => {
166 let mut example =
167 serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
168 panic!("Failed to parse example file: {}\n{error}", path.display())
169 });
170 if example.spec.name.is_empty() {
171 example.spec.name = filename;
172 }
173 examples.push(example);
174 }
175 "jsonl" => examples.extend(
176 content
177 .lines()
178 .enumerate()
179 .map(|(line_ix, line)| {
180 let mut example =
181 serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
182 panic!(
183 "Failed to parse example on {}:{}\n{error}",
184 path.display(),
185 line_ix + 1
186 )
187 });
188 if example.spec.name.is_empty() {
189 example.spec.name = format!("{filename}-{line_ix}")
190 }
191 example
192 })
193 .collect::<Vec<Example>>(),
194 ),
195 "md" => {
196 let mut example = parse_markdown_example(&content).unwrap();
197 if example.spec.name.is_empty() {
198 example.spec.name = filename;
199 }
200 examples.push(example);
201 }
202 ext => {
203 panic!("{} has invalid example extension `{ext}`", path.display())
204 }
205 }
206 }
207
208 examples
209}
210
211pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
212 examples.sort_by(|a, b| {
213 a.spec
214 .repository_url
215 .cmp(&b.spec.repository_url)
216 .then(b.spec.revision.cmp(&a.spec.revision))
217 });
218}
219
220pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
221 let mut examples_by_repo = HashMap::default();
222 for example in examples.iter_mut() {
223 examples_by_repo
224 .entry(example.spec.repository_url.clone())
225 .or_insert_with(Vec::new)
226 .push(example);
227 }
228 examples_by_repo.into_values().collect()
229}
230
231fn parse_markdown_example(input: &str) -> Result<Example> {
232 let spec = ExampleSpec::from_markdown(input)?;
233 Ok(Example {
234 spec,
235 prompt_inputs: None,
236 prompt: None,
237 predictions: Vec::new(),
238 score: Vec::new(),
239 state: None,
240 })
241}