1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 ops::Range,
5 path::{Path, PathBuf},
6 sync::Arc,
7};
8use strum::EnumIter;
9use uuid::Uuid;
10
11use crate::PredictEditsGitInfo;
12
13// TODO: snippet ordering within file / relative to excerpt
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PredictEditsRequest {
17 pub excerpt: String,
18 pub excerpt_path: Arc<Path>,
19 /// Within file
20 pub excerpt_range: Range<usize>,
21 /// Within `excerpt`
22 pub cursor_offset: usize,
23 /// Within `signatures`
24 pub excerpt_parent: Option<usize>,
25 pub signatures: Vec<Signature>,
26 pub referenced_declarations: Vec<ReferencedDeclaration>,
27 pub events: Vec<Event>,
28 #[serde(default)]
29 pub can_collect_data: bool,
30 #[serde(skip_serializing_if = "Vec::is_empty", default)]
31 pub diagnostic_groups: Vec<DiagnosticGroup>,
32 #[serde(skip_serializing_if = "is_default", default)]
33 pub diagnostic_groups_truncated: bool,
34 /// Info about the git repository state, only present when can_collect_data is true.
35 #[serde(skip_serializing_if = "Option::is_none", default)]
36 pub git_info: Option<PredictEditsGitInfo>,
37 // Only available to staff
38 #[serde(default)]
39 pub debug_info: bool,
40 #[serde(skip_serializing_if = "Option::is_none", default)]
41 pub prompt_max_bytes: Option<usize>,
42 #[serde(default)]
43 pub prompt_format: PromptFormat,
44}
45
46#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum PromptFormat {
48 #[default]
49 MarkedExcerpt,
50 LabeledSections,
51 /// Prompt format intended for use via zeta_cli
52 OnlySnippets,
53}
54
55impl PromptFormat {
56 pub fn iter() -> impl Iterator<Item = Self> {
57 <Self as strum::IntoEnumIterator>::iter()
58 }
59}
60
61impl std::fmt::Display for PromptFormat {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
65 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
66 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
73#[serde(tag = "event")]
74pub enum Event {
75 BufferChange {
76 path: Option<PathBuf>,
77 old_path: Option<PathBuf>,
78 diff: String,
79 predicted: bool,
80 },
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Signature {
85 pub text: String,
86 pub text_is_truncated: bool,
87 #[serde(skip_serializing_if = "Option::is_none", default)]
88 pub parent_index: Option<usize>,
89 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
90 /// file is implicitly the file that contains the descendant declaration or excerpt.
91 pub range: Range<usize>,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ReferencedDeclaration {
96 pub path: Arc<Path>,
97 pub text: String,
98 pub text_is_truncated: bool,
99 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
100 pub range: Range<usize>,
101 /// Range within `text`
102 pub signature_range: Range<usize>,
103 /// Index within `signatures`.
104 #[serde(skip_serializing_if = "Option::is_none", default)]
105 pub parent_index: Option<usize>,
106 pub score_components: DeclarationScoreComponents,
107 pub signature_score: f32,
108 pub declaration_score: f32,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct DeclarationScoreComponents {
113 pub is_same_file: bool,
114 pub is_referenced_nearby: bool,
115 pub is_referenced_in_breadcrumb: bool,
116 pub reference_count: usize,
117 pub same_file_declaration_count: usize,
118 pub declaration_count: usize,
119 pub reference_line_distance: u32,
120 pub declaration_line_distance: u32,
121 pub declaration_line_distance_rank: usize,
122 pub excerpt_vs_item_jaccard: f32,
123 pub excerpt_vs_signature_jaccard: f32,
124 pub adjacent_vs_item_jaccard: f32,
125 pub adjacent_vs_signature_jaccard: f32,
126 pub excerpt_vs_item_weighted_overlap: f32,
127 pub excerpt_vs_signature_weighted_overlap: f32,
128 pub adjacent_vs_item_weighted_overlap: f32,
129 pub adjacent_vs_signature_weighted_overlap: f32,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133#[serde(transparent)]
134pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct PredictEditsResponse {
138 pub request_id: Uuid,
139 pub edits: Vec<Edit>,
140 pub debug_info: Option<DebugInfo>,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct DebugInfo {
145 pub prompt: String,
146 pub prompt_planning_time: Duration,
147 pub model_response: String,
148 pub inference_time: Duration,
149 pub parsing_time: Duration,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct Edit {
154 pub path: Arc<Path>,
155 pub range: Range<usize>,
156 pub content: String,
157}
158
159fn is_default<T: Default + PartialEq>(value: &T) -> bool {
160 *value == T::default()
161}