predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::{Display, Write as _},
  5    ops::{Add, Range, Sub},
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::PredictEditsGitInfo;
 13
 14#[derive(Debug, Clone, Serialize, Deserialize)]
 15pub struct PlanContextRetrievalRequest {
 16    pub excerpt: String,
 17    pub excerpt_path: Arc<Path>,
 18    pub excerpt_line_range: Range<Line>,
 19    pub cursor_file_max_row: Line,
 20    pub events: Vec<Event>,
 21}
 22
 23#[derive(Debug, Clone, Serialize, Deserialize)]
 24pub struct PredictEditsRequest {
 25    pub excerpt: String,
 26    pub excerpt_path: Arc<Path>,
 27    /// Within file
 28    pub excerpt_range: Range<usize>,
 29    pub excerpt_line_range: Range<Line>,
 30    pub cursor_point: Point,
 31    /// Within `signatures`
 32    pub excerpt_parent: Option<usize>,
 33    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 34    pub included_files: Vec<IncludedFile>,
 35    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 36    pub signatures: Vec<Signature>,
 37    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 38    pub referenced_declarations: Vec<ReferencedDeclaration>,
 39    pub events: Vec<Event>,
 40    #[serde(default)]
 41    pub can_collect_data: bool,
 42    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 43    pub diagnostic_groups: Vec<DiagnosticGroup>,
 44    #[serde(skip_serializing_if = "is_default", default)]
 45    pub diagnostic_groups_truncated: bool,
 46    /// Info about the git repository state, only present when can_collect_data is true.
 47    #[serde(skip_serializing_if = "Option::is_none", default)]
 48    pub git_info: Option<PredictEditsGitInfo>,
 49    // Only available to staff
 50    #[serde(default)]
 51    pub debug_info: bool,
 52    #[serde(skip_serializing_if = "Option::is_none", default)]
 53    pub prompt_max_bytes: Option<usize>,
 54    #[serde(default)]
 55    pub prompt_format: PromptFormat,
 56}
 57
 58#[derive(Debug, Clone, Serialize, Deserialize)]
 59pub struct IncludedFile {
 60    pub path: Arc<Path>,
 61    pub max_row: Line,
 62    pub excerpts: Vec<Excerpt>,
 63}
 64
 65#[derive(Debug, Clone, Serialize, Deserialize)]
 66pub struct Excerpt {
 67    pub start_line: Line,
 68    pub text: Arc<str>,
 69}
 70
 71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 72pub enum PromptFormat {
 73    MarkedExcerpt,
 74    LabeledSections,
 75    NumLinesUniDiff,
 76    OldTextNewText,
 77    /// Prompt format intended for use via zeta_cli
 78    OnlySnippets,
 79    /// One-sentence instructions used in fine-tuned models
 80    Minimal,
 81}
 82
 83impl PromptFormat {
 84    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 85}
 86
 87impl Default for PromptFormat {
 88    fn default() -> Self {
 89        Self::DEFAULT
 90    }
 91}
 92
 93impl PromptFormat {
 94    pub fn iter() -> impl Iterator<Item = Self> {
 95        <Self as strum::IntoEnumIterator>::iter()
 96    }
 97}
 98
 99impl std::fmt::Display for PromptFormat {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
103            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
104            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
105            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
106            PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
107            PromptFormat::Minimal => write!(f, "Minimal"),
108        }
109    }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
114#[serde(tag = "event")]
115pub enum Event {
116    BufferChange {
117        path: Option<PathBuf>,
118        old_path: Option<PathBuf>,
119        diff: String,
120        predicted: bool,
121    },
122}
123
124impl Display for Event {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            Event::BufferChange {
128                path,
129                old_path,
130                diff,
131                predicted,
132            } => {
133                let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
134                let old_path = old_path.as_deref().unwrap_or(new_path);
135
136                if *predicted {
137                    write!(
138                        f,
139                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
140                        DiffPathFmt(old_path),
141                        DiffPathFmt(new_path)
142                    )
143                } else {
144                    write!(
145                        f,
146                        "--- a/{}\n+++ b/{}\n{diff}",
147                        DiffPathFmt(old_path),
148                        DiffPathFmt(new_path)
149                    )
150                }
151            }
152        }
153    }
154}
155
156/// always format the Path as a unix path with `/` as the path sep in Diffs
157pub struct DiffPathFmt<'a>(pub &'a Path);
158
159impl<'a> std::fmt::Display for DiffPathFmt<'a> {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        let mut is_first = true;
162        for component in self.0.components() {
163            if !is_first {
164                f.write_char('/')?;
165            } else {
166                is_first = false;
167            }
168            write!(f, "{}", component.as_os_str().display())?;
169        }
170        Ok(())
171    }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct Signature {
176    pub text: String,
177    pub text_is_truncated: bool,
178    #[serde(skip_serializing_if = "Option::is_none", default)]
179    pub parent_index: Option<usize>,
180    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
181    /// file is implicitly the file that contains the descendant declaration or excerpt.
182    pub range: Range<Line>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ReferencedDeclaration {
187    pub path: Arc<Path>,
188    pub text: String,
189    pub text_is_truncated: bool,
190    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
191    pub range: Range<Line>,
192    /// Range within `text`
193    pub signature_range: Range<usize>,
194    /// Index within `signatures`.
195    #[serde(skip_serializing_if = "Option::is_none", default)]
196    pub parent_index: Option<usize>,
197    pub score_components: DeclarationScoreComponents,
198    pub signature_score: f32,
199    pub declaration_score: f32,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct DeclarationScoreComponents {
204    pub is_same_file: bool,
205    pub is_referenced_nearby: bool,
206    pub is_referenced_in_breadcrumb: bool,
207    pub reference_count: usize,
208    pub same_file_declaration_count: usize,
209    pub declaration_count: usize,
210    pub reference_line_distance: u32,
211    pub declaration_line_distance: u32,
212    pub excerpt_vs_item_jaccard: f32,
213    pub excerpt_vs_signature_jaccard: f32,
214    pub adjacent_vs_item_jaccard: f32,
215    pub adjacent_vs_signature_jaccard: f32,
216    pub excerpt_vs_item_weighted_overlap: f32,
217    pub excerpt_vs_signature_weighted_overlap: f32,
218    pub adjacent_vs_item_weighted_overlap: f32,
219    pub adjacent_vs_signature_weighted_overlap: f32,
220    pub path_import_match_count: usize,
221    pub wildcard_path_import_match_count: usize,
222    pub import_similarity: f32,
223    pub max_import_similarity: f32,
224    pub normalized_import_similarity: f32,
225    pub wildcard_import_similarity: f32,
226    pub normalized_wildcard_import_similarity: f32,
227    pub included_by_others: usize,
228    pub includes_others: usize,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
232#[serde(transparent)]
233pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct PredictEditsResponse {
237    pub request_id: Uuid,
238    pub edits: Vec<Edit>,
239    pub debug_info: Option<DebugInfo>,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct DebugInfo {
244    pub prompt: String,
245    pub prompt_planning_time: Duration,
246    pub model_response: String,
247    pub inference_time: Duration,
248    pub parsing_time: Duration,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct Edit {
253    pub path: Arc<Path>,
254    pub range: Range<Line>,
255    pub content: String,
256}
257
258fn is_default<T: Default + PartialEq>(value: &T) -> bool {
259    *value == T::default()
260}
261
262#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
263pub struct Point {
264    pub line: Line,
265    pub column: u32,
266}
267
268#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
269#[serde(transparent)]
270pub struct Line(pub u32);
271
272impl Add for Line {
273    type Output = Self;
274
275    fn add(self, rhs: Self) -> Self::Output {
276        Self(self.0 + rhs.0)
277    }
278}
279
280impl Sub for Line {
281    type Output = Self;
282
283    fn sub(self, rhs: Self) -> Self::Output {
284        Self(self.0 - rhs.0)
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use indoc::indoc;
292    use pretty_assertions::assert_eq;
293
294    #[test]
295    fn test_event_display() {
296        let ev = Event::BufferChange {
297            path: None,
298            old_path: None,
299            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
300            predicted: false,
301        };
302        assert_eq!(
303            ev.to_string(),
304            indoc! {"
305                --- a/untitled
306                +++ b/untitled
307                @@ -1,2 +1,2 @@
308                -a
309                -b
310            "}
311        );
312
313        let ev = Event::BufferChange {
314            path: Some(PathBuf::from("foo/bar.txt")),
315            old_path: Some(PathBuf::from("foo/bar.txt")),
316            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
317            predicted: false,
318        };
319        assert_eq!(
320            ev.to_string(),
321            indoc! {"
322                --- a/foo/bar.txt
323                +++ b/foo/bar.txt
324                @@ -1,2 +1,2 @@
325                -a
326                -b
327            "}
328        );
329
330        let ev = Event::BufferChange {
331            path: Some(PathBuf::from("abc.txt")),
332            old_path: Some(PathBuf::from("123.txt")),
333            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
334            predicted: false,
335        };
336        assert_eq!(
337            ev.to_string(),
338            indoc! {"
339                --- a/123.txt
340                +++ b/abc.txt
341                @@ -1,2 +1,2 @@
342                -a
343                -b
344            "}
345        );
346
347        let ev = Event::BufferChange {
348            path: Some(PathBuf::from("abc.txt")),
349            old_path: Some(PathBuf::from("123.txt")),
350            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
351            predicted: true,
352        };
353        assert_eq!(
354            ev.to_string(),
355            indoc! {"
356                // User accepted prediction:
357                --- a/123.txt
358                +++ b/abc.txt
359                @@ -1,2 +1,2 @@
360                -a
361                -b
362            "}
363        );
364    }
365}