1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::{Display, Write as _},
5 ops::{Add, Range, Sub},
6 path::Path,
7 sync::Arc,
8};
9use strum::EnumIter;
10use uuid::Uuid;
11
12use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlanContextRetrievalRequest {
16 pub excerpt: String,
17 pub excerpt_path: Arc<Path>,
18 pub excerpt_line_range: Range<Line>,
19 pub cursor_file_max_row: Line,
20 pub events: Vec<Arc<Event>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PredictEditsRequest {
25 pub excerpt: String,
26 pub excerpt_path: Arc<Path>,
27 /// Within file
28 pub excerpt_range: Range<usize>,
29 pub excerpt_line_range: Range<Line>,
30 pub cursor_point: Point,
31 /// Within `signatures`
32 pub excerpt_parent: Option<usize>,
33 #[serde(skip_serializing_if = "Vec::is_empty", default)]
34 pub related_files: Vec<RelatedFile>,
35 pub events: Vec<Arc<Event>>,
36 #[serde(default)]
37 pub can_collect_data: bool,
38 /// Info about the git repository state, only present when can_collect_data is true.
39 #[serde(skip_serializing_if = "Option::is_none", default)]
40 pub git_info: Option<PredictEditsGitInfo>,
41 // Only available to staff
42 #[serde(default)]
43 pub debug_info: bool,
44 #[serde(skip_serializing_if = "Option::is_none", default)]
45 pub prompt_max_bytes: Option<usize>,
46 #[serde(default)]
47 pub prompt_format: PromptFormat,
48 #[serde(default)]
49 pub trigger: PredictEditsRequestTrigger,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct RelatedFile {
54 pub path: Arc<Path>,
55 pub max_row: Line,
56 pub excerpts: Vec<Excerpt>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct Excerpt {
61 pub start_line: Line,
62 pub text: Arc<str>,
63}
64
65#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
66pub enum PromptFormat {
67 /// XML old_tex/new_text
68 OldTextNewText,
69 /// Prompt format intended for use via edit_prediction_cli
70 OnlySnippets,
71 /// One-sentence instructions used in fine-tuned models
72 Minimal,
73 /// One-sentence instructions + FIM-like template
74 MinimalQwen,
75 /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
76 SeedCoder1120,
77}
78
79impl PromptFormat {
80 pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
81}
82
83impl Default for PromptFormat {
84 fn default() -> Self {
85 Self::DEFAULT
86 }
87}
88
89impl PromptFormat {
90 pub fn iter() -> impl Iterator<Item = Self> {
91 <Self as strum::IntoEnumIterator>::iter()
92 }
93}
94
95impl std::fmt::Display for PromptFormat {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 match self {
98 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
99 PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
100 PromptFormat::Minimal => write!(f, "Minimal"),
101 PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
102 PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
103 }
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
109#[serde(tag = "event")]
110pub enum Event {
111 BufferChange {
112 path: Arc<Path>,
113 old_path: Arc<Path>,
114 diff: String,
115 predicted: bool,
116 in_open_source_repo: bool,
117 },
118}
119
120impl Display for Event {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 Event::BufferChange {
124 path,
125 old_path,
126 diff,
127 predicted,
128 ..
129 } => {
130 if *predicted {
131 write!(
132 f,
133 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
134 DiffPathFmt(old_path),
135 DiffPathFmt(path)
136 )
137 } else {
138 write!(
139 f,
140 "--- a/{}\n+++ b/{}\n{diff}",
141 DiffPathFmt(old_path),
142 DiffPathFmt(path)
143 )
144 }
145 }
146 }
147 }
148}
149
150/// always format the Path as a unix path with `/` as the path sep in Diffs
151pub struct DiffPathFmt<'a>(pub &'a Path);
152
153impl<'a> std::fmt::Display for DiffPathFmt<'a> {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 let mut is_first = true;
156 for component in self.0.components() {
157 if !is_first {
158 f.write_char('/')?;
159 } else {
160 is_first = false;
161 }
162 write!(f, "{}", component.as_os_str().display())?;
163 }
164 Ok(())
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct PredictEditsResponse {
170 pub request_id: Uuid,
171 pub edits: Vec<Edit>,
172 pub debug_info: Option<DebugInfo>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct DebugInfo {
177 pub prompt: String,
178 pub prompt_planning_time: Duration,
179 pub model_response: String,
180 pub inference_time: Duration,
181 pub parsing_time: Duration,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct Edit {
186 pub path: Arc<Path>,
187 pub range: Range<Line>,
188 pub content: String,
189}
190
191#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
192pub struct Point {
193 pub line: Line,
194 pub column: u32,
195}
196
197#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
198#[serde(transparent)]
199pub struct Line(pub u32);
200
201impl Add for Line {
202 type Output = Self;
203
204 fn add(self, rhs: Self) -> Self::Output {
205 Self(self.0 + rhs.0)
206 }
207}
208
209impl Sub for Line {
210 type Output = Self;
211
212 fn sub(self, rhs: Self) -> Self::Output {
213 Self(self.0 - rhs.0)
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use indoc::indoc;
221 use pretty_assertions::assert_eq;
222
223 #[test]
224 fn test_event_display() {
225 let ev = Event::BufferChange {
226 path: Path::new("untitled").into(),
227 old_path: Path::new("untitled").into(),
228 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
229 predicted: false,
230 in_open_source_repo: true,
231 };
232 assert_eq!(
233 ev.to_string(),
234 indoc! {"
235 --- a/untitled
236 +++ b/untitled
237 @@ -1,2 +1,2 @@
238 -a
239 -b
240 "}
241 );
242
243 let ev = Event::BufferChange {
244 path: Path::new("foo/bar.txt").into(),
245 old_path: Path::new("foo/bar.txt").into(),
246 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
247 predicted: false,
248 in_open_source_repo: true,
249 };
250 assert_eq!(
251 ev.to_string(),
252 indoc! {"
253 --- a/foo/bar.txt
254 +++ b/foo/bar.txt
255 @@ -1,2 +1,2 @@
256 -a
257 -b
258 "}
259 );
260
261 let ev = Event::BufferChange {
262 path: Path::new("abc.txt").into(),
263 old_path: Path::new("123.txt").into(),
264 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
265 predicted: false,
266 in_open_source_repo: true,
267 };
268 assert_eq!(
269 ev.to_string(),
270 indoc! {"
271 --- a/123.txt
272 +++ b/abc.txt
273 @@ -1,2 +1,2 @@
274 -a
275 -b
276 "}
277 );
278
279 let ev = Event::BufferChange {
280 path: Path::new("abc.txt").into(),
281 old_path: Path::new("123.txt").into(),
282 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
283 predicted: true,
284 in_open_source_repo: true,
285 };
286 assert_eq!(
287 ev.to_string(),
288 indoc! {"
289 // User accepted prediction:
290 --- a/123.txt
291 +++ b/abc.txt
292 @@ -1,2 +1,2 @@
293 -a
294 -b
295 "}
296 );
297 }
298}