teacher.rs

  1use crate::{
  2    example::Example,
  3    source_location::SourceLocation,
  4    training::{
  5        context::{ContextType, collect_context, strip_special_tags},
  6        llm_client::LlmClient,
  7    },
  8};
  9use anthropic::{Message, RequestContent, ResponseContent, Role};
 10use anyhow::Result;
 11
 12pub struct TeacherModel {
 13    pub llm_name: String,
 14    pub context: ContextType,
 15    pub client: LlmClient,
 16}
 17
 18#[derive(Debug, serde::Serialize)]
 19pub struct TeacherOutput {
 20    parsed_output: String,
 21    prompt: String,
 22    raw_llm_response: String,
 23    context: String,
 24    diff: String,
 25}
 26
 27impl TeacherModel {
 28    const PROMPT: &str = include_str!("teacher.prompt.md");
 29    pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
 30    pub(crate) const REGION_END: &str = "<|editable_region_end|>";
 31    pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
 32
 33    /// Number of lines to include before the cursor position
 34    pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
 35
 36    /// Number of lines to include after the cursor position
 37    pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
 38
 39    /// Truncate edit history to this number of last lines
 40    const MAX_HISTORY_LINES: usize = 128;
 41
 42    pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
 43        TeacherModel {
 44            llm_name,
 45            context,
 46            client,
 47        }
 48    }
 49
 50    pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
 51        let name = input.unique_name();
 52        let worktree_dir = input.setup_worktree(name).await?;
 53        let cursor: SourceLocation = input
 54            .cursor_position
 55            .parse()
 56            .expect("Failed to parse cursor position");
 57
 58        let context = collect_context(&self.context, &worktree_dir, cursor.clone());
 59        let edit_history = Self::format_edit_history(&input.edit_history);
 60
 61        let prompt = Self::PROMPT
 62            .replace("{{context}}", &context)
 63            .replace("{{edit_history}}", &edit_history);
 64
 65        let messages = vec![Message {
 66            role: Role::User,
 67            content: vec![RequestContent::Text {
 68                text: prompt.clone(),
 69                cache_control: None,
 70            }],
 71        }];
 72
 73        let Some(response) = self
 74            .client
 75            .generate(self.llm_name.clone(), 16384, messages)
 76            .await?
 77        else {
 78            return Ok(None);
 79        };
 80
 81        let response_text = response
 82            .content
 83            .into_iter()
 84            .filter_map(|content| match content {
 85                ResponseContent::Text { text } => Some(text),
 86                _ => None,
 87            })
 88            .collect::<Vec<String>>()
 89            .join("\n");
 90
 91        let parsed_output = self.parse_response(&response_text);
 92
 93        let original_editable_region = Self::extract_editable_region(&context);
 94        let context_after_edit = context.replace(&original_editable_region, &parsed_output);
 95        let context_after_edit = strip_special_tags(&context_after_edit);
 96        let context_before_edit = strip_special_tags(&context);
 97        let diff = language::unified_diff(&context_before_edit, &context_after_edit);
 98
 99        // zeta distill --batch batch_results.txt
100        // zeta distill
101        // 1. Run `zeta distill <2000 examples <- all examples>` for the first time
102        //  - store LLM requests in a batch, don't actual send the request
103        //  - send the batch (2000 requests) after all inputs are processed
104        // 2. `zeta send-batches`
105        //   - upload the batch to Anthropic
106
107        // https://platform.claude.com/docs/en/build-with-claude/batch-processing
108        // https://crates.io/crates/anthropic-sdk-rust
109
110        //   - poll for results
111        //   - when ready, store results in cache (a database)
112        // 3. `zeta distill` again
113        //    - use the cached results this time
114
115        Ok(Some(TeacherOutput {
116            parsed_output,
117            prompt,
118            raw_llm_response: response_text,
119            context,
120            diff,
121        }))
122    }
123
124    fn parse_response(&self, content: &str) -> String {
125        let codeblock = Self::extract_last_codeblock(content);
126        let editable_region = Self::extract_editable_region(&codeblock);
127
128        editable_region
129    }
130
131    /// Extract content from the last code-fenced block if any, or else return content as is
132    fn extract_last_codeblock(text: &str) -> String {
133        let mut last_block = None;
134        let mut search_start = 0;
135
136        while let Some(start) = text[search_start..].find("```") {
137            let start = start + search_start;
138            let bytes = text.as_bytes();
139            let mut backtick_end = start;
140
141            while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
142                backtick_end += 1;
143            }
144
145            let backtick_count = backtick_end - start;
146            let closing_backticks = "`".repeat(backtick_count);
147
148            if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
149                let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
150                last_block = Some(code_block.to_string());
151                search_start = backtick_end + end_pos + backtick_count;
152            } else {
153                break;
154            }
155        }
156
157        last_block.unwrap_or_else(|| text.to_string())
158    }
159
160    fn extract_editable_region(text: &str) -> String {
161        let start = text
162            .find(Self::REGION_START)
163            .map_or(0, |pos| pos + Self::REGION_START.len());
164        let end = text.find(Self::REGION_END).unwrap_or(text.len());
165
166        text[start..end].to_string()
167    }
168
169    /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
170    fn format_edit_history(edit_history: &str) -> String {
171        let lines = edit_history
172            .lines()
173            .filter(|&s| Self::is_content_line(s))
174            .collect::<Vec<_>>();
175
176        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
177            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
178        } else {
179            &lines
180        };
181        history_lines.join("\n")
182    }
183
184    fn is_content_line(s: &str) -> bool {
185        s.starts_with("-")
186            || s.starts_with("+")
187            || s.starts_with(" ")
188            || s.starts_with("---")
189            || s.starts_with("+++")
190            || s.starts_with("@@")
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_parse_response() {
200        let teacher = TeacherModel::new(
201            "test".to_string(),
202            ContextType::CurrentFile,
203            LlmClient::dummy(),
204        );
205        let response = "This is a test response.";
206        let parsed = teacher.parse_response(response);
207        assert_eq!(parsed, response.to_string());
208
209        let response = indoc::indoc! {"
210            Some thinking
211
212            `````
213            actual response
214            `````
215            "};
216        let parsed = teacher.parse_response(response);
217        assert_eq!(parsed, "actual response");
218    }
219
220    #[test]
221    fn test_extract_last_code_block() {
222        let text = indoc::indoc! {"
223            Some thinking
224
225            ```
226            first block
227            ```
228
229            `````
230            last block
231            `````
232            "};
233        let last_block = TeacherModel::extract_last_codeblock(text);
234        assert_eq!(last_block, "last block");
235    }
236
237    #[test]
238    fn test_extract_editable_region() {
239        let teacher = TeacherModel::new(
240            "test".to_string(),
241            ContextType::CurrentFile,
242            LlmClient::dummy(),
243        );
244        let response = indoc::indoc! {"
245            some lines
246            are
247            here
248            <|editable_region_start|>
249            one
250            two three
251
252            <|editable_region_end|>
253            more
254            lines here
255            "};
256        let parsed = teacher.parse_response(response);
257        assert_eq!(
258            parsed,
259            indoc::indoc! {"
260            one
261            two three
262
263            "}
264        );
265    }
266}