1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::Display,
  5    ops::{Add, Range, Sub},
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::PredictEditsGitInfo;
 13
 14// TODO: snippet ordering within file / relative to excerpt
 15
 16#[derive(Debug, Clone, Serialize, Deserialize)]
 17pub struct PredictEditsRequest {
 18    pub excerpt: String,
 19    pub excerpt_path: Arc<Path>,
 20    /// Within file
 21    pub excerpt_range: Range<usize>,
 22    pub excerpt_line_range: Range<Line>,
 23    pub cursor_point: Point,
 24    /// Within `signatures`
 25    pub excerpt_parent: Option<usize>,
 26    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 27    pub included_files: Vec<IncludedFile>,
 28    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 29    pub signatures: Vec<Signature>,
 30    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 31    pub referenced_declarations: Vec<ReferencedDeclaration>,
 32    pub events: Vec<Event>,
 33    #[serde(default)]
 34    pub can_collect_data: bool,
 35    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 36    pub diagnostic_groups: Vec<DiagnosticGroup>,
 37    #[serde(skip_serializing_if = "is_default", default)]
 38    pub diagnostic_groups_truncated: bool,
 39    /// Info about the git repository state, only present when can_collect_data is true.
 40    #[serde(skip_serializing_if = "Option::is_none", default)]
 41    pub git_info: Option<PredictEditsGitInfo>,
 42    // Only available to staff
 43    #[serde(default)]
 44    pub debug_info: bool,
 45    #[serde(skip_serializing_if = "Option::is_none", default)]
 46    pub prompt_max_bytes: Option<usize>,
 47    #[serde(default)]
 48    pub prompt_format: PromptFormat,
 49}
 50
 51#[derive(Debug, Clone, Serialize, Deserialize)]
 52pub struct IncludedFile {
 53    pub path: Arc<Path>,
 54    pub max_row: Line,
 55    pub excerpts: Vec<Excerpt>,
 56}
 57
 58#[derive(Debug, Clone, Serialize, Deserialize)]
 59pub struct Excerpt {
 60    pub start_line: Line,
 61    pub text: Arc<str>,
 62}
 63
 64#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 65pub enum PromptFormat {
 66    MarkedExcerpt,
 67    LabeledSections,
 68    NumLinesUniDiff,
 69    /// Prompt format intended for use via zeta_cli
 70    OnlySnippets,
 71}
 72
 73impl PromptFormat {
 74    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 75}
 76
 77impl Default for PromptFormat {
 78    fn default() -> Self {
 79        Self::DEFAULT
 80    }
 81}
 82
 83impl PromptFormat {
 84    pub fn iter() -> impl Iterator<Item = Self> {
 85        <Self as strum::IntoEnumIterator>::iter()
 86    }
 87}
 88
 89impl std::fmt::Display for PromptFormat {
 90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 91        match self {
 92            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
 93            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
 94            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
 95            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
 96        }
 97    }
 98}
 99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
102#[serde(tag = "event")]
103pub enum Event {
104    BufferChange {
105        path: Option<PathBuf>,
106        old_path: Option<PathBuf>,
107        diff: String,
108        predicted: bool,
109    },
110}
111
112impl Display for Event {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            Event::BufferChange {
116                path,
117                old_path,
118                diff,
119                predicted,
120            } => {
121                let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
122                let old_path = old_path.as_deref().unwrap_or(new_path);
123
124                if *predicted {
125                    write!(
126                        f,
127                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
128                        old_path.display(),
129                        new_path.display()
130                    )
131                } else {
132                    write!(
133                        f,
134                        "--- a/{}\n+++ b/{}\n{diff}",
135                        old_path.display(),
136                        new_path.display()
137                    )
138                }
139            }
140        }
141    }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct Signature {
146    pub text: String,
147    pub text_is_truncated: bool,
148    #[serde(skip_serializing_if = "Option::is_none", default)]
149    pub parent_index: Option<usize>,
150    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
151    /// file is implicitly the file that contains the descendant declaration or excerpt.
152    pub range: Range<Line>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ReferencedDeclaration {
157    pub path: Arc<Path>,
158    pub text: String,
159    pub text_is_truncated: bool,
160    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
161    pub range: Range<Line>,
162    /// Range within `text`
163    pub signature_range: Range<usize>,
164    /// Index within `signatures`.
165    #[serde(skip_serializing_if = "Option::is_none", default)]
166    pub parent_index: Option<usize>,
167    pub score_components: DeclarationScoreComponents,
168    pub signature_score: f32,
169    pub declaration_score: f32,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct DeclarationScoreComponents {
174    pub is_same_file: bool,
175    pub is_referenced_nearby: bool,
176    pub is_referenced_in_breadcrumb: bool,
177    pub reference_count: usize,
178    pub same_file_declaration_count: usize,
179    pub declaration_count: usize,
180    pub reference_line_distance: u32,
181    pub declaration_line_distance: u32,
182    pub excerpt_vs_item_jaccard: f32,
183    pub excerpt_vs_signature_jaccard: f32,
184    pub adjacent_vs_item_jaccard: f32,
185    pub adjacent_vs_signature_jaccard: f32,
186    pub excerpt_vs_item_weighted_overlap: f32,
187    pub excerpt_vs_signature_weighted_overlap: f32,
188    pub adjacent_vs_item_weighted_overlap: f32,
189    pub adjacent_vs_signature_weighted_overlap: f32,
190    pub path_import_match_count: usize,
191    pub wildcard_path_import_match_count: usize,
192    pub import_similarity: f32,
193    pub max_import_similarity: f32,
194    pub normalized_import_similarity: f32,
195    pub wildcard_import_similarity: f32,
196    pub normalized_wildcard_import_similarity: f32,
197    pub included_by_others: usize,
198    pub includes_others: usize,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202#[serde(transparent)]
203pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct PredictEditsResponse {
207    pub request_id: Uuid,
208    pub edits: Vec<Edit>,
209    pub debug_info: Option<DebugInfo>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct DebugInfo {
214    pub prompt: String,
215    pub prompt_planning_time: Duration,
216    pub model_response: String,
217    pub inference_time: Duration,
218    pub parsing_time: Duration,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct Edit {
223    pub path: Arc<Path>,
224    pub range: Range<Line>,
225    pub content: String,
226}
227
228fn is_default<T: Default + PartialEq>(value: &T) -> bool {
229    *value == T::default()
230}
231
232#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
233pub struct Point {
234    pub line: Line,
235    pub column: u32,
236}
237
238#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
239#[serde(transparent)]
240pub struct Line(pub u32);
241
242impl Add for Line {
243    type Output = Self;
244
245    fn add(self, rhs: Self) -> Self::Output {
246        Self(self.0 + rhs.0)
247    }
248}
249
250impl Sub for Line {
251    type Output = Self;
252
253    fn sub(self, rhs: Self) -> Self::Output {
254        Self(self.0 - rhs.0)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use indoc::indoc;
262    use pretty_assertions::assert_eq;
263
264    #[test]
265    fn test_event_display() {
266        let ev = Event::BufferChange {
267            path: None,
268            old_path: None,
269            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
270            predicted: false,
271        };
272        assert_eq!(
273            ev.to_string(),
274            indoc! {"
275                --- a/untitled
276                +++ b/untitled
277                @@ -1,2 +1,2 @@
278                -a
279                -b
280            "}
281        );
282
283        let ev = Event::BufferChange {
284            path: Some(PathBuf::from("foo/bar.txt")),
285            old_path: Some(PathBuf::from("foo/bar.txt")),
286            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
287            predicted: false,
288        };
289        assert_eq!(
290            ev.to_string(),
291            indoc! {"
292                --- a/foo/bar.txt
293                +++ b/foo/bar.txt
294                @@ -1,2 +1,2 @@
295                -a
296                -b
297            "}
298        );
299
300        let ev = Event::BufferChange {
301            path: Some(PathBuf::from("abc.txt")),
302            old_path: Some(PathBuf::from("123.txt")),
303            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
304            predicted: false,
305        };
306        assert_eq!(
307            ev.to_string(),
308            indoc! {"
309                --- a/123.txt
310                +++ b/abc.txt
311                @@ -1,2 +1,2 @@
312                -a
313                -b
314            "}
315        );
316
317        let ev = Event::BufferChange {
318            path: Some(PathBuf::from("abc.txt")),
319            old_path: Some(PathBuf::from("123.txt")),
320            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
321            predicted: true,
322        };
323        assert_eq!(
324            ev.to_string(),
325            indoc! {"
326                // User accepted prediction:
327                --- a/123.txt
328                +++ b/abc.txt
329                @@ -1,2 +1,2 @@
330                -a
331                -b
332            "}
333        );
334    }
335}