1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::{Display, Write as _},
5 ops::{Add, Range, Sub},
6 path::Path,
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::PredictEditsGitInfo;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlanContextRetrievalRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 pub excerpt_line_range: Range<Line>,
19 pub cursor_file_max_row: Line,
20 pub events: Vec<Arc<Event>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredictEditsRequest {
25 pub excerpt: String,
26 pub excerpt_path: Arc<Path>,
27 /// Within file
28 pub excerpt_range: Range<usize>,
29 pub excerpt_line_range: Range<Line>,
30 pub cursor_point: Point,
31 /// Within `signatures`
32 pub excerpt_parent: Option<usize>,
33 #[serde(skip_serializing_if = "Vec::is_empty", default)]
34 pub included_files: Vec<IncludedFile>,
35 #[serde(skip_serializing_if = "Vec::is_empty", default)]
36 pub signatures: Vec<Signature>,
37 #[serde(skip_serializing_if = "Vec::is_empty", default)]
38 pub referenced_declarations: Vec<ReferencedDeclaration>,
39 pub events: Vec<Arc<Event>>,
40 #[serde(default)]
41 pub can_collect_data: bool,
42 #[serde(skip_serializing_if = "Vec::is_empty", default)]
43 pub diagnostic_groups: Vec<DiagnosticGroup>,
44 #[serde(skip_serializing_if = "is_default", default)]
45 pub diagnostic_groups_truncated: bool,
46 /// Info about the git repository state, only present when can_collect_data is true.
47 #[serde(skip_serializing_if = "Option::is_none", default)]
48 pub git_info: Option<PredictEditsGitInfo>,
49 // Only available to staff
50 #[serde(default)]
51 pub debug_info: bool,
52 #[serde(skip_serializing_if = "Option::is_none", default)]
53 pub prompt_max_bytes: Option<usize>,
54 #[serde(default)]
55 pub prompt_format: PromptFormat,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct IncludedFile {
60 pub path: Arc<Path>,
61 pub max_row: Line,
62 pub excerpts: Vec<Excerpt>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Excerpt {
67 pub start_line: Line,
68 pub text: Arc<str>,
69}
70
71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
72pub enum PromptFormat {
73 MarkedExcerpt,
74 LabeledSections,
75 NumLinesUniDiff,
76 OldTextNewText,
77 /// Prompt format intended for use via zeta_cli
78 OnlySnippets,
79 /// One-sentence instructions used in fine-tuned models
80 Minimal,
81 /// One-sentence instructions + FIM-like template
82 MinimalQwen,
83 /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
84 SeedCoder1120,
85}
86
87impl PromptFormat {
88 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
89}
90
91impl Default for PromptFormat {
92 fn default() -> Self {
93 Self::DEFAULT
94 }
95}
96
97impl PromptFormat {
98 pub fn iter() -> impl Iterator<Item = Self> {
99 <Self as strum::IntoEnumIterator>::iter()
100 }
101}
102
103impl std::fmt::Display for PromptFormat {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
107 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
108 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
109 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
110 PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
111 PromptFormat::Minimal => write!(f, "Minimal"),
112 PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
113 PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
114 }
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
120#[serde(tag = "event")]
121pub enum Event {
122 BufferChange {
123 path: Arc<Path>,
124 old_path: Arc<Path>,
125 diff: String,
126 predicted: bool,
127 in_open_source_repo: bool,
128 },
129}
130
131impl Display for Event {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match self {
134 Event::BufferChange {
135 path,
136 old_path,
137 diff,
138 predicted,
139 ..
140 } => {
141 if *predicted {
142 write!(
143 f,
144 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
145 DiffPathFmt(old_path),
146 DiffPathFmt(path)
147 )
148 } else {
149 write!(
150 f,
151 "--- a/{}\n+++ b/{}\n{diff}",
152 DiffPathFmt(old_path),
153 DiffPathFmt(path)
154 )
155 }
156 }
157 }
158 }
159}
160
161/// always format the Path as a unix path with `/` as the path sep in Diffs
162pub struct DiffPathFmt<'a>(pub &'a Path);
163
164impl<'a> std::fmt::Display for DiffPathFmt<'a> {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 let mut is_first = true;
167 for component in self.0.components() {
168 if !is_first {
169 f.write_char('/')?;
170 } else {
171 is_first = false;
172 }
173 write!(f, "{}", component.as_os_str().display())?;
174 }
175 Ok(())
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct Signature {
181 pub text: String,
182 pub text_is_truncated: bool,
183 #[serde(skip_serializing_if = "Option::is_none", default)]
184 pub parent_index: Option<usize>,
185 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
186 /// file is implicitly the file that contains the descendant declaration or excerpt.
187 pub range: Range<Line>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct ReferencedDeclaration {
192 pub path: Arc<Path>,
193 pub text: String,
194 pub text_is_truncated: bool,
195 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
196 pub range: Range<Line>,
197 /// Range within `text`
198 pub signature_range: Range<usize>,
199 /// Index within `signatures`.
200 #[serde(skip_serializing_if = "Option::is_none", default)]
201 pub parent_index: Option<usize>,
202 pub score_components: DeclarationScoreComponents,
203 pub signature_score: f32,
204 pub declaration_score: f32,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct DeclarationScoreComponents {
209 pub is_same_file: bool,
210 pub is_referenced_nearby: bool,
211 pub is_referenced_in_breadcrumb: bool,
212 pub reference_count: usize,
213 pub same_file_declaration_count: usize,
214 pub declaration_count: usize,
215 pub reference_line_distance: u32,
216 pub declaration_line_distance: u32,
217 pub excerpt_vs_item_jaccard: f32,
218 pub excerpt_vs_signature_jaccard: f32,
219 pub adjacent_vs_item_jaccard: f32,
220 pub adjacent_vs_signature_jaccard: f32,
221 pub excerpt_vs_item_weighted_overlap: f32,
222 pub excerpt_vs_signature_weighted_overlap: f32,
223 pub adjacent_vs_item_weighted_overlap: f32,
224 pub adjacent_vs_signature_weighted_overlap: f32,
225 pub path_import_match_count: usize,
226 pub wildcard_path_import_match_count: usize,
227 pub import_similarity: f32,
228 pub max_import_similarity: f32,
229 pub normalized_import_similarity: f32,
230 pub wildcard_import_similarity: f32,
231 pub normalized_wildcard_import_similarity: f32,
232 pub included_by_others: usize,
233 pub includes_others: usize,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
237#[serde(transparent)]
238pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct PredictEditsResponse {
242 pub request_id: Uuid,
243 pub edits: Vec<Edit>,
244 pub debug_info: Option<DebugInfo>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct DebugInfo {
249 pub prompt: String,
250 pub prompt_planning_time: Duration,
251 pub model_response: String,
252 pub inference_time: Duration,
253 pub parsing_time: Duration,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct Edit {
258 pub path: Arc<Path>,
259 pub range: Range<Line>,
260 pub content: String,
261}
262
263fn is_default<T: Default + PartialEq>(value: &T) -> bool {
264 *value == T::default()
265}
266
267#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
268pub struct Point {
269 pub line: Line,
270 pub column: u32,
271}
272
273#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
274#[serde(transparent)]
275pub struct Line(pub u32);
276
277impl Add for Line {
278 type Output = Self;
279
280 fn add(self, rhs: Self) -> Self::Output {
281 Self(self.0 + rhs.0)
282 }
283}
284
285impl Sub for Line {
286 type Output = Self;
287
288 fn sub(self, rhs: Self) -> Self::Output {
289 Self(self.0 - rhs.0)
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use indoc::indoc;
297 use pretty_assertions::assert_eq;
298
299 #[test]
300 fn test_event_display() {
301 let ev = Event::BufferChange {
302 path: Path::new("untitled").into(),
303 old_path: Path::new("untitled").into(),
304 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
305 predicted: false,
306 in_open_source_repo: true,
307 };
308 assert_eq!(
309 ev.to_string(),
310 indoc! {"
311 --- a/untitled
312 +++ b/untitled
313 @@ -1,2 +1,2 @@
314 -a
315 -b
316 "}
317 );
318
319 let ev = Event::BufferChange {
320 path: Path::new("foo/bar.txt").into(),
321 old_path: Path::new("foo/bar.txt").into(),
322 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
323 predicted: false,
324 in_open_source_repo: true,
325 };
326 assert_eq!(
327 ev.to_string(),
328 indoc! {"
329 --- a/foo/bar.txt
330 +++ b/foo/bar.txt
331 @@ -1,2 +1,2 @@
332 -a
333 -b
334 "}
335 );
336
337 let ev = Event::BufferChange {
338 path: Path::new("abc.txt").into(),
339 old_path: Path::new("123.txt").into(),
340 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
341 predicted: false,
342 in_open_source_repo: true,
343 };
344 assert_eq!(
345 ev.to_string(),
346 indoc! {"
347 --- a/123.txt
348 +++ b/abc.txt
349 @@ -1,2 +1,2 @@
350 -a
351 -b
352 "}
353 );
354
355 let ev = Event::BufferChange {
356 path: Path::new("abc.txt").into(),
357 old_path: Path::new("123.txt").into(),
358 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
359 predicted: true,
360 in_open_source_repo: true,
361 };
362 assert_eq!(
363 ev.to_string(),
364 indoc! {"
365 // User accepted prediction:
366 --- a/123.txt
367 +++ b/abc.txt
368 @@ -1,2 +1,2 @@
369 -a
370 -b
371 "}
372 );
373 }
374}