predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    ops::{Add, Range, Sub},
  5    path::{Path, PathBuf},
  6    sync::Arc,
  7};
  8use strum::EnumIter;
  9use uuid::Uuid;
 10
 11use crate::PredictEditsGitInfo;
 12
 13// TODO: snippet ordering within file / relative to excerpt
 14
 15#[derive(Debug, Clone, Serialize, Deserialize)]
 16pub struct PredictEditsRequest {
 17    pub excerpt: String,
 18    pub excerpt_path: Arc<Path>,
 19    /// Within file
 20    pub excerpt_range: Range<usize>,
 21    pub excerpt_line_range: Range<Line>,
 22    pub cursor_point: Point,
 23    /// Within `signatures`
 24    pub excerpt_parent: Option<usize>,
 25    pub signatures: Vec<Signature>,
 26    pub referenced_declarations: Vec<ReferencedDeclaration>,
 27    pub events: Vec<Event>,
 28    #[serde(default)]
 29    pub can_collect_data: bool,
 30    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 31    pub diagnostic_groups: Vec<DiagnosticGroup>,
 32    #[serde(skip_serializing_if = "is_default", default)]
 33    pub diagnostic_groups_truncated: bool,
 34    /// Info about the git repository state, only present when can_collect_data is true.
 35    #[serde(skip_serializing_if = "Option::is_none", default)]
 36    pub git_info: Option<PredictEditsGitInfo>,
 37    // Only available to staff
 38    #[serde(default)]
 39    pub debug_info: bool,
 40    #[serde(skip_serializing_if = "Option::is_none", default)]
 41    pub prompt_max_bytes: Option<usize>,
 42    #[serde(default)]
 43    pub prompt_format: PromptFormat,
 44}
 45
 46#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum PromptFormat {
 48    MarkedExcerpt,
 49    LabeledSections,
 50    NumLinesUniDiff,
 51    /// Prompt format intended for use via zeta_cli
 52    OnlySnippets,
 53}
 54
 55impl PromptFormat {
 56    pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
 57}
 58
 59impl Default for PromptFormat {
 60    fn default() -> Self {
 61        Self::DEFAULT
 62    }
 63}
 64
 65impl PromptFormat {
 66    pub fn iter() -> impl Iterator<Item = Self> {
 67        <Self as strum::IntoEnumIterator>::iter()
 68    }
 69}
 70
 71impl std::fmt::Display for PromptFormat {
 72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 73        match self {
 74            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
 75            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
 76            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
 77            PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
 78        }
 79    }
 80}
 81
 82#[derive(Debug, Clone, Serialize, Deserialize)]
 83#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
 84#[serde(tag = "event")]
 85pub enum Event {
 86    BufferChange {
 87        path: Option<PathBuf>,
 88        old_path: Option<PathBuf>,
 89        diff: String,
 90        predicted: bool,
 91    },
 92}
 93
 94#[derive(Debug, Clone, Serialize, Deserialize)]
 95pub struct Signature {
 96    pub text: String,
 97    pub text_is_truncated: bool,
 98    #[serde(skip_serializing_if = "Option::is_none", default)]
 99    pub parent_index: Option<usize>,
100    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
101    /// file is implicitly the file that contains the descendant declaration or excerpt.
102    pub range: Range<Line>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ReferencedDeclaration {
107    pub path: Arc<Path>,
108    pub text: String,
109    pub text_is_truncated: bool,
110    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
111    pub range: Range<Line>,
112    /// Range within `text`
113    pub signature_range: Range<usize>,
114    /// Index within `signatures`.
115    #[serde(skip_serializing_if = "Option::is_none", default)]
116    pub parent_index: Option<usize>,
117    pub score_components: DeclarationScoreComponents,
118    pub signature_score: f32,
119    pub declaration_score: f32,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct DeclarationScoreComponents {
124    pub is_same_file: bool,
125    pub is_referenced_nearby: bool,
126    pub is_referenced_in_breadcrumb: bool,
127    pub reference_count: usize,
128    pub same_file_declaration_count: usize,
129    pub declaration_count: usize,
130    pub reference_line_distance: u32,
131    pub declaration_line_distance: u32,
132    pub excerpt_vs_item_jaccard: f32,
133    pub excerpt_vs_signature_jaccard: f32,
134    pub adjacent_vs_item_jaccard: f32,
135    pub adjacent_vs_signature_jaccard: f32,
136    pub excerpt_vs_item_weighted_overlap: f32,
137    pub excerpt_vs_signature_weighted_overlap: f32,
138    pub adjacent_vs_item_weighted_overlap: f32,
139    pub adjacent_vs_signature_weighted_overlap: f32,
140    pub path_import_match_count: usize,
141    pub wildcard_path_import_match_count: usize,
142    pub import_similarity: f32,
143    pub max_import_similarity: f32,
144    pub normalized_import_similarity: f32,
145    pub wildcard_import_similarity: f32,
146    pub normalized_wildcard_import_similarity: f32,
147    pub included_by_others: usize,
148    pub includes_others: usize,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(transparent)]
153pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct PredictEditsResponse {
157    pub request_id: Uuid,
158    pub edits: Vec<Edit>,
159    pub debug_info: Option<DebugInfo>,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct DebugInfo {
164    pub prompt: String,
165    pub prompt_planning_time: Duration,
166    pub model_response: String,
167    pub inference_time: Duration,
168    pub parsing_time: Duration,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct Edit {
173    pub path: Arc<Path>,
174    pub range: Range<Line>,
175    pub content: String,
176}
177
178fn is_default<T: Default + PartialEq>(value: &T) -> bool {
179    *value == T::default()
180}
181
182#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
183pub struct Point {
184    pub line: Line,
185    pub column: u32,
186}
187
188#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
189#[serde(transparent)]
190pub struct Line(pub u32);
191
192impl Add for Line {
193    type Output = Self;
194
195    fn add(self, rhs: Self) -> Self::Output {
196        Self(self.0 + rhs.0)
197    }
198}
199
200impl Sub for Line {
201    type Output = Self;
202
203    fn sub(self, rhs: Self) -> Self::Output {
204        Self(self.0 - rhs.0)
205    }
206}