predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::Display,
  5    ops::{Add, Range, Sub},
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::PredictEditsGitInfo;
 13
 14// TODO: snippet ordering within file / relative to excerpt
 15
 16#[derive(Debug, Clone, Serialize, Deserialize)]
 17pub struct PredictEditsRequest {
 18    pub excerpt: String,
 19    pub excerpt_path: Arc<Path>,
 20    /// Within file
 21    pub excerpt_range: Range<usize>,
 22    pub excerpt_line_range: Range<Line>,
 23    pub cursor_point: Point,
 24    /// Within `signatures`
 25    pub excerpt_parent: Option<usize>,
 26    pub signatures: Vec<Signature>,
 27    pub referenced_declarations: Vec<ReferencedDeclaration>,
 28    pub events: Vec<Event>,
 29    #[serde(default)]
 30    pub can_collect_data: bool,
 31    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 32    pub diagnostic_groups: Vec<DiagnosticGroup>,
 33    #[serde(skip_serializing_if = "is_default", default)]
 34    pub diagnostic_groups_truncated: bool,
 35    /// Info about the git repository state, only present when can_collect_data is true.
 36    #[serde(skip_serializing_if = "Option::is_none", default)]
 37    pub git_info: Option<PredictEditsGitInfo>,
 38    // Only available to staff
 39    #[serde(default)]
 40    pub debug_info: bool,
 41    #[serde(skip_serializing_if = "Option::is_none", default)]
 42    pub prompt_max_bytes: Option<usize>,
 43    #[serde(default)]
 44    pub prompt_format: PromptFormat,
 45}
 46
 47#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 48pub enum PromptFormat {
 49    MarkedExcerpt,
 50    LabeledSections,
 51    NumLinesUniDiff,
 52    /// Prompt format intended for use via zeta_cli
 53    OnlySnippets,
 54}
 55
 56impl PromptFormat {
 57    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 58}
 59
 60impl Default for PromptFormat {
 61    fn default() -> Self {
 62        Self::DEFAULT
 63    }
 64}
 65
 66impl PromptFormat {
 67    pub fn iter() -> impl Iterator<Item = Self> {
 68        <Self as strum::IntoEnumIterator>::iter()
 69    }
 70}
 71
 72impl std::fmt::Display for PromptFormat {
 73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 74        match self {
 75            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
 76            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
 77            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
 78            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
 79        }
 80    }
 81}
 82
 83#[derive(Debug, Clone, Serialize, Deserialize)]
 84#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
 85#[serde(tag = "event")]
 86pub enum Event {
 87    BufferChange {
 88        path: Option<PathBuf>,
 89        old_path: Option<PathBuf>,
 90        diff: String,
 91        predicted: bool,
 92    },
 93}
 94
 95impl Display for Event {
 96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 97        match self {
 98            Event::BufferChange {
 99                path,
100                old_path,
101                diff,
102                predicted,
103            } => {
104                let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
105                let old_path = old_path.as_deref().unwrap_or(new_path);
106
107                if *predicted {
108                    write!(
109                        f,
110                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
111                        old_path.display(),
112                        new_path.display()
113                    )
114                } else {
115                    write!(
116                        f,
117                        "--- a/{}\n+++ b/{}\n{diff}",
118                        old_path.display(),
119                        new_path.display()
120                    )
121                }
122            }
123        }
124    }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct Signature {
129    pub text: String,
130    pub text_is_truncated: bool,
131    #[serde(skip_serializing_if = "Option::is_none", default)]
132    pub parent_index: Option<usize>,
133    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
134    /// file is implicitly the file that contains the descendant declaration or excerpt.
135    pub range: Range<Line>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ReferencedDeclaration {
140    pub path: Arc<Path>,
141    pub text: String,
142    pub text_is_truncated: bool,
143    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
144    pub range: Range<Line>,
145    /// Range within `text`
146    pub signature_range: Range<usize>,
147    /// Index within `signatures`.
148    #[serde(skip_serializing_if = "Option::is_none", default)]
149    pub parent_index: Option<usize>,
150    pub score_components: DeclarationScoreComponents,
151    pub signature_score: f32,
152    pub declaration_score: f32,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct DeclarationScoreComponents {
157    pub is_same_file: bool,
158    pub is_referenced_nearby: bool,
159    pub is_referenced_in_breadcrumb: bool,
160    pub reference_count: usize,
161    pub same_file_declaration_count: usize,
162    pub declaration_count: usize,
163    pub reference_line_distance: u32,
164    pub declaration_line_distance: u32,
165    pub excerpt_vs_item_jaccard: f32,
166    pub excerpt_vs_signature_jaccard: f32,
167    pub adjacent_vs_item_jaccard: f32,
168    pub adjacent_vs_signature_jaccard: f32,
169    pub excerpt_vs_item_weighted_overlap: f32,
170    pub excerpt_vs_signature_weighted_overlap: f32,
171    pub adjacent_vs_item_weighted_overlap: f32,
172    pub adjacent_vs_signature_weighted_overlap: f32,
173    pub path_import_match_count: usize,
174    pub wildcard_path_import_match_count: usize,
175    pub import_similarity: f32,
176    pub max_import_similarity: f32,
177    pub normalized_import_similarity: f32,
178    pub wildcard_import_similarity: f32,
179    pub normalized_wildcard_import_similarity: f32,
180    pub included_by_others: usize,
181    pub includes_others: usize,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185#[serde(transparent)]
186pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct PredictEditsResponse {
190    pub request_id: Uuid,
191    pub edits: Vec<Edit>,
192    pub debug_info: Option<DebugInfo>,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct DebugInfo {
197    pub prompt: String,
198    pub prompt_planning_time: Duration,
199    pub model_response: String,
200    pub inference_time: Duration,
201    pub parsing_time: Duration,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct Edit {
206    pub path: Arc<Path>,
207    pub range: Range<Line>,
208    pub content: String,
209}
210
211fn is_default<T: Default + PartialEq>(value: &T) -> bool {
212    *value == T::default()
213}
214
215#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
216pub struct Point {
217    pub line: Line,
218    pub column: u32,
219}
220
221#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
222#[serde(transparent)]
223pub struct Line(pub u32);
224
225impl Add for Line {
226    type Output = Self;
227
228    fn add(self, rhs: Self) -> Self::Output {
229        Self(self.0 + rhs.0)
230    }
231}
232
233impl Sub for Line {
234    type Output = Self;
235
236    fn sub(self, rhs: Self) -> Self::Output {
237        Self(self.0 - rhs.0)
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use indoc::indoc;
245    use pretty_assertions::assert_eq;
246
247    #[test]
248    fn test_event_display() {
249        let ev = Event::BufferChange {
250            path: None,
251            old_path: None,
252            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
253            predicted: false,
254        };
255        assert_eq!(
256            ev.to_string(),
257            indoc! {"
258                --- a/untitled
259                +++ b/untitled
260                @@ -1,2 +1,2 @@
261                -a
262                -b
263            "}
264        );
265
266        let ev = Event::BufferChange {
267            path: Some(PathBuf::from("foo/bar.txt")),
268            old_path: Some(PathBuf::from("foo/bar.txt")),
269            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
270            predicted: false,
271        };
272        assert_eq!(
273            ev.to_string(),
274            indoc! {"
275                --- a/foo/bar.txt
276                +++ b/foo/bar.txt
277                @@ -1,2 +1,2 @@
278                -a
279                -b
280            "}
281        );
282
283        let ev = Event::BufferChange {
284            path: Some(PathBuf::from("abc.txt")),
285            old_path: Some(PathBuf::from("123.txt")),
286            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
287            predicted: false,
288        };
289        assert_eq!(
290            ev.to_string(),
291            indoc! {"
292                --- a/123.txt
293                +++ b/abc.txt
294                @@ -1,2 +1,2 @@
295                -a
296                -b
297            "}
298        );
299
300        let ev = Event::BufferChange {
301            path: Some(PathBuf::from("abc.txt")),
302            old_path: Some(PathBuf::from("123.txt")),
303            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
304            predicted: true,
305        };
306        assert_eq!(
307            ev.to_string(),
308            indoc! {"
309                // User accepted prediction:
310                --- a/123.txt
311                +++ b/abc.txt
312                @@ -1,2 +1,2 @@
313                -a
314                -b
315            "}
316        );
317    }
318}