predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    fmt::{Display, Write as _},
  5    ops::{Add, Range, Sub},
  6    path::Path,
  7    sync::Arc,
  8};
  9use strum::EnumIter;
 10use uuid::Uuid;
 11
 12use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger};
 13
 14#[derive(Debug, Clone, Serialize, Deserialize)]
 15pub struct PlanContextRetrievalRequest {
 16    pub excerpt: String,
 17    pub excerpt_path: Arc<Path>,
 18    pub excerpt_line_range: Range<Line>,
 19    pub cursor_file_max_row: Line,
 20    pub events: Vec<Arc<Event>>,
 21}
 22
 23#[derive(Debug, Clone, Serialize, Deserialize)]
 24pub struct PredictEditsRequest {
 25    pub excerpt: String,
 26    pub excerpt_path: Arc<Path>,
 27    /// Within file
 28    pub excerpt_range: Range<usize>,
 29    pub excerpt_line_range: Range<Line>,
 30    pub cursor_point: Point,
 31    /// Within `signatures`
 32    pub excerpt_parent: Option<usize>,
 33    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 34    pub related_files: Vec<RelatedFile>,
 35    pub events: Vec<Arc<Event>>,
 36    #[serde(default)]
 37    pub can_collect_data: bool,
 38    /// Info about the git repository state, only present when can_collect_data is true.
 39    #[serde(skip_serializing_if = "Option::is_none", default)]
 40    pub git_info: Option<PredictEditsGitInfo>,
 41    // Only available to staff
 42    #[serde(default)]
 43    pub debug_info: bool,
 44    #[serde(skip_serializing_if = "Option::is_none", default)]
 45    pub prompt_max_bytes: Option<usize>,
 46    #[serde(default)]
 47    pub prompt_format: PromptFormat,
 48    #[serde(default)]
 49    pub trigger: PredictEditsRequestTrigger,
 50}
 51
 52#[derive(Debug, Clone, Serialize, Deserialize)]
 53pub struct RelatedFile {
 54    pub path: Arc<Path>,
 55    pub max_row: Line,
 56    pub excerpts: Vec<Excerpt>,
 57}
 58
 59#[derive(Debug, Clone, Serialize, Deserialize)]
 60pub struct Excerpt {
 61    pub start_line: Line,
 62    pub text: Arc<str>,
 63}
 64
 65#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 66pub enum PromptFormat {
 67    /// XML old_tex/new_text
 68    OldTextNewText,
 69    /// Prompt format intended for use via edit_prediction_cli
 70    OnlySnippets,
 71    /// One-sentence instructions used in fine-tuned models
 72    Minimal,
 73    /// One-sentence instructions + FIM-like template
 74    MinimalQwen,
 75    /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
 76    SeedCoder1120,
 77}
 78
 79impl PromptFormat {
 80    pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
 81}
 82
 83impl Default for PromptFormat {
 84    fn default() -> Self {
 85        Self::DEFAULT
 86    }
 87}
 88
 89impl PromptFormat {
 90    pub fn iter() -> impl Iterator<Item = Self> {
 91        <Self as strum::IntoEnumIterator>::iter()
 92    }
 93}
 94
 95impl std::fmt::Display for PromptFormat {
 96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 97        match self {
 98            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
 99            PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
100            PromptFormat::Minimal => write!(f, "Minimal"),
101            PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
102            PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
103        }
104    }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
109#[serde(tag = "event")]
110pub enum Event {
111    BufferChange {
112        path: Arc<Path>,
113        old_path: Arc<Path>,
114        diff: String,
115        predicted: bool,
116        in_open_source_repo: bool,
117    },
118}
119
120impl Display for Event {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            Event::BufferChange {
124                path,
125                old_path,
126                diff,
127                predicted,
128                ..
129            } => {
130                if *predicted {
131                    write!(
132                        f,
133                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
134                        DiffPathFmt(old_path),
135                        DiffPathFmt(path)
136                    )
137                } else {
138                    write!(
139                        f,
140                        "--- a/{}\n+++ b/{}\n{diff}",
141                        DiffPathFmt(old_path),
142                        DiffPathFmt(path)
143                    )
144                }
145            }
146        }
147    }
148}
149
150/// always format the Path as a unix path with `/` as the path sep in Diffs
151pub struct DiffPathFmt<'a>(pub &'a Path);
152
153impl<'a> std::fmt::Display for DiffPathFmt<'a> {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        let mut is_first = true;
156        for component in self.0.components() {
157            if !is_first {
158                f.write_char('/')?;
159            } else {
160                is_first = false;
161            }
162            write!(f, "{}", component.as_os_str().display())?;
163        }
164        Ok(())
165    }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct PredictEditsResponse {
170    pub request_id: Uuid,
171    pub edits: Vec<Edit>,
172    pub debug_info: Option<DebugInfo>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct DebugInfo {
177    pub prompt: String,
178    pub prompt_planning_time: Duration,
179    pub model_response: String,
180    pub inference_time: Duration,
181    pub parsing_time: Duration,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct Edit {
186    pub path: Arc<Path>,
187    pub range: Range<Line>,
188    pub content: String,
189}
190
191#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
192pub struct Point {
193    pub line: Line,
194    pub column: u32,
195}
196
197#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
198#[serde(transparent)]
199pub struct Line(pub u32);
200
201impl Add for Line {
202    type Output = Self;
203
204    fn add(self, rhs: Self) -> Self::Output {
205        Self(self.0 + rhs.0)
206    }
207}
208
209impl Sub for Line {
210    type Output = Self;
211
212    fn sub(self, rhs: Self) -> Self::Output {
213        Self(self.0 - rhs.0)
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use indoc::indoc;
221    use pretty_assertions::assert_eq;
222
223    #[test]
224    fn test_event_display() {
225        let ev = Event::BufferChange {
226            path: Path::new("untitled").into(),
227            old_path: Path::new("untitled").into(),
228            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
229            predicted: false,
230            in_open_source_repo: true,
231        };
232        assert_eq!(
233            ev.to_string(),
234            indoc! {"
235                --- a/untitled
236                +++ b/untitled
237                @@ -1,2 +1,2 @@
238                -a
239                -b
240            "}
241        );
242
243        let ev = Event::BufferChange {
244            path: Path::new("foo/bar.txt").into(),
245            old_path: Path::new("foo/bar.txt").into(),
246            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
247            predicted: false,
248            in_open_source_repo: true,
249        };
250        assert_eq!(
251            ev.to_string(),
252            indoc! {"
253                --- a/foo/bar.txt
254                +++ b/foo/bar.txt
255                @@ -1,2 +1,2 @@
256                -a
257                -b
258            "}
259        );
260
261        let ev = Event::BufferChange {
262            path: Path::new("abc.txt").into(),
263            old_path: Path::new("123.txt").into(),
264            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
265            predicted: false,
266            in_open_source_repo: true,
267        };
268        assert_eq!(
269            ev.to_string(),
270            indoc! {"
271                --- a/123.txt
272                +++ b/abc.txt
273                @@ -1,2 +1,2 @@
274                -a
275                -b
276            "}
277        );
278
279        let ev = Event::BufferChange {
280            path: Path::new("abc.txt").into(),
281            old_path: Path::new("123.txt").into(),
282            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
283            predicted: true,
284            in_open_source_repo: true,
285        };
286        assert_eq!(
287            ev.to_string(),
288            indoc! {"
289                // User accepted prediction:
290                --- a/123.txt
291                +++ b/abc.txt
292                @@ -1,2 +1,2 @@
293                -a
294                -b
295            "}
296        );
297    }
298}