predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    ops::Range,
  5    path::{Path, PathBuf},
  6    sync::Arc,
  7};
  8use uuid::Uuid;
  9
 10use crate::PredictEditsGitInfo;
 11
 12// TODO: snippet ordering within file / relative to excerpt
 13
 14#[derive(Debug, Clone, Serialize, Deserialize)]
 15pub struct PredictEditsRequest {
 16    pub excerpt: String,
 17    pub excerpt_path: Arc<Path>,
 18    /// Within file
 19    pub excerpt_range: Range<usize>,
 20    /// Within `excerpt`
 21    pub cursor_offset: usize,
 22    /// Within `signatures`
 23    pub excerpt_parent: Option<usize>,
 24    pub signatures: Vec<Signature>,
 25    pub referenced_declarations: Vec<ReferencedDeclaration>,
 26    pub events: Vec<Event>,
 27    #[serde(default)]
 28    pub can_collect_data: bool,
 29    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 30    pub diagnostic_groups: Vec<DiagnosticGroup>,
 31    #[serde(skip_serializing_if = "is_default", default)]
 32    pub diagnostic_groups_truncated: bool,
 33    /// Info about the git repository state, only present when can_collect_data is true.
 34    #[serde(skip_serializing_if = "Option::is_none", default)]
 35    pub git_info: Option<PredictEditsGitInfo>,
 36    // Only available to staff
 37    #[serde(default)]
 38    pub debug_info: bool,
 39    #[serde(skip_serializing_if = "Option::is_none", default)]
 40    pub prompt_max_bytes: Option<usize>,
 41    #[serde(default)]
 42    pub prompt_format: PromptFormat,
 43}
 44
 45#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
 46pub enum PromptFormat {
 47    #[default]
 48    MarkedExcerpt,
 49    LabeledSections,
 50}
 51
 52#[derive(Debug, Clone, Serialize, Deserialize)]
 53#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
 54#[serde(tag = "event")]
 55pub enum Event {
 56    BufferChange {
 57        path: Option<PathBuf>,
 58        old_path: Option<PathBuf>,
 59        diff: String,
 60        predicted: bool,
 61    },
 62}
 63
 64#[derive(Debug, Clone, Serialize, Deserialize)]
 65pub struct Signature {
 66    pub text: String,
 67    pub text_is_truncated: bool,
 68    #[serde(skip_serializing_if = "Option::is_none", default)]
 69    pub parent_index: Option<usize>,
 70    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
 71    /// file is implicitly the file that contains the descendant declaration or excerpt.
 72    pub range: Range<usize>,
 73}
 74
 75#[derive(Debug, Clone, Serialize, Deserialize)]
 76pub struct ReferencedDeclaration {
 77    pub path: Arc<Path>,
 78    pub text: String,
 79    pub text_is_truncated: bool,
 80    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
 81    pub range: Range<usize>,
 82    /// Range within `text`
 83    pub signature_range: Range<usize>,
 84    /// Index within `signatures`.
 85    #[serde(skip_serializing_if = "Option::is_none", default)]
 86    pub parent_index: Option<usize>,
 87    pub score_components: ScoreComponents,
 88    pub signature_score: f32,
 89    pub declaration_score: f32,
 90}
 91
 92#[derive(Debug, Clone, Serialize, Deserialize)]
 93pub struct ScoreComponents {
 94    pub is_same_file: bool,
 95    pub is_referenced_nearby: bool,
 96    pub is_referenced_in_breadcrumb: bool,
 97    pub reference_count: usize,
 98    pub same_file_declaration_count: usize,
 99    pub declaration_count: usize,
100    pub reference_line_distance: u32,
101    pub declaration_line_distance: u32,
102    pub declaration_line_distance_rank: usize,
103    pub containing_range_vs_item_jaccard: f32,
104    pub containing_range_vs_signature_jaccard: f32,
105    pub adjacent_vs_item_jaccard: f32,
106    pub adjacent_vs_signature_jaccard: f32,
107    pub containing_range_vs_item_weighted_overlap: f32,
108    pub containing_range_vs_signature_weighted_overlap: f32,
109    pub adjacent_vs_item_weighted_overlap: f32,
110    pub adjacent_vs_signature_weighted_overlap: f32,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(transparent)]
115pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct PredictEditsResponse {
119    pub request_id: Uuid,
120    pub edits: Vec<Edit>,
121    pub debug_info: Option<DebugInfo>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct DebugInfo {
126    pub prompt: String,
127    pub prompt_planning_time: Duration,
128    pub model_response: String,
129    pub inference_time: Duration,
130    pub parsing_time: Duration,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct Edit {
135    pub path: Arc<Path>,
136    pub range: Range<usize>,
137    pub content: String,
138}
139
140fn is_default<T: Default + PartialEq>(value: &T) -> bool {
141    *value == T::default()
142}