predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::{Display, Write as _},
  5    ops::{Add, Range, Sub},
  6    path::Path,
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger};
 13
 14#[derive(Debug, Clone, Serialize, Deserialize)]
 15pub struct PlanContextRetrievalRequest {
 16    pub excerpt: String,
 17    pub excerpt_path: Arc<Path>,
 18    pub excerpt_line_range: Range<Line>,
 19    pub cursor_file_max_row: Line,
 20    pub events: Vec<Arc<Event>>,
 21}
 22
 23#[derive(Debug, Clone, Serialize, Deserialize)]
 24pub struct PredictEditsRequest {
 25    pub excerpt: String,
 26    pub excerpt_path: Arc<Path>,
 27    /// Within file
 28    pub excerpt_range: Range<usize>,
 29    pub excerpt_line_range: Range<Line>,
 30    pub cursor_point: Point,
 31    /// Within `signatures`
 32    pub excerpt_parent: Option<usize>,
 33    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 34    pub included_files: Vec<IncludedFile>,
 35    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 36    pub signatures: Vec<Signature>,
 37    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 38    pub referenced_declarations: Vec<ReferencedDeclaration>,
 39    pub events: Vec<Arc<Event>>,
 40    #[serde(default)]
 41    pub can_collect_data: bool,
 42    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 43    pub diagnostic_groups: Vec<DiagnosticGroup>,
 44    #[serde(skip_serializing_if = "is_default", default)]
 45    pub diagnostic_groups_truncated: bool,
 46    /// Info about the git repository state, only present when can_collect_data is true.
 47    #[serde(skip_serializing_if = "Option::is_none", default)]
 48    pub git_info: Option<PredictEditsGitInfo>,
 49    // Only available to staff
 50    #[serde(default)]
 51    pub debug_info: bool,
 52    #[serde(skip_serializing_if = "Option::is_none", default)]
 53    pub prompt_max_bytes: Option<usize>,
 54    #[serde(default)]
 55    pub prompt_format: PromptFormat,
 56    #[serde(default)]
 57    pub trigger: PredictEditsRequestTrigger,
 58}
 59
 60#[derive(Debug, Clone, Serialize, Deserialize)]
 61pub struct IncludedFile {
 62    pub path: Arc<Path>,
 63    pub max_row: Line,
 64    pub excerpts: Vec<Excerpt>,
 65}
 66
 67#[derive(Debug, Clone, Serialize, Deserialize)]
 68pub struct Excerpt {
 69    pub start_line: Line,
 70    pub text: Arc<str>,
 71}
 72
 73#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 74pub enum PromptFormat {
 75    MarkedExcerpt,
 76    LabeledSections,
 77    NumLinesUniDiff,
 78    OldTextNewText,
 79    /// Prompt format intended for use via zeta_cli
 80    OnlySnippets,
 81    /// One-sentence instructions used in fine-tuned models
 82    Minimal,
 83    /// One-sentence instructions + FIM-like template
 84    MinimalQwen,
 85    /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
 86    SeedCoder1120,
 87}
 88
 89impl PromptFormat {
 90    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 91}
 92
 93impl Default for PromptFormat {
 94    fn default() -> Self {
 95        Self::DEFAULT
 96    }
 97}
 98
 99impl PromptFormat {
100    pub fn iter() -> impl Iterator<Item = Self> {
101        <Self as strum::IntoEnumIterator>::iter()
102    }
103}
104
105impl std::fmt::Display for PromptFormat {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
109            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
110            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
111            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
112            PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
113            PromptFormat::Minimal => write!(f, "Minimal"),
114            PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
115            PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
116        }
117    }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
122#[serde(tag = "event")]
123pub enum Event {
124    BufferChange {
125        path: Arc<Path>,
126        old_path: Arc<Path>,
127        diff: String,
128        predicted: bool,
129        in_open_source_repo: bool,
130    },
131}
132
133impl Display for Event {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match self {
136            Event::BufferChange {
137                path,
138                old_path,
139                diff,
140                predicted,
141                ..
142            } => {
143                if *predicted {
144                    write!(
145                        f,
146                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
147                        DiffPathFmt(old_path),
148                        DiffPathFmt(path)
149                    )
150                } else {
151                    write!(
152                        f,
153                        "--- a/{}\n+++ b/{}\n{diff}",
154                        DiffPathFmt(old_path),
155                        DiffPathFmt(path)
156                    )
157                }
158            }
159        }
160    }
161}
162
163/// always format the Path as a unix path with `/` as the path sep in Diffs
164pub struct DiffPathFmt<'a>(pub &'a Path);
165
166impl<'a> std::fmt::Display for DiffPathFmt<'a> {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        let mut is_first = true;
169        for component in self.0.components() {
170            if !is_first {
171                f.write_char('/')?;
172            } else {
173                is_first = false;
174            }
175            write!(f, "{}", component.as_os_str().display())?;
176        }
177        Ok(())
178    }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct Signature {
183    pub text: String,
184    pub text_is_truncated: bool,
185    #[serde(skip_serializing_if = "Option::is_none", default)]
186    pub parent_index: Option<usize>,
187    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
188    /// file is implicitly the file that contains the descendant declaration or excerpt.
189    pub range: Range<Line>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct ReferencedDeclaration {
194    pub path: Arc<Path>,
195    pub text: String,
196    pub text_is_truncated: bool,
197    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
198    pub range: Range<Line>,
199    /// Range within `text`
200    pub signature_range: Range<usize>,
201    /// Index within `signatures`.
202    #[serde(skip_serializing_if = "Option::is_none", default)]
203    pub parent_index: Option<usize>,
204    pub score_components: DeclarationScoreComponents,
205    pub signature_score: f32,
206    pub declaration_score: f32,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct DeclarationScoreComponents {
211    pub is_same_file: bool,
212    pub is_referenced_nearby: bool,
213    pub is_referenced_in_breadcrumb: bool,
214    pub reference_count: usize,
215    pub same_file_declaration_count: usize,
216    pub declaration_count: usize,
217    pub reference_line_distance: u32,
218    pub declaration_line_distance: u32,
219    pub excerpt_vs_item_jaccard: f32,
220    pub excerpt_vs_signature_jaccard: f32,
221    pub adjacent_vs_item_jaccard: f32,
222    pub adjacent_vs_signature_jaccard: f32,
223    pub excerpt_vs_item_weighted_overlap: f32,
224    pub excerpt_vs_signature_weighted_overlap: f32,
225    pub adjacent_vs_item_weighted_overlap: f32,
226    pub adjacent_vs_signature_weighted_overlap: f32,
227    pub path_import_match_count: usize,
228    pub wildcard_path_import_match_count: usize,
229    pub import_similarity: f32,
230    pub max_import_similarity: f32,
231    pub normalized_import_similarity: f32,
232    pub wildcard_import_similarity: f32,
233    pub normalized_wildcard_import_similarity: f32,
234    pub included_by_others: usize,
235    pub includes_others: usize,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239#[serde(transparent)]
240pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct PredictEditsResponse {
244    pub request_id: Uuid,
245    pub edits: Vec<Edit>,
246    pub debug_info: Option<DebugInfo>,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct DebugInfo {
251    pub prompt: String,
252    pub prompt_planning_time: Duration,
253    pub model_response: String,
254    pub inference_time: Duration,
255    pub parsing_time: Duration,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct Edit {
260    pub path: Arc<Path>,
261    pub range: Range<Line>,
262    pub content: String,
263}
264
265fn is_default<T: Default + PartialEq>(value: &T) -> bool {
266    *value == T::default()
267}
268
269#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
270pub struct Point {
271    pub line: Line,
272    pub column: u32,
273}
274
275#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
276#[serde(transparent)]
277pub struct Line(pub u32);
278
279impl Add for Line {
280    type Output = Self;
281
282    fn add(self, rhs: Self) -> Self::Output {
283        Self(self.0 + rhs.0)
284    }
285}
286
287impl Sub for Line {
288    type Output = Self;
289
290    fn sub(self, rhs: Self) -> Self::Output {
291        Self(self.0 - rhs.0)
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use indoc::indoc;
299    use pretty_assertions::assert_eq;
300
301    #[test]
302    fn test_event_display() {
303        let ev = Event::BufferChange {
304            path: Path::new("untitled").into(),
305            old_path: Path::new("untitled").into(),
306            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
307            predicted: false,
308            in_open_source_repo: true,
309        };
310        assert_eq!(
311            ev.to_string(),
312            indoc! {"
313                --- a/untitled
314                +++ b/untitled
315                @@ -1,2 +1,2 @@
316                -a
317                -b
318            "}
319        );
320
321        let ev = Event::BufferChange {
322            path: Path::new("foo/bar.txt").into(),
323            old_path: Path::new("foo/bar.txt").into(),
324            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
325            predicted: false,
326            in_open_source_repo: true,
327        };
328        assert_eq!(
329            ev.to_string(),
330            indoc! {"
331                --- a/foo/bar.txt
332                +++ b/foo/bar.txt
333                @@ -1,2 +1,2 @@
334                -a
335                -b
336            "}
337        );
338
339        let ev = Event::BufferChange {
340            path: Path::new("abc.txt").into(),
341            old_path: Path::new("123.txt").into(),
342            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
343            predicted: false,
344            in_open_source_repo: true,
345        };
346        assert_eq!(
347            ev.to_string(),
348            indoc! {"
349                --- a/123.txt
350                +++ b/abc.txt
351                @@ -1,2 +1,2 @@
352                -a
353                -b
354            "}
355        );
356
357        let ev = Event::BufferChange {
358            path: Path::new("abc.txt").into(),
359            old_path: Path::new("123.txt").into(),
360            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
361            predicted: true,
362            in_open_source_repo: true,
363        };
364        assert_eq!(
365            ev.to_string(),
366            indoc! {"
367                // User accepted prediction:
368                --- a/123.txt
369                +++ b/abc.txt
370                @@ -1,2 +1,2 @@
371                -a
372                -b
373            "}
374        );
375    }
376}