zeta_prompt.rs

  1use serde::{Deserialize, Serialize};
  2use std::fmt::Write;
  3use std::ops::Range;
  4use std::path::Path;
  5use std::sync::Arc;
  6
  7pub const CURSOR_MARKER: &str = "<|user_cursor|>";
  8
  9#[derive(Clone, Debug, Serialize, Deserialize)]
 10pub struct ZetaPromptInput {
 11    pub cursor_path: Arc<Path>,
 12    pub cursor_excerpt: Arc<str>,
 13    pub editable_range_in_excerpt: Range<usize>,
 14    pub cursor_offset_in_excerpt: usize,
 15    pub events: Vec<Arc<Event>>,
 16    pub related_files: Arc<[RelatedFile]>,
 17}
 18
 19#[derive(Clone, Debug, Serialize, Deserialize)]
 20#[serde(tag = "event")]
 21pub enum Event {
 22    BufferChange {
 23        path: Arc<Path>,
 24        old_path: Arc<Path>,
 25        diff: String,
 26        predicted: bool,
 27        in_open_source_repo: bool,
 28    },
 29}
 30
 31pub fn write_event(prompt: &mut String, event: &Event) {
 32    fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
 33        for component in path.components() {
 34            prompt.push('/');
 35            write!(prompt, "{}", component.as_os_str().display()).ok();
 36        }
 37    }
 38    match event {
 39        Event::BufferChange {
 40            path,
 41            old_path,
 42            diff,
 43            predicted,
 44            in_open_source_repo: _,
 45        } => {
 46            if *predicted {
 47                prompt.push_str("// User accepted prediction:\n");
 48            }
 49            prompt.push_str("--- a");
 50            write_path_as_unix_str(prompt, old_path.as_ref());
 51            prompt.push_str("\n+++ b");
 52            write_path_as_unix_str(prompt, path.as_ref());
 53            prompt.push('\n');
 54            prompt.push_str(diff);
 55        }
 56    }
 57}
 58
 59#[derive(Clone, Debug, Serialize, Deserialize)]
 60pub struct RelatedFile {
 61    pub path: Arc<Path>,
 62    pub max_row: u32,
 63    pub excerpts: Vec<RelatedExcerpt>,
 64}
 65
 66#[derive(Clone, Debug, Serialize, Deserialize)]
 67pub struct RelatedExcerpt {
 68    pub row_range: Range<u32>,
 69    pub text: String,
 70}
 71
 72pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
 73    let mut prompt = String::new();
 74    write_related_files(&mut prompt, &input.related_files);
 75    write_edit_history_section(&mut prompt, input);
 76    write_cursor_excerpt_section(&mut prompt, input);
 77    prompt
 78}
 79
 80pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
 81    push_delimited(prompt, "related_files", &[], |prompt| {
 82        for file in related_files {
 83            let path_str = file.path.to_string_lossy();
 84            push_delimited(prompt, "related_file", &[("path", &path_str)], |prompt| {
 85                for excerpt in &file.excerpts {
 86                    push_delimited(
 87                        prompt,
 88                        "related_excerpt",
 89                        &[(
 90                            "lines",
 91                            &format!(
 92                                "{}-{}",
 93                                excerpt.row_range.start + 1,
 94                                excerpt.row_range.end + 1
 95                            ),
 96                        )],
 97                        |prompt| {
 98                            prompt.push_str(&excerpt.text);
 99                            prompt.push('\n');
100                        },
101                    );
102                }
103            });
104        }
105    });
106}
107
108fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
109    push_delimited(prompt, "edit_history", &[], |prompt| {
110        if input.events.is_empty() {
111            prompt.push_str("(No edit history)");
112        } else {
113            for event in &input.events {
114                write_event(prompt, event);
115            }
116        }
117    });
118}
119
120fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
121    push_delimited(prompt, "cursor_excerpt", &[], |prompt| {
122        let path_str = input.cursor_path.to_string_lossy();
123        push_delimited(prompt, "file", &[("path", &path_str)], |prompt| {
124            prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
125            push_delimited(prompt, "editable_region", &[], |prompt| {
126                prompt.push_str(
127                    &input.cursor_excerpt
128                        [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
129                );
130                prompt.push_str(CURSOR_MARKER);
131                prompt.push_str(
132                    &input.cursor_excerpt
133                        [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
134                );
135            });
136            prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
137        });
138    });
139}
140
141fn push_delimited(
142    prompt: &mut String,
143    tag: &'static str,
144    arguments: &[(&str, &str)],
145    cb: impl FnOnce(&mut String),
146) {
147    if !prompt.ends_with("\n") {
148        prompt.push('\n');
149    }
150    prompt.push('<');
151    prompt.push_str(tag);
152    for (arg_name, arg_value) in arguments {
153        write!(prompt, " {}=\"{}\"", arg_name, arg_value).ok();
154    }
155    prompt.push_str(">\n");
156
157    cb(prompt);
158
159    if !prompt.ends_with('\n') {
160        prompt.push('\n');
161    }
162    prompt.push_str("</");
163    prompt.push_str(tag);
164    prompt.push_str(">\n");
165}