1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 ops::Range,
5 path::{Path, PathBuf},
6 sync::Arc,
7};
8use strum::EnumIter;
9use uuid::Uuid;
10
11use crate::PredictEditsGitInfo;
12
13// TODO: snippet ordering within file / relative to excerpt
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PredictEditsRequest {
17 pub excerpt: String,
18 pub excerpt_path: Arc<Path>,
19 /// Within file
20 pub excerpt_range: Range<usize>,
21 /// Within `excerpt`
22 pub cursor_offset: usize,
23 /// Within `signatures`
24 pub excerpt_parent: Option<usize>,
25 pub signatures: Vec<Signature>,
26 pub referenced_declarations: Vec<ReferencedDeclaration>,
27 pub events: Vec<Event>,
28 #[serde(default)]
29 pub can_collect_data: bool,
30 #[serde(skip_serializing_if = "Vec::is_empty", default)]
31 pub diagnostic_groups: Vec<DiagnosticGroup>,
32 #[serde(skip_serializing_if = "is_default", default)]
33 pub diagnostic_groups_truncated: bool,
34 /// Info about the git repository state, only present when can_collect_data is true.
35 #[serde(skip_serializing_if = "Option::is_none", default)]
36 pub git_info: Option<PredictEditsGitInfo>,
37 // Only available to staff
38 #[serde(default)]
39 pub debug_info: bool,
40 #[serde(skip_serializing_if = "Option::is_none", default)]
41 pub prompt_max_bytes: Option<usize>,
42 #[serde(default)]
43 pub prompt_format: PromptFormat,
44}
45
46#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum PromptFormat {
48 #[default]
49 MarkedExcerpt,
50 LabeledSections,
51}
52
53impl PromptFormat {
54 pub fn iter() -> impl Iterator<Item = Self> {
55 <Self as strum::IntoEnumIterator>::iter()
56 }
57}
58
59impl std::fmt::Display for PromptFormat {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
63 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
70#[serde(tag = "event")]
71pub enum Event {
72 BufferChange {
73 path: Option<PathBuf>,
74 old_path: Option<PathBuf>,
75 diff: String,
76 predicted: bool,
77 },
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Signature {
82 pub text: String,
83 pub text_is_truncated: bool,
84 #[serde(skip_serializing_if = "Option::is_none", default)]
85 pub parent_index: Option<usize>,
86 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
87 /// file is implicitly the file that contains the descendant declaration or excerpt.
88 pub range: Range<usize>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ReferencedDeclaration {
93 pub path: Arc<Path>,
94 pub text: String,
95 pub text_is_truncated: bool,
96 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
97 pub range: Range<usize>,
98 /// Range within `text`
99 pub signature_range: Range<usize>,
100 /// Index within `signatures`.
101 #[serde(skip_serializing_if = "Option::is_none", default)]
102 pub parent_index: Option<usize>,
103 pub score_components: ScoreComponents,
104 pub signature_score: f32,
105 pub declaration_score: f32,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ScoreComponents {
110 pub is_same_file: bool,
111 pub is_referenced_nearby: bool,
112 pub is_referenced_in_breadcrumb: bool,
113 pub reference_count: usize,
114 pub same_file_declaration_count: usize,
115 pub declaration_count: usize,
116 pub reference_line_distance: u32,
117 pub declaration_line_distance: u32,
118 pub declaration_line_distance_rank: usize,
119 pub containing_range_vs_item_jaccard: f32,
120 pub containing_range_vs_signature_jaccard: f32,
121 pub adjacent_vs_item_jaccard: f32,
122 pub adjacent_vs_signature_jaccard: f32,
123 pub containing_range_vs_item_weighted_overlap: f32,
124 pub containing_range_vs_signature_weighted_overlap: f32,
125 pub adjacent_vs_item_weighted_overlap: f32,
126 pub adjacent_vs_signature_weighted_overlap: f32,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130#[serde(transparent)]
131pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct PredictEditsResponse {
135 pub request_id: Uuid,
136 pub edits: Vec<Edit>,
137 pub debug_info: Option<DebugInfo>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct DebugInfo {
142 pub prompt: String,
143 pub prompt_planning_time: Duration,
144 pub model_response: String,
145 pub inference_time: Duration,
146 pub parsing_time: Duration,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct Edit {
151 pub path: Arc<Path>,
152 pub range: Range<usize>,
153 pub content: String,
154}
155
156fn is_default<T: Default + PartialEq>(value: &T) -> bool {
157 *value == T::default()
158}