1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::{Display, Write as _},
5 ops::{Add, Range, Sub},
6 path::Path,
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlanContextRetrievalRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 pub excerpt_line_range: Range<Line>,
19 pub cursor_file_max_row: Line,
20 pub events: Vec<Arc<Event>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredictEditsRequest {
25 pub excerpt: String,
26 pub excerpt_path: Arc<Path>,
27 /// Within file
28 pub excerpt_range: Range<usize>,
29 pub excerpt_line_range: Range<Line>,
30 pub cursor_point: Point,
31 /// Within `signatures`
32 pub excerpt_parent: Option<usize>,
33 #[serde(skip_serializing_if = "Vec::is_empty", default)]
34 pub included_files: Vec<IncludedFile>,
35 #[serde(skip_serializing_if = "Vec::is_empty", default)]
36 pub signatures: Vec<Signature>,
37 #[serde(skip_serializing_if = "Vec::is_empty", default)]
38 pub referenced_declarations: Vec<ReferencedDeclaration>,
39 pub events: Vec<Arc<Event>>,
40 #[serde(default)]
41 pub can_collect_data: bool,
42 #[serde(skip_serializing_if = "Vec::is_empty", default)]
43 pub diagnostic_groups: Vec<DiagnosticGroup>,
44 #[serde(skip_serializing_if = "is_default", default)]
45 pub diagnostic_groups_truncated: bool,
46 /// Info about the git repository state, only present when can_collect_data is true.
47 #[serde(skip_serializing_if = "Option::is_none", default)]
48 pub git_info: Option<PredictEditsGitInfo>,
49 // Only available to staff
50 #[serde(default)]
51 pub debug_info: bool,
52 #[serde(skip_serializing_if = "Option::is_none", default)]
53 pub prompt_max_bytes: Option<usize>,
54 #[serde(default)]
55 pub prompt_format: PromptFormat,
56 #[serde(default)]
57 pub trigger: PredictEditsRequestTrigger,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct IncludedFile {
62 pub path: Arc<Path>,
63 pub max_row: Line,
64 pub excerpts: Vec<Excerpt>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Excerpt {
69 pub start_line: Line,
70 pub text: Arc<str>,
71}
72
73#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
74pub enum PromptFormat {
75 MarkedExcerpt,
76 LabeledSections,
77 NumLinesUniDiff,
78 OldTextNewText,
79 /// Prompt format intended for use via zeta_cli
80 OnlySnippets,
81 /// One-sentence instructions used in fine-tuned models
82 Minimal,
83 /// One-sentence instructions + FIM-like template
84 MinimalQwen,
85 /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
86 SeedCoder1120,
87}
88
89impl PromptFormat {
90 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
91}
92
93impl Default for PromptFormat {
94 fn default() -> Self {
95 Self::DEFAULT
96 }
97}
98
99impl PromptFormat {
100 pub fn iter() -> impl Iterator<Item = Self> {
101 <Self as strum::IntoEnumIterator>::iter()
102 }
103}
104
105impl std::fmt::Display for PromptFormat {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 match self {
108 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
109 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
110 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
111 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
112 PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
113 PromptFormat::Minimal => write!(f, "Minimal"),
114 PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
115 PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
116 }
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
122#[serde(tag = "event")]
123pub enum Event {
124 BufferChange {
125 path: Arc<Path>,
126 old_path: Arc<Path>,
127 diff: String,
128 predicted: bool,
129 in_open_source_repo: bool,
130 },
131}
132
133impl Display for Event {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 Event::BufferChange {
137 path,
138 old_path,
139 diff,
140 predicted,
141 ..
142 } => {
143 if *predicted {
144 write!(
145 f,
146 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
147 DiffPathFmt(old_path),
148 DiffPathFmt(path)
149 )
150 } else {
151 write!(
152 f,
153 "--- a/{}\n+++ b/{}\n{diff}",
154 DiffPathFmt(old_path),
155 DiffPathFmt(path)
156 )
157 }
158 }
159 }
160 }
161}
162
163/// always format the Path as a unix path with `/` as the path sep in Diffs
164pub struct DiffPathFmt<'a>(pub &'a Path);
165
166impl<'a> std::fmt::Display for DiffPathFmt<'a> {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 let mut is_first = true;
169 for component in self.0.components() {
170 if !is_first {
171 f.write_char('/')?;
172 } else {
173 is_first = false;
174 }
175 write!(f, "{}", component.as_os_str().display())?;
176 }
177 Ok(())
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct Signature {
183 pub text: String,
184 pub text_is_truncated: bool,
185 #[serde(skip_serializing_if = "Option::is_none", default)]
186 pub parent_index: Option<usize>,
187 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
188 /// file is implicitly the file that contains the descendant declaration or excerpt.
189 pub range: Range<Line>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct ReferencedDeclaration {
194 pub path: Arc<Path>,
195 pub text: String,
196 pub text_is_truncated: bool,
197 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
198 pub range: Range<Line>,
199 /// Range within `text`
200 pub signature_range: Range<usize>,
201 /// Index within `signatures`.
202 #[serde(skip_serializing_if = "Option::is_none", default)]
203 pub parent_index: Option<usize>,
204 pub score_components: DeclarationScoreComponents,
205 pub signature_score: f32,
206 pub declaration_score: f32,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct DeclarationScoreComponents {
211 pub is_same_file: bool,
212 pub is_referenced_nearby: bool,
213 pub is_referenced_in_breadcrumb: bool,
214 pub reference_count: usize,
215 pub same_file_declaration_count: usize,
216 pub declaration_count: usize,
217 pub reference_line_distance: u32,
218 pub declaration_line_distance: u32,
219 pub excerpt_vs_item_jaccard: f32,
220 pub excerpt_vs_signature_jaccard: f32,
221 pub adjacent_vs_item_jaccard: f32,
222 pub adjacent_vs_signature_jaccard: f32,
223 pub excerpt_vs_item_weighted_overlap: f32,
224 pub excerpt_vs_signature_weighted_overlap: f32,
225 pub adjacent_vs_item_weighted_overlap: f32,
226 pub adjacent_vs_signature_weighted_overlap: f32,
227 pub path_import_match_count: usize,
228 pub wildcard_path_import_match_count: usize,
229 pub import_similarity: f32,
230 pub max_import_similarity: f32,
231 pub normalized_import_similarity: f32,
232 pub wildcard_import_similarity: f32,
233 pub normalized_wildcard_import_similarity: f32,
234 pub included_by_others: usize,
235 pub includes_others: usize,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239#[serde(transparent)]
240pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct PredictEditsResponse {
244 pub request_id: Uuid,
245 pub edits: Vec<Edit>,
246 pub debug_info: Option<DebugInfo>,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct DebugInfo {
251 pub prompt: String,
252 pub prompt_planning_time: Duration,
253 pub model_response: String,
254 pub inference_time: Duration,
255 pub parsing_time: Duration,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct Edit {
260 pub path: Arc<Path>,
261 pub range: Range<Line>,
262 pub content: String,
263}
264
265fn is_default<T: Default + PartialEq>(value: &T) -> bool {
266 *value == T::default()
267}
268
269#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
270pub struct Point {
271 pub line: Line,
272 pub column: u32,
273}
274
275#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
276#[serde(transparent)]
277pub struct Line(pub u32);
278
279impl Add for Line {
280 type Output = Self;
281
282 fn add(self, rhs: Self) -> Self::Output {
283 Self(self.0 + rhs.0)
284 }
285}
286
287impl Sub for Line {
288 type Output = Self;
289
290 fn sub(self, rhs: Self) -> Self::Output {
291 Self(self.0 - rhs.0)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use indoc::indoc;
299 use pretty_assertions::assert_eq;
300
301 #[test]
302 fn test_event_display() {
303 let ev = Event::BufferChange {
304 path: Path::new("untitled").into(),
305 old_path: Path::new("untitled").into(),
306 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
307 predicted: false,
308 in_open_source_repo: true,
309 };
310 assert_eq!(
311 ev.to_string(),
312 indoc! {"
313 --- a/untitled
314 +++ b/untitled
315 @@ -1,2 +1,2 @@
316 -a
317 -b
318 "}
319 );
320
321 let ev = Event::BufferChange {
322 path: Path::new("foo/bar.txt").into(),
323 old_path: Path::new("foo/bar.txt").into(),
324 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
325 predicted: false,
326 in_open_source_repo: true,
327 };
328 assert_eq!(
329 ev.to_string(),
330 indoc! {"
331 --- a/foo/bar.txt
332 +++ b/foo/bar.txt
333 @@ -1,2 +1,2 @@
334 -a
335 -b
336 "}
337 );
338
339 let ev = Event::BufferChange {
340 path: Path::new("abc.txt").into(),
341 old_path: Path::new("123.txt").into(),
342 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
343 predicted: false,
344 in_open_source_repo: true,
345 };
346 assert_eq!(
347 ev.to_string(),
348 indoc! {"
349 --- a/123.txt
350 +++ b/abc.txt
351 @@ -1,2 +1,2 @@
352 -a
353 -b
354 "}
355 );
356
357 let ev = Event::BufferChange {
358 path: Path::new("abc.txt").into(),
359 old_path: Path::new("123.txt").into(),
360 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
361 predicted: true,
362 in_open_source_repo: true,
363 };
364 assert_eq!(
365 ev.to_string(),
366 indoc! {"
367 // User accepted prediction:
368 --- a/123.txt
369 +++ b/abc.txt
370 @@ -1,2 +1,2 @@
371 -a
372 -b
373 "}
374 );
375 }
376}