1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::{Display, Write as _},
5 ops::{Add, Range, Sub},
6 path::{Path, PathBuf},
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::PredictEditsGitInfo;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlanContextRetrievalRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 pub excerpt_line_range: Range<Line>,
19 pub cursor_file_max_row: Line,
20 pub events: Vec<Event>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredictEditsRequest {
25 pub excerpt: String,
26 pub excerpt_path: Arc<Path>,
27 /// Within file
28 pub excerpt_range: Range<usize>,
29 pub excerpt_line_range: Range<Line>,
30 pub cursor_point: Point,
31 /// Within `signatures`
32 pub excerpt_parent: Option<usize>,
33 #[serde(skip_serializing_if = "Vec::is_empty", default)]
34 pub included_files: Vec<IncludedFile>,
35 #[serde(skip_serializing_if = "Vec::is_empty", default)]
36 pub signatures: Vec<Signature>,
37 #[serde(skip_serializing_if = "Vec::is_empty", default)]
38 pub referenced_declarations: Vec<ReferencedDeclaration>,
39 pub events: Vec<Event>,
40 #[serde(default)]
41 pub can_collect_data: bool,
42 #[serde(skip_serializing_if = "Vec::is_empty", default)]
43 pub diagnostic_groups: Vec<DiagnosticGroup>,
44 #[serde(skip_serializing_if = "is_default", default)]
45 pub diagnostic_groups_truncated: bool,
46 /// Info about the git repository state, only present when can_collect_data is true.
47 #[serde(skip_serializing_if = "Option::is_none", default)]
48 pub git_info: Option<PredictEditsGitInfo>,
49 // Only available to staff
50 #[serde(default)]
51 pub debug_info: bool,
52 #[serde(skip_serializing_if = "Option::is_none", default)]
53 pub prompt_max_bytes: Option<usize>,
54 #[serde(default)]
55 pub prompt_format: PromptFormat,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct IncludedFile {
60 pub path: Arc<Path>,
61 pub max_row: Line,
62 pub excerpts: Vec<Excerpt>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Excerpt {
67 pub start_line: Line,
68 pub text: Arc<str>,
69}
70
71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
72pub enum PromptFormat {
73 MarkedExcerpt,
74 LabeledSections,
75 NumLinesUniDiff,
76 OldTextNewText,
77 /// Prompt format intended for use via zeta_cli
78 OnlySnippets,
79 /// One-sentence instructions used in fine-tuned models
80 Minimal,
81 /// One-sentence instructions + FIM-like template
82 MinimalQwen,
83 /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
84 SeedCoder1120,
85}
86
87impl PromptFormat {
88 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
89}
90
91impl Default for PromptFormat {
92 fn default() -> Self {
93 Self::DEFAULT
94 }
95}
96
97impl PromptFormat {
98 pub fn iter() -> impl Iterator<Item = Self> {
99 <Self as strum::IntoEnumIterator>::iter()
100 }
101}
102
103impl std::fmt::Display for PromptFormat {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
107 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
108 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
109 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
110 PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
111 PromptFormat::Minimal => write!(f, "Minimal"),
112 PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
113 PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
114 }
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
120#[serde(tag = "event")]
121pub enum Event {
122 BufferChange {
123 path: Option<PathBuf>,
124 old_path: Option<PathBuf>,
125 diff: String,
126 predicted: bool,
127 },
128}
129
130impl Display for Event {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 match self {
133 Event::BufferChange {
134 path,
135 old_path,
136 diff,
137 predicted,
138 } => {
139 let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
140 let old_path = old_path.as_deref().unwrap_or(new_path);
141
142 if *predicted {
143 write!(
144 f,
145 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
146 DiffPathFmt(old_path),
147 DiffPathFmt(new_path)
148 )
149 } else {
150 write!(
151 f,
152 "--- a/{}\n+++ b/{}\n{diff}",
153 DiffPathFmt(old_path),
154 DiffPathFmt(new_path)
155 )
156 }
157 }
158 }
159 }
160}
161
162/// always format the Path as a unix path with `/` as the path sep in Diffs
163pub struct DiffPathFmt<'a>(pub &'a Path);
164
165impl<'a> std::fmt::Display for DiffPathFmt<'a> {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 let mut is_first = true;
168 for component in self.0.components() {
169 if !is_first {
170 f.write_char('/')?;
171 } else {
172 is_first = false;
173 }
174 write!(f, "{}", component.as_os_str().display())?;
175 }
176 Ok(())
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct Signature {
182 pub text: String,
183 pub text_is_truncated: bool,
184 #[serde(skip_serializing_if = "Option::is_none", default)]
185 pub parent_index: Option<usize>,
186 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
187 /// file is implicitly the file that contains the descendant declaration or excerpt.
188 pub range: Range<Line>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ReferencedDeclaration {
193 pub path: Arc<Path>,
194 pub text: String,
195 pub text_is_truncated: bool,
196 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
197 pub range: Range<Line>,
198 /// Range within `text`
199 pub signature_range: Range<usize>,
200 /// Index within `signatures`.
201 #[serde(skip_serializing_if = "Option::is_none", default)]
202 pub parent_index: Option<usize>,
203 pub score_components: DeclarationScoreComponents,
204 pub signature_score: f32,
205 pub declaration_score: f32,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct DeclarationScoreComponents {
210 pub is_same_file: bool,
211 pub is_referenced_nearby: bool,
212 pub is_referenced_in_breadcrumb: bool,
213 pub reference_count: usize,
214 pub same_file_declaration_count: usize,
215 pub declaration_count: usize,
216 pub reference_line_distance: u32,
217 pub declaration_line_distance: u32,
218 pub excerpt_vs_item_jaccard: f32,
219 pub excerpt_vs_signature_jaccard: f32,
220 pub adjacent_vs_item_jaccard: f32,
221 pub adjacent_vs_signature_jaccard: f32,
222 pub excerpt_vs_item_weighted_overlap: f32,
223 pub excerpt_vs_signature_weighted_overlap: f32,
224 pub adjacent_vs_item_weighted_overlap: f32,
225 pub adjacent_vs_signature_weighted_overlap: f32,
226 pub path_import_match_count: usize,
227 pub wildcard_path_import_match_count: usize,
228 pub import_similarity: f32,
229 pub max_import_similarity: f32,
230 pub normalized_import_similarity: f32,
231 pub wildcard_import_similarity: f32,
232 pub normalized_wildcard_import_similarity: f32,
233 pub included_by_others: usize,
234 pub includes_others: usize,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238#[serde(transparent)]
239pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct PredictEditsResponse {
243 pub request_id: Uuid,
244 pub edits: Vec<Edit>,
245 pub debug_info: Option<DebugInfo>,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct DebugInfo {
250 pub prompt: String,
251 pub prompt_planning_time: Duration,
252 pub model_response: String,
253 pub inference_time: Duration,
254 pub parsing_time: Duration,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct Edit {
259 pub path: Arc<Path>,
260 pub range: Range<Line>,
261 pub content: String,
262}
263
264fn is_default<T: Default + PartialEq>(value: &T) -> bool {
265 *value == T::default()
266}
267
268#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
269pub struct Point {
270 pub line: Line,
271 pub column: u32,
272}
273
274#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
275#[serde(transparent)]
276pub struct Line(pub u32);
277
278impl Add for Line {
279 type Output = Self;
280
281 fn add(self, rhs: Self) -> Self::Output {
282 Self(self.0 + rhs.0)
283 }
284}
285
286impl Sub for Line {
287 type Output = Self;
288
289 fn sub(self, rhs: Self) -> Self::Output {
290 Self(self.0 - rhs.0)
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use indoc::indoc;
298 use pretty_assertions::assert_eq;
299
300 #[test]
301 fn test_event_display() {
302 let ev = Event::BufferChange {
303 path: None,
304 old_path: None,
305 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
306 predicted: false,
307 };
308 assert_eq!(
309 ev.to_string(),
310 indoc! {"
311 --- a/untitled
312 +++ b/untitled
313 @@ -1,2 +1,2 @@
314 -a
315 -b
316 "}
317 );
318
319 let ev = Event::BufferChange {
320 path: Some(PathBuf::from("foo/bar.txt")),
321 old_path: Some(PathBuf::from("foo/bar.txt")),
322 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
323 predicted: false,
324 };
325 assert_eq!(
326 ev.to_string(),
327 indoc! {"
328 --- a/foo/bar.txt
329 +++ b/foo/bar.txt
330 @@ -1,2 +1,2 @@
331 -a
332 -b
333 "}
334 );
335
336 let ev = Event::BufferChange {
337 path: Some(PathBuf::from("abc.txt")),
338 old_path: Some(PathBuf::from("123.txt")),
339 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
340 predicted: false,
341 };
342 assert_eq!(
343 ev.to_string(),
344 indoc! {"
345 --- a/123.txt
346 +++ b/abc.txt
347 @@ -1,2 +1,2 @@
348 -a
349 -b
350 "}
351 );
352
353 let ev = Event::BufferChange {
354 path: Some(PathBuf::from("abc.txt")),
355 old_path: Some(PathBuf::from("123.txt")),
356 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
357 predicted: true,
358 };
359 assert_eq!(
360 ev.to_string(),
361 indoc! {"
362 // User accepted prediction:
363 --- a/123.txt
364 +++ b/abc.txt
365 @@ -1,2 +1,2 @@
366 -a
367 -b
368 "}
369 );
370 }
371}