1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{ExampleProgress, InfoStyle, Step},
10 retrieve_context::run_context_retrieval,
11};
12use anyhow::Context as _;
13use edit_prediction::{DebugEvent, EditPredictionStore};
14use futures::{FutureExt as _, StreamExt as _, future::Shared};
15use gpui::{AppContext as _, AsyncApp, Task};
16use std::{
17 fs,
18 sync::{
19 Arc, Mutex, OnceLock,
20 atomic::{AtomicUsize, Ordering::SeqCst},
21 },
22};
23use zeta_prompt::ZetaVersion;
24
25static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
26
27pub async fn run_prediction(
28 example: &mut Example,
29 args: &PredictArgs,
30 app_state: Arc<EpAppState>,
31 example_progress: &ExampleProgress,
32 mut cx: AsyncApp,
33) -> anyhow::Result<()> {
34 let repetition_count = args.repetitions;
35
36 if let Some(existing_prediction) = example.predictions.first() {
37 let has_prediction = existing_prediction.actual_patch.is_some()
38 || !existing_prediction.actual_output.is_empty();
39 if has_prediction {
40 match args.provider {
41 None => return Ok(()),
42 Some(provider) if existing_prediction.provider == provider => return Ok(()),
43 Some(_) => example.predictions.clear(),
44 }
45 }
46 }
47
48 let Some(provider) = args.provider else {
49 anyhow::bail!(
50 "No existing predictions found. Use --provider to specify which model to use for prediction."
51 );
52 };
53
54 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
55
56 if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
57 provider
58 {
59 let _step_progress = example_progress.start(Step::Predict);
60
61 run_format_prompt(
62 example,
63 &FormatPromptArgs { provider },
64 app_state.clone(),
65 example_progress,
66 cx,
67 )
68 .await?;
69
70 let batched = matches!(provider, PredictionProvider::Teacher(..));
71 return predict_anthropic(example, repetition_count, version, batched).await;
72 }
73
74 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
75
76 let step_progress = example_progress.start(Step::Predict);
77
78 if matches!(
79 provider,
80 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
81 ) {
82 step_progress.set_substatus("authenticating");
83 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
84 AUTHENTICATED
85 .get_or_init(|| {
86 let client = app_state.client.clone();
87 cx.spawn(async move |cx| {
88 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
89 eprintln!("Authentication failed: {}", e);
90 }
91 })
92 .shared()
93 })
94 .clone()
95 .await;
96 }
97
98 let ep_store = cx
99 .update(|cx| EditPredictionStore::try_global(cx))
100 .context("EditPredictionStore not initialized")?;
101
102 ep_store.update(&mut cx, |store, _cx| {
103 let model = match provider {
104 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
105 PredictionProvider::Zeta2(version) => {
106 edit_prediction::EditPredictionModel::Zeta2 { version }
107 }
108 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
109 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
110 PredictionProvider::Teacher(..)
111 | PredictionProvider::TeacherNonBatching(..)
112 | PredictionProvider::Repair => {
113 unreachable!()
114 }
115 };
116 store.set_edit_prediction_model(model);
117 });
118 step_progress.set_substatus("configuring model");
119 let state = example.state.as_ref().context("state must be set")?;
120 let run_dir = RUN_DIR.join(&example.spec.name);
121
122 let updated_example = Arc::new(Mutex::new(example.clone()));
123 let current_run_ix = Arc::new(AtomicUsize::new(0));
124
125 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
126 let debug_task = cx.background_spawn({
127 let updated_example = updated_example.clone();
128 let current_run_ix = current_run_ix.clone();
129 let run_dir = run_dir.clone();
130 async move {
131 while let Some(event) = debug_rx.next().await {
132 let run_ix = current_run_ix.load(SeqCst);
133 let mut updated_example = updated_example.lock().unwrap();
134
135 let run_dir = if repetition_count > 1 {
136 run_dir.join(format!("{:03}", run_ix))
137 } else {
138 run_dir.clone()
139 };
140
141 match event {
142 DebugEvent::EditPredictionStarted(request) => {
143 assert_eq!(updated_example.predictions.len(), run_ix + 1);
144
145 if let Some(prompt) = request.prompt {
146 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
147 if matches!(provider, PredictionProvider::Zeta2(_)) {
148 updated_example.prompt.get_or_insert(ExamplePrompt {
149 input: prompt,
150 expected_output: String::new(),
151 provider,
152 });
153 }
154 }
155 }
156 DebugEvent::EditPredictionFinished(request) => {
157 assert_eq!(updated_example.predictions.len(), run_ix + 1);
158
159 if let Some(output) = request.model_output {
160 fs::write(run_dir.join("prediction_response.md"), &output)?;
161 updated_example
162 .predictions
163 .last_mut()
164 .unwrap()
165 .actual_output = output;
166 }
167 if run_ix >= repetition_count {
168 break;
169 }
170 }
171 _ => {}
172 }
173 }
174 anyhow::Ok(())
175 }
176 });
177
178 for ix in 0..repetition_count {
179 current_run_ix.store(ix, SeqCst);
180 let run_dir = if repetition_count > 1 {
181 run_dir.join(format!("{:03}", ix))
182 } else {
183 run_dir.clone()
184 };
185
186 fs::create_dir_all(&run_dir)?;
187 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
188 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
189 }
190 #[cfg(unix)]
191 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
192 #[cfg(windows)]
193 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
194
195 updated_example
196 .lock()
197 .unwrap()
198 .predictions
199 .push(ExamplePrediction {
200 actual_patch: None,
201 actual_output: String::new(),
202 error: None,
203 provider,
204 });
205
206 step_progress.set_substatus("requesting prediction");
207 let prediction = ep_store
208 .update(&mut cx, |store, cx| {
209 store.request_prediction(
210 &state.project,
211 &state.buffer,
212 state.cursor_position,
213 cloud_llm_client::PredictEditsRequestTrigger::Cli,
214 cx,
215 )
216 })
217 .await?;
218
219 let actual_patch = prediction.and_then(|prediction| {
220 let prediction = prediction.prediction.ok()?;
221 prediction
222 .edit_preview
223 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
224 });
225
226 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
227
228 updated_example
229 .lock()
230 .unwrap()
231 .predictions
232 .last_mut()
233 .unwrap()
234 .actual_patch = actual_patch;
235
236 if ix == repetition_count - 1 {
237 let (info, style) = if has_prediction {
238 ("predicted", InfoStyle::Normal)
239 } else {
240 ("no prediction", InfoStyle::Warning)
241 };
242 step_progress.set_info(info, style);
243 }
244 }
245
246 ep_store.update(&mut cx, |store, _| {
247 store.remove_project(&state.project);
248 });
249 debug_task.await?;
250
251 *example = Arc::into_inner(updated_example)
252 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
253 .into_inner()
254 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
255 Ok(())
256}
257
258async fn predict_anthropic(
259 example: &mut Example,
260 _repetition_count: usize,
261 version: ZetaVersion,
262 batched: bool,
263) -> anyhow::Result<()> {
264 let llm_model_name = "claude-sonnet-4-5";
265 let max_tokens = 16384;
266 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
267 let client = if batched {
268 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
269 } else {
270 AnthropicClient::plain()
271 };
272 client.expect("Failed to create Anthropic client")
273 });
274
275 let prompt = example.prompt.as_ref().context("Prompt is required")?;
276
277 let messages = vec![anthropic::Message {
278 role: anthropic::Role::User,
279 content: vec![anthropic::RequestContent::Text {
280 text: prompt.input.clone(),
281 cache_control: None,
282 }],
283 }];
284
285 let Some(response) = llm_client
286 .generate(llm_model_name, max_tokens, messages)
287 .await?
288 else {
289 // Request stashed for batched processing
290 return Ok(());
291 };
292
293 let actual_output = response
294 .content
295 .into_iter()
296 .filter_map(|content| match content {
297 anthropic::ResponseContent::Text { text } => Some(text),
298 _ => None,
299 })
300 .collect::<Vec<String>>()
301 .join("\n");
302
303 let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
304
305 let prediction = ExamplePrediction {
306 actual_patch: Some(actual_patch),
307 actual_output,
308 error: None,
309 provider: if batched {
310 PredictionProvider::Teacher(version)
311 } else {
312 PredictionProvider::TeacherNonBatching(version)
313 },
314 };
315
316 example.predictions.push(prediction);
317 Ok(())
318}
319
320pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
321 match provider {
322 Some(PredictionProvider::Teacher(..)) => {
323 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
324 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
325 .expect("Failed to create Anthropic client")
326 });
327 llm_client
328 .sync_batches()
329 .await
330 .context("Failed to sync batches")?;
331 }
332 _ => (),
333 };
334 Ok(())
335}