predict_edits_v3.rs

  1use chrono::Duration;
  2use serde::{Deserialize, Serialize};
  3use std::{
  4    borrow::Cow,
  5    fmt::{Display, Write as _},
  6    ops::{Add, Range, Sub},
  7    path::Path,
  8    sync::Arc,
  9};
 10use strum::EnumIter;
 11use uuid::Uuid;
 12
 13use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger};
 14
 15#[derive(Debug, Clone, Serialize, Deserialize)]
 16pub struct PlanContextRetrievalRequest {
 17    pub excerpt: String,
 18    pub excerpt_path: Arc<Path>,
 19    pub excerpt_line_range: Range<Line>,
 20    pub cursor_file_max_row: Line,
 21    pub events: Vec<Arc<Event>>,
 22}
 23
 24#[derive(Debug, Clone, Serialize, Deserialize)]
 25pub struct PredictEditsRequest {
 26    pub excerpt: String,
 27    pub excerpt_path: Arc<Path>,
 28    /// Within file
 29    pub excerpt_range: Range<usize>,
 30    pub excerpt_line_range: Range<Line>,
 31    pub cursor_point: Point,
 32    /// Within `signatures`
 33    pub excerpt_parent: Option<usize>,
 34    #[serde(skip_serializing_if = "Vec::is_empty", default)]
 35    pub related_files: Vec<RelatedFile>,
 36    pub events: Vec<Arc<Event>>,
 37    #[serde(default)]
 38    pub can_collect_data: bool,
 39    /// Info about the git repository state, only present when can_collect_data is true.
 40    #[serde(skip_serializing_if = "Option::is_none", default)]
 41    pub git_info: Option<PredictEditsGitInfo>,
 42    // Only available to staff
 43    #[serde(default)]
 44    pub debug_info: bool,
 45    #[serde(skip_serializing_if = "Option::is_none", default)]
 46    pub prompt_max_bytes: Option<usize>,
 47    #[serde(default)]
 48    pub prompt_format: PromptFormat,
 49    #[serde(default)]
 50    pub trigger: PredictEditsRequestTrigger,
 51}
 52
 53#[derive(Debug, Clone, Serialize, Deserialize)]
 54pub struct RelatedFile {
 55    pub path: Arc<Path>,
 56    pub max_row: Line,
 57    pub excerpts: Vec<Excerpt>,
 58}
 59
 60#[derive(Debug, Clone, Serialize, Deserialize)]
 61pub struct Excerpt {
 62    pub start_line: Line,
 63    pub text: Arc<str>,
 64}
 65
 66#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 67pub enum PromptFormat {
 68    /// XML old_tex/new_text
 69    OldTextNewText,
 70    /// Prompt format intended for use via edit_prediction_cli
 71    OnlySnippets,
 72    /// One-sentence instructions used in fine-tuned models
 73    Minimal,
 74    /// One-sentence instructions + FIM-like template
 75    MinimalQwen,
 76    /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
 77    SeedCoder1120,
 78}
 79
 80impl PromptFormat {
 81    pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
 82}
 83
 84impl Default for PromptFormat {
 85    fn default() -> Self {
 86        Self::DEFAULT
 87    }
 88}
 89
 90impl PromptFormat {
 91    pub fn iter() -> impl Iterator<Item = Self> {
 92        <Self as strum::IntoEnumIterator>::iter()
 93    }
 94}
 95
 96impl std::fmt::Display for PromptFormat {
 97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 98        match self {
 99            PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
100            PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
101            PromptFormat::Minimal => write!(f, "Minimal"),
102            PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
103            PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
104        }
105    }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
110#[serde(tag = "event")]
111pub enum Event {
112    BufferChange {
113        path: Arc<Path>,
114        old_path: Arc<Path>,
115        diff: String,
116        predicted: bool,
117        in_open_source_repo: bool,
118    },
119}
120
121impl Display for Event {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            Event::BufferChange {
125                path,
126                old_path,
127                diff,
128                predicted,
129                ..
130            } => {
131                if *predicted {
132                    write!(
133                        f,
134                        "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
135                        DiffPathFmt(old_path),
136                        DiffPathFmt(path)
137                    )
138                } else {
139                    write!(
140                        f,
141                        "--- a/{}\n+++ b/{}\n{diff}",
142                        DiffPathFmt(old_path),
143                        DiffPathFmt(path)
144                    )
145                }
146            }
147        }
148    }
149}
150
151/// always format the Path as a unix path with `/` as the path sep in Diffs
152pub struct DiffPathFmt<'a>(pub &'a Path);
153
154impl<'a> std::fmt::Display for DiffPathFmt<'a> {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        let mut is_first = true;
157        for component in self.0.components() {
158            if !is_first {
159                f.write_char('/')?;
160            } else {
161                is_first = false;
162            }
163            write!(f, "{}", component.as_os_str().display())?;
164        }
165        Ok(())
166    }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct PredictEditsResponse {
171    pub request_id: Uuid,
172    pub edits: Vec<Edit>,
173    pub debug_info: Option<DebugInfo>,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct DebugInfo {
178    pub prompt: String,
179    pub prompt_planning_time: Duration,
180    pub model_response: String,
181    pub inference_time: Duration,
182    pub parsing_time: Duration,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct Edit {
187    pub path: Arc<Path>,
188    pub range: Range<Line>,
189    pub content: String,
190}
191
192#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
193pub struct Point {
194    pub line: Line,
195    pub column: u32,
196}
197
198#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
199#[serde(transparent)]
200pub struct Line(pub u32);
201
202impl Add for Line {
203    type Output = Self;
204
205    fn add(self, rhs: Self) -> Self::Output {
206        Self(self.0 + rhs.0)
207    }
208}
209
210impl Sub for Line {
211    type Output = Self;
212
213    fn sub(self, rhs: Self) -> Self::Output {
214        Self(self.0 - rhs.0)
215    }
216}
217
218#[derive(Debug, Deserialize, Serialize)]
219pub struct RawCompletionRequest {
220    pub model: String,
221    pub prompt: String,
222    pub max_tokens: u32,
223    pub temperature: Option<f32>,
224    pub stop: Vec<Cow<'static, str>>,
225}
226
227#[derive(Debug, Deserialize, Serialize)]
228pub struct RawCompletionResponse {
229    pub id: String,
230    pub object: String,
231    pub created: u64,
232    pub model: String,
233    pub choices: Vec<RawCompletionChoice>,
234    pub usage: RawCompletionUsage,
235}
236
237#[derive(Debug, Deserialize, Serialize)]
238pub struct RawCompletionChoice {
239    pub text: String,
240    pub finish_reason: Option<String>,
241}
242
243#[derive(Debug, Serialize, Deserialize)]
244pub struct RawCompletionUsage {
245    pub prompt_tokens: u32,
246    pub completion_tokens: u32,
247    pub total_tokens: u32,
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use indoc::indoc;
254    use pretty_assertions::assert_eq;
255
256    #[test]
257    fn test_event_display() {
258        let ev = Event::BufferChange {
259            path: Path::new("untitled").into(),
260            old_path: Path::new("untitled").into(),
261            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
262            predicted: false,
263            in_open_source_repo: true,
264        };
265        assert_eq!(
266            ev.to_string(),
267            indoc! {"
268                --- a/untitled
269                +++ b/untitled
270                @@ -1,2 +1,2 @@
271                -a
272                -b
273            "}
274        );
275
276        let ev = Event::BufferChange {
277            path: Path::new("foo/bar.txt").into(),
278            old_path: Path::new("foo/bar.txt").into(),
279            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
280            predicted: false,
281            in_open_source_repo: true,
282        };
283        assert_eq!(
284            ev.to_string(),
285            indoc! {"
286                --- a/foo/bar.txt
287                +++ b/foo/bar.txt
288                @@ -1,2 +1,2 @@
289                -a
290                -b
291            "}
292        );
293
294        let ev = Event::BufferChange {
295            path: Path::new("abc.txt").into(),
296            old_path: Path::new("123.txt").into(),
297            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
298            predicted: false,
299            in_open_source_repo: true,
300        };
301        assert_eq!(
302            ev.to_string(),
303            indoc! {"
304                --- a/123.txt
305                +++ b/abc.txt
306                @@ -1,2 +1,2 @@
307                -a
308                -b
309            "}
310        );
311
312        let ev = Event::BufferChange {
313            path: Path::new("abc.txt").into(),
314            old_path: Path::new("123.txt").into(),
315            diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(),
316            predicted: true,
317            in_open_source_repo: true,
318        };
319        assert_eq!(
320            ev.to_string(),
321            indoc! {"
322                // User accepted prediction:
323                --- a/123.txt
324                +++ b/abc.txt
325                @@ -1,2 +1,2 @@
326                -a
327                -b
328            "}
329        );
330    }
331}