1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::{Display, Write as _},
5 ops::{Add, Range, Sub},
6 path::{Path, PathBuf},
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::PredictEditsGitInfo;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlanContextRetrievalRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 pub excerpt_line_range: Range<Line>,
19 pub cursor_file_max_row: Line,
20 pub events: Vec<Event>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredictEditsRequest {
25 pub excerpt: String,
26 pub excerpt_path: Arc<Path>,
27 /// Within file
28 pub excerpt_range: Range<usize>,
29 pub excerpt_line_range: Range<Line>,
30 pub cursor_point: Point,
31 /// Within `signatures`
32 pub excerpt_parent: Option<usize>,
33 #[serde(skip_serializing_if = "Vec::is_empty", default)]
34 pub included_files: Vec<IncludedFile>,
35 #[serde(skip_serializing_if = "Vec::is_empty", default)]
36 pub signatures: Vec<Signature>,
37 #[serde(skip_serializing_if = "Vec::is_empty", default)]
38 pub referenced_declarations: Vec<ReferencedDeclaration>,
39 pub events: Vec<Event>,
40 #[serde(default)]
41 pub can_collect_data: bool,
42 #[serde(skip_serializing_if = "Vec::is_empty", default)]
43 pub diagnostic_groups: Vec<DiagnosticGroup>,
44 #[serde(skip_serializing_if = "is_default", default)]
45 pub diagnostic_groups_truncated: bool,
46 /// Info about the git repository state, only present when can_collect_data is true.
47 #[serde(skip_serializing_if = "Option::is_none", default)]
48 pub git_info: Option<PredictEditsGitInfo>,
49 // Only available to staff
50 #[serde(default)]
51 pub debug_info: bool,
52 #[serde(skip_serializing_if = "Option::is_none", default)]
53 pub prompt_max_bytes: Option<usize>,
54 #[serde(default)]
55 pub prompt_format: PromptFormat,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct IncludedFile {
60 pub path: Arc<Path>,
61 pub max_row: Line,
62 pub excerpts: Vec<Excerpt>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Excerpt {
67 pub start_line: Line,
68 pub text: Arc<str>,
69}
70
71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
72pub enum PromptFormat {
73 MarkedExcerpt,
74 LabeledSections,
75 NumLinesUniDiff,
76 /// Prompt format intended for use via zeta_cli
77 OnlySnippets,
78}
79
80impl PromptFormat {
81 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
82}
83
84impl Default for PromptFormat {
85 fn default() -> Self {
86 Self::DEFAULT
87 }
88}
89
90impl PromptFormat {
91 pub fn iter() -> impl Iterator<Item = Self> {
92 <Self as strum::IntoEnumIterator>::iter()
93 }
94}
95
96impl std::fmt::Display for PromptFormat {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 match self {
99 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
100 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
101 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
102 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
103 }
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
109#[serde(tag = "event")]
110pub enum Event {
111 BufferChange {
112 path: Option<PathBuf>,
113 old_path: Option<PathBuf>,
114 diff: String,
115 predicted: bool,
116 },
117}
118
119impl Display for Event {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 match self {
122 Event::BufferChange {
123 path,
124 old_path,
125 diff,
126 predicted,
127 } => {
128 let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
129 let old_path = old_path.as_deref().unwrap_or(new_path);
130
131 if *predicted {
132 write!(
133 f,
134 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
135 DiffPathFmt(old_path),
136 DiffPathFmt(new_path)
137 )
138 } else {
139 write!(
140 f,
141 "--- a/{}\n+++ b/{}\n{diff}",
142 DiffPathFmt(old_path),
143 DiffPathFmt(new_path)
144 )
145 }
146 }
147 }
148 }
149}
150
151/// always format the Path as a unix path with `/` as the path sep in Diffs
152pub struct DiffPathFmt<'a>(pub &'a Path);
153
154impl<'a> std::fmt::Display for DiffPathFmt<'a> {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 let mut is_first = true;
157 for component in self.0.components() {
158 if !is_first {
159 f.write_char('/')?;
160 } else {
161 is_first = false;
162 }
163 write!(f, "{}", component.as_os_str().display())?;
164 }
165 Ok(())
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct Signature {
171 pub text: String,
172 pub text_is_truncated: bool,
173 #[serde(skip_serializing_if = "Option::is_none", default)]
174 pub parent_index: Option<usize>,
175 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
176 /// file is implicitly the file that contains the descendant declaration or excerpt.
177 pub range: Range<Line>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ReferencedDeclaration {
182 pub path: Arc<Path>,
183 pub text: String,
184 pub text_is_truncated: bool,
185 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
186 pub range: Range<Line>,
187 /// Range within `text`
188 pub signature_range: Range<usize>,
189 /// Index within `signatures`.
190 #[serde(skip_serializing_if = "Option::is_none", default)]
191 pub parent_index: Option<usize>,
192 pub score_components: DeclarationScoreComponents,
193 pub signature_score: f32,
194 pub declaration_score: f32,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct DeclarationScoreComponents {
199 pub is_same_file: bool,
200 pub is_referenced_nearby: bool,
201 pub is_referenced_in_breadcrumb: bool,
202 pub reference_count: usize,
203 pub same_file_declaration_count: usize,
204 pub declaration_count: usize,
205 pub reference_line_distance: u32,
206 pub declaration_line_distance: u32,
207 pub excerpt_vs_item_jaccard: f32,
208 pub excerpt_vs_signature_jaccard: f32,
209 pub adjacent_vs_item_jaccard: f32,
210 pub adjacent_vs_signature_jaccard: f32,
211 pub excerpt_vs_item_weighted_overlap: f32,
212 pub excerpt_vs_signature_weighted_overlap: f32,
213 pub adjacent_vs_item_weighted_overlap: f32,
214 pub adjacent_vs_signature_weighted_overlap: f32,
215 pub path_import_match_count: usize,
216 pub wildcard_path_import_match_count: usize,
217 pub import_similarity: f32,
218 pub max_import_similarity: f32,
219 pub normalized_import_similarity: f32,
220 pub wildcard_import_similarity: f32,
221 pub normalized_wildcard_import_similarity: f32,
222 pub included_by_others: usize,
223 pub includes_others: usize,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
227#[serde(transparent)]
228pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct PredictEditsResponse {
232 pub request_id: Uuid,
233 pub edits: Vec<Edit>,
234 pub debug_info: Option<DebugInfo>,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct DebugInfo {
239 pub prompt: String,
240 pub prompt_planning_time: Duration,
241 pub model_response: String,
242 pub inference_time: Duration,
243 pub parsing_time: Duration,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct Edit {
248 pub path: Arc<Path>,
249 pub range: Range<Line>,
250 pub content: String,
251}
252
253fn is_default<T: Default + PartialEq>(value: &T) -> bool {
254 *value == T::default()
255}
256
257#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
258pub struct Point {
259 pub line: Line,
260 pub column: u32,
261}
262
263#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
264#[serde(transparent)]
265pub struct Line(pub u32);
266
267impl Add for Line {
268 type Output = Self;
269
270 fn add(self, rhs: Self) -> Self::Output {
271 Self(self.0 + rhs.0)
272 }
273}
274
275impl Sub for Line {
276 type Output = Self;
277
278 fn sub(self, rhs: Self) -> Self::Output {
279 Self(self.0 - rhs.0)
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use indoc::indoc;
287 use pretty_assertions::assert_eq;
288
289 #[test]
290 fn test_event_display() {
291 let ev = Event::BufferChange {
292 path: None,
293 old_path: None,
294 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
295 predicted: false,
296 };
297 assert_eq!(
298 ev.to_string(),
299 indoc! {"
300 --- a/untitled
301 +++ b/untitled
302 @@ -1,2 +1,2 @@
303 -a
304 -b
305 "}
306 );
307
308 let ev = Event::BufferChange {
309 path: Some(PathBuf::from("foo/bar.txt")),
310 old_path: Some(PathBuf::from("foo/bar.txt")),
311 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
312 predicted: false,
313 };
314 assert_eq!(
315 ev.to_string(),
316 indoc! {"
317 --- a/foo/bar.txt
318 +++ b/foo/bar.txt
319 @@ -1,2 +1,2 @@
320 -a
321 -b
322 "}
323 );
324
325 let ev = Event::BufferChange {
326 path: Some(PathBuf::from("abc.txt")),
327 old_path: Some(PathBuf::from("123.txt")),
328 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
329 predicted: false,
330 };
331 assert_eq!(
332 ev.to_string(),
333 indoc! {"
334 --- a/123.txt
335 +++ b/abc.txt
336 @@ -1,2 +1,2 @@
337 -a
338 -b
339 "}
340 );
341
342 let ev = Event::BufferChange {
343 path: Some(PathBuf::from("abc.txt")),
344 old_path: Some(PathBuf::from("123.txt")),
345 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
346 predicted: true,
347 };
348 assert_eq!(
349 ev.to_string(),
350 indoc! {"
351 // User accepted prediction:
352 --- a/123.txt
353 +++ b/abc.txt
354 @@ -1,2 +1,2 @@
355 -a
356 -b
357 "}
358 );
359 }
360}