predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::{Display, Write as _},
  5    ops::{Add, Range, Sub},
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::PredictEditsGitInfo;
 13
 14#[derive(Debug, Clone, Serialize, Deserialize)]
 15pub struct PlanContextRetrievalRequest {
 16    pub excerpt: String,
 17    pub excerpt_path: Arc<Path>,
 18    pub excerpt_line_range: Range<Line>,
 19    pub cursor_file_max_row: Line,
 20    pub events: Vec<Event>,
 21}
 22
 23#[derive(Debug, Clone, Serialize, Deserialize)]
 24pub struct PredictEditsRequest {
 25    pub excerpt: String,
 26    pub excerpt_path: Arc<Path>,
 27    /// Within file
 28    pub excerpt_range: Range<usize>,
 29    pub excerpt_line_range: Range<Line>,
 30    pub cursor_point: Point,
 31    /// Within `signatures`
 32    pub excerpt_parent: Option<usize>,
 33    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 34    pub included_files: Vec<IncludedFile>,
 35    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 36    pub signatures: Vec<Signature>,
 37    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 38    pub referenced_declarations: Vec<ReferencedDeclaration>,
 39    pub events: Vec<Event>,
 40    #[serde(default)]
 41    pub can_collect_data: bool,
 42    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 43    pub diagnostic_groups: Vec<DiagnosticGroup>,
 44    #[serde(skip_serializing_if = "is_default", default)]
 45    pub diagnostic_groups_truncated: bool,
 46    /// Info about the git repository state, only present when can_collect_data is true.
 47    #[serde(skip_serializing_if = "Option::is_none", default)]
 48    pub git_info: Option<PredictEditsGitInfo>,
 49    // Only available to staff
 50    #[serde(default)]
 51    pub debug_info: bool,
 52    #[serde(skip_serializing_if = "Option::is_none", default)]
 53    pub prompt_max_bytes: Option<usize>,
 54    #[serde(default)]
 55    pub prompt_format: PromptFormat,
 56}
 57
 58#[derive(Debug, Clone, Serialize, Deserialize)]
 59pub struct IncludedFile {
 60    pub path: Arc<Path>,
 61    pub max_row: Line,
 62    pub excerpts: Vec<Excerpt>,
 63}
 64
 65#[derive(Debug, Clone, Serialize, Deserialize)]
 66pub struct Excerpt {
 67    pub start_line: Line,
 68    pub text: Arc<str>,
 69}
 70
 71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 72pub enum PromptFormat {
 73    MarkedExcerpt,
 74    LabeledSections,
 75    NumLinesUniDiff,
 76    OldTextNewText,
 77    /// Prompt format intended for use via zeta_cli
 78    OnlySnippets,
 79    /// One-sentence instructions used in fine-tuned models
 80    Minimal,
 81    /// One-sentence instructions + FIM-like template
 82    MinimalQwen,
 83}
 84
 85impl PromptFormat {
 86    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 87}
 88
 89impl Default for PromptFormat {
 90    fn default() -> Self {
 91        Self::DEFAULT
 92    }
 93}
 94
 95impl PromptFormat {
 96    pub fn iter() -> impl Iterator<Item = Self> {
 97        <Self as strum::IntoEnumIterator>::iter()
 98    }
 99}
100
101impl std::fmt::Display for PromptFormat {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
105            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
106            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
107            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
108            PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
109            PromptFormat::Minimal => write!(f, "Minimal"),
110            PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
111        }
112    }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
117#[serde(tag = "event")]
118pub enum Event {
119    BufferChange {
120        path: Option<PathBuf>,
121        old_path: Option<PathBuf>,
122        diff: String,
123        predicted: bool,
124    },
125}
126
127impl Display for Event {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        match self {
130            Event::BufferChange {
131                path,
132                old_path,
133                diff,
134                predicted,
135            } => {
136                let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
137                let old_path = old_path.as_deref().unwrap_or(new_path);
138
139                if *predicted {
140                    write!(
141                        f,
142                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
143                        DiffPathFmt(old_path),
144                        DiffPathFmt(new_path)
145                    )
146                } else {
147                    write!(
148                        f,
149                        "--- a/{}\n+++ b/{}\n{diff}",
150                        DiffPathFmt(old_path),
151                        DiffPathFmt(new_path)
152                    )
153                }
154            }
155        }
156    }
157}
158
159/// always format the Path as a unix path with `/` as the path sep in Diffs
160pub struct DiffPathFmt<'a>(pub &'a Path);
161
162impl<'a> std::fmt::Display for DiffPathFmt<'a> {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        let mut is_first = true;
165        for component in self.0.components() {
166            if !is_first {
167                f.write_char('/')?;
168            } else {
169                is_first = false;
170            }
171            write!(f, "{}", component.as_os_str().display())?;
172        }
173        Ok(())
174    }
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct Signature {
179    pub text: String,
180    pub text_is_truncated: bool,
181    #[serde(skip_serializing_if = "Option::is_none", default)]
182    pub parent_index: Option<usize>,
183    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
184    /// file is implicitly the file that contains the descendant declaration or excerpt.
185    pub range: Range<Line>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct ReferencedDeclaration {
190    pub path: Arc<Path>,
191    pub text: String,
192    pub text_is_truncated: bool,
193    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
194    pub range: Range<Line>,
195    /// Range within `text`
196    pub signature_range: Range<usize>,
197    /// Index within `signatures`.
198    #[serde(skip_serializing_if = "Option::is_none", default)]
199    pub parent_index: Option<usize>,
200    pub score_components: DeclarationScoreComponents,
201    pub signature_score: f32,
202    pub declaration_score: f32,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct DeclarationScoreComponents {
207    pub is_same_file: bool,
208    pub is_referenced_nearby: bool,
209    pub is_referenced_in_breadcrumb: bool,
210    pub reference_count: usize,
211    pub same_file_declaration_count: usize,
212    pub declaration_count: usize,
213    pub reference_line_distance: u32,
214    pub declaration_line_distance: u32,
215    pub excerpt_vs_item_jaccard: f32,
216    pub excerpt_vs_signature_jaccard: f32,
217    pub adjacent_vs_item_jaccard: f32,
218    pub adjacent_vs_signature_jaccard: f32,
219    pub excerpt_vs_item_weighted_overlap: f32,
220    pub excerpt_vs_signature_weighted_overlap: f32,
221    pub adjacent_vs_item_weighted_overlap: f32,
222    pub adjacent_vs_signature_weighted_overlap: f32,
223    pub path_import_match_count: usize,
224    pub wildcard_path_import_match_count: usize,
225    pub import_similarity: f32,
226    pub max_import_similarity: f32,
227    pub normalized_import_similarity: f32,
228    pub wildcard_import_similarity: f32,
229    pub normalized_wildcard_import_similarity: f32,
230    pub included_by_others: usize,
231    pub includes_others: usize,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
235#[serde(transparent)]
236pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct PredictEditsResponse {
240    pub request_id: Uuid,
241    pub edits: Vec<Edit>,
242    pub debug_info: Option<DebugInfo>,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct DebugInfo {
247    pub prompt: String,
248    pub prompt_planning_time: Duration,
249    pub model_response: String,
250    pub inference_time: Duration,
251    pub parsing_time: Duration,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct Edit {
256    pub path: Arc<Path>,
257    pub range: Range<Line>,
258    pub content: String,
259}
260
261fn is_default<T: Default + PartialEq>(value: &T) -> bool {
262    *value == T::default()
263}
264
265#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
266pub struct Point {
267    pub line: Line,
268    pub column: u32,
269}
270
271#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
272#[serde(transparent)]
273pub struct Line(pub u32);
274
275impl Add for Line {
276    type Output = Self;
277
278    fn add(self, rhs: Self) -> Self::Output {
279        Self(self.0 + rhs.0)
280    }
281}
282
283impl Sub for Line {
284    type Output = Self;
285
286    fn sub(self, rhs: Self) -> Self::Output {
287        Self(self.0 - rhs.0)
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use indoc::indoc;
295    use pretty_assertions::assert_eq;
296
297    #[test]
298    fn test_event_display() {
299        let ev = Event::BufferChange {
300            path: None,
301            old_path: None,
302            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
303            predicted: false,
304        };
305        assert_eq!(
306            ev.to_string(),
307            indoc! {"
308                --- a/untitled
309                +++ b/untitled
310                @@ -1,2 +1,2 @@
311                -a
312                -b
313            "}
314        );
315
316        let ev = Event::BufferChange {
317            path: Some(PathBuf::from("foo/bar.txt")),
318            old_path: Some(PathBuf::from("foo/bar.txt")),
319            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
320            predicted: false,
321        };
322        assert_eq!(
323            ev.to_string(),
324            indoc! {"
325                --- a/foo/bar.txt
326                +++ b/foo/bar.txt
327                @@ -1,2 +1,2 @@
328                -a
329                -b
330            "}
331        );
332
333        let ev = Event::BufferChange {
334            path: Some(PathBuf::from("abc.txt")),
335            old_path: Some(PathBuf::from("123.txt")),
336            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
337            predicted: false,
338        };
339        assert_eq!(
340            ev.to_string(),
341            indoc! {"
342                --- a/123.txt
343                +++ b/abc.txt
344                @@ -1,2 +1,2 @@
345                -a
346                -b
347            "}
348        );
349
350        let ev = Event::BufferChange {
351            path: Some(PathBuf::from("abc.txt")),
352            old_path: Some(PathBuf::from("123.txt")),
353            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
354            predicted: true,
355        };
356        assert_eq!(
357            ev.to_string(),
358            indoc! {"
359                // User accepted prediction:
360                --- a/123.txt
361                +++ b/abc.txt
362                @@ -1,2 +1,2 @@
363                -a
364                -b
365            "}
366        );
367    }
368}