1use crate::{
2 example::Example,
3 source_location::SourceLocation,
4 training::{
5 context::{ContextType, collect_context, strip_special_tags},
6 llm_client::LlmClient,
7 },
8};
9use anthropic::{Message, RequestContent, ResponseContent, Role};
10use anyhow::Result;
11
12pub struct TeacherModel {
13 pub llm_name: String,
14 pub context: ContextType,
15 pub client: LlmClient,
16}
17
18#[derive(Debug, serde::Serialize)]
19pub struct TeacherOutput {
20 parsed_output: String,
21 prompt: String,
22 raw_llm_response: String,
23 context: String,
24 diff: String,
25}
26
27impl TeacherModel {
28 const PROMPT: &str = include_str!("teacher.prompt.md");
29 pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
30 pub(crate) const REGION_END: &str = "<|editable_region_end|>";
31 pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
32
33 /// Number of lines to include before the cursor position
34 pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
35
36 /// Number of lines to include after the cursor position
37 pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
38
39 /// Truncate edit history to this number of last lines
40 const MAX_HISTORY_LINES: usize = 128;
41
42 pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
43 TeacherModel {
44 llm_name,
45 context,
46 client,
47 }
48 }
49
50 pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
51 let name = input.unique_name();
52 let worktree_dir = input.setup_worktree(name).await?;
53 let cursor: SourceLocation = input
54 .cursor_position
55 .parse()
56 .expect("Failed to parse cursor position");
57
58 let context = collect_context(&self.context, &worktree_dir, cursor.clone());
59 let edit_history = Self::format_edit_history(&input.edit_history);
60
61 let prompt = Self::PROMPT
62 .replace("{{context}}", &context)
63 .replace("{{edit_history}}", &edit_history);
64
65 let messages = vec![Message {
66 role: Role::User,
67 content: vec![RequestContent::Text {
68 text: prompt.clone(),
69 cache_control: None,
70 }],
71 }];
72
73 let Some(response) = self
74 .client
75 .generate(self.llm_name.clone(), 16384, messages)
76 .await?
77 else {
78 return Ok(None);
79 };
80
81 let response_text = response
82 .content
83 .into_iter()
84 .filter_map(|content| match content {
85 ResponseContent::Text { text } => Some(text),
86 _ => None,
87 })
88 .collect::<Vec<String>>()
89 .join("\n");
90
91 let parsed_output = self.parse_response(&response_text);
92
93 let original_editable_region = Self::extract_editable_region(&context);
94 let context_after_edit = context.replace(&original_editable_region, &parsed_output);
95 let context_after_edit = strip_special_tags(&context_after_edit);
96 let context_before_edit = strip_special_tags(&context);
97 let diff = language::unified_diff(&context_before_edit, &context_after_edit);
98
99 // zeta distill --batch batch_results.txt
100 // zeta distill
101 // 1. Run `zeta distill <2000 examples <- all examples>` for the first time
102 // - store LLM requests in a batch, don't actual send the request
103 // - send the batch (2000 requests) after all inputs are processed
104 // 2. `zeta send-batches`
105 // - upload the batch to Anthropic
106
107 // https://platform.claude.com/docs/en/build-with-claude/batch-processing
108 // https://crates.io/crates/anthropic-sdk-rust
109
110 // - poll for results
111 // - when ready, store results in cache (a database)
112 // 3. `zeta distill` again
113 // - use the cached results this time
114
115 Ok(Some(TeacherOutput {
116 parsed_output,
117 prompt,
118 raw_llm_response: response_text,
119 context,
120 diff,
121 }))
122 }
123
124 fn parse_response(&self, content: &str) -> String {
125 let codeblock = Self::extract_last_codeblock(content);
126 let editable_region = Self::extract_editable_region(&codeblock);
127
128 editable_region
129 }
130
131 /// Extract content from the last code-fenced block if any, or else return content as is
132 fn extract_last_codeblock(text: &str) -> String {
133 let mut last_block = None;
134 let mut search_start = 0;
135
136 while let Some(start) = text[search_start..].find("```") {
137 let start = start + search_start;
138 let bytes = text.as_bytes();
139 let mut backtick_end = start;
140
141 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
142 backtick_end += 1;
143 }
144
145 let backtick_count = backtick_end - start;
146 let closing_backticks = "`".repeat(backtick_count);
147
148 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
149 let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
150 last_block = Some(code_block.to_string());
151 search_start = backtick_end + end_pos + backtick_count;
152 } else {
153 break;
154 }
155 }
156
157 last_block.unwrap_or_else(|| text.to_string())
158 }
159
160 fn extract_editable_region(text: &str) -> String {
161 let start = text
162 .find(Self::REGION_START)
163 .map_or(0, |pos| pos + Self::REGION_START.len());
164 let end = text.find(Self::REGION_END).unwrap_or(text.len());
165
166 text[start..end].to_string()
167 }
168
169 /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
170 fn format_edit_history(edit_history: &str) -> String {
171 let lines = edit_history
172 .lines()
173 .filter(|&s| Self::is_content_line(s))
174 .collect::<Vec<_>>();
175
176 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
177 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
178 } else {
179 &lines
180 };
181 history_lines.join("\n")
182 }
183
184 fn is_content_line(s: &str) -> bool {
185 s.starts_with("-")
186 || s.starts_with("+")
187 || s.starts_with(" ")
188 || s.starts_with("---")
189 || s.starts_with("+++")
190 || s.starts_with("@@")
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_parse_response() {
200 let teacher = TeacherModel::new(
201 "test".to_string(),
202 ContextType::CurrentFile,
203 LlmClient::dummy(),
204 );
205 let response = "This is a test response.";
206 let parsed = teacher.parse_response(response);
207 assert_eq!(parsed, response.to_string());
208
209 let response = indoc::indoc! {"
210 Some thinking
211
212 `````
213 actual response
214 `````
215 "};
216 let parsed = teacher.parse_response(response);
217 assert_eq!(parsed, "actual response");
218 }
219
220 #[test]
221 fn test_extract_last_code_block() {
222 let text = indoc::indoc! {"
223 Some thinking
224
225 ```
226 first block
227 ```
228
229 `````
230 last block
231 `````
232 "};
233 let last_block = TeacherModel::extract_last_codeblock(text);
234 assert_eq!(last_block, "last block");
235 }
236
237 #[test]
238 fn test_extract_editable_region() {
239 let teacher = TeacherModel::new(
240 "test".to_string(),
241 ContextType::CurrentFile,
242 LlmClient::dummy(),
243 );
244 let response = indoc::indoc! {"
245 some lines
246 are
247 here
248 <|editable_region_start|>
249 one
250 two three
251
252 <|editable_region_end|>
253 more
254 lines here
255 "};
256 let parsed = teacher.parse_response(response);
257 assert_eq!(
258 parsed,
259 indoc::indoc! {"
260 one
261 two three
262
263 "}
264 );
265 }
266}