1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::{Display, Write as _},
5 ops::{Add, Range, Sub},
6 path::{Path, PathBuf},
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::PredictEditsGitInfo;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlanContextRetrievalRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 pub excerpt_line_range: Range<Line>,
19 pub cursor_file_max_row: Line,
20 pub events: Vec<Event>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredictEditsRequest {
25 pub excerpt: String,
26 pub excerpt_path: Arc<Path>,
27 /// Within file
28 pub excerpt_range: Range<usize>,
29 pub excerpt_line_range: Range<Line>,
30 pub cursor_point: Point,
31 /// Within `signatures`
32 pub excerpt_parent: Option<usize>,
33 #[serde(skip_serializing_if = "Vec::is_empty", default)]
34 pub included_files: Vec<IncludedFile>,
35 #[serde(skip_serializing_if = "Vec::is_empty", default)]
36 pub signatures: Vec<Signature>,
37 #[serde(skip_serializing_if = "Vec::is_empty", default)]
38 pub referenced_declarations: Vec<ReferencedDeclaration>,
39 pub events: Vec<Event>,
40 #[serde(default)]
41 pub can_collect_data: bool,
42 #[serde(skip_serializing_if = "Vec::is_empty", default)]
43 pub diagnostic_groups: Vec<DiagnosticGroup>,
44 #[serde(skip_serializing_if = "is_default", default)]
45 pub diagnostic_groups_truncated: bool,
46 /// Info about the git repository state, only present when can_collect_data is true.
47 #[serde(skip_serializing_if = "Option::is_none", default)]
48 pub git_info: Option<PredictEditsGitInfo>,
49 // Only available to staff
50 #[serde(default)]
51 pub debug_info: bool,
52 #[serde(skip_serializing_if = "Option::is_none", default)]
53 pub prompt_max_bytes: Option<usize>,
54 #[serde(default)]
55 pub prompt_format: PromptFormat,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct IncludedFile {
60 pub path: Arc<Path>,
61 pub max_row: Line,
62 pub excerpts: Vec<Excerpt>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Excerpt {
67 pub start_line: Line,
68 pub text: Arc<str>,
69}
70
71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
72pub enum PromptFormat {
73 MarkedExcerpt,
74 LabeledSections,
75 NumLinesUniDiff,
76 OldTextNewText,
77 /// Prompt format intended for use via zeta_cli
78 OnlySnippets,
79 /// One-sentence instructions used in fine-tuned models
80 Minimal,
81 /// One-sentence instructions + FIM-like template
82 MinimalQwen,
83}
84
85impl PromptFormat {
86 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
87}
88
89impl Default for PromptFormat {
90 fn default() -> Self {
91 Self::DEFAULT
92 }
93}
94
95impl PromptFormat {
96 pub fn iter() -> impl Iterator<Item = Self> {
97 <Self as strum::IntoEnumIterator>::iter()
98 }
99}
100
101impl std::fmt::Display for PromptFormat {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
105 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
106 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
107 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
108 PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
109 PromptFormat::Minimal => write!(f, "Minimal"),
110 PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
111 }
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
117#[serde(tag = "event")]
118pub enum Event {
119 BufferChange {
120 path: Option<PathBuf>,
121 old_path: Option<PathBuf>,
122 diff: String,
123 predicted: bool,
124 },
125}
126
127impl Display for Event {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 match self {
130 Event::BufferChange {
131 path,
132 old_path,
133 diff,
134 predicted,
135 } => {
136 let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
137 let old_path = old_path.as_deref().unwrap_or(new_path);
138
139 if *predicted {
140 write!(
141 f,
142 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
143 DiffPathFmt(old_path),
144 DiffPathFmt(new_path)
145 )
146 } else {
147 write!(
148 f,
149 "--- a/{}\n+++ b/{}\n{diff}",
150 DiffPathFmt(old_path),
151 DiffPathFmt(new_path)
152 )
153 }
154 }
155 }
156 }
157}
158
159/// always format the Path as a unix path with `/` as the path sep in Diffs
160pub struct DiffPathFmt<'a>(pub &'a Path);
161
162impl<'a> std::fmt::Display for DiffPathFmt<'a> {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 let mut is_first = true;
165 for component in self.0.components() {
166 if !is_first {
167 f.write_char('/')?;
168 } else {
169 is_first = false;
170 }
171 write!(f, "{}", component.as_os_str().display())?;
172 }
173 Ok(())
174 }
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct Signature {
179 pub text: String,
180 pub text_is_truncated: bool,
181 #[serde(skip_serializing_if = "Option::is_none", default)]
182 pub parent_index: Option<usize>,
183 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
184 /// file is implicitly the file that contains the descendant declaration or excerpt.
185 pub range: Range<Line>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct ReferencedDeclaration {
190 pub path: Arc<Path>,
191 pub text: String,
192 pub text_is_truncated: bool,
193 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
194 pub range: Range<Line>,
195 /// Range within `text`
196 pub signature_range: Range<usize>,
197 /// Index within `signatures`.
198 #[serde(skip_serializing_if = "Option::is_none", default)]
199 pub parent_index: Option<usize>,
200 pub score_components: DeclarationScoreComponents,
201 pub signature_score: f32,
202 pub declaration_score: f32,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct DeclarationScoreComponents {
207 pub is_same_file: bool,
208 pub is_referenced_nearby: bool,
209 pub is_referenced_in_breadcrumb: bool,
210 pub reference_count: usize,
211 pub same_file_declaration_count: usize,
212 pub declaration_count: usize,
213 pub reference_line_distance: u32,
214 pub declaration_line_distance: u32,
215 pub excerpt_vs_item_jaccard: f32,
216 pub excerpt_vs_signature_jaccard: f32,
217 pub adjacent_vs_item_jaccard: f32,
218 pub adjacent_vs_signature_jaccard: f32,
219 pub excerpt_vs_item_weighted_overlap: f32,
220 pub excerpt_vs_signature_weighted_overlap: f32,
221 pub adjacent_vs_item_weighted_overlap: f32,
222 pub adjacent_vs_signature_weighted_overlap: f32,
223 pub path_import_match_count: usize,
224 pub wildcard_path_import_match_count: usize,
225 pub import_similarity: f32,
226 pub max_import_similarity: f32,
227 pub normalized_import_similarity: f32,
228 pub wildcard_import_similarity: f32,
229 pub normalized_wildcard_import_similarity: f32,
230 pub included_by_others: usize,
231 pub includes_others: usize,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
235#[serde(transparent)]
236pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct PredictEditsResponse {
240 pub request_id: Uuid,
241 pub edits: Vec<Edit>,
242 pub debug_info: Option<DebugInfo>,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct DebugInfo {
247 pub prompt: String,
248 pub prompt_planning_time: Duration,
249 pub model_response: String,
250 pub inference_time: Duration,
251 pub parsing_time: Duration,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct Edit {
256 pub path: Arc<Path>,
257 pub range: Range<Line>,
258 pub content: String,
259}
260
261fn is_default<T: Default + PartialEq>(value: &T) -> bool {
262 *value == T::default()
263}
264
265#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
266pub struct Point {
267 pub line: Line,
268 pub column: u32,
269}
270
271#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
272#[serde(transparent)]
273pub struct Line(pub u32);
274
275impl Add for Line {
276 type Output = Self;
277
278 fn add(self, rhs: Self) -> Self::Output {
279 Self(self.0 + rhs.0)
280 }
281}
282
283impl Sub for Line {
284 type Output = Self;
285
286 fn sub(self, rhs: Self) -> Self::Output {
287 Self(self.0 - rhs.0)
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use indoc::indoc;
295 use pretty_assertions::assert_eq;
296
297 #[test]
298 fn test_event_display() {
299 let ev = Event::BufferChange {
300 path: None,
301 old_path: None,
302 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
303 predicted: false,
304 };
305 assert_eq!(
306 ev.to_string(),
307 indoc! {"
308 --- a/untitled
309 +++ b/untitled
310 @@ -1,2 +1,2 @@
311 -a
312 -b
313 "}
314 );
315
316 let ev = Event::BufferChange {
317 path: Some(PathBuf::from("foo/bar.txt")),
318 old_path: Some(PathBuf::from("foo/bar.txt")),
319 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
320 predicted: false,
321 };
322 assert_eq!(
323 ev.to_string(),
324 indoc! {"
325 --- a/foo/bar.txt
326 +++ b/foo/bar.txt
327 @@ -1,2 +1,2 @@
328 -a
329 -b
330 "}
331 );
332
333 let ev = Event::BufferChange {
334 path: Some(PathBuf::from("abc.txt")),
335 old_path: Some(PathBuf::from("123.txt")),
336 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
337 predicted: false,
338 };
339 assert_eq!(
340 ev.to_string(),
341 indoc! {"
342 --- a/123.txt
343 +++ b/abc.txt
344 @@ -1,2 +1,2 @@
345 -a
346 -b
347 "}
348 );
349
350 let ev = Event::BufferChange {
351 path: Some(PathBuf::from("abc.txt")),
352 old_path: Some(PathBuf::from("123.txt")),
353 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
354 predicted: true,
355 };
356 assert_eq!(
357 ev.to_string(),
358 indoc! {"
359 // User accepted prediction:
360 --- a/123.txt
361 +++ b/abc.txt
362 @@ -1,2 +1,2 @@
363 -a
364 -b
365 "}
366 );
367 }
368}