predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::{Display, Write as _},
  5    ops::{Add, Range, Sub},
  6    path::{Path, PathBuf},
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::PredictEditsGitInfo;
 13
 14#[derive(Debug, Clone, Serialize, Deserialize)]
 15pub struct PlanContextRetrievalRequest {
 16    pub excerpt: String,
 17    pub excerpt_path: Arc<Path>,
 18    pub excerpt_line_range: Range<Line>,
 19    pub cursor_file_max_row: Line,
 20    pub events: Vec<Event>,
 21}
 22
 23#[derive(Debug, Clone, Serialize, Deserialize)]
 24pub struct PredictEditsRequest {
 25    pub excerpt: String,
 26    pub excerpt_path: Arc<Path>,
 27    /// Within file
 28    pub excerpt_range: Range<usize>,
 29    pub excerpt_line_range: Range<Line>,
 30    pub cursor_point: Point,
 31    /// Within `signatures`
 32    pub excerpt_parent: Option<usize>,
 33    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 34    pub included_files: Vec<IncludedFile>,
 35    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 36    pub signatures: Vec<Signature>,
 37    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 38    pub referenced_declarations: Vec<ReferencedDeclaration>,
 39    pub events: Vec<Event>,
 40    #[serde(default)]
 41    pub can_collect_data: bool,
 42    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 43    pub diagnostic_groups: Vec<DiagnosticGroup>,
 44    #[serde(skip_serializing_if = "is_default", default)]
 45    pub diagnostic_groups_truncated: bool,
 46    /// Info about the git repository state, only present when can_collect_data is true.
 47    #[serde(skip_serializing_if = "Option::is_none", default)]
 48    pub git_info: Option<PredictEditsGitInfo>,
 49    // Only available to staff
 50    #[serde(default)]
 51    pub debug_info: bool,
 52    #[serde(skip_serializing_if = "Option::is_none", default)]
 53    pub prompt_max_bytes: Option<usize>,
 54    #[serde(default)]
 55    pub prompt_format: PromptFormat,
 56}
 57
 58#[derive(Debug, Clone, Serialize, Deserialize)]
 59pub struct IncludedFile {
 60    pub path: Arc<Path>,
 61    pub max_row: Line,
 62    pub excerpts: Vec<Excerpt>,
 63}
 64
 65#[derive(Debug, Clone, Serialize, Deserialize)]
 66pub struct Excerpt {
 67    pub start_line: Line,
 68    pub text: Arc<str>,
 69}
 70
 71#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 72pub enum PromptFormat {
 73    MarkedExcerpt,
 74    LabeledSections,
 75    NumLinesUniDiff,
 76    /// Prompt format intended for use via zeta_cli
 77    OnlySnippets,
 78}
 79
 80impl PromptFormat {
 81    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 82}
 83
 84impl Default for PromptFormat {
 85    fn default() -> Self {
 86        Self::DEFAULT
 87    }
 88}
 89
 90impl PromptFormat {
 91    pub fn iter() -> impl Iterator<Item = Self> {
 92        <Self as strum::IntoEnumIterator>::iter()
 93    }
 94}
 95
 96impl std::fmt::Display for PromptFormat {
 97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 98        match self {
 99            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
100            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
101            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
102            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
103        }
104    }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
109#[serde(tag = "event")]
110pub enum Event {
111    BufferChange {
112        path: Option<PathBuf>,
113        old_path: Option<PathBuf>,
114        diff: String,
115        predicted: bool,
116    },
117}
118
119impl Display for Event {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            Event::BufferChange {
123                path,
124                old_path,
125                diff,
126                predicted,
127            } => {
128                let new_path = path.as_deref().unwrap_or(Path::new("untitled"));
129                let old_path = old_path.as_deref().unwrap_or(new_path);
130
131                if *predicted {
132                    write!(
133                        f,
134                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
135                        DiffPathFmt(old_path),
136                        DiffPathFmt(new_path)
137                    )
138                } else {
139                    write!(
140                        f,
141                        "--- a/{}\n+++ b/{}\n{diff}",
142                        DiffPathFmt(old_path),
143                        DiffPathFmt(new_path)
144                    )
145                }
146            }
147        }
148    }
149}
150
151/// always format the Path as a unix path with `/` as the path sep in Diffs
152pub struct DiffPathFmt<'a>(pub &'a Path);
153
154impl<'a> std::fmt::Display for DiffPathFmt<'a> {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        let mut is_first = true;
157        for component in self.0.components() {
158            if !is_first {
159                f.write_char('/')?;
160            } else {
161                is_first = false;
162            }
163            write!(f, "{}", component.as_os_str().display())?;
164        }
165        Ok(())
166    }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct Signature {
171    pub text: String,
172    pub text_is_truncated: bool,
173    #[serde(skip_serializing_if = "Option::is_none", default)]
174    pub parent_index: Option<usize>,
175    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
176    /// file is implicitly the file that contains the descendant declaration or excerpt.
177    pub range: Range<Line>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ReferencedDeclaration {
182    pub path: Arc<Path>,
183    pub text: String,
184    pub text_is_truncated: bool,
185    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
186    pub range: Range<Line>,
187    /// Range within `text`
188    pub signature_range: Range<usize>,
189    /// Index within `signatures`.
190    #[serde(skip_serializing_if = "Option::is_none", default)]
191    pub parent_index: Option<usize>,
192    pub score_components: DeclarationScoreComponents,
193    pub signature_score: f32,
194    pub declaration_score: f32,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct DeclarationScoreComponents {
199    pub is_same_file: bool,
200    pub is_referenced_nearby: bool,
201    pub is_referenced_in_breadcrumb: bool,
202    pub reference_count: usize,
203    pub same_file_declaration_count: usize,
204    pub declaration_count: usize,
205    pub reference_line_distance: u32,
206    pub declaration_line_distance: u32,
207    pub excerpt_vs_item_jaccard: f32,
208    pub excerpt_vs_signature_jaccard: f32,
209    pub adjacent_vs_item_jaccard: f32,
210    pub adjacent_vs_signature_jaccard: f32,
211    pub excerpt_vs_item_weighted_overlap: f32,
212    pub excerpt_vs_signature_weighted_overlap: f32,
213    pub adjacent_vs_item_weighted_overlap: f32,
214    pub adjacent_vs_signature_weighted_overlap: f32,
215    pub path_import_match_count: usize,
216    pub wildcard_path_import_match_count: usize,
217    pub import_similarity: f32,
218    pub max_import_similarity: f32,
219    pub normalized_import_similarity: f32,
220    pub wildcard_import_similarity: f32,
221    pub normalized_wildcard_import_similarity: f32,
222    pub included_by_others: usize,
223    pub includes_others: usize,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
227#[serde(transparent)]
228pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct PredictEditsResponse {
232    pub request_id: Uuid,
233    pub edits: Vec<Edit>,
234    pub debug_info: Option<DebugInfo>,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct DebugInfo {
239    pub prompt: String,
240    pub prompt_planning_time: Duration,
241    pub model_response: String,
242    pub inference_time: Duration,
243    pub parsing_time: Duration,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct Edit {
248    pub path: Arc<Path>,
249    pub range: Range<Line>,
250    pub content: String,
251}
252
253fn is_default<T: Default + PartialEq>(value: &T) -> bool {
254    *value == T::default()
255}
256
257#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
258pub struct Point {
259    pub line: Line,
260    pub column: u32,
261}
262
263#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
264#[serde(transparent)]
265pub struct Line(pub u32);
266
267impl Add for Line {
268    type Output = Self;
269
270    fn add(self, rhs: Self) -> Self::Output {
271        Self(self.0 + rhs.0)
272    }
273}
274
275impl Sub for Line {
276    type Output = Self;
277
278    fn sub(self, rhs: Self) -> Self::Output {
279        Self(self.0 - rhs.0)
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use indoc::indoc;
287    use pretty_assertions::assert_eq;
288
289    #[test]
290    fn test_event_display() {
291        let ev = Event::BufferChange {
292            path: None,
293            old_path: None,
294            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
295            predicted: false,
296        };
297        assert_eq!(
298            ev.to_string(),
299            indoc! {"
300                --- a/untitled
301                +++ b/untitled
302                @@ -1,2 +1,2 @@
303                -a
304                -b
305            "}
306        );
307
308        let ev = Event::BufferChange {
309            path: Some(PathBuf::from("foo/bar.txt")),
310            old_path: Some(PathBuf::from("foo/bar.txt")),
311            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
312            predicted: false,
313        };
314        assert_eq!(
315            ev.to_string(),
316            indoc! {"
317                --- a/foo/bar.txt
318                +++ b/foo/bar.txt
319                @@ -1,2 +1,2 @@
320                -a
321                -b
322            "}
323        );
324
325        let ev = Event::BufferChange {
326            path: Some(PathBuf::from("abc.txt")),
327            old_path: Some(PathBuf::from("123.txt")),
328            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
329            predicted: false,
330        };
331        assert_eq!(
332            ev.to_string(),
333            indoc! {"
334                --- a/123.txt
335                +++ b/abc.txt
336                @@ -1,2 +1,2 @@
337                -a
338                -b
339            "}
340        );
341
342        let ev = Event::BufferChange {
343            path: Some(PathBuf::from("abc.txt")),
344            old_path: Some(PathBuf::from("123.txt")),
345            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
346            predicted: true,
347        };
348        assert_eq!(
349            ev.to_string(),
350            indoc! {"
351                // User accepted prediction:
352                --- a/123.txt
353                +++ b/abc.txt
354                @@ -1,2 +1,2 @@
355                -a
356                -b
357            "}
358        );
359    }
360}