predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::{Display, Write as _},
  5    ops::{Add, Range, Sub},
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::PredictEditsGitInfo;
 13
 14#[derive(Debug, Clone, Serialize, Deserialize)]
 15pub struct PlanContextRetrievalRequest {
 16    pub excerpt: String,
 17    pub excerpt_path: Arc<Path>,
 18    pub excerpt_line_range: Range<Line>,
 19    pub cursor_file_max_row: Line,
 20    pub events: Vec<Event>,
 21}
 22
 23#[derive(Debug, Clone, Serialize, Deserialize)]
 24pub struct PredictEditsRequest {
 25    pub excerpt: String,
 26    pub excerpt_path: Arc<Path>,
 27    /// Within file
 28    pub excerpt_range: Range<usize>,
 29    pub excerpt_line_range: Range<Line>,
 30    pub cursor_point: Point,
 31    /// Within `signatures`
 32    pub excerpt_parent: Option<usize>,
 33    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 34    pub included_files: Vec<IncludedFile>,
 35    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 36    pub signatures: Vec<Signature>,
 37    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 38    pub referenced_declarations: Vec<ReferencedDeclaration>,
 39    pub events: Vec<Event>,
 40    #[serde(default)]
 41    pub can_collect_data: bool,
 42    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 43    pub diagnostic_groups: Vec<DiagnosticGroup>,
 44    #[serde(skip_serializing_if = "is_default", default)]
 45    pub diagnostic_groups_truncated: bool,
 46    /// Info about the git repository state, only present when can_collect_data is true.
 47    #[serde(skip_serializing_if = "Option::is_none", default)]
 48    pub git_info: Option<PredictEditsGitInfo>,
 49    // Only available to staff
 50    #[serde(default)]
 51    pub debug_info: bool,
 52    #[serde(skip_serializing_if = "Option::is_none", default)]
 53    pub prompt_max_bytes: Option<usize>,
 54    #[serde(default)]
 55    pub prompt_format: PromptFormat,
 56}
 57
 58#[derive(Debug, Clone, Serialize, Deserialize)]
 59pub struct IncludedFile {
 60    pub path: Arc<Path>,
 61    pub max_row: Line,
 62    pub excerpts: Vec<Excerpt>,
 63}
 64
 65#[derive(Debug, Clone, Serialize, Deserialize)]
 66pub struct Excerpt {
 67    pub start_line: Line,
 68    pub text: Arc<str>,
 69}
 70
 71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 72pub enum PromptFormat {
 73    MarkedExcerpt,
 74    LabeledSections,
 75    NumLinesUniDiff,
 76    OldTextNewText,
 77    /// Prompt format intended for use via zeta_cli
 78    OnlySnippets,
 79}
 80
 81impl PromptFormat {
 82    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 83}
 84
 85impl Default for PromptFormat {
 86    fn default() -> Self {
 87        Self::DEFAULT
 88    }
 89}
 90
 91impl PromptFormat {
 92    pub fn iter() -> impl Iterator<Item = Self> {
 93        <Self as strum::IntoEnumIterator>::iter()
 94    }
 95}
 96
 97impl std::fmt::Display for PromptFormat {
 98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 99        match self {
100            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
101            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
102            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
103            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
104            PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
105        }
106    }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
111#[serde(tag = "event")]
112pub enum Event {
113    BufferChange {
114        path: Option<PathBuf>,
115        old_path: Option<PathBuf>,
116        diff: String,
117        predicted: bool,
118    },
119}
120
121impl Display for Event {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            Event::BufferChange {
125                path,
126                old_path,
127                diff,
128                predicted,
129            } => {
130                let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
131                let old_path = old_path.as_deref().unwrap_or(new_path);
132
133                if *predicted {
134                    write!(
135                        f,
136                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
137                        DiffPathFmt(old_path),
138                        DiffPathFmt(new_path)
139                    )
140                } else {
141                    write!(
142                        f,
143                        "--- a/{}\n+++ b/{}\n{diff}",
144                        DiffPathFmt(old_path),
145                        DiffPathFmt(new_path)
146                    )
147                }
148            }
149        }
150    }
151}
152
153/// always format the Path as a unix path with `/` as the path sep in Diffs
154pub struct DiffPathFmt<'a>(pub &'a Path);
155
156impl<'a> std::fmt::Display for DiffPathFmt<'a> {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        let mut is_first = true;
159        for component in self.0.components() {
160            if !is_first {
161                f.write_char('/')?;
162            } else {
163                is_first = false;
164            }
165            write!(f, "{}", component.as_os_str().display())?;
166        }
167        Ok(())
168    }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct Signature {
173    pub text: String,
174    pub text_is_truncated: bool,
175    #[serde(skip_serializing_if = "Option::is_none", default)]
176    pub parent_index: Option<usize>,
177    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
178    /// file is implicitly the file that contains the descendant declaration or excerpt.
179    pub range: Range<Line>,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct ReferencedDeclaration {
184    pub path: Arc<Path>,
185    pub text: String,
186    pub text_is_truncated: bool,
187    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
188    pub range: Range<Line>,
189    /// Range within `text`
190    pub signature_range: Range<usize>,
191    /// Index within `signatures`.
192    #[serde(skip_serializing_if = "Option::is_none", default)]
193    pub parent_index: Option<usize>,
194    pub score_components: DeclarationScoreComponents,
195    pub signature_score: f32,
196    pub declaration_score: f32,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct DeclarationScoreComponents {
201    pub is_same_file: bool,
202    pub is_referenced_nearby: bool,
203    pub is_referenced_in_breadcrumb: bool,
204    pub reference_count: usize,
205    pub same_file_declaration_count: usize,
206    pub declaration_count: usize,
207    pub reference_line_distance: u32,
208    pub declaration_line_distance: u32,
209    pub excerpt_vs_item_jaccard: f32,
210    pub excerpt_vs_signature_jaccard: f32,
211    pub adjacent_vs_item_jaccard: f32,
212    pub adjacent_vs_signature_jaccard: f32,
213    pub excerpt_vs_item_weighted_overlap: f32,
214    pub excerpt_vs_signature_weighted_overlap: f32,
215    pub adjacent_vs_item_weighted_overlap: f32,
216    pub adjacent_vs_signature_weighted_overlap: f32,
217    pub path_import_match_count: usize,
218    pub wildcard_path_import_match_count: usize,
219    pub import_similarity: f32,
220    pub max_import_similarity: f32,
221    pub normalized_import_similarity: f32,
222    pub wildcard_import_similarity: f32,
223    pub normalized_wildcard_import_similarity: f32,
224    pub included_by_others: usize,
225    pub includes_others: usize,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229#[serde(transparent)]
230pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct PredictEditsResponse {
234    pub request_id: Uuid,
235    pub edits: Vec<Edit>,
236    pub debug_info: Option<DebugInfo>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct DebugInfo {
241    pub prompt: String,
242    pub prompt_planning_time: Duration,
243    pub model_response: String,
244    pub inference_time: Duration,
245    pub parsing_time: Duration,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct Edit {
250    pub path: Arc<Path>,
251    pub range: Range<Line>,
252    pub content: String,
253}
254
255fn is_default<T: Default + PartialEq>(value: &T) -> bool {
256    *value == T::default()
257}
258
259#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
260pub struct Point {
261    pub line: Line,
262    pub column: u32,
263}
264
265#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
266#[serde(transparent)]
267pub struct Line(pub u32);
268
269impl Add for Line {
270    type Output = Self;
271
272    fn add(self, rhs: Self) -> Self::Output {
273        Self(self.0 + rhs.0)
274    }
275}
276
277impl Sub for Line {
278    type Output = Self;
279
280    fn sub(self, rhs: Self) -> Self::Output {
281        Self(self.0 - rhs.0)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use indoc::indoc;
289    use pretty_assertions::assert_eq;
290
291    #[test]
292    fn test_event_display() {
293        let ev = Event::BufferChange {
294            path: None,
295            old_path: None,
296            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
297            predicted: false,
298        };
299        assert_eq!(
300            ev.to_string(),
301            indoc! {"
302                --- a/untitled
303                +++ b/untitled
304                @@ -1,2 +1,2 @@
305                -a
306                -b
307            "}
308        );
309
310        let ev = Event::BufferChange {
311            path: Some(PathBuf::from("foo/bar.txt")),
312            old_path: Some(PathBuf::from("foo/bar.txt")),
313            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
314            predicted: false,
315        };
316        assert_eq!(
317            ev.to_string(),
318            indoc! {"
319                --- a/foo/bar.txt
320                +++ b/foo/bar.txt
321                @@ -1,2 +1,2 @@
322                -a
323                -b
324            "}
325        );
326
327        let ev = Event::BufferChange {
328            path: Some(PathBuf::from("abc.txt")),
329            old_path: Some(PathBuf::from("123.txt")),
330            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
331            predicted: false,
332        };
333        assert_eq!(
334            ev.to_string(),
335            indoc! {"
336                --- a/123.txt
337                +++ b/abc.txt
338                @@ -1,2 +1,2 @@
339                -a
340                -b
341            "}
342        );
343
344        let ev = Event::BufferChange {
345            path: Some(PathBuf::from("abc.txt")),
346            old_path: Some(PathBuf::from("123.txt")),
347            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
348            predicted: true,
349        };
350        assert_eq!(
351            ev.to_string(),
352            indoc! {"
353                // User accepted prediction:
354                --- a/123.txt
355                +++ b/abc.txt
356                @@ -1,2 +1,2 @@
357                -a
358                -b
359            "}
360        );
361    }
362}