predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    ops::Range,
  5    path::{Path, PathBuf},
  6    sync::Arc,
  7};
  8use strum::EnumIter;
  9use uuid::Uuid;
 10
 11use crate::PredictEditsGitInfo;
 12
 13// TODO: snippet ordering within file / relative to excerpt
 14
 15#[derive(Debug, Clone, Serialize, Deserialize)]
 16pub struct PredictEditsRequest {
 17    pub excerpt: String,
 18    pub excerpt_path: Arc<Path>,
 19    /// Within file
 20    pub excerpt_range: Range<usize>,
 21    /// Within `excerpt`
 22    pub cursor_offset: usize,
 23    /// Within `signatures`
 24    pub excerpt_parent: Option<usize>,
 25    pub signatures: Vec<Signature>,
 26    pub referenced_declarations: Vec<ReferencedDeclaration>,
 27    pub events: Vec<Event>,
 28    #[serde(default)]
 29    pub can_collect_data: bool,
 30    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 31    pub diagnostic_groups: Vec<DiagnosticGroup>,
 32    #[serde(skip_serializing_if = "is_default", default)]
 33    pub diagnostic_groups_truncated: bool,
 34    /// Info about the git repository state, only present when can_collect_data is true.
 35    #[serde(skip_serializing_if = "Option::is_none", default)]
 36    pub git_info: Option<PredictEditsGitInfo>,
 37    // Only available to staff
 38    #[serde(default)]
 39    pub debug_info: bool,
 40    #[serde(skip_serializing_if = "Option::is_none", default)]
 41    pub prompt_max_bytes: Option<usize>,
 42    #[serde(default)]
 43    pub prompt_format: PromptFormat,
 44}
 45
 46#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum PromptFormat {
 48    #[default]
 49    MarkedExcerpt,
 50    LabeledSections,
 51}
 52
 53impl PromptFormat {
 54    pub fn iter() -> impl Iterator<Item = Self> {
 55        <Self as strum::IntoEnumIterator>::iter()
 56    }
 57}
 58
 59impl std::fmt::Display for PromptFormat {
 60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 61        match self {
 62            PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
 63            PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
 64        }
 65    }
 66}
 67
 68#[derive(Debug, Clone, Serialize, Deserialize)]
 69#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
 70#[serde(tag = "event")]
 71pub enum Event {
 72    BufferChange {
 73        path: Option<PathBuf>,
 74        old_path: Option<PathBuf>,
 75        diff: String,
 76        predicted: bool,
 77    },
 78}
 79
 80#[derive(Debug, Clone, Serialize, Deserialize)]
 81pub struct Signature {
 82    pub text: String,
 83    pub text_is_truncated: bool,
 84    #[serde(skip_serializing_if = "Option::is_none", default)]
 85    pub parent_index: Option<usize>,
 86    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
 87    /// file is implicitly the file that contains the descendant declaration or excerpt.
 88    pub range: Range<usize>,
 89}
 90
 91#[derive(Debug, Clone, Serialize, Deserialize)]
 92pub struct ReferencedDeclaration {
 93    pub path: Arc<Path>,
 94    pub text: String,
 95    pub text_is_truncated: bool,
 96    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
 97    pub range: Range<usize>,
 98    /// Range within `text`
 99    pub signature_range: Range<usize>,
100    /// Index within `signatures`.
101    #[serde(skip_serializing_if = "Option::is_none", default)]
102    pub parent_index: Option<usize>,
103    pub score_components: ScoreComponents,
104    pub signature_score: f32,
105    pub declaration_score: f32,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ScoreComponents {
110    pub is_same_file: bool,
111    pub is_referenced_nearby: bool,
112    pub is_referenced_in_breadcrumb: bool,
113    pub reference_count: usize,
114    pub same_file_declaration_count: usize,
115    pub declaration_count: usize,
116    pub reference_line_distance: u32,
117    pub declaration_line_distance: u32,
118    pub declaration_line_distance_rank: usize,
119    pub containing_range_vs_item_jaccard: f32,
120    pub containing_range_vs_signature_jaccard: f32,
121    pub adjacent_vs_item_jaccard: f32,
122    pub adjacent_vs_signature_jaccard: f32,
123    pub containing_range_vs_item_weighted_overlap: f32,
124    pub containing_range_vs_signature_weighted_overlap: f32,
125    pub adjacent_vs_item_weighted_overlap: f32,
126    pub adjacent_vs_signature_weighted_overlap: f32,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130#[serde(transparent)]
131pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct PredictEditsResponse {
135    pub request_id: Uuid,
136    pub edits: Vec<Edit>,
137    pub debug_info: Option<DebugInfo>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct DebugInfo {
142    pub prompt: String,
143    pub prompt_planning_time: Duration,
144    pub model_response: String,
145    pub inference_time: Duration,
146    pub parsing_time: Duration,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct Edit {
151    pub path: Arc<Path>,
152    pub range: Range<usize>,
153    pub content: String,
154}
155
156fn is_default<T: Default + PartialEq>(value: &T) -> bool {
157    *value == T::default()
158}