1use crate::{
2 BatchProvider, FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 llm_client::{LlmClient, model_for_backend},
8 load_project::run_load_project,
9 openai_client::OpenAiClient,
10 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
11 progress::{ExampleProgress, InfoStyle, Step},
12 qa,
13 repair::{build_repair_prompt_for_prediction, needs_repair_qa, parse_repair_response},
14 retrieve_context::run_context_retrieval,
15};
16use anyhow::Context as _;
17use edit_prediction::{DebugEvent, EditPredictionStore};
18use futures::{FutureExt as _, StreamExt as _, future::Shared};
19use gpui::{AppContext as _, AsyncApp, Task};
20use std::{
21 fs,
22 sync::{
23 Arc, Mutex, OnceLock,
24 atomic::{AtomicUsize, Ordering::SeqCst},
25 },
26};
27
28static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
29static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
30
31pub async fn run_prediction(
32 example: &mut Example,
33 args: &PredictArgs,
34 app_state: Arc<EpAppState>,
35 example_progress: &ExampleProgress,
36 mut cx: AsyncApp,
37) -> anyhow::Result<()> {
38 let repetition_count = args.repetitions;
39
40 if let Some(existing_prediction) = example.predictions.first() {
41 let has_prediction = existing_prediction.actual_patch.is_some()
42 || !existing_prediction.actual_output.is_empty();
43 if has_prediction {
44 match args.provider {
45 None => return Ok(()),
46 Some(provider) if existing_prediction.provider == provider => return Ok(()),
47 Some(_) => example.predictions.clear(),
48 }
49 }
50 }
51
52 let Some(provider) = args.provider else {
53 anyhow::bail!(
54 "No existing predictions found. Use --provider to specify which model to use for prediction."
55 );
56 };
57
58 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
59
60 if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
61 provider
62 {
63 let _step_progress = example_progress.start(Step::Predict);
64
65 run_format_prompt(
66 example,
67 &FormatPromptArgs { provider },
68 app_state.clone(),
69 example_progress,
70 cx,
71 )
72 .await?;
73
74 let batched = matches!(provider, PredictionProvider::Teacher(..));
75 return predict_teacher(example, backend, batched, repetition_count).await;
76 }
77
78 if let PredictionProvider::RepairedTeacher(backend) = provider {
79 let _step_progress = example_progress.start(Step::Predict);
80
81 run_format_prompt(
82 example,
83 &FormatPromptArgs { provider },
84 app_state.clone(),
85 example_progress,
86 cx,
87 )
88 .await?;
89
90 return predict_repaired_teacher(example, backend, repetition_count).await;
91 }
92
93 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
94
95 let step_progress = example_progress.start(Step::Predict);
96
97 if matches!(
98 provider,
99 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
100 ) {
101 step_progress.set_substatus("authenticating");
102 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
103 AUTHENTICATED
104 .get_or_init(|| {
105 let client = app_state.client.clone();
106 cx.spawn(async move |cx| {
107 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
108 eprintln!("Authentication failed: {}", e);
109 }
110 })
111 .shared()
112 })
113 .clone()
114 .await;
115 }
116
117 let ep_store = cx
118 .update(|cx| EditPredictionStore::try_global(cx))
119 .context("EditPredictionStore not initialized")?;
120
121 ep_store.update(&mut cx, |store, _cx| {
122 let model = match provider {
123 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
124 PredictionProvider::Zeta2(version) => {
125 edit_prediction::EditPredictionModel::Zeta2 { version }
126 }
127 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
128 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
129 PredictionProvider::Teacher(..)
130 | PredictionProvider::TeacherNonBatching(..)
131 | PredictionProvider::RepairedTeacher(..)
132 | PredictionProvider::Repair => {
133 unreachable!()
134 }
135 };
136 store.set_edit_prediction_model(model);
137 });
138 step_progress.set_substatus("configuring model");
139 let state = example.state.as_ref().context("state must be set")?;
140 let run_dir = RUN_DIR.join(&example.spec.name);
141
142 let updated_example = Arc::new(Mutex::new(example.clone()));
143 let current_run_ix = Arc::new(AtomicUsize::new(0));
144
145 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
146 let debug_task = cx.background_spawn({
147 let updated_example = updated_example.clone();
148 let current_run_ix = current_run_ix.clone();
149 let run_dir = run_dir.clone();
150 async move {
151 while let Some(event) = debug_rx.next().await {
152 let run_ix = current_run_ix.load(SeqCst);
153 let mut updated_example = updated_example.lock().unwrap();
154
155 let run_dir = if repetition_count > 1 {
156 run_dir.join(format!("{:03}", run_ix))
157 } else {
158 run_dir.clone()
159 };
160
161 match event {
162 DebugEvent::EditPredictionStarted(request) => {
163 assert_eq!(updated_example.predictions.len(), run_ix + 1);
164
165 if let Some(prompt) = request.prompt {
166 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
167 if matches!(provider, PredictionProvider::Zeta2(_)) {
168 updated_example.prompt.get_or_insert(ExamplePrompt {
169 input: prompt,
170 expected_output: String::new(),
171 rejected_output: None,
172 provider,
173 });
174 }
175 }
176 }
177 DebugEvent::EditPredictionFinished(request) => {
178 assert_eq!(updated_example.predictions.len(), run_ix + 1);
179
180 if let Some(output) = request.model_output {
181 fs::write(run_dir.join("prediction_response.md"), &output)?;
182 updated_example
183 .predictions
184 .last_mut()
185 .unwrap()
186 .actual_output = output;
187 }
188 if run_ix >= repetition_count {
189 break;
190 }
191 }
192 _ => {}
193 }
194 }
195 anyhow::Ok(())
196 }
197 });
198
199 for ix in 0..repetition_count {
200 current_run_ix.store(ix, SeqCst);
201 let run_dir = if repetition_count > 1 {
202 run_dir.join(format!("{:03}", ix))
203 } else {
204 run_dir.clone()
205 };
206
207 fs::create_dir_all(&run_dir)?;
208 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
209 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
210 }
211 #[cfg(unix)]
212 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
213 #[cfg(windows)]
214 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
215
216 updated_example
217 .lock()
218 .unwrap()
219 .predictions
220 .push(ExamplePrediction {
221 actual_patch: None,
222 actual_output: String::new(),
223 error: None,
224 provider,
225 });
226
227 step_progress.set_substatus("requesting prediction");
228 let prediction = ep_store
229 .update(&mut cx, |store, cx| {
230 store.request_prediction(
231 &state.project,
232 &state.buffer,
233 state.cursor_position,
234 cloud_llm_client::PredictEditsRequestTrigger::Cli,
235 cx,
236 )
237 })
238 .await?;
239
240 let actual_patch = prediction.and_then(|prediction| {
241 let prediction = prediction.prediction.ok()?;
242 prediction
243 .edit_preview
244 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
245 });
246
247 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
248
249 updated_example
250 .lock()
251 .unwrap()
252 .predictions
253 .last_mut()
254 .unwrap()
255 .actual_patch = actual_patch;
256
257 if ix == repetition_count - 1 {
258 let (info, style) = if has_prediction {
259 ("predicted", InfoStyle::Normal)
260 } else {
261 ("no prediction", InfoStyle::Warning)
262 };
263 step_progress.set_info(info, style);
264 }
265 }
266
267 ep_store.update(&mut cx, |store, _| {
268 store.remove_project(&state.project);
269 });
270 debug_task.await?;
271
272 *example = Arc::into_inner(updated_example)
273 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
274 .into_inner()
275 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
276 Ok(())
277}
278
279async fn predict_teacher(
280 example: &mut Example,
281 backend: TeacherBackend,
282 batched: bool,
283 repetition_count: usize,
284) -> anyhow::Result<()> {
285 match backend {
286 TeacherBackend::Sonnet45 => {
287 predict_anthropic(example, backend, batched, repetition_count).await
288 }
289 TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await,
290 }
291}
292
293async fn predict_anthropic(
294 example: &mut Example,
295 backend: TeacherBackend,
296 batched: bool,
297 repetition_count: usize,
298) -> anyhow::Result<()> {
299 let llm_model_name = backend.model_name();
300 let max_tokens = 16384;
301 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
302 let client = if batched {
303 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
304 } else {
305 AnthropicClient::plain()
306 };
307 client.expect("Failed to create Anthropic client")
308 });
309
310 let prompt = example.prompt.as_ref().context("Prompt is required")?;
311
312 for ix in 0..repetition_count {
313 let messages = vec![anthropic::Message {
314 role: anthropic::Role::User,
315 content: vec![anthropic::RequestContent::Text {
316 text: prompt.input.clone(),
317 cache_control: None,
318 }],
319 }];
320
321 let seed = if repetition_count > 1 { Some(ix) } else { None };
322 let Some(response) = llm_client
323 .generate(llm_model_name, max_tokens, messages, seed)
324 .await?
325 else {
326 // Request stashed for batched processing
327 return Ok(());
328 };
329
330 let actual_output = response
331 .content
332 .into_iter()
333 .filter_map(|content| match content {
334 anthropic::ResponseContent::Text { text } => Some(text),
335 _ => None,
336 })
337 .collect::<Vec<String>>()
338 .join("\n");
339
340 let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
341
342 let prediction = ExamplePrediction {
343 actual_patch: Some(actual_patch),
344 actual_output,
345 error: None,
346 provider: if batched {
347 PredictionProvider::Teacher(backend)
348 } else {
349 PredictionProvider::TeacherNonBatching(backend)
350 },
351 };
352
353 example.predictions.push(prediction);
354 }
355 Ok(())
356}
357
358async fn predict_openai(
359 example: &mut Example,
360 backend: TeacherBackend,
361 batched: bool,
362 repetition_count: usize,
363) -> anyhow::Result<()> {
364 let llm_model_name = backend.model_name();
365 let max_tokens = 16384;
366 let llm_client = OPENAI_CLIENT.get_or_init(|| {
367 let client = if batched {
368 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
369 } else {
370 OpenAiClient::plain()
371 };
372 client.expect("Failed to create OpenAI client")
373 });
374
375 let prompt = example.prompt.as_ref().context("Prompt is required")?;
376
377 for ix in 0..repetition_count {
378 let messages = vec![open_ai::RequestMessage::User {
379 content: open_ai::MessageContent::Plain(prompt.input.clone()),
380 }];
381
382 let seed = if repetition_count > 1 { Some(ix) } else { None };
383 let Some(response) = llm_client
384 .generate(llm_model_name, max_tokens, messages, seed)
385 .await?
386 else {
387 // Request stashed for batched processing
388 return Ok(());
389 };
390
391 let actual_output = response
392 .choices
393 .into_iter()
394 .filter_map(|choice| match choice.message {
395 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
396 open_ai::MessageContent::Plain(text) => text,
397 open_ai::MessageContent::Multipart(parts) => parts
398 .into_iter()
399 .filter_map(|p| match p {
400 open_ai::MessagePart::Text { text } => Some(text),
401 _ => None,
402 })
403 .collect::<Vec<_>>()
404 .join(""),
405 }),
406 _ => None,
407 })
408 .collect::<Vec<String>>()
409 .join("\n");
410
411 let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
412
413 let prediction = ExamplePrediction {
414 actual_patch: Some(actual_patch),
415 actual_output,
416 error: None,
417 provider: if batched {
418 PredictionProvider::Teacher(backend)
419 } else {
420 PredictionProvider::TeacherNonBatching(backend)
421 },
422 };
423
424 example.predictions.push(prediction);
425 }
426 Ok(())
427}
428
429/// Default confidence threshold for repair
430const DEFAULT_REPAIR_CONFIDENCE_THRESHOLD: u8 = 3;
431
432/// Predict using teacher model, then run QA evaluation on all predictions,
433/// and replace predictions that need repair.
434///
435/// This is a non-batched flow that processes each step synchronously.
436/// - Predictions that pass QA keep their original Teacher provider
437/// - Predictions that fail QA are replaced with repaired versions (RepairedTeacher provider)
438/// - QA results are not stored because they would be outdated after replacement
439async fn predict_repaired_teacher(
440 example: &mut Example,
441 backend: TeacherBackend,
442 repetition_count: usize,
443) -> anyhow::Result<()> {
444 // Step 1: Run teacher prediction (non-batched for immediate results)
445 predict_teacher(example, backend, false, repetition_count).await?;
446
447 if example.predictions.is_empty() {
448 return Ok(());
449 }
450
451 let batch_provider = match backend {
452 TeacherBackend::Sonnet45 => BatchProvider::Anthropic,
453 TeacherBackend::Gpt52 => BatchProvider::Openai,
454 };
455 let llm_client = LlmClient::new(batch_provider, false)?;
456 let model = model_for_backend(batch_provider);
457
458 // Step 2: Run QA for all predictions and repair those that need it
459 let mut final_predictions = Vec::with_capacity(example.predictions.len());
460 let mut final_qa = Vec::with_capacity(example.predictions.len());
461
462 for prediction in &example.predictions {
463 // Skip QA if no actual patch was generated
464 if prediction.actual_patch.is_none() {
465 final_predictions.push(prediction.clone());
466 final_qa.push(None);
467 continue;
468 }
469
470 // Run QA evaluation for this prediction
471 let qa_result =
472 if let Some(qa_prompt) = qa::build_prompt_for_prediction(example, prediction) {
473 match llm_client.generate(model, 1024, &qa_prompt).await? {
474 Some(response_text) => Some(qa::parse_response(&response_text)),
475 None => None,
476 }
477 } else {
478 None
479 };
480
481 // Check if repair is needed
482 let needs_repair = qa_result
483 .as_ref()
484 .map(|qa| needs_repair_qa(qa, DEFAULT_REPAIR_CONFIDENCE_THRESHOLD))
485 .unwrap_or(false);
486
487 if needs_repair {
488 let qa = qa_result
489 .as_ref()
490 .expect("qa_result must be Some if needs_repair is true");
491 // Step 3: Run repair for this prediction
492 if let Some(repair_prompt) = build_repair_prompt_for_prediction(example, prediction, qa)
493 {
494 if let Some(response_text) =
495 llm_client.generate(model, 16384, &repair_prompt).await?
496 {
497 match parse_repair_response(example, &response_text) {
498 Ok(mut repaired_prediction) => {
499 repaired_prediction.provider =
500 PredictionProvider::RepairedTeacher(backend);
501 final_predictions.push(repaired_prediction);
502 final_qa.push(qa_result);
503 }
504 Err(e) => {
505 final_predictions.push(ExamplePrediction {
506 actual_patch: None,
507 actual_output: response_text,
508 error: Some(format!("Failed to parse repair response: {}", e)),
509 provider: PredictionProvider::RepairedTeacher(backend),
510 });
511 final_qa.push(qa_result);
512 }
513 }
514 } else {
515 // Repair generation returned None, keep original
516 final_predictions.push(prediction.clone());
517 final_qa.push(qa_result);
518 }
519 } else {
520 // Couldn't build repair prompt, keep original
521 final_predictions.push(prediction.clone());
522 final_qa.push(qa_result);
523 }
524 } else {
525 // No repair needed, keep original (with Teacher provider)
526 final_predictions.push(prediction.clone());
527 final_qa.push(qa_result);
528 }
529 }
530
531 example.predictions = final_predictions;
532 example.qa = final_qa;
533
534 Ok(())
535}
536
537pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
538 match provider {
539 Some(PredictionProvider::Teacher(backend)) => match backend {
540 TeacherBackend::Sonnet45 => {
541 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
542 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
543 .expect("Failed to create Anthropic client")
544 });
545 llm_client
546 .sync_batches()
547 .await
548 .context("Failed to sync Anthropic batches")?;
549 }
550 TeacherBackend::Gpt52 => {
551 let llm_client = OPENAI_CLIENT.get_or_init(|| {
552 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
553 .expect("Failed to create OpenAI client")
554 });
555 llm_client
556 .sync_batches()
557 .await
558 .context("Failed to sync OpenAI batches")?;
559 }
560 },
561 _ => (),
562 };
563 Ok(())
564}