1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::Display,
5 ops::{Add, Range, Sub},
6 path::{Path, PathBuf},
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::PredictEditsGitInfo;
13
14// TODO: snippet ordering within file / relative to excerpt
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PredictEditsRequest {
18 pub excerpt: String,
19 pub excerpt_path: Arc<Path>,
20 /// Within file
21 pub excerpt_range: Range<usize>,
22 pub excerpt_line_range: Range<Line>,
23 pub cursor_point: Point,
24 /// Within `signatures`
25 pub excerpt_parent: Option<usize>,
26 pub signatures: Vec<Signature>,
27 pub referenced_declarations: Vec<ReferencedDeclaration>,
28 pub events: Vec<Event>,
29 #[serde(default)]
30 pub can_collect_data: bool,
31 #[serde(skip_serializing_if = "Vec::is_empty", default)]
32 pub diagnostic_groups: Vec<DiagnosticGroup>,
33 #[serde(skip_serializing_if = "is_default", default)]
34 pub diagnostic_groups_truncated: bool,
35 /// Info about the git repository state, only present when can_collect_data is true.
36 #[serde(skip_serializing_if = "Option::is_none", default)]
37 pub git_info: Option<PredictEditsGitInfo>,
38 // Only available to staff
39 #[serde(default)]
40 pub debug_info: bool,
41 #[serde(skip_serializing_if = "Option::is_none", default)]
42 pub prompt_max_bytes: Option<usize>,
43 #[serde(default)]
44 pub prompt_format: PromptFormat,
45}
46
47#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
48pub enum PromptFormat {
49 MarkedExcerpt,
50 LabeledSections,
51 NumLinesUniDiff,
52 /// Prompt format intended for use via zeta_cli
53 OnlySnippets,
54}
55
56impl PromptFormat {
57 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
58}
59
60impl Default for PromptFormat {
61 fn default() -> Self {
62 Self::DEFAULT
63 }
64}
65
66impl PromptFormat {
67 pub fn iter() -> impl Iterator<Item = Self> {
68 <Self as strum::IntoEnumIterator>::iter()
69 }
70}
71
72impl std::fmt::Display for PromptFormat {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
76 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
77 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
78 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
85#[serde(tag = "event")]
86pub enum Event {
87 BufferChange {
88 path: Option<PathBuf>,
89 old_path: Option<PathBuf>,
90 diff: String,
91 predicted: bool,
92 },
93}
94
95impl Display for Event {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 match self {
98 Event::BufferChange {
99 path,
100 old_path,
101 diff,
102 predicted,
103 } => {
104 let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
105 let old_path = old_path.as_deref().unwrap_or(new_path);
106
107 if *predicted {
108 write!(
109 f,
110 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
111 old_path.display(),
112 new_path.display()
113 )
114 } else {
115 write!(
116 f,
117 "--- a/{}\n+++ b/{}\n{diff}",
118 old_path.display(),
119 new_path.display()
120 )
121 }
122 }
123 }
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct Signature {
129 pub text: String,
130 pub text_is_truncated: bool,
131 #[serde(skip_serializing_if = "Option::is_none", default)]
132 pub parent_index: Option<usize>,
133 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
134 /// file is implicitly the file that contains the descendant declaration or excerpt.
135 pub range: Range<Line>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ReferencedDeclaration {
140 pub path: Arc<Path>,
141 pub text: String,
142 pub text_is_truncated: bool,
143 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
144 pub range: Range<Line>,
145 /// Range within `text`
146 pub signature_range: Range<usize>,
147 /// Index within `signatures`.
148 #[serde(skip_serializing_if = "Option::is_none", default)]
149 pub parent_index: Option<usize>,
150 pub score_components: DeclarationScoreComponents,
151 pub signature_score: f32,
152 pub declaration_score: f32,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct DeclarationScoreComponents {
157 pub is_same_file: bool,
158 pub is_referenced_nearby: bool,
159 pub is_referenced_in_breadcrumb: bool,
160 pub reference_count: usize,
161 pub same_file_declaration_count: usize,
162 pub declaration_count: usize,
163 pub reference_line_distance: u32,
164 pub declaration_line_distance: u32,
165 pub excerpt_vs_item_jaccard: f32,
166 pub excerpt_vs_signature_jaccard: f32,
167 pub adjacent_vs_item_jaccard: f32,
168 pub adjacent_vs_signature_jaccard: f32,
169 pub excerpt_vs_item_weighted_overlap: f32,
170 pub excerpt_vs_signature_weighted_overlap: f32,
171 pub adjacent_vs_item_weighted_overlap: f32,
172 pub adjacent_vs_signature_weighted_overlap: f32,
173 pub path_import_match_count: usize,
174 pub wildcard_path_import_match_count: usize,
175 pub import_similarity: f32,
176 pub max_import_similarity: f32,
177 pub normalized_import_similarity: f32,
178 pub wildcard_import_similarity: f32,
179 pub normalized_wildcard_import_similarity: f32,
180 pub included_by_others: usize,
181 pub includes_others: usize,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185#[serde(transparent)]
186pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct PredictEditsResponse {
190 pub request_id: Uuid,
191 pub edits: Vec<Edit>,
192 pub debug_info: Option<DebugInfo>,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct DebugInfo {
197 pub prompt: String,
198 pub prompt_planning_time: Duration,
199 pub model_response: String,
200 pub inference_time: Duration,
201 pub parsing_time: Duration,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct Edit {
206 pub path: Arc<Path>,
207 pub range: Range<Line>,
208 pub content: String,
209}
210
211fn is_default<T: Default + PartialEq>(value: &T) -> bool {
212 *value == T::default()
213}
214
215#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
216pub struct Point {
217 pub line: Line,
218 pub column: u32,
219}
220
221#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
222#[serde(transparent)]
223pub struct Line(pub u32);
224
225impl Add for Line {
226 type Output = Self;
227
228 fn add(self, rhs: Self) -> Self::Output {
229 Self(self.0 + rhs.0)
230 }
231}
232
233impl Sub for Line {
234 type Output = Self;
235
236 fn sub(self, rhs: Self) -> Self::Output {
237 Self(self.0 - rhs.0)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use indoc::indoc;
245 use pretty_assertions::assert_eq;
246
247 #[test]
248 fn test_event_display() {
249 let ev = Event::BufferChange {
250 path: None,
251 old_path: None,
252 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
253 predicted: false,
254 };
255 assert_eq!(
256 ev.to_string(),
257 indoc! {"
258 --- a/untitled
259 +++ b/untitled
260 @@ -1,2 +1,2 @@
261 -a
262 -b
263 "}
264 );
265
266 let ev = Event::BufferChange {
267 path: Some(PathBuf::from("foo/bar.txt")),
268 old_path: Some(PathBuf::from("foo/bar.txt")),
269 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
270 predicted: false,
271 };
272 assert_eq!(
273 ev.to_string(),
274 indoc! {"
275 --- a/foo/bar.txt
276 +++ b/foo/bar.txt
277 @@ -1,2 +1,2 @@
278 -a
279 -b
280 "}
281 );
282
283 let ev = Event::BufferChange {
284 path: Some(PathBuf::from("abc.txt")),
285 old_path: Some(PathBuf::from("123.txt")),
286 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
287 predicted: false,
288 };
289 assert_eq!(
290 ev.to_string(),
291 indoc! {"
292 --- a/123.txt
293 +++ b/abc.txt
294 @@ -1,2 +1,2 @@
295 -a
296 -b
297 "}
298 );
299
300 let ev = Event::BufferChange {
301 path: Some(PathBuf::from("abc.txt")),
302 old_path: Some(PathBuf::from("123.txt")),
303 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
304 predicted: true,
305 };
306 assert_eq!(
307 ev.to_string(),
308 indoc! {"
309 // User accepted prediction:
310 --- a/123.txt
311 +++ b/abc.txt
312 @@ -1,2 +1,2 @@
313 -a
314 -b
315 "}
316 );
317 }
318}