1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 ops::Range,
5 path::{Path, PathBuf},
6 sync::Arc,
7};
8use strum::EnumIter;
9use uuid::Uuid;
10
11use crate::PredictEditsGitInfo;
12
13// TODO: snippet ordering within file / relative to excerpt
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PredictEditsRequest {
17 pub excerpt: String,
18 pub excerpt_path: Arc<Path>,
19 /// Within file
20 pub excerpt_range: Range<usize>,
21 /// Within `excerpt`
22 pub cursor_offset: usize,
23 /// Within `signatures`
24 pub excerpt_parent: Option<usize>,
25 pub signatures: Vec<Signature>,
26 pub referenced_declarations: Vec<ReferencedDeclaration>,
27 pub events: Vec<Event>,
28 #[serde(default)]
29 pub can_collect_data: bool,
30 #[serde(skip_serializing_if = "Vec::is_empty", default)]
31 pub diagnostic_groups: Vec<DiagnosticGroup>,
32 #[serde(skip_serializing_if = "is_default", default)]
33 pub diagnostic_groups_truncated: bool,
34 /// Info about the git repository state, only present when can_collect_data is true.
35 #[serde(skip_serializing_if = "Option::is_none", default)]
36 pub git_info: Option<PredictEditsGitInfo>,
37 // Only available to staff
38 #[serde(default)]
39 pub debug_info: bool,
40 #[serde(skip_serializing_if = "Option::is_none", default)]
41 pub prompt_max_bytes: Option<usize>,
42 #[serde(default)]
43 pub prompt_format: PromptFormat,
44}
45
46#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum PromptFormat {
48 MarkedExcerpt,
49 LabeledSections,
50 /// Prompt format intended for use via zeta_cli
51 OnlySnippets,
52}
53
54impl PromptFormat {
55 pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
56}
57
58impl Default for PromptFormat {
59 fn default() -> Self {
60 Self::DEFAULT
61 }
62}
63
64impl PromptFormat {
65 pub fn iter() -> impl Iterator<Item = Self> {
66 <Self as strum::IntoEnumIterator>::iter()
67 }
68}
69
70impl std::fmt::Display for PromptFormat {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
74 PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
75 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
82#[serde(tag = "event")]
83pub enum Event {
84 BufferChange {
85 path: Option<PathBuf>,
86 old_path: Option<PathBuf>,
87 diff: String,
88 predicted: bool,
89 },
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Signature {
94 pub text: String,
95 pub text_is_truncated: bool,
96 #[serde(skip_serializing_if = "Option::is_none", default)]
97 pub parent_index: Option<usize>,
98 /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
99 /// file is implicitly the file that contains the descendant declaration or excerpt.
100 pub range: Range<usize>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ReferencedDeclaration {
105 pub path: Arc<Path>,
106 pub text: String,
107 pub text_is_truncated: bool,
108 /// Range of `text` within file, possibly truncated according to `text_is_truncated`
109 pub range: Range<usize>,
110 /// Range within `text`
111 pub signature_range: Range<usize>,
112 /// Index within `signatures`.
113 #[serde(skip_serializing_if = "Option::is_none", default)]
114 pub parent_index: Option<usize>,
115 pub score_components: DeclarationScoreComponents,
116 pub signature_score: f32,
117 pub declaration_score: f32,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct DeclarationScoreComponents {
122 pub is_same_file: bool,
123 pub is_referenced_nearby: bool,
124 pub is_referenced_in_breadcrumb: bool,
125 pub reference_count: usize,
126 pub same_file_declaration_count: usize,
127 pub declaration_count: usize,
128 pub reference_line_distance: u32,
129 pub declaration_line_distance: u32,
130 pub declaration_line_distance_rank: usize,
131 pub excerpt_vs_item_jaccard: f32,
132 pub excerpt_vs_signature_jaccard: f32,
133 pub adjacent_vs_item_jaccard: f32,
134 pub adjacent_vs_signature_jaccard: f32,
135 pub excerpt_vs_item_weighted_overlap: f32,
136 pub excerpt_vs_signature_weighted_overlap: f32,
137 pub adjacent_vs_item_weighted_overlap: f32,
138 pub adjacent_vs_signature_weighted_overlap: f32,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142#[serde(transparent)]
143pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct PredictEditsResponse {
147 pub request_id: Uuid,
148 pub edits: Vec<Edit>,
149 pub debug_info: Option<DebugInfo>,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct DebugInfo {
154 pub prompt: String,
155 pub prompt_planning_time: Duration,
156 pub model_response: String,
157 pub inference_time: Duration,
158 pub parsing_time: Duration,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct Edit {
163 pub path: Arc<Path>,
164 pub range: Range<usize>,
165 pub content: String,
166}
167
168fn is_default<T: Default + PartialEq>(value: &T) -> bool {
169 *value == T::default()
170}