1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::{Display, Write as _},
5 ops::{Add, Range, Sub},
6 path::{Path, PathBuf},
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::PredictEditsGitInfo;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlanContextRetrievalRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 pub excerpt_line_range: Range<Line>,
19 pub cursor_file_max_row: Line,
20 pub events: Vec<Event>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredictEditsRequest {
25 pub excerpt: String,
26 pub excerpt_path: Arc<Path>,
27 /// Within file
28 pub excerpt_range: Range<usize>,
29 pub excerpt_line_range: Range<Line>,
30 pub cursor_point: Point,
31 /// Within `signatures`
32 pub excerpt_parent: Option<usize>,
33 #[serde(skip_serializing_if = "Vec::is_empty", default)]
34 pub included_files: Vec<IncludedFile>,
35 #[serde(skip_serializing_if = "Vec::is_empty", default)]
36 pub signatures: Vec<Signature>,
37 #[serde(skip_serializing_if = "Vec::is_empty", default)]
38 pub referenced_declarations: Vec<ReferencedDeclaration>,
39 pub events: Vec<Event>,
40 #[serde(default)]
41 pub can_collect_data: bool,
42 #[serde(skip_serializing_if = "Vec::is_empty", default)]
43 pub diagnostic_groups: Vec<DiagnosticGroup>,
44 #[serde(skip_serializing_if = "is_default", default)]
45 pub diagnostic_groups_truncated: bool,
46 /// Info about the git repository state, only present when can_collect_data is true.
47 #[serde(skip_serializing_if = "Option::is_none", default)]
48 pub git_info: Option<PredictEditsGitInfo>,
49 // Only available to staff
50 #[serde(default)]
51 pub debug_info: bool,
52 #[serde(skip_serializing_if = "Option::is_none", default)]
53 pub prompt_max_bytes: Option<usize>,
54 #[serde(default)]
55 pub prompt_format: PromptFormat,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct IncludedFile {
60 pub path: Arc<Path>,
61 pub max_row: Line,
62 pub excerpts: Vec<Excerpt>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Excerpt {
67 pub start_line: Line,
68 pub text: Arc<str>,
69}
70
71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
72pub enum PromptFormat {
73 MarkedExcerpt,
74 LabeledSections,
75 NumLinesUniDiff,
76 OldTextNewText,
77 /// Prompt format intended for use via zeta_cli
78 OnlySnippets,
79 /// One-sentence instructions used in fine-tuned models
80 Minimal,
81}
82
83impl PromptFormat {
84 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
85}
86
87impl Default for PromptFormat {
88 fn default() -> Self {
89 Self::DEFAULT
90 }
91}
92
93impl PromptFormat {
94 pub fn iter() -> impl Iterator<Item = Self> {
95 <Self as strum::IntoEnumIterator>::iter()
96 }
97}
98
99impl std::fmt::Display for PromptFormat {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
103 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
104 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
105 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
106 PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
107 PromptFormat::Minimal => write!(f, "Minimal"),
108 }
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
114#[serde(tag = "event")]
115pub enum Event {
116 BufferChange {
117 path: Option<PathBuf>,
118 old_path: Option<PathBuf>,
119 diff: String,
120 predicted: bool,
121 },
122}
123
124impl Display for Event {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 Event::BufferChange {
128 path,
129 old_path,
130 diff,
131 predicted,
132 } => {
133 let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
134 let old_path = old_path.as_deref().unwrap_or(new_path);
135
136 if *predicted {
137 write!(
138 f,
139 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
140 DiffPathFmt(old_path),
141 DiffPathFmt(new_path)
142 )
143 } else {
144 write!(
145 f,
146 "--- a/{}\n+++ b/{}\n{diff}",
147 DiffPathFmt(old_path),
148 DiffPathFmt(new_path)
149 )
150 }
151 }
152 }
153 }
154}
155
156/// always format the Path as a unix path with `/` as the path sep in Diffs
157pub struct DiffPathFmt<'a>(pub &'a Path);
158
159impl<'a> std::fmt::Display for DiffPathFmt<'a> {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 let mut is_first = true;
162 for component in self.0.components() {
163 if !is_first {
164 f.write_char('/')?;
165 } else {
166 is_first = false;
167 }
168 write!(f, "{}", component.as_os_str().display())?;
169 }
170 Ok(())
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct Signature {
176 pub text: String,
177 pub text_is_truncated: bool,
178 #[serde(skip_serializing_if = "Option::is_none", default)]
179 pub parent_index: Option<usize>,
180 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
181 /// file is implicitly the file that contains the descendant declaration or excerpt.
182 pub range: Range<Line>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ReferencedDeclaration {
187 pub path: Arc<Path>,
188 pub text: String,
189 pub text_is_truncated: bool,
190 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
191 pub range: Range<Line>,
192 /// Range within `text`
193 pub signature_range: Range<usize>,
194 /// Index within `signatures`.
195 #[serde(skip_serializing_if = "Option::is_none", default)]
196 pub parent_index: Option<usize>,
197 pub score_components: DeclarationScoreComponents,
198 pub signature_score: f32,
199 pub declaration_score: f32,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct DeclarationScoreComponents {
204 pub is_same_file: bool,
205 pub is_referenced_nearby: bool,
206 pub is_referenced_in_breadcrumb: bool,
207 pub reference_count: usize,
208 pub same_file_declaration_count: usize,
209 pub declaration_count: usize,
210 pub reference_line_distance: u32,
211 pub declaration_line_distance: u32,
212 pub excerpt_vs_item_jaccard: f32,
213 pub excerpt_vs_signature_jaccard: f32,
214 pub adjacent_vs_item_jaccard: f32,
215 pub adjacent_vs_signature_jaccard: f32,
216 pub excerpt_vs_item_weighted_overlap: f32,
217 pub excerpt_vs_signature_weighted_overlap: f32,
218 pub adjacent_vs_item_weighted_overlap: f32,
219 pub adjacent_vs_signature_weighted_overlap: f32,
220 pub path_import_match_count: usize,
221 pub wildcard_path_import_match_count: usize,
222 pub import_similarity: f32,
223 pub max_import_similarity: f32,
224 pub normalized_import_similarity: f32,
225 pub wildcard_import_similarity: f32,
226 pub normalized_wildcard_import_similarity: f32,
227 pub included_by_others: usize,
228 pub includes_others: usize,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
232#[serde(transparent)]
233pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct PredictEditsResponse {
237 pub request_id: Uuid,
238 pub edits: Vec<Edit>,
239 pub debug_info: Option<DebugInfo>,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct DebugInfo {
244 pub prompt: String,
245 pub prompt_planning_time: Duration,
246 pub model_response: String,
247 pub inference_time: Duration,
248 pub parsing_time: Duration,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct Edit {
253 pub path: Arc<Path>,
254 pub range: Range<Line>,
255 pub content: String,
256}
257
258fn is_default<T: Default + PartialEq>(value: &T) -> bool {
259 *value == T::default()
260}
261
262#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
263pub struct Point {
264 pub line: Line,
265 pub column: u32,
266}
267
268#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
269#[serde(transparent)]
270pub struct Line(pub u32);
271
272impl Add for Line {
273 type Output = Self;
274
275 fn add(self, rhs: Self) -> Self::Output {
276 Self(self.0 + rhs.0)
277 }
278}
279
280impl Sub for Line {
281 type Output = Self;
282
283 fn sub(self, rhs: Self) -> Self::Output {
284 Self(self.0 - rhs.0)
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use indoc::indoc;
292 use pretty_assertions::assert_eq;
293
294 #[test]
295 fn test_event_display() {
296 let ev = Event::BufferChange {
297 path: None,
298 old_path: None,
299 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
300 predicted: false,
301 };
302 assert_eq!(
303 ev.to_string(),
304 indoc! {"
305 --- a/untitled
306 +++ b/untitled
307 @@ -1,2 +1,2 @@
308 -a
309 -b
310 "}
311 );
312
313 let ev = Event::BufferChange {
314 path: Some(PathBuf::from("foo/bar.txt")),
315 old_path: Some(PathBuf::from("foo/bar.txt")),
316 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
317 predicted: false,
318 };
319 assert_eq!(
320 ev.to_string(),
321 indoc! {"
322 --- a/foo/bar.txt
323 +++ b/foo/bar.txt
324 @@ -1,2 +1,2 @@
325 -a
326 -b
327 "}
328 );
329
330 let ev = Event::BufferChange {
331 path: Some(PathBuf::from("abc.txt")),
332 old_path: Some(PathBuf::from("123.txt")),
333 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
334 predicted: false,
335 };
336 assert_eq!(
337 ev.to_string(),
338 indoc! {"
339 --- a/123.txt
340 +++ b/abc.txt
341 @@ -1,2 +1,2 @@
342 -a
343 -b
344 "}
345 );
346
347 let ev = Event::BufferChange {
348 path: Some(PathBuf::from("abc.txt")),
349 old_path: Some(PathBuf::from("123.txt")),
350 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
351 predicted: true,
352 };
353 assert_eq!(
354 ev.to_string(),
355 indoc! {"
356 // User accepted prediction:
357 --- a/123.txt
358 +++ b/abc.txt
359 @@ -1,2 +1,2 @@
360 -a
361 -b
362 "}
363 );
364 }
365}