1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::Display,
5 ops::{Add, Range, Sub},
6 path::{Path, PathBuf},
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::PredictEditsGitInfo;
13
14// TODO: snippet ordering within file / relative to excerpt
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PredictEditsRequest {
18 pub excerpt: String,
19 pub excerpt_path: Arc<Path>,
20 /// Within file
21 pub excerpt_range: Range<usize>,
22 pub excerpt_line_range: Range<Line>,
23 pub cursor_point: Point,
24 /// Within `signatures`
25 pub excerpt_parent: Option<usize>,
26 #[serde(skip_serializing_if = "Vec::is_empty", default)]
27 pub included_files: Vec<IncludedFile>,
28 #[serde(skip_serializing_if = "Vec::is_empty", default)]
29 pub signatures: Vec<Signature>,
30 #[serde(skip_serializing_if = "Vec::is_empty", default)]
31 pub referenced_declarations: Vec<ReferencedDeclaration>,
32 pub events: Vec<Event>,
33 #[serde(default)]
34 pub can_collect_data: bool,
35 #[serde(skip_serializing_if = "Vec::is_empty", default)]
36 pub diagnostic_groups: Vec<DiagnosticGroup>,
37 #[serde(skip_serializing_if = "is_default", default)]
38 pub diagnostic_groups_truncated: bool,
39 /// Info about the git repository state, only present when can_collect_data is true.
40 #[serde(skip_serializing_if = "Option::is_none", default)]
41 pub git_info: Option<PredictEditsGitInfo>,
42 // Only available to staff
43 #[serde(default)]
44 pub debug_info: bool,
45 #[serde(skip_serializing_if = "Option::is_none", default)]
46 pub prompt_max_bytes: Option<usize>,
47 #[serde(default)]
48 pub prompt_format: PromptFormat,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct IncludedFile {
53 pub path: Arc<Path>,
54 pub max_row: Line,
55 pub excerpts: Vec<Excerpt>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct Excerpt {
60 pub start_line: Line,
61 pub text: Arc<str>,
62}
63
64#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
65pub enum PromptFormat {
66 MarkedExcerpt,
67 LabeledSections,
68 NumLinesUniDiff,
69 /// Prompt format intended for use via zeta_cli
70 OnlySnippets,
71}
72
73impl PromptFormat {
74 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
75}
76
77impl Default for PromptFormat {
78 fn default() -> Self {
79 Self::DEFAULT
80 }
81}
82
83impl PromptFormat {
84 pub fn iter() -> impl Iterator<Item = Self> {
85 <Self as strum::IntoEnumIterator>::iter()
86 }
87}
88
89impl std::fmt::Display for PromptFormat {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self {
92 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
93 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
94 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
95 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
102#[serde(tag = "event")]
103pub enum Event {
104 BufferChange {
105 path: Option<PathBuf>,
106 old_path: Option<PathBuf>,
107 diff: String,
108 predicted: bool,
109 },
110}
111
112impl Display for Event {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 match self {
115 Event::BufferChange {
116 path,
117 old_path,
118 diff,
119 predicted,
120 } => {
121 let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
122 let old_path = old_path.as_deref().unwrap_or(new_path);
123
124 if *predicted {
125 write!(
126 f,
127 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
128 old_path.display(),
129 new_path.display()
130 )
131 } else {
132 write!(
133 f,
134 "--- a/{}\n+++ b/{}\n{diff}",
135 old_path.display(),
136 new_path.display()
137 )
138 }
139 }
140 }
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct Signature {
146 pub text: String,
147 pub text_is_truncated: bool,
148 #[serde(skip_serializing_if = "Option::is_none", default)]
149 pub parent_index: Option<usize>,
150 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
151 /// file is implicitly the file that contains the descendant declaration or excerpt.
152 pub range: Range<Line>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ReferencedDeclaration {
157 pub path: Arc<Path>,
158 pub text: String,
159 pub text_is_truncated: bool,
160 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
161 pub range: Range<Line>,
162 /// Range within `text`
163 pub signature_range: Range<usize>,
164 /// Index within `signatures`.
165 #[serde(skip_serializing_if = "Option::is_none", default)]
166 pub parent_index: Option<usize>,
167 pub score_components: DeclarationScoreComponents,
168 pub signature_score: f32,
169 pub declaration_score: f32,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct DeclarationScoreComponents {
174 pub is_same_file: bool,
175 pub is_referenced_nearby: bool,
176 pub is_referenced_in_breadcrumb: bool,
177 pub reference_count: usize,
178 pub same_file_declaration_count: usize,
179 pub declaration_count: usize,
180 pub reference_line_distance: u32,
181 pub declaration_line_distance: u32,
182 pub excerpt_vs_item_jaccard: f32,
183 pub excerpt_vs_signature_jaccard: f32,
184 pub adjacent_vs_item_jaccard: f32,
185 pub adjacent_vs_signature_jaccard: f32,
186 pub excerpt_vs_item_weighted_overlap: f32,
187 pub excerpt_vs_signature_weighted_overlap: f32,
188 pub adjacent_vs_item_weighted_overlap: f32,
189 pub adjacent_vs_signature_weighted_overlap: f32,
190 pub path_import_match_count: usize,
191 pub wildcard_path_import_match_count: usize,
192 pub import_similarity: f32,
193 pub max_import_similarity: f32,
194 pub normalized_import_similarity: f32,
195 pub wildcard_import_similarity: f32,
196 pub normalized_wildcard_import_similarity: f32,
197 pub included_by_others: usize,
198 pub includes_others: usize,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202#[serde(transparent)]
203pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct PredictEditsResponse {
207 pub request_id: Uuid,
208 pub edits: Vec<Edit>,
209 pub debug_info: Option<DebugInfo>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct DebugInfo {
214 pub prompt: String,
215 pub prompt_planning_time: Duration,
216 pub model_response: String,
217 pub inference_time: Duration,
218 pub parsing_time: Duration,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct Edit {
223 pub path: Arc<Path>,
224 pub range: Range<Line>,
225 pub content: String,
226}
227
228fn is_default<T: Default + PartialEq>(value: &T) -> bool {
229 *value == T::default()
230}
231
232#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
233pub struct Point {
234 pub line: Line,
235 pub column: u32,
236}
237
238#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
239#[serde(transparent)]
240pub struct Line(pub u32);
241
242impl Add for Line {
243 type Output = Self;
244
245 fn add(self, rhs: Self) -> Self::Output {
246 Self(self.0 + rhs.0)
247 }
248}
249
250impl Sub for Line {
251 type Output = Self;
252
253 fn sub(self, rhs: Self) -> Self::Output {
254 Self(self.0 - rhs.0)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use indoc::indoc;
262 use pretty_assertions::assert_eq;
263
264 #[test]
265 fn test_event_display() {
266 let ev = Event::BufferChange {
267 path: None,
268 old_path: None,
269 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
270 predicted: false,
271 };
272 assert_eq!(
273 ev.to_string(),
274 indoc! {"
275 --- a/untitled
276 +++ b/untitled
277 @@ -1,2 +1,2 @@
278 -a
279 -b
280 "}
281 );
282
283 let ev = Event::BufferChange {
284 path: Some(PathBuf::from("foo/bar.txt")),
285 old_path: Some(PathBuf::from("foo/bar.txt")),
286 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
287 predicted: false,
288 };
289 assert_eq!(
290 ev.to_string(),
291 indoc! {"
292 --- a/foo/bar.txt
293 +++ b/foo/bar.txt
294 @@ -1,2 +1,2 @@
295 -a
296 -b
297 "}
298 );
299
300 let ev = Event::BufferChange {
301 path: Some(PathBuf::from("abc.txt")),
302 old_path: Some(PathBuf::from("123.txt")),
303 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
304 predicted: false,
305 };
306 assert_eq!(
307 ev.to_string(),
308 indoc! {"
309 --- a/123.txt
310 +++ b/abc.txt
311 @@ -1,2 +1,2 @@
312 -a
313 -b
314 "}
315 );
316
317 let ev = Event::BufferChange {
318 path: Some(PathBuf::from("abc.txt")),
319 old_path: Some(PathBuf::from("123.txt")),
320 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
321 predicted: true,
322 };
323 assert_eq!(
324 ev.to_string(),
325 indoc! {"
326 // User accepted prediction:
327 --- a/123.txt
328 +++ b/abc.txt
329 @@ -1,2 +1,2 @@
330 -a
331 -b
332 "}
333 );
334 }
335}