1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::{Display, Write as _},
5 ops::{Add, Range, Sub},
6 path::{Path, PathBuf},
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::PredictEditsGitInfo;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlanContextRetrievalRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 pub excerpt_line_range: Range<Line>,
19 pub cursor_file_max_row: Line,
20 pub events: Vec<Event>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredictEditsRequest {
25 pub excerpt: String,
26 pub excerpt_path: Arc<Path>,
27 /// Within file
28 pub excerpt_range: Range<usize>,
29 pub excerpt_line_range: Range<Line>,
30 pub cursor_point: Point,
31 /// Within `signatures`
32 pub excerpt_parent: Option<usize>,
33 #[serde(skip_serializing_if = "Vec::is_empty", default)]
34 pub included_files: Vec<IncludedFile>,
35 #[serde(skip_serializing_if = "Vec::is_empty", default)]
36 pub signatures: Vec<Signature>,
37 #[serde(skip_serializing_if = "Vec::is_empty", default)]
38 pub referenced_declarations: Vec<ReferencedDeclaration>,
39 pub events: Vec<Event>,
40 #[serde(default)]
41 pub can_collect_data: bool,
42 #[serde(skip_serializing_if = "Vec::is_empty", default)]
43 pub diagnostic_groups: Vec<DiagnosticGroup>,
44 #[serde(skip_serializing_if = "is_default", default)]
45 pub diagnostic_groups_truncated: bool,
46 /// Info about the git repository state, only present when can_collect_data is true.
47 #[serde(skip_serializing_if = "Option::is_none", default)]
48 pub git_info: Option<PredictEditsGitInfo>,
49 // Only available to staff
50 #[serde(default)]
51 pub debug_info: bool,
52 #[serde(skip_serializing_if = "Option::is_none", default)]
53 pub prompt_max_bytes: Option<usize>,
54 #[serde(default)]
55 pub prompt_format: PromptFormat,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct IncludedFile {
60 pub path: Arc<Path>,
61 pub max_row: Line,
62 pub excerpts: Vec<Excerpt>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Excerpt {
67 pub start_line: Line,
68 pub text: Arc<str>,
69}
70
71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
72pub enum PromptFormat {
73 MarkedExcerpt,
74 LabeledSections,
75 NumLinesUniDiff,
76 OldTextNewText,
77 /// Prompt format intended for use via zeta_cli
78 OnlySnippets,
79}
80
81impl PromptFormat {
82 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
83}
84
85impl Default for PromptFormat {
86 fn default() -> Self {
87 Self::DEFAULT
88 }
89}
90
91impl PromptFormat {
92 pub fn iter() -> impl Iterator<Item = Self> {
93 <Self as strum::IntoEnumIterator>::iter()
94 }
95}
96
97impl std::fmt::Display for PromptFormat {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
101 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
102 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
103 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
104 PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
111#[serde(tag = "event")]
112pub enum Event {
113 BufferChange {
114 path: Option<PathBuf>,
115 old_path: Option<PathBuf>,
116 diff: String,
117 predicted: bool,
118 },
119}
120
121impl Display for Event {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 Event::BufferChange {
125 path,
126 old_path,
127 diff,
128 predicted,
129 } => {
130 let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
131 let old_path = old_path.as_deref().unwrap_or(new_path);
132
133 if *predicted {
134 write!(
135 f,
136 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
137 DiffPathFmt(old_path),
138 DiffPathFmt(new_path)
139 )
140 } else {
141 write!(
142 f,
143 "--- a/{}\n+++ b/{}\n{diff}",
144 DiffPathFmt(old_path),
145 DiffPathFmt(new_path)
146 )
147 }
148 }
149 }
150 }
151}
152
153/// always format the Path as a unix path with `/` as the path sep in Diffs
154pub struct DiffPathFmt<'a>(pub &'a Path);
155
156impl<'a> std::fmt::Display for DiffPathFmt<'a> {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 let mut is_first = true;
159 for component in self.0.components() {
160 if !is_first {
161 f.write_char('/')?;
162 } else {
163 is_first = false;
164 }
165 write!(f, "{}", component.as_os_str().display())?;
166 }
167 Ok(())
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct Signature {
173 pub text: String,
174 pub text_is_truncated: bool,
175 #[serde(skip_serializing_if = "Option::is_none", default)]
176 pub parent_index: Option<usize>,
177 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
178 /// file is implicitly the file that contains the descendant declaration or excerpt.
179 pub range: Range<Line>,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct ReferencedDeclaration {
184 pub path: Arc<Path>,
185 pub text: String,
186 pub text_is_truncated: bool,
187 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
188 pub range: Range<Line>,
189 /// Range within `text`
190 pub signature_range: Range<usize>,
191 /// Index within `signatures`.
192 #[serde(skip_serializing_if = "Option::is_none", default)]
193 pub parent_index: Option<usize>,
194 pub score_components: DeclarationScoreComponents,
195 pub signature_score: f32,
196 pub declaration_score: f32,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct DeclarationScoreComponents {
201 pub is_same_file: bool,
202 pub is_referenced_nearby: bool,
203 pub is_referenced_in_breadcrumb: bool,
204 pub reference_count: usize,
205 pub same_file_declaration_count: usize,
206 pub declaration_count: usize,
207 pub reference_line_distance: u32,
208 pub declaration_line_distance: u32,
209 pub excerpt_vs_item_jaccard: f32,
210 pub excerpt_vs_signature_jaccard: f32,
211 pub adjacent_vs_item_jaccard: f32,
212 pub adjacent_vs_signature_jaccard: f32,
213 pub excerpt_vs_item_weighted_overlap: f32,
214 pub excerpt_vs_signature_weighted_overlap: f32,
215 pub adjacent_vs_item_weighted_overlap: f32,
216 pub adjacent_vs_signature_weighted_overlap: f32,
217 pub path_import_match_count: usize,
218 pub wildcard_path_import_match_count: usize,
219 pub import_similarity: f32,
220 pub max_import_similarity: f32,
221 pub normalized_import_similarity: f32,
222 pub wildcard_import_similarity: f32,
223 pub normalized_wildcard_import_similarity: f32,
224 pub included_by_others: usize,
225 pub includes_others: usize,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229#[serde(transparent)]
230pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct PredictEditsResponse {
234 pub request_id: Uuid,
235 pub edits: Vec<Edit>,
236 pub debug_info: Option<DebugInfo>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct DebugInfo {
241 pub prompt: String,
242 pub prompt_planning_time: Duration,
243 pub model_response: String,
244 pub inference_time: Duration,
245 pub parsing_time: Duration,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct Edit {
250 pub path: Arc<Path>,
251 pub range: Range<Line>,
252 pub content: String,
253}
254
255fn is_default<T: Default + PartialEq>(value: &T) -> bool {
256 *value == T::default()
257}
258
259#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
260pub struct Point {
261 pub line: Line,
262 pub column: u32,
263}
264
265#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
266#[serde(transparent)]
267pub struct Line(pub u32);
268
269impl Add for Line {
270 type Output = Self;
271
272 fn add(self, rhs: Self) -> Self::Output {
273 Self(self.0 + rhs.0)
274 }
275}
276
277impl Sub for Line {
278 type Output = Self;
279
280 fn sub(self, rhs: Self) -> Self::Output {
281 Self(self.0 - rhs.0)
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use indoc::indoc;
289 use pretty_assertions::assert_eq;
290
291 #[test]
292 fn test_event_display() {
293 let ev = Event::BufferChange {
294 path: None,
295 old_path: None,
296 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
297 predicted: false,
298 };
299 assert_eq!(
300 ev.to_string(),
301 indoc! {"
302 --- a/untitled
303 +++ b/untitled
304 @@ -1,2 +1,2 @@
305 -a
306 -b
307 "}
308 );
309
310 let ev = Event::BufferChange {
311 path: Some(PathBuf::from("foo/bar.txt")),
312 old_path: Some(PathBuf::from("foo/bar.txt")),
313 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
314 predicted: false,
315 };
316 assert_eq!(
317 ev.to_string(),
318 indoc! {"
319 --- a/foo/bar.txt
320 +++ b/foo/bar.txt
321 @@ -1,2 +1,2 @@
322 -a
323 -b
324 "}
325 );
326
327 let ev = Event::BufferChange {
328 path: Some(PathBuf::from("abc.txt")),
329 old_path: Some(PathBuf::from("123.txt")),
330 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
331 predicted: false,
332 };
333 assert_eq!(
334 ev.to_string(),
335 indoc! {"
336 --- a/123.txt
337 +++ b/abc.txt
338 @@ -1,2 +1,2 @@
339 -a
340 -b
341 "}
342 );
343
344 let ev = Event::BufferChange {
345 path: Some(PathBuf::from("abc.txt")),
346 old_path: Some(PathBuf::from("123.txt")),
347 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
348 predicted: true,
349 };
350 assert_eq!(
351 ev.to_string(),
352 indoc! {"
353 // User accepted prediction:
354 --- a/123.txt
355 +++ b/abc.txt
356 @@ -1,2 +1,2 @@
357 -a
358 -b
359 "}
360 );
361 }
362}