predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    parse_output::parse_prediction_output,
 10    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 11    progress::{ExampleProgress, InfoStyle, Progress, Step, StepProgress},
 12    retrieve_context::run_context_retrieval,
 13};
 14use anyhow::Context as _;
 15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
 16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
 18use gpui::{AppContext as _, AsyncApp, Task};
 19use http_client::{AsyncBody, HttpClient, Method};
 20use reqwest_client::ReqwestClient;
 21use std::{
 22    fs,
 23    sync::{
 24        Arc, Mutex, OnceLock,
 25        atomic::{AtomicUsize, Ordering::SeqCst},
 26    },
 27};
 28use zeta_prompt::ZetaFormat;
 29
 30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 32
 33pub async fn run_prediction(
 34    example: &mut Example,
 35    args: &PredictArgs,
 36    app_state: Arc<EpAppState>,
 37    example_progress: &ExampleProgress,
 38    mut cx: AsyncApp,
 39) -> anyhow::Result<()> {
 40    let repetition_count = args.repetitions;
 41
 42    if let Some(existing_prediction) = example.predictions.first() {
 43        let has_prediction = existing_prediction.actual_patch.is_some()
 44            || !existing_prediction.actual_output.is_empty();
 45        if has_prediction {
 46            match args.provider {
 47                None => return Ok(()),
 48                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 49                Some(_) => example.predictions.clear(),
 50            }
 51        }
 52    }
 53
 54    let Some(provider) = args.provider else {
 55        anyhow::bail!(
 56            "No existing predictions found. Use --provider to specify which model to use for prediction."
 57        );
 58    };
 59
 60    if let PredictionProvider::Teacher(backend)
 61    | PredictionProvider::TeacherMultiRegion(backend)
 62    | PredictionProvider::TeacherNonBatching(backend)
 63    | PredictionProvider::TeacherMultiRegionNonBatching(backend) = provider
 64    {
 65        run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 66        run_format_prompt(
 67            example,
 68            &FormatPromptArgs { provider },
 69            app_state.clone(),
 70            example_progress,
 71            cx,
 72        )
 73        .await?;
 74
 75        let step_progress = example_progress.start(Step::Predict);
 76        let batched = matches!(
 77            provider,
 78            PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
 79        );
 80        return predict_teacher(
 81            example,
 82            backend,
 83            batched,
 84            repetition_count,
 85            args.cache_only,
 86            &step_progress,
 87        )
 88        .await;
 89    }
 90
 91    if let PredictionProvider::Baseten(format) = provider {
 92        run_format_prompt(
 93            example,
 94            &FormatPromptArgs {
 95                provider: PredictionProvider::Zeta2(format),
 96            },
 97            app_state.clone(),
 98            example_progress,
 99            cx,
100        )
101        .await?;
102
103        let step_progress = example_progress.start(Step::Predict);
104        return predict_baseten(example, format, &step_progress).await;
105    }
106
107    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
108    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
109
110    let step_progress = example_progress.start(Step::Predict);
111
112    if matches!(
113        provider,
114        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
115    ) {
116        step_progress.set_substatus("authenticating");
117        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
118        AUTHENTICATED
119            .get_or_init(|| {
120                let client = app_state.client.clone();
121                cx.spawn(async move |cx| {
122                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
123                        eprintln!("Authentication failed: {}", e);
124                    }
125                })
126                .shared()
127            })
128            .clone()
129            .await;
130    }
131
132    let ep_store = cx
133        .update(|cx| EditPredictionStore::try_global(cx))
134        .context("EditPredictionStore not initialized")?;
135
136    ep_store.update(&mut cx, |store, _cx| {
137        let model = match provider {
138            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
139            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
140            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
141            PredictionProvider::Teacher(..)
142            | PredictionProvider::TeacherMultiRegion(..)
143            | PredictionProvider::TeacherNonBatching(..)
144            | PredictionProvider::TeacherMultiRegionNonBatching(..)
145            | PredictionProvider::Repair
146            | PredictionProvider::Baseten(_) => {
147                unreachable!()
148            }
149        };
150        store.set_edit_prediction_model(model);
151
152        // If user specified a non-default Zeta2 version, configure raw endpoint.
153        // ZED_ZETA_MODEL env var is optional.
154        if let PredictionProvider::Zeta2(format) = provider {
155            if format != ZetaFormat::default() {
156                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
157                let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
158                store.set_zeta2_raw_config(Zeta2RawConfig {
159                    model_id,
160                    environment,
161                    format,
162                });
163            }
164        }
165    });
166    step_progress.set_substatus("configuring model");
167    let state = example.state.as_ref().context("state must be set")?;
168    let run_dir = RUN_DIR.join(&example.spec.name);
169
170    let updated_example = Arc::new(Mutex::new(example.clone()));
171    let current_run_ix = Arc::new(AtomicUsize::new(0));
172
173    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
174    let debug_task = cx.background_spawn({
175        let updated_example = updated_example.clone();
176        let current_run_ix = current_run_ix.clone();
177        let run_dir = run_dir.clone();
178        async move {
179            while let Some(event) = debug_rx.next().await {
180                let run_ix = current_run_ix.load(SeqCst);
181                let mut updated_example = updated_example.lock().unwrap();
182
183                let run_dir = if repetition_count > 1 {
184                    run_dir.join(format!("{:03}", run_ix))
185                } else {
186                    run_dir.clone()
187                };
188
189                match event {
190                    DebugEvent::EditPredictionStarted(request) => {
191                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
192
193                        if let Some(prompt) = request.prompt {
194                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
195                            if matches!(provider, PredictionProvider::Zeta2(_)) {
196                                updated_example.prompt.get_or_insert(ExamplePrompt {
197                                    input: prompt,
198                                    expected_output: None,
199                                    rejected_output: None,
200                                    provider,
201                                    prefill: None,
202                                });
203                            }
204                        }
205                    }
206                    DebugEvent::EditPredictionFinished(request) => {
207                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
208
209                        if let Some(output) = request.model_output {
210                            fs::write(run_dir.join("prediction_response.md"), &output)?;
211                            updated_example
212                                .predictions
213                                .last_mut()
214                                .unwrap()
215                                .actual_output = output;
216                        }
217                        if run_ix >= repetition_count {
218                            break;
219                        }
220                    }
221                    _ => {}
222                }
223            }
224            anyhow::Ok(())
225        }
226    });
227
228    for ix in 0..repetition_count {
229        current_run_ix.store(ix, SeqCst);
230        let run_dir = if repetition_count > 1 {
231            run_dir.join(format!("{:03}", ix))
232        } else {
233            run_dir.clone()
234        };
235
236        if repetition_count > 1 {
237            step_progress.set_substatus(format!(
238                "running prediction {}/{}",
239                ix + 1,
240                repetition_count
241            ));
242        } else {
243            step_progress.set_substatus("running prediction");
244        }
245
246        fs::create_dir_all(&run_dir)?;
247        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
248            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
249        }
250        #[cfg(unix)]
251        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
252        #[cfg(windows)]
253        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
254
255        updated_example
256            .lock()
257            .unwrap()
258            .predictions
259            .push(ExamplePrediction {
260                actual_patch: None,
261                actual_output: String::new(),
262                actual_cursor: None,
263                error: None,
264                provider,
265                cumulative_logprob: None,
266                avg_logprob: None,
267            });
268
269        step_progress.set_substatus("requesting prediction");
270        let prediction = ep_store
271            .update(&mut cx, |store, cx| {
272                store.request_prediction(
273                    &state.project,
274                    &state.buffer,
275                    state.cursor_position,
276                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
277                    cx,
278                )
279            })
280            .await?;
281
282        let actual_patch = prediction.and_then(|prediction| {
283            let prediction = prediction.prediction.ok()?;
284            prediction
285                .edit_preview
286                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
287        });
288
289        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
290
291        updated_example
292            .lock()
293            .unwrap()
294            .predictions
295            .last_mut()
296            .unwrap()
297            .actual_patch = actual_patch;
298
299        if ix == repetition_count - 1 {
300            let (info, style) = if has_prediction {
301                ("predicted", InfoStyle::Normal)
302            } else {
303                ("no prediction", InfoStyle::Warning)
304            };
305            step_progress.set_info(info, style);
306        }
307    }
308
309    ep_store.update(&mut cx, |store, _| {
310        store.remove_project(&state.project);
311    });
312    debug_task.await?;
313
314    *example = Arc::into_inner(updated_example)
315        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
316        .into_inner()
317        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
318    Ok(())
319}
320
321async fn predict_teacher(
322    example: &mut Example,
323    backend: TeacherBackend,
324    batched: bool,
325    repetition_count: usize,
326    cache_only: bool,
327    step_progress: &crate::progress::StepProgress,
328) -> anyhow::Result<()> {
329    match backend {
330        TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
331            predict_anthropic(
332                example,
333                backend,
334                batched,
335                repetition_count,
336                cache_only,
337                step_progress,
338            )
339            .await
340        }
341        TeacherBackend::Gpt52 => {
342            predict_openai(
343                example,
344                backend,
345                batched,
346                repetition_count,
347                cache_only,
348                step_progress,
349            )
350            .await
351        }
352    }
353}
354
355async fn predict_anthropic(
356    example: &mut Example,
357    backend: TeacherBackend,
358    batched: bool,
359    repetition_count: usize,
360    cache_only: bool,
361    step_progress: &crate::progress::StepProgress,
362) -> anyhow::Result<()> {
363    let llm_model_name = backend.model_name();
364    let max_tokens = 16384;
365    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
366        let client = if batched {
367            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
368        } else {
369            AnthropicClient::plain()
370        };
371        client.expect("Failed to create Anthropic client")
372    });
373
374    let prompt = example.prompt.as_ref().context("Prompt is required")?;
375
376    for ix in 0..repetition_count {
377        if repetition_count > 1 {
378            step_progress.set_substatus(format!(
379                "running prediction {}/{}",
380                ix + 1,
381                repetition_count
382            ));
383        } else {
384            step_progress.set_substatus("running prediction");
385        }
386
387        let messages = vec![anthropic::Message {
388            role: anthropic::Role::User,
389            content: vec![anthropic::RequestContent::Text {
390                text: prompt.input.clone(),
391                cache_control: None,
392            }],
393        }];
394
395        let seed = if repetition_count > 1 { Some(ix) } else { None };
396        let Some(response) = llm_client
397            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
398            .await?
399        else {
400            // Request stashed for batched processing
401            continue;
402        };
403
404        let actual_output = response
405            .content
406            .into_iter()
407            .filter_map(|content| match content {
408                anthropic::ResponseContent::Text { text } => Some(text),
409                _ => None,
410            })
411            .collect::<Vec<String>>()
412            .join("\n");
413
414        let parser_provider = if batched {
415            example
416                .prompt
417                .as_ref()
418                .map(|prompt| prompt.provider)
419                .unwrap_or(PredictionProvider::Teacher(backend))
420        } else {
421            match example.prompt.as_ref().map(|prompt| prompt.provider) {
422                Some(PredictionProvider::TeacherMultiRegion(_))
423                | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
424                    PredictionProvider::TeacherMultiRegionNonBatching(backend)
425                }
426                _ => PredictionProvider::TeacherNonBatching(backend),
427            }
428        };
429
430        let (actual_patch, actual_cursor) = match parser_provider {
431            PredictionProvider::TeacherMultiRegion(_)
432            | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
433                TeacherMultiRegionPrompt::parse(example, &actual_output)?
434            }
435            _ => TeacherPrompt::parse(example, &actual_output)?,
436        };
437
438        let prediction = ExamplePrediction {
439            actual_patch: Some(actual_patch),
440            actual_output,
441            actual_cursor,
442            error: None,
443            provider: if batched {
444                match example.prompt.as_ref().map(|prompt| prompt.provider) {
445                    Some(PredictionProvider::TeacherMultiRegion(_)) => {
446                        PredictionProvider::TeacherMultiRegion(backend)
447                    }
448                    _ => PredictionProvider::Teacher(backend),
449                }
450            } else {
451                match example.prompt.as_ref().map(|prompt| prompt.provider) {
452                    Some(PredictionProvider::TeacherMultiRegion(_))
453                    | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
454                        PredictionProvider::TeacherMultiRegionNonBatching(backend)
455                    }
456                    _ => PredictionProvider::TeacherNonBatching(backend),
457                }
458            },
459            cumulative_logprob: None,
460            avg_logprob: None,
461        };
462
463        example.predictions.push(prediction);
464    }
465    Ok(())
466}
467
468async fn predict_openai(
469    example: &mut Example,
470    backend: TeacherBackend,
471    batched: bool,
472    repetition_count: usize,
473    cache_only: bool,
474    step_progress: &crate::progress::StepProgress,
475) -> anyhow::Result<()> {
476    let llm_model_name = backend.model_name();
477    let max_tokens = 16384;
478    let llm_client = OPENAI_CLIENT.get_or_init(|| {
479        let client = if batched {
480            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
481        } else {
482            OpenAiClient::plain()
483        };
484        client.expect("Failed to create OpenAI client")
485    });
486
487    let prompt = example.prompt.as_ref().context("Prompt is required")?;
488
489    for ix in 0..repetition_count {
490        if repetition_count > 1 {
491            step_progress.set_substatus(format!(
492                "running prediction {}/{}",
493                ix + 1,
494                repetition_count
495            ));
496        } else {
497            step_progress.set_substatus("running prediction");
498        }
499
500        let messages = vec![open_ai::RequestMessage::User {
501            content: open_ai::MessageContent::Plain(prompt.input.clone()),
502        }];
503
504        let seed = if repetition_count > 1 { Some(ix) } else { None };
505        let Some(response) = llm_client
506            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
507            .await?
508        else {
509            // Request stashed for batched processing
510            continue;
511        };
512
513        let actual_output = response
514            .choices
515            .into_iter()
516            .filter_map(|choice| match choice.message {
517                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
518                    open_ai::MessageContent::Plain(text) => text,
519                    open_ai::MessageContent::Multipart(parts) => parts
520                        .into_iter()
521                        .filter_map(|p| match p {
522                            open_ai::MessagePart::Text { text } => Some(text),
523                            _ => None,
524                        })
525                        .collect::<Vec<_>>()
526                        .join(""),
527                }),
528                _ => None,
529            })
530            .collect::<Vec<String>>()
531            .join("\n");
532
533        let parser_provider = if batched {
534            example
535                .prompt
536                .as_ref()
537                .map(|prompt| prompt.provider)
538                .unwrap_or(PredictionProvider::Teacher(backend))
539        } else {
540            match example.prompt.as_ref().map(|prompt| prompt.provider) {
541                Some(PredictionProvider::TeacherMultiRegion(_))
542                | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
543                    PredictionProvider::TeacherMultiRegionNonBatching(backend)
544                }
545                _ => PredictionProvider::TeacherNonBatching(backend),
546            }
547        };
548
549        let (actual_patch, actual_cursor) = match parser_provider {
550            PredictionProvider::TeacherMultiRegion(_)
551            | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
552                TeacherMultiRegionPrompt::parse(example, &actual_output)?
553            }
554            _ => TeacherPrompt::parse(example, &actual_output)?,
555        };
556
557        let prediction = ExamplePrediction {
558            actual_patch: Some(actual_patch),
559            actual_output,
560            actual_cursor,
561            error: None,
562            provider: if batched {
563                match example.prompt.as_ref().map(|prompt| prompt.provider) {
564                    Some(PredictionProvider::TeacherMultiRegion(_)) => {
565                        PredictionProvider::TeacherMultiRegion(backend)
566                    }
567                    _ => PredictionProvider::Teacher(backend),
568                }
569            } else {
570                match example.prompt.as_ref().map(|prompt| prompt.provider) {
571                    Some(PredictionProvider::TeacherMultiRegion(_))
572                    | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
573                        PredictionProvider::TeacherMultiRegionNonBatching(backend)
574                    }
575                    _ => PredictionProvider::TeacherNonBatching(backend),
576                }
577            },
578            cumulative_logprob: None,
579            avg_logprob: None,
580        };
581
582        example.predictions.push(prediction);
583    }
584    Ok(())
585}
586
587pub async fn predict_baseten(
588    example: &mut Example,
589    format: ZetaFormat,
590    step_progress: &StepProgress,
591) -> anyhow::Result<()> {
592    let model_id =
593        std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
594
595    let api_key =
596        std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
597
598    let prompt = example.prompt.as_ref().context("Prompt is required")?;
599    let prompt_text = prompt.input.clone();
600    let prefill = prompt.prefill.clone().unwrap_or_default();
601
602    step_progress.set_substatus("running prediction via baseten");
603
604    let environment: String = <&'static str>::from(&format).to_lowercase();
605    let url = format!(
606        "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
607    );
608
609    let request_body = RawCompletionRequest {
610        model: model_id,
611        prompt: prompt_text.clone(),
612        max_tokens: Some(2048),
613        temperature: Some(0.),
614        stop: vec![],
615        environment: None,
616    };
617
618    let body_bytes =
619        serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
620
621    let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
622    let request = http_client::Request::builder()
623        .method(Method::POST)
624        .uri(&url)
625        .header("Content-Type", "application/json")
626        .header("Authorization", format!("Api-Key {api_key}"))
627        .body(AsyncBody::from(body_bytes))?;
628
629    let mut response = http_client.send(request).await?;
630    let status = response.status();
631
632    let mut body = String::new();
633    response
634        .body_mut()
635        .read_to_string(&mut body)
636        .await
637        .context("Failed to read Baseten response body")?;
638
639    if !status.is_success() {
640        anyhow::bail!("Baseten API returned {status}: {body}");
641    }
642
643    let completion: RawCompletionResponse =
644        serde_json::from_str(&body).context("Failed to parse Baseten response")?;
645
646    let actual_output = completion
647        .choices
648        .into_iter()
649        .next()
650        .map(|choice| choice.text)
651        .unwrap_or_default();
652
653    let actual_output = format!("{prefill}{actual_output}");
654
655    let (actual_patch, actual_cursor) =
656        parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
657
658    let prediction = ExamplePrediction {
659        actual_patch: Some(actual_patch),
660        actual_output,
661        actual_cursor,
662        error: None,
663        provider: PredictionProvider::Baseten(format),
664        cumulative_logprob: None,
665        avg_logprob: None,
666    };
667
668    example.predictions.push(prediction);
669    Ok(())
670}
671
672pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
673    match provider {
674        Some(PredictionProvider::Teacher(backend))
675        | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
676            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
677                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
678                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
679                        .expect("Failed to create Anthropic client")
680                });
681                llm_client
682                    .sync_batches()
683                    .await
684                    .context("Failed to sync Anthropic batches")?;
685            }
686            TeacherBackend::Gpt52 => {
687                let llm_client = OPENAI_CLIENT.get_or_init(|| {
688                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
689                        .expect("Failed to create OpenAI client")
690                });
691                llm_client
692                    .sync_batches()
693                    .await
694                    .context("Failed to sync OpenAI batches")?;
695            }
696        },
697        _ => (),
698    };
699    Ok(())
700}
701
702pub async fn reprocess_after_batch_wait(
703    examples: &mut [Example],
704    args: &PredictArgs,
705) -> anyhow::Result<()> {
706    let Some(PredictionProvider::Teacher(backend)) = args.provider else {
707        return Ok(());
708    };
709
710    let mut reprocessed = 0;
711    for example in examples.iter_mut() {
712        let has_prediction = example
713            .predictions
714            .iter()
715            .any(|p| p.actual_patch.is_some() || !p.actual_output.is_empty());
716        if has_prediction || example.prompt.is_none() {
717            continue;
718        }
719
720        let example_progress = Progress::global().start_group(&example.spec.name);
721        let step_progress = example_progress.start(Step::Predict);
722        predict_teacher(
723            example,
724            backend,
725            true,
726            args.repetitions,
727            false,
728            &step_progress,
729        )
730        .await?;
731        reprocessed += 1;
732    }
733
734    if reprocessed > 0 {
735        eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
736    }
737
738    Ok(())
739}
740
741pub async fn wait_for_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
742    let poll_interval = std::time::Duration::from_secs(30);
743
744    loop {
745        let pending = pending_batch_count(provider)?;
746        if pending == 0 {
747            break;
748        }
749
750        eprintln!(
751            "Waiting for {} pending batch request(s) to complete... (polling every {}s)",
752            pending,
753            poll_interval.as_secs()
754        );
755        std::thread::sleep(poll_interval);
756
757        sync_batches(provider).await?;
758    }
759
760    Ok(())
761}
762
763fn pending_batch_count(provider: Option<&PredictionProvider>) -> anyhow::Result<usize> {
764    match provider {
765        Some(PredictionProvider::Teacher(backend)) => match backend {
766            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
767                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
768                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
769                        .expect("Failed to create Anthropic client")
770                });
771                llm_client.pending_batch_count()
772            }
773            TeacherBackend::Gpt52 => {
774                let llm_client = OPENAI_CLIENT.get_or_init(|| {
775                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
776                        .expect("Failed to create OpenAI client")
777                });
778                llm_client.pending_batch_count()
779            }
780        },
781        _ => Ok(0),
782    }
783}