1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 ops::Range,
5 path::{Path, PathBuf},
6 sync::Arc,
7};
8use uuid::Uuid;
9
10use crate::PredictEditsGitInfo;
11
12// TODO: snippet ordering within file / relative to excerpt
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PredictEditsRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 /// Within file
19 pub excerpt_range: Range<usize>,
20 /// Within `excerpt`
21 pub cursor_offset: usize,
22 /// Within `signatures`
23 pub excerpt_parent: Option<usize>,
24 pub signatures: Vec<Signature>,
25 pub referenced_declarations: Vec<ReferencedDeclaration>,
26 pub events: Vec<Event>,
27 #[serde(default)]
28 pub can_collect_data: bool,
29 #[serde(skip_serializing_if = "Vec::is_empty", default)]
30 pub diagnostic_groups: Vec<DiagnosticGroup>,
31 #[serde(skip_serializing_if = "is_default", default)]
32 pub diagnostic_groups_truncated: bool,
33 /// Info about the git repository state, only present when can_collect_data is true.
34 #[serde(skip_serializing_if = "Option::is_none", default)]
35 pub git_info: Option<PredictEditsGitInfo>,
36 // Only available to staff
37 #[serde(default)]
38 pub debug_info: bool,
39 #[serde(skip_serializing_if = "Option::is_none", default)]
40 pub prompt_max_bytes: Option<usize>,
41 #[serde(default)]
42 pub prompt_format: PromptFormat,
43}
44
45#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
46pub enum PromptFormat {
47 #[default]
48 MarkedExcerpt,
49 LabeledSections,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "event")]
54pub enum Event {
55 BufferChange {
56 path: Option<PathBuf>,
57 old_path: Option<PathBuf>,
58 diff: String,
59 predicted: bool,
60 },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct Signature {
65 pub text: String,
66 pub text_is_truncated: bool,
67 #[serde(skip_serializing_if = "Option::is_none", default)]
68 pub parent_index: Option<usize>,
69 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
70 /// file is implicitly the file that contains the descendant declaration or excerpt.
71 pub range: Range<usize>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ReferencedDeclaration {
76 pub path: Arc<Path>,
77 pub text: String,
78 pub text_is_truncated: bool,
79 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
80 pub range: Range<usize>,
81 /// Range within `text`
82 pub signature_range: Range<usize>,
83 /// Index within `signatures`.
84 #[serde(skip_serializing_if = "Option::is_none", default)]
85 pub parent_index: Option<usize>,
86 pub score_components: ScoreComponents,
87 pub signature_score: f32,
88 pub declaration_score: f32,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ScoreComponents {
93 pub is_same_file: bool,
94 pub is_referenced_nearby: bool,
95 pub is_referenced_in_breadcrumb: bool,
96 pub reference_count: usize,
97 pub same_file_declaration_count: usize,
98 pub declaration_count: usize,
99 pub reference_line_distance: u32,
100 pub declaration_line_distance: u32,
101 pub declaration_line_distance_rank: usize,
102 pub containing_range_vs_item_jaccard: f32,
103 pub containing_range_vs_signature_jaccard: f32,
104 pub adjacent_vs_item_jaccard: f32,
105 pub adjacent_vs_signature_jaccard: f32,
106 pub containing_range_vs_item_weighted_overlap: f32,
107 pub containing_range_vs_signature_weighted_overlap: f32,
108 pub adjacent_vs_item_weighted_overlap: f32,
109 pub adjacent_vs_signature_weighted_overlap: f32,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(transparent)]
114pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct PredictEditsResponse {
118 pub request_id: Uuid,
119 pub edits: Vec<Edit>,
120 pub debug_info: Option<DebugInfo>,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct DebugInfo {
125 pub prompt: String,
126 pub prompt_planning_time: Duration,
127 pub model_response: String,
128 pub inference_time: Duration,
129 pub parsing_time: Duration,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct Edit {
134 pub path: Arc<Path>,
135 pub range: Range<usize>,
136 pub content: String,
137}
138
139fn is_default<T: Default + PartialEq>(value: &T) -> bool {
140 *value == T::default()
141}