1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{ops::Range, path::PathBuf};
4use uuid::Uuid;
5
6use crate::PredictEditsGitInfo;
7
8// TODO: snippet ordering within file / relative to excerpt
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PredictEditsRequest {
12 pub excerpt: String,
13 pub excerpt_path: PathBuf,
14 /// Within file
15 pub excerpt_range: Range<usize>,
16 /// Within `excerpt`
17 pub cursor_offset: usize,
18 /// Within `signatures`
19 pub excerpt_parent: Option<usize>,
20 pub signatures: Vec<Signature>,
21 pub referenced_declarations: Vec<ReferencedDeclaration>,
22 pub events: Vec<Event>,
23 #[serde(default)]
24 pub can_collect_data: bool,
25 #[serde(skip_serializing_if = "Vec::is_empty", default)]
26 pub diagnostic_groups: Vec<DiagnosticGroup>,
27 #[serde(skip_serializing_if = "is_default", default)]
28 pub diagnostic_groups_truncated: bool,
29 /// Info about the git repository state, only present when can_collect_data is true.
30 #[serde(skip_serializing_if = "Option::is_none", default)]
31 pub git_info: Option<PredictEditsGitInfo>,
32 // Only available to staff
33 #[serde(default)]
34 pub debug_info: bool,
35 pub prompt_max_bytes: Option<usize>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "event")]
40pub enum Event {
41 BufferChange {
42 path: Option<PathBuf>,
43 old_path: Option<PathBuf>,
44 diff: String,
45 predicted: bool,
46 },
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Signature {
51 pub text: String,
52 pub text_is_truncated: bool,
53 #[serde(skip_serializing_if = "Option::is_none", default)]
54 pub parent_index: Option<usize>,
55 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
56 /// file is implicitly the file that contains the descendant declaration or excerpt.
57 pub range: Range<usize>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ReferencedDeclaration {
62 pub path: PathBuf,
63 pub text: String,
64 pub text_is_truncated: bool,
65 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
66 pub range: Range<usize>,
67 /// Range within `text`
68 pub signature_range: Range<usize>,
69 /// Index within `signatures`.
70 #[serde(skip_serializing_if = "Option::is_none", default)]
71 pub parent_index: Option<usize>,
72 pub score_components: ScoreComponents,
73 pub signature_score: f32,
74 pub declaration_score: f32,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ScoreComponents {
79 pub is_same_file: bool,
80 pub is_referenced_nearby: bool,
81 pub is_referenced_in_breadcrumb: bool,
82 pub reference_count: usize,
83 pub same_file_declaration_count: usize,
84 pub declaration_count: usize,
85 pub reference_line_distance: u32,
86 pub declaration_line_distance: u32,
87 pub declaration_line_distance_rank: usize,
88 pub containing_range_vs_item_jaccard: f32,
89 pub containing_range_vs_signature_jaccard: f32,
90 pub adjacent_vs_item_jaccard: f32,
91 pub adjacent_vs_signature_jaccard: f32,
92 pub containing_range_vs_item_weighted_overlap: f32,
93 pub containing_range_vs_signature_weighted_overlap: f32,
94 pub adjacent_vs_item_weighted_overlap: f32,
95 pub adjacent_vs_signature_weighted_overlap: f32,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[serde(transparent)]
100pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct PredictEditsResponse {
104 pub request_id: Uuid,
105 pub edits: Vec<Edit>,
106 pub debug_info: Option<DebugInfo>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct DebugInfo {
111 pub prompt: String,
112 pub prompt_planning_time: Duration,
113 pub model_response: String,
114 pub inference_time: Duration,
115 pub parsing_time: Duration,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct Edit {
120 pub path: PathBuf,
121 pub range: Range<usize>,
122 pub content: String,
123}
124
125fn is_default<T: Default + PartialEq>(value: &T) -> bool {
126 *value == T::default()
127}