predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    parse_output::parse_prediction_output,
 10    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 11    progress::{ExampleProgress, InfoStyle, Progress, Step, StepProgress},
 12    retrieve_context::run_context_retrieval,
 13};
 14use anyhow::Context as _;
 15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
 16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
 18use gpui::{AppContext as _, AsyncApp, Task};
 19use http_client::{AsyncBody, HttpClient, Method};
 20use reqwest_client::ReqwestClient;
 21use std::{
 22    fs,
 23    sync::{
 24        Arc, Mutex, OnceLock,
 25        atomic::{AtomicUsize, Ordering::SeqCst},
 26    },
 27};
 28use zeta_prompt::ZetaFormat;
 29
 30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 32
 33pub async fn run_prediction(
 34    example: &mut Example,
 35    args: &PredictArgs,
 36    app_state: Arc<EpAppState>,
 37    example_progress: &ExampleProgress,
 38    mut cx: AsyncApp,
 39) -> anyhow::Result<()> {
 40    let repetition_count = args.repetitions;
 41
 42    if let Some(existing_prediction) = example.predictions.first() {
 43        let has_prediction = existing_prediction.actual_patch.is_some()
 44            || !existing_prediction.actual_output.is_empty();
 45        if has_prediction {
 46            match args.provider {
 47                None => return Ok(()),
 48                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 49                Some(_) => example.predictions.clear(),
 50            }
 51        }
 52    }
 53
 54    let Some(provider) = args.provider else {
 55        anyhow::bail!(
 56            "No existing predictions found. Use --provider to specify which model to use for prediction."
 57        );
 58    };
 59
 60    if matches!(
 61        provider,
 62        PredictionProvider::TeacherMultiRegion(..)
 63            | PredictionProvider::TeacherMultiRegionNonBatching(..)
 64    ) {
 65        anyhow::bail!("Teacher multi-region providers are not supported for prediction.");
 66    }
 67
 68    if let PredictionProvider::Teacher(backend, _)
 69    | PredictionProvider::TeacherNonBatching(backend, _) = provider
 70    {
 71        run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 72        run_format_prompt(
 73            example,
 74            &FormatPromptArgs { provider },
 75            app_state.clone(),
 76            example_progress,
 77            cx,
 78        )
 79        .await?;
 80
 81        let step_progress = example_progress.start(Step::Predict);
 82        let batched = matches!(
 83            provider,
 84            PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
 85        );
 86        return predict_teacher(
 87            example,
 88            backend,
 89            batched,
 90            repetition_count,
 91            args.cache_only,
 92            &step_progress,
 93        )
 94        .await;
 95    }
 96
 97    if let PredictionProvider::Baseten(format) = provider {
 98        run_format_prompt(
 99            example,
100            &FormatPromptArgs {
101                provider: PredictionProvider::Zeta2(format),
102            },
103            app_state.clone(),
104            example_progress,
105            cx,
106        )
107        .await?;
108
109        let step_progress = example_progress.start(Step::Predict);
110        return predict_baseten(example, format, &step_progress).await;
111    }
112
113    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
114    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
115
116    let step_progress = example_progress.start(Step::Predict);
117
118    if matches!(
119        provider,
120        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
121    ) {
122        step_progress.set_substatus("authenticating");
123        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
124        AUTHENTICATED
125            .get_or_init(|| {
126                let client = app_state.client.clone();
127                cx.spawn(async move |cx| {
128                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
129                        eprintln!("Authentication failed: {}", e);
130                    }
131                })
132                .shared()
133            })
134            .clone()
135            .await;
136    }
137
138    let ep_store = cx
139        .update(|cx| EditPredictionStore::try_global(cx))
140        .context("EditPredictionStore not initialized")?;
141
142    ep_store.update(&mut cx, |store, _cx| {
143        let model = match provider {
144            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
145            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
146            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
147            PredictionProvider::Teacher(..)
148            | PredictionProvider::TeacherMultiRegion(..)
149            | PredictionProvider::TeacherNonBatching(..)
150            | PredictionProvider::TeacherMultiRegionNonBatching(..)
151            | PredictionProvider::Repair
152            | PredictionProvider::Baseten(_) => {
153                unreachable!()
154            }
155        };
156        store.set_edit_prediction_model(model);
157
158        // If user specified a non-default Zeta2 version, configure raw endpoint.
159        // ZED_ZETA_MODEL env var is optional.
160        if let PredictionProvider::Zeta2(format) = provider {
161            if format != ZetaFormat::default() {
162                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
163                let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
164                store.set_zeta2_raw_config(Zeta2RawConfig {
165                    model_id,
166                    environment,
167                    format,
168                });
169            }
170        }
171    });
172    step_progress.set_substatus("configuring model");
173    let state = example.state.as_ref().context("state must be set")?;
174    let run_dir = RUN_DIR.join(&example.spec.name);
175
176    let updated_example = Arc::new(Mutex::new(example.clone()));
177    let current_run_ix = Arc::new(AtomicUsize::new(0));
178
179    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
180    let debug_task = cx.background_spawn({
181        let updated_example = updated_example.clone();
182        let current_run_ix = current_run_ix.clone();
183        let run_dir = run_dir.clone();
184        async move {
185            while let Some(event) = debug_rx.next().await {
186                let run_ix = current_run_ix.load(SeqCst);
187                let mut updated_example = updated_example.lock().unwrap();
188
189                let run_dir = if repetition_count > 1 {
190                    run_dir.join(format!("{:03}", run_ix))
191                } else {
192                    run_dir.clone()
193                };
194
195                match event {
196                    DebugEvent::EditPredictionStarted(request) => {
197                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
198
199                        if let Some(prompt) = request.prompt {
200                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
201                            if matches!(provider, PredictionProvider::Zeta2(_)) {
202                                updated_example.prompt.get_or_insert(ExamplePrompt {
203                                    input: prompt,
204                                    expected_output: None,
205                                    rejected_output: None,
206                                    provider,
207                                    prefill: None,
208                                });
209                            }
210                        }
211                    }
212                    DebugEvent::EditPredictionFinished(request) => {
213                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
214
215                        if let Some(output) = request.model_output {
216                            fs::write(run_dir.join("prediction_response.md"), &output)?;
217                            updated_example
218                                .predictions
219                                .last_mut()
220                                .unwrap()
221                                .actual_output = output;
222                        }
223                        if run_ix >= repetition_count {
224                            break;
225                        }
226                    }
227                    _ => {}
228                }
229            }
230            anyhow::Ok(())
231        }
232    });
233
234    for ix in 0..repetition_count {
235        current_run_ix.store(ix, SeqCst);
236        let run_dir = if repetition_count > 1 {
237            run_dir.join(format!("{:03}", ix))
238        } else {
239            run_dir.clone()
240        };
241
242        if repetition_count > 1 {
243            step_progress.set_substatus(format!(
244                "running prediction {}/{}",
245                ix + 1,
246                repetition_count
247            ));
248        } else {
249            step_progress.set_substatus("running prediction");
250        }
251
252        fs::create_dir_all(&run_dir)?;
253        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
254            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
255        }
256        #[cfg(unix)]
257        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
258        #[cfg(windows)]
259        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
260
261        updated_example
262            .lock()
263            .unwrap()
264            .predictions
265            .push(ExamplePrediction {
266                actual_patch: None,
267                actual_output: String::new(),
268                actual_cursor: None,
269                error: None,
270                provider,
271                cumulative_logprob: None,
272                avg_logprob: None,
273            });
274
275        step_progress.set_substatus("requesting prediction");
276        let prediction = ep_store
277            .update(&mut cx, |store, cx| {
278                store.request_prediction(
279                    &state.project,
280                    &state.buffer,
281                    state.cursor_position,
282                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
283                    cx,
284                )
285            })
286            .await?;
287
288        let actual_patch = prediction.and_then(|prediction| {
289            let prediction = prediction.prediction.ok()?;
290            prediction
291                .edit_preview
292                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
293        });
294
295        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
296
297        updated_example
298            .lock()
299            .unwrap()
300            .predictions
301            .last_mut()
302            .unwrap()
303            .actual_patch = actual_patch;
304
305        if ix == repetition_count - 1 {
306            let (info, style) = if has_prediction {
307                ("predicted", InfoStyle::Normal)
308            } else {
309                ("no prediction", InfoStyle::Warning)
310            };
311            step_progress.set_info(info, style);
312        }
313    }
314
315    ep_store.update(&mut cx, |store, _| {
316        store.remove_project(&state.project);
317    });
318    debug_task.await?;
319
320    *example = Arc::into_inner(updated_example)
321        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
322        .into_inner()
323        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
324    Ok(())
325}
326
327async fn predict_teacher(
328    example: &mut Example,
329    backend: TeacherBackend,
330    batched: bool,
331    repetition_count: usize,
332    cache_only: bool,
333    step_progress: &crate::progress::StepProgress,
334) -> anyhow::Result<()> {
335    match backend {
336        TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
337            predict_anthropic(
338                example,
339                backend,
340                batched,
341                repetition_count,
342                cache_only,
343                step_progress,
344            )
345            .await
346        }
347        TeacherBackend::Gpt52 => {
348            predict_openai(
349                example,
350                backend,
351                batched,
352                repetition_count,
353                cache_only,
354                step_progress,
355            )
356            .await
357        }
358    }
359}
360
361async fn predict_anthropic(
362    example: &mut Example,
363    backend: TeacherBackend,
364    batched: bool,
365    repetition_count: usize,
366    cache_only: bool,
367    step_progress: &crate::progress::StepProgress,
368) -> anyhow::Result<()> {
369    let llm_model_name = backend.model_name();
370    let max_tokens = 16384;
371    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
372        let client = if batched {
373            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
374        } else {
375            AnthropicClient::plain()
376        };
377        client.expect("Failed to create Anthropic client")
378    });
379
380    let prompt = example.prompt.as_ref().context("Prompt is required")?;
381
382    for ix in 0..repetition_count {
383        if repetition_count > 1 {
384            step_progress.set_substatus(format!(
385                "running prediction {}/{}",
386                ix + 1,
387                repetition_count
388            ));
389        } else {
390            step_progress.set_substatus("running prediction");
391        }
392
393        let messages = vec![anthropic::Message {
394            role: anthropic::Role::User,
395            content: vec![anthropic::RequestContent::Text {
396                text: prompt.input.clone(),
397                cache_control: None,
398            }],
399        }];
400
401        let seed = if repetition_count > 1 { Some(ix) } else { None };
402        let Some(response) = llm_client
403            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
404            .await?
405        else {
406            // Request stashed for batched processing
407            continue;
408        };
409
410        let actual_output = response
411            .content
412            .into_iter()
413            .filter_map(|content| match content {
414                anthropic::ResponseContent::Text { text } => Some(text),
415                _ => None,
416            })
417            .collect::<Vec<String>>()
418            .join("\n");
419
420        let parser_provider = if batched {
421            example
422                .prompt
423                .as_ref()
424                .map(|prompt| prompt.provider)
425                .unwrap_or(PredictionProvider::Teacher(backend, ZetaFormat::default()))
426        } else {
427            match example.prompt.as_ref().map(|prompt| prompt.provider) {
428                Some(PredictionProvider::TeacherMultiRegion(_))
429                | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
430                    PredictionProvider::TeacherMultiRegionNonBatching(backend)
431                }
432                _ => PredictionProvider::TeacherNonBatching(backend, ZetaFormat::default()),
433            }
434        };
435
436        let (actual_patch, actual_cursor) = match parser_provider {
437            PredictionProvider::TeacherMultiRegion(_)
438            | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
439                TeacherMultiRegionPrompt::parse(example, &actual_output)?
440            }
441            _ => TeacherPrompt::parse(example, &actual_output)?,
442        };
443
444        let prediction = ExamplePrediction {
445            actual_patch: Some(actual_patch),
446            actual_output,
447            actual_cursor,
448            error: None,
449            provider: if batched {
450                match example.prompt.as_ref().map(|prompt| prompt.provider) {
451                    Some(PredictionProvider::TeacherMultiRegion(_)) => {
452                        PredictionProvider::TeacherMultiRegion(backend)
453                    }
454                    _ => PredictionProvider::Teacher(backend, ZetaFormat::default()),
455                }
456            } else {
457                match example.prompt.as_ref().map(|prompt| prompt.provider) {
458                    Some(PredictionProvider::TeacherMultiRegion(_))
459                    | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
460                        PredictionProvider::TeacherMultiRegionNonBatching(backend)
461                    }
462                    _ => PredictionProvider::TeacherNonBatching(backend, ZetaFormat::default()),
463                }
464            },
465            cumulative_logprob: None,
466            avg_logprob: None,
467        };
468
469        example.predictions.push(prediction);
470    }
471    Ok(())
472}
473
474async fn predict_openai(
475    example: &mut Example,
476    backend: TeacherBackend,
477    batched: bool,
478    repetition_count: usize,
479    cache_only: bool,
480    step_progress: &crate::progress::StepProgress,
481) -> anyhow::Result<()> {
482    let llm_model_name = backend.model_name();
483    let max_tokens = 16384;
484    let llm_client = OPENAI_CLIENT.get_or_init(|| {
485        let client = if batched {
486            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
487        } else {
488            OpenAiClient::plain()
489        };
490        client.expect("Failed to create OpenAI client")
491    });
492
493    let prompt = example.prompt.as_ref().context("Prompt is required")?;
494
495    for ix in 0..repetition_count {
496        if repetition_count > 1 {
497            step_progress.set_substatus(format!(
498                "running prediction {}/{}",
499                ix + 1,
500                repetition_count
501            ));
502        } else {
503            step_progress.set_substatus("running prediction");
504        }
505
506        let messages = vec![open_ai::RequestMessage::User {
507            content: open_ai::MessageContent::Plain(prompt.input.clone()),
508        }];
509
510        let seed = if repetition_count > 1 { Some(ix) } else { None };
511        let Some(response) = llm_client
512            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
513            .await?
514        else {
515            // Request stashed for batched processing
516            continue;
517        };
518
519        let actual_output = response
520            .choices
521            .into_iter()
522            .filter_map(|choice| match choice.message {
523                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
524                    open_ai::MessageContent::Plain(text) => text,
525                    open_ai::MessageContent::Multipart(parts) => parts
526                        .into_iter()
527                        .filter_map(|p| match p {
528                            open_ai::MessagePart::Text { text } => Some(text),
529                            _ => None,
530                        })
531                        .collect::<Vec<_>>()
532                        .join(""),
533                }),
534                _ => None,
535            })
536            .collect::<Vec<String>>()
537            .join("\n");
538
539        let parser_provider = if batched {
540            example
541                .prompt
542                .as_ref()
543                .map(|prompt| prompt.provider)
544                .unwrap_or(PredictionProvider::Teacher(backend, ZetaFormat::default()))
545        } else {
546            match example.prompt.as_ref().map(|prompt| prompt.provider) {
547                Some(PredictionProvider::TeacherMultiRegion(_))
548                | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
549                    PredictionProvider::TeacherMultiRegionNonBatching(backend)
550                }
551                _ => PredictionProvider::TeacherNonBatching(backend, ZetaFormat::default()),
552            }
553        };
554
555        let (actual_patch, actual_cursor) = match parser_provider {
556            PredictionProvider::TeacherMultiRegion(_)
557            | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
558                TeacherMultiRegionPrompt::parse(example, &actual_output)?
559            }
560            _ => TeacherPrompt::parse(example, &actual_output)?,
561        };
562
563        let prediction = ExamplePrediction {
564            actual_patch: Some(actual_patch),
565            actual_output,
566            actual_cursor,
567            error: None,
568            provider: if batched {
569                match example.prompt.as_ref().map(|prompt| prompt.provider) {
570                    Some(PredictionProvider::TeacherMultiRegion(_)) => {
571                        PredictionProvider::TeacherMultiRegion(backend)
572                    }
573                    _ => PredictionProvider::Teacher(backend, ZetaFormat::default()),
574                }
575            } else {
576                match example.prompt.as_ref().map(|prompt| prompt.provider) {
577                    Some(PredictionProvider::TeacherMultiRegion(_))
578                    | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
579                        PredictionProvider::TeacherMultiRegionNonBatching(backend)
580                    }
581                    _ => PredictionProvider::TeacherNonBatching(backend, ZetaFormat::default()),
582                }
583            },
584            cumulative_logprob: None,
585            avg_logprob: None,
586        };
587
588        example.predictions.push(prediction);
589    }
590    Ok(())
591}
592
593pub async fn predict_baseten(
594    example: &mut Example,
595    format: ZetaFormat,
596    step_progress: &StepProgress,
597) -> anyhow::Result<()> {
598    let model_id =
599        std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
600
601    let api_key =
602        std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
603
604    let prompt = example.prompt.as_ref().context("Prompt is required")?;
605    let prompt_text = prompt.input.clone();
606    let prefill = prompt.prefill.clone().unwrap_or_default();
607
608    step_progress.set_substatus("running prediction via baseten");
609
610    let environment: String = <&'static str>::from(&format).to_lowercase();
611    let url = format!(
612        "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
613    );
614
615    let request_body = RawCompletionRequest {
616        model: model_id,
617        prompt: prompt_text.clone(),
618        max_tokens: Some(2048),
619        temperature: Some(0.),
620        stop: vec![],
621        environment: None,
622    };
623
624    let body_bytes =
625        serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
626
627    let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
628    let request = http_client::Request::builder()
629        .method(Method::POST)
630        .uri(&url)
631        .header("Content-Type", "application/json")
632        .header("Authorization", format!("Api-Key {api_key}"))
633        .body(AsyncBody::from(body_bytes))?;
634
635    let mut response = http_client.send(request).await?;
636    let status = response.status();
637
638    let mut body = String::new();
639    response
640        .body_mut()
641        .read_to_string(&mut body)
642        .await
643        .context("Failed to read Baseten response body")?;
644
645    if !status.is_success() {
646        anyhow::bail!("Baseten API returned {status}: {body}");
647    }
648
649    let completion: RawCompletionResponse =
650        serde_json::from_str(&body).context("Failed to parse Baseten response")?;
651
652    let actual_output = completion
653        .choices
654        .into_iter()
655        .next()
656        .map(|choice| choice.text)
657        .unwrap_or_default();
658
659    let actual_output = format!("{prefill}{actual_output}");
660
661    let (actual_patch, actual_cursor) =
662        parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
663
664    let prediction = ExamplePrediction {
665        actual_patch: Some(actual_patch),
666        actual_output,
667        actual_cursor,
668        error: None,
669        provider: PredictionProvider::Baseten(format),
670        cumulative_logprob: None,
671        avg_logprob: None,
672    };
673
674    example.predictions.push(prediction);
675    Ok(())
676}
677
678pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
679    match provider {
680        Some(PredictionProvider::Teacher(backend, _))
681        | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
682            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
683                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
684                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
685                        .expect("Failed to create Anthropic client")
686                });
687                llm_client
688                    .sync_batches()
689                    .await
690                    .context("Failed to sync Anthropic batches")?;
691            }
692            TeacherBackend::Gpt52 => {
693                let llm_client = OPENAI_CLIENT.get_or_init(|| {
694                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
695                        .expect("Failed to create OpenAI client")
696                });
697                llm_client
698                    .sync_batches()
699                    .await
700                    .context("Failed to sync OpenAI batches")?;
701            }
702        },
703        _ => (),
704    };
705    Ok(())
706}
707
708pub async fn reprocess_after_batch_wait(
709    examples: &mut [Example],
710    args: &PredictArgs,
711) -> anyhow::Result<()> {
712    let Some(PredictionProvider::Teacher(backend, _)) = args.provider else {
713        return Ok(());
714    };
715
716    let mut reprocessed = 0;
717    for example in examples.iter_mut() {
718        let has_prediction = example
719            .predictions
720            .iter()
721            .any(|p| p.actual_patch.is_some() || !p.actual_output.is_empty());
722        if has_prediction || example.prompt.is_none() {
723            continue;
724        }
725
726        let example_progress = Progress::global().start_group(&example.spec.name);
727        let step_progress = example_progress.start(Step::Predict);
728        predict_teacher(
729            example,
730            backend,
731            true,
732            args.repetitions,
733            false,
734            &step_progress,
735        )
736        .await?;
737        reprocessed += 1;
738    }
739
740    if reprocessed > 0 {
741        eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
742    }
743
744    Ok(())
745}
746
747pub async fn wait_for_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
748    let poll_interval = std::time::Duration::from_secs(30);
749
750    loop {
751        let pending = pending_batch_count(provider)?;
752        if pending == 0 {
753            break;
754        }
755
756        eprintln!(
757            "Waiting for {} pending batch request(s) to complete... (polling every {}s)",
758            pending,
759            poll_interval.as_secs()
760        );
761        std::thread::sleep(poll_interval);
762
763        sync_batches(provider).await?;
764    }
765
766    Ok(())
767}
768
769fn pending_batch_count(provider: Option<&PredictionProvider>) -> anyhow::Result<usize> {
770    match provider {
771        Some(PredictionProvider::Teacher(backend, _)) => match backend {
772            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
773                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
774                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
775                        .expect("Failed to create Anthropic client")
776                });
777                llm_client.pending_batch_count()
778            }
779            TeacherBackend::Gpt52 => {
780                let llm_client = OPENAI_CLIENT.get_or_init(|| {
781                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
782                        .expect("Failed to create OpenAI client")
783                });
784                llm_client.pending_batch_count()
785            }
786        },
787        _ => Ok(0),
788    }
789}