1use chrono::Duration;
2use serde::{Deserialize, Serialize};
3use std::{
4 borrow::Cow,
5 fmt::{Display, Write as _},
6 ops::{Add, Range, Sub},
7 path::Path,
8 sync::Arc,
9};
10use strum::EnumIter;
11use uuid::Uuid;
12
13use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PlanContextRetrievalRequest {
17 pub excerpt: String,
18 pub excerpt_path: Arc<Path>,
19 pub excerpt_line_range: Range<Line>,
20 pub cursor_file_max_row: Line,
21 pub events: Vec<Arc<Event>>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PredictEditsRequest {
26 pub excerpt: String,
27 pub excerpt_path: Arc<Path>,
28 /// Within file
29 pub excerpt_range: Range<usize>,
30 pub excerpt_line_range: Range<Line>,
31 pub cursor_point: Point,
32 /// Within `signatures`
33 pub excerpt_parent: Option<usize>,
34 #[serde(skip_serializing_if = "Vec::is_empty", default)]
35 pub related_files: Vec<RelatedFile>,
36 pub events: Vec<Arc<Event>>,
37 #[serde(default)]
38 pub can_collect_data: bool,
39 /// Info about the git repository state, only present when can_collect_data is true.
40 #[serde(skip_serializing_if = "Option::is_none", default)]
41 pub git_info: Option<PredictEditsGitInfo>,
42 // Only available to staff
43 #[serde(default)]
44 pub debug_info: bool,
45 #[serde(skip_serializing_if = "Option::is_none", default)]
46 pub prompt_max_bytes: Option<usize>,
47 #[serde(default)]
48 pub prompt_format: PromptFormat,
49 #[serde(default)]
50 pub trigger: PredictEditsRequestTrigger,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct RelatedFile {
55 pub path: Arc<Path>,
56 pub max_row: Line,
57 pub excerpts: Vec<Excerpt>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct Excerpt {
62 pub start_line: Line,
63 pub text: Arc<str>,
64}
65
66#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
67pub enum PromptFormat {
68 /// XML old_tex/new_text
69 OldTextNewText,
70 /// Prompt format intended for use via edit_prediction_cli
71 OnlySnippets,
72 /// One-sentence instructions used in fine-tuned models
73 Minimal,
74 /// One-sentence instructions + FIM-like template
75 MinimalQwen,
76 /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
77 SeedCoder1120,
78}
79
80impl PromptFormat {
81 pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
82}
83
84impl Default for PromptFormat {
85 fn default() -> Self {
86 Self::DEFAULT
87 }
88}
89
90impl PromptFormat {
91 pub fn iter() -> impl Iterator<Item = Self> {
92 <Self as strum::IntoEnumIterator>::iter()
93 }
94}
95
96impl std::fmt::Display for PromptFormat {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 match self {
99 PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
100 PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
101 PromptFormat::Minimal => write!(f, "Minimal"),
102 PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
103 PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
104 }
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
110#[serde(tag = "event")]
111pub enum Event {
112 BufferChange {
113 path: Arc<Path>,
114 old_path: Arc<Path>,
115 diff: String,
116 predicted: bool,
117 in_open_source_repo: bool,
118 },
119}
120
121impl Display for Event {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 Event::BufferChange {
125 path,
126 old_path,
127 diff,
128 predicted,
129 ..
130 } => {
131 if *predicted {
132 write!(
133 f,
134 "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
135 DiffPathFmt(old_path),
136 DiffPathFmt(path)
137 )
138 } else {
139 write!(
140 f,
141 "--- a/{}\n+++ b/{}\n{diff}",
142 DiffPathFmt(old_path),
143 DiffPathFmt(path)
144 )
145 }
146 }
147 }
148 }
149}
150
151/// always format the Path as a unix path with `/` as the path sep in Diffs
152pub struct DiffPathFmt<'a>(pub &'a Path);
153
154impl<'a> std::fmt::Display for DiffPathFmt<'a> {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 let mut is_first = true;
157 for component in self.0.components() {
158 if !is_first {
159 f.write_char('/')?;
160 } else {
161 is_first = false;
162 }
163 write!(f, "{}", component.as_os_str().display())?;
164 }
165 Ok(())
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct PredictEditsResponse {
171 pub request_id: Uuid,
172 pub edits: Vec<Edit>,
173 pub debug_info: Option<DebugInfo>,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct DebugInfo {
178 pub prompt: String,
179 pub prompt_planning_time: Duration,
180 pub model_response: String,
181 pub inference_time: Duration,
182 pub parsing_time: Duration,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct Edit {
187 pub path: Arc<Path>,
188 pub range: Range<Line>,
189 pub content: String,
190}
191
192#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
193pub struct Point {
194 pub line: Line,
195 pub column: u32,
196}
197
198#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
199#[serde(transparent)]
200pub struct Line(pub u32);
201
202impl Add for Line {
203 type Output = Self;
204
205 fn add(self, rhs: Self) -> Self::Output {
206 Self(self.0 + rhs.0)
207 }
208}
209
210impl Sub for Line {
211 type Output = Self;
212
213 fn sub(self, rhs: Self) -> Self::Output {
214 Self(self.0 - rhs.0)
215 }
216}
217
218#[derive(Debug, Deserialize, Serialize)]
219pub struct RawCompletionRequest {
220 pub model: String,
221 pub prompt: String,
222 #[serde(skip_serializing_if = "Option::is_none")]
223 pub max_tokens: Option<u32>,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 pub temperature: Option<f32>,
226 pub stop: Vec<Cow<'static, str>>,
227}
228
229#[derive(Debug, Deserialize, Serialize)]
230pub struct RawCompletionResponse {
231 pub id: String,
232 pub object: String,
233 pub created: u64,
234 pub model: String,
235 pub choices: Vec<RawCompletionChoice>,
236 pub usage: RawCompletionUsage,
237}
238
239#[derive(Debug, Deserialize, Serialize)]
240pub struct RawCompletionChoice {
241 pub text: String,
242 pub finish_reason: Option<String>,
243}
244
245#[derive(Debug, Serialize, Deserialize)]
246pub struct RawCompletionUsage {
247 pub prompt_tokens: u32,
248 pub completion_tokens: u32,
249 pub total_tokens: u32,
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use indoc::indoc;
256 use pretty_assertions::assert_eq;
257
258 #[test]
259 fn test_event_display() {
260 let ev = Event::BufferChange {
261 path: Path::new("untitled").into(),
262 old_path: Path::new("untitled").into(),
263 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
264 predicted: false,
265 in_open_source_repo: true,
266 };
267 assert_eq!(
268 ev.to_string(),
269 indoc! {"
270 --- a/untitled
271 +++ b/untitled
272 @@ -1,2 +1,2 @@
273 -a
274 -b
275 "}
276 );
277
278 let ev = Event::BufferChange {
279 path: Path::new("foo/bar.txt").into(),
280 old_path: Path::new("foo/bar.txt").into(),
281 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
282 predicted: false,
283 in_open_source_repo: true,
284 };
285 assert_eq!(
286 ev.to_string(),
287 indoc! {"
288 --- a/foo/bar.txt
289 +++ b/foo/bar.txt
290 @@ -1,2 +1,2 @@
291 -a
292 -b
293 "}
294 );
295
296 let ev = Event::BufferChange {
297 path: Path::new("abc.txt").into(),
298 old_path: Path::new("123.txt").into(),
299 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
300 predicted: false,
301 in_open_source_repo: true,
302 };
303 assert_eq!(
304 ev.to_string(),
305 indoc! {"
306 --- a/123.txt
307 +++ b/abc.txt
308 @@ -1,2 +1,2 @@
309 -a
310 -b
311 "}
312 );
313
314 let ev = Event::BufferChange {
315 path: Path::new("abc.txt").into(),
316 old_path: Path::new("123.txt").into(),
317 diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
318 predicted: true,
319 in_open_source_repo: true,
320 };
321 assert_eq!(
322 ev.to_string(),
323 indoc! {"
324 // User accepted prediction:
325 --- a/123.txt
326 +++ b/abc.txt
327 @@ -1,2 +1,2 @@
328 -a
329 -b
330 "}
331 );
332 }
333}