predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::{Display, Write as _},
  5    ops::{Add, Range, Sub},
  6    path::Path,
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::PredictEditsGitInfo;
 13
 14#[derive(Debug, Clone, Serialize, Deserialize)]
 15pub struct PlanContextRetrievalRequest {
 16    pub excerpt: String,
 17    pub excerpt_path: Arc<Path>,
 18    pub excerpt_line_range: Range<Line>,
 19    pub cursor_file_max_row: Line,
 20    pub events: Vec<Arc<Event>>,
 21}
 22
 23#[derive(Debug, Clone, Serialize, Deserialize)]
 24pub struct PredictEditsRequest {
 25    pub excerpt: String,
 26    pub excerpt_path: Arc<Path>,
 27    /// Within file
 28    pub excerpt_range: Range<usize>,
 29    pub excerpt_line_range: Range<Line>,
 30    pub cursor_point: Point,
 31    /// Within `signatures`
 32    pub excerpt_parent: Option<usize>,
 33    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 34    pub included_files: Vec<IncludedFile>,
 35    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 36    pub signatures: Vec<Signature>,
 37    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 38    pub referenced_declarations: Vec<ReferencedDeclaration>,
 39    pub events: Vec<Arc<Event>>,
 40    #[serde(default)]
 41    pub can_collect_data: bool,
 42    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 43    pub diagnostic_groups: Vec<DiagnosticGroup>,
 44    #[serde(skip_serializing_if = "is_default", default)]
 45    pub diagnostic_groups_truncated: bool,
 46    /// Info about the git repository state, only present when can_collect_data is true.
 47    #[serde(skip_serializing_if = "Option::is_none", default)]
 48    pub git_info: Option<PredictEditsGitInfo>,
 49    // Only available to staff
 50    #[serde(default)]
 51    pub debug_info: bool,
 52    #[serde(skip_serializing_if = "Option::is_none", default)]
 53    pub prompt_max_bytes: Option<usize>,
 54    #[serde(default)]
 55    pub prompt_format: PromptFormat,
 56}
 57
 58#[derive(Debug, Clone, Serialize, Deserialize)]
 59pub struct IncludedFile {
 60    pub path: Arc<Path>,
 61    pub max_row: Line,
 62    pub excerpts: Vec<Excerpt>,
 63}
 64
 65#[derive(Debug, Clone, Serialize, Deserialize)]
 66pub struct Excerpt {
 67    pub start_line: Line,
 68    pub text: Arc<str>,
 69}
 70
 71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 72pub enum PromptFormat {
 73    MarkedExcerpt,
 74    LabeledSections,
 75    NumLinesUniDiff,
 76    OldTextNewText,
 77    /// Prompt format intended for use via zeta_cli
 78    OnlySnippets,
 79    /// One-sentence instructions used in fine-tuned models
 80    Minimal,
 81    /// One-sentence instructions + FIM-like template
 82    MinimalQwen,
 83    /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
 84    SeedCoder1120,
 85}
 86
 87impl PromptFormat {
 88    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 89}
 90
 91impl Default for PromptFormat {
 92    fn default() -> Self {
 93        Self::DEFAULT
 94    }
 95}
 96
 97impl PromptFormat {
 98    pub fn iter() -> impl Iterator<Item = Self> {
 99        <Self as strum::IntoEnumIterator>::iter()
100    }
101}
102
103impl std::fmt::Display for PromptFormat {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
107            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
108            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
109            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
110            PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
111            PromptFormat::Minimal => write!(f, "Minimal"),
112            PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
113            PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
114        }
115    }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
120#[serde(tag = "event")]
121pub enum Event {
122    BufferChange {
123        path: Arc<Path>,
124        old_path: Arc<Path>,
125        diff: String,
126        predicted: bool,
127        in_open_source_repo: bool,
128    },
129}
130
131impl Display for Event {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        match self {
134            Event::BufferChange {
135                path,
136                old_path,
137                diff,
138                predicted,
139                ..
140            } => {
141                if *predicted {
142                    write!(
143                        f,
144                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
145                        DiffPathFmt(old_path),
146                        DiffPathFmt(path)
147                    )
148                } else {
149                    write!(
150                        f,
151                        "--- a/{}\n+++ b/{}\n{diff}",
152                        DiffPathFmt(old_path),
153                        DiffPathFmt(path)
154                    )
155                }
156            }
157        }
158    }
159}
160
161/// always format the Path as a unix path with `/` as the path sep in Diffs
162pub struct DiffPathFmt<'a>(pub &'a Path);
163
164impl<'a> std::fmt::Display for DiffPathFmt<'a> {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        let mut is_first = true;
167        for component in self.0.components() {
168            if !is_first {
169                f.write_char('/')?;
170            } else {
171                is_first = false;
172            }
173            write!(f, "{}", component.as_os_str().display())?;
174        }
175        Ok(())
176    }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct Signature {
181    pub text: String,
182    pub text_is_truncated: bool,
183    #[serde(skip_serializing_if = "Option::is_none", default)]
184    pub parent_index: Option<usize>,
185    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
186    /// file is implicitly the file that contains the descendant declaration or excerpt.
187    pub range: Range<Line>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct ReferencedDeclaration {
192    pub path: Arc<Path>,
193    pub text: String,
194    pub text_is_truncated: bool,
195    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
196    pub range: Range<Line>,
197    /// Range within `text`
198    pub signature_range: Range<usize>,
199    /// Index within `signatures`.
200    #[serde(skip_serializing_if = "Option::is_none", default)]
201    pub parent_index: Option<usize>,
202    pub score_components: DeclarationScoreComponents,
203    pub signature_score: f32,
204    pub declaration_score: f32,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct DeclarationScoreComponents {
209    pub is_same_file: bool,
210    pub is_referenced_nearby: bool,
211    pub is_referenced_in_breadcrumb: bool,
212    pub reference_count: usize,
213    pub same_file_declaration_count: usize,
214    pub declaration_count: usize,
215    pub reference_line_distance: u32,
216    pub declaration_line_distance: u32,
217    pub excerpt_vs_item_jaccard: f32,
218    pub excerpt_vs_signature_jaccard: f32,
219    pub adjacent_vs_item_jaccard: f32,
220    pub adjacent_vs_signature_jaccard: f32,
221    pub excerpt_vs_item_weighted_overlap: f32,
222    pub excerpt_vs_signature_weighted_overlap: f32,
223    pub adjacent_vs_item_weighted_overlap: f32,
224    pub adjacent_vs_signature_weighted_overlap: f32,
225    pub path_import_match_count: usize,
226    pub wildcard_path_import_match_count: usize,
227    pub import_similarity: f32,
228    pub max_import_similarity: f32,
229    pub normalized_import_similarity: f32,
230    pub wildcard_import_similarity: f32,
231    pub normalized_wildcard_import_similarity: f32,
232    pub included_by_others: usize,
233    pub includes_others: usize,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
237#[serde(transparent)]
238pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct PredictEditsResponse {
242    pub request_id: Uuid,
243    pub edits: Vec<Edit>,
244    pub debug_info: Option<DebugInfo>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct DebugInfo {
249    pub prompt: String,
250    pub prompt_planning_time: Duration,
251    pub model_response: String,
252    pub inference_time: Duration,
253    pub parsing_time: Duration,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct Edit {
258    pub path: Arc<Path>,
259    pub range: Range<Line>,
260    pub content: String,
261}
262
263fn is_default<T: Default + PartialEq>(value: &T) -> bool {
264    *value == T::default()
265}
266
267#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
268pub struct Point {
269    pub line: Line,
270    pub column: u32,
271}
272
273#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
274#[serde(transparent)]
275pub struct Line(pub u32);
276
277impl Add for Line {
278    type Output = Self;
279
280    fn add(self, rhs: Self) -> Self::Output {
281        Self(self.0 + rhs.0)
282    }
283}
284
285impl Sub for Line {
286    type Output = Self;
287
288    fn sub(self, rhs: Self) -> Self::Output {
289        Self(self.0 - rhs.0)
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use indoc::indoc;
297    use pretty_assertions::assert_eq;
298
299    #[test]
300    fn test_event_display() {
301        let ev = Event::BufferChange {
302            path: Path::new("untitled").into(),
303            old_path: Path::new("untitled").into(),
304            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
305            predicted: false,
306            in_open_source_repo: true,
307        };
308        assert_eq!(
309            ev.to_string(),
310            indoc! {"
311                --- a/untitled
312                +++ b/untitled
313                @@ -1,2 +1,2 @@
314                -a
315                -b
316            "}
317        );
318
319        let ev = Event::BufferChange {
320            path: Path::new("foo/bar.txt").into(),
321            old_path: Path::new("foo/bar.txt").into(),
322            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
323            predicted: false,
324            in_open_source_repo: true,
325        };
326        assert_eq!(
327            ev.to_string(),
328            indoc! {"
329                --- a/foo/bar.txt
330                +++ b/foo/bar.txt
331                @@ -1,2 +1,2 @@
332                -a
333                -b
334            "}
335        );
336
337        let ev = Event::BufferChange {
338            path: Path::new("abc.txt").into(),
339            old_path: Path::new("123.txt").into(),
340            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
341            predicted: false,
342            in_open_source_repo: true,
343        };
344        assert_eq!(
345            ev.to_string(),
346            indoc! {"
347                --- a/123.txt
348                +++ b/abc.txt
349                @@ -1,2 +1,2 @@
350                -a
351                -b
352            "}
353        );
354
355        let ev = Event::BufferChange {
356            path: Path::new("abc.txt").into(),
357            old_path: Path::new("123.txt").into(),
358            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
359            predicted: true,
360            in_open_source_repo: true,
361        };
362        assert_eq!(
363            ev.to_string(),
364            indoc! {"
365                // User accepted prediction:
366                --- a/123.txt
367                +++ b/abc.txt
368                @@ -1,2 +1,2 @@
369                -a
370                -b
371            "}
372        );
373    }
374}