1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 ops::{Add, Range, Sub},
5 path::{Path, PathBuf},
6 sync::Arc,
7};
8use strum::EnumIter;
9use uuid::Uuid;
10
11use crate::PredictEditsGitInfo;
12
13// TODO: snippet ordering within file / relative to excerpt
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PredictEditsRequest {
17 pub excerpt: String,
18 pub excerpt_path: Arc<Path>,
19 /// Within file
20 pub excerpt_range: Range<usize>,
21 pub excerpt_line_range: Range<Line>,
22 pub cursor_point: Point,
23 /// Within `signatures`
24 pub excerpt_parent: Option<usize>,
25 pub signatures: Vec<Signature>,
26 pub referenced_declarations: Vec<ReferencedDeclaration>,
27 pub events: Vec<Event>,
28 #[serde(default)]
29 pub can_collect_data: bool,
30 #[serde(skip_serializing_if = "Vec::is_empty", default)]
31 pub diagnostic_groups: Vec<DiagnosticGroup>,
32 #[serde(skip_serializing_if = "is_default", default)]
33 pub diagnostic_groups_truncated: bool,
34 /// Info about the git repository state, only present when can_collect_data is true.
35 #[serde(skip_serializing_if = "Option::is_none", default)]
36 pub git_info: Option<PredictEditsGitInfo>,
37 // Only available to staff
38 #[serde(default)]
39 pub debug_info: bool,
40 #[serde(skip_serializing_if = "Option::is_none", default)]
41 pub prompt_max_bytes: Option<usize>,
42 #[serde(default)]
43 pub prompt_format: PromptFormat,
44}
45
46#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum PromptFormat {
48 MarkedExcerpt,
49 LabeledSections,
50 NumLinesUniDiff,
51 /// Prompt format intended for use via zeta_cli
52 OnlySnippets,
53}
54
55impl PromptFormat {
56 pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
57}
58
59impl Default for PromptFormat {
60 fn default() -> Self {
61 Self::DEFAULT
62 }
63}
64
65impl PromptFormat {
66 pub fn iter() -> impl Iterator<Item = Self> {
67 <Self as strum::IntoEnumIterator>::iter()
68 }
69}
70
71impl std::fmt::Display for PromptFormat {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
75 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
76 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
77 PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
78 }
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
84#[serde(tag = "event")]
85pub enum Event {
86 BufferChange {
87 path: Option<PathBuf>,
88 old_path: Option<PathBuf>,
89 diff: String,
90 predicted: bool,
91 },
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct Signature {
96 pub text: String,
97 pub text_is_truncated: bool,
98 #[serde(skip_serializing_if = "Option::is_none", default)]
99 pub parent_index: Option<usize>,
100 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
101 /// file is implicitly the file that contains the descendant declaration or excerpt.
102 pub range: Range<Line>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ReferencedDeclaration {
107 pub path: Arc<Path>,
108 pub text: String,
109 pub text_is_truncated: bool,
110 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
111 pub range: Range<Line>,
112 /// Range within `text`
113 pub signature_range: Range<usize>,
114 /// Index within `signatures`.
115 #[serde(skip_serializing_if = "Option::is_none", default)]
116 pub parent_index: Option<usize>,
117 pub score_components: DeclarationScoreComponents,
118 pub signature_score: f32,
119 pub declaration_score: f32,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct DeclarationScoreComponents {
124 pub is_same_file: bool,
125 pub is_referenced_nearby: bool,
126 pub is_referenced_in_breadcrumb: bool,
127 pub reference_count: usize,
128 pub same_file_declaration_count: usize,
129 pub declaration_count: usize,
130 pub reference_line_distance: u32,
131 pub declaration_line_distance: u32,
132 pub excerpt_vs_item_jaccard: f32,
133 pub excerpt_vs_signature_jaccard: f32,
134 pub adjacent_vs_item_jaccard: f32,
135 pub adjacent_vs_signature_jaccard: f32,
136 pub excerpt_vs_item_weighted_overlap: f32,
137 pub excerpt_vs_signature_weighted_overlap: f32,
138 pub adjacent_vs_item_weighted_overlap: f32,
139 pub adjacent_vs_signature_weighted_overlap: f32,
140 pub path_import_match_count: usize,
141 pub wildcard_path_import_match_count: usize,
142 pub import_similarity: f32,
143 pub max_import_similarity: f32,
144 pub normalized_import_similarity: f32,
145 pub wildcard_import_similarity: f32,
146 pub normalized_wildcard_import_similarity: f32,
147 pub included_by_others: usize,
148 pub includes_others: usize,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(transparent)]
153pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct PredictEditsResponse {
157 pub request_id: Uuid,
158 pub edits: Vec<Edit>,
159 pub debug_info: Option<DebugInfo>,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct DebugInfo {
164 pub prompt: String,
165 pub prompt_planning_time: Duration,
166 pub model_response: String,
167 pub inference_time: Duration,
168 pub parsing_time: Duration,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct Edit {
173 pub path: Arc<Path>,
174 pub range: Range<Line>,
175 pub content: String,
176}
177
178fn is_default<T: Default + PartialEq>(value: &T) -> bool {
179 *value == T::default()
180}
181
182#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
183pub struct Point {
184 pub line: Line,
185 pub column: u32,
186}
187
188#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
189#[serde(transparent)]
190pub struct Line(pub u32);
191
192impl Add for Line {
193 type Output = Self;
194
195 fn add(self, rhs: Self) -> Self::Output {
196 Self(self.0 + rhs.0)
197 }
198}
199
200impl Sub for Line {
201 type Output = Self;
202
203 fn sub(self, rhs: Self) -> Self::Output {
204 Self(self.0 - rhs.0)
205 }
206}