predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    parse_output::parse_prediction_output,
 10    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 11    progress::{ExampleProgress, InfoStyle, Progress, Step, StepProgress},
 12    retrieve_context::run_context_retrieval,
 13};
 14use anyhow::Context as _;
 15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
 16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
 18use gpui::{AppContext as _, AsyncApp, Task};
 19use http_client::{AsyncBody, HttpClient, Method};
 20use reqwest_client::ReqwestClient;
 21use std::{
 22    fs,
 23    sync::{
 24        Arc, Mutex, OnceLock,
 25        atomic::{AtomicUsize, Ordering::SeqCst},
 26    },
 27};
 28use zeta_prompt::ZetaFormat;
 29
 30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 32
 33pub async fn run_prediction(
 34    example: &mut Example,
 35    args: &PredictArgs,
 36    app_state: Arc<EpAppState>,
 37    example_progress: &ExampleProgress,
 38    mut cx: AsyncApp,
 39) -> anyhow::Result<()> {
 40    let repetition_count = args.repetitions;
 41
 42    if let Some(existing_prediction) = example.predictions.first() {
 43        let has_prediction = existing_prediction.actual_patch.is_some()
 44            || !existing_prediction.actual_output.is_empty();
 45        if has_prediction {
 46            match args.provider {
 47                None => return Ok(()),
 48                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 49                Some(_) => example.predictions.clear(),
 50            }
 51        }
 52    }
 53
 54    let Some(provider) = args.provider else {
 55        anyhow::bail!(
 56            "No existing predictions found. Use --provider to specify which model to use for prediction."
 57        );
 58    };
 59
 60    if let PredictionProvider::Teacher(backend)
 61    | PredictionProvider::TeacherMultiRegion(backend)
 62    | PredictionProvider::TeacherNonBatching(backend)
 63    | PredictionProvider::TeacherMultiRegionNonBatching(backend) = provider
 64    {
 65        run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 66        run_format_prompt(
 67            example,
 68            &FormatPromptArgs { provider },
 69            app_state.clone(),
 70            example_progress,
 71            cx,
 72        )
 73        .await?;
 74
 75        let step_progress = example_progress.start(Step::Predict);
 76        let batched = matches!(
 77            provider,
 78            PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
 79        );
 80        return predict_teacher(
 81            example,
 82            backend,
 83            batched,
 84            repetition_count,
 85            args.cache_only,
 86            &step_progress,
 87        )
 88        .await;
 89    }
 90
 91    if let PredictionProvider::Baseten(format) = provider {
 92        run_format_prompt(
 93            example,
 94            &FormatPromptArgs {
 95                provider: PredictionProvider::Zeta2(format),
 96            },
 97            app_state.clone(),
 98            example_progress,
 99            cx,
100        )
101        .await?;
102
103        let step_progress = example_progress.start(Step::Predict);
104        return predict_baseten(example, format, &step_progress).await;
105    }
106
107    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
108    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
109
110    let step_progress = example_progress.start(Step::Predict);
111
112    if matches!(
113        provider,
114        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
115    ) {
116        step_progress.set_substatus("authenticating");
117        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
118        AUTHENTICATED
119            .get_or_init(|| {
120                let client = app_state.client.clone();
121                cx.spawn(async move |cx| {
122                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
123                        eprintln!("Authentication failed: {}", e);
124                    }
125                })
126                .shared()
127            })
128            .clone()
129            .await;
130    }
131
132    let ep_store = cx
133        .update(|cx| EditPredictionStore::try_global(cx))
134        .context("EditPredictionStore not initialized")?;
135
136    ep_store.update(&mut cx, |store, _cx| {
137        let model = match provider {
138            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
139            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
140            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
141            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
142            PredictionProvider::Teacher(..)
143            | PredictionProvider::TeacherMultiRegion(..)
144            | PredictionProvider::TeacherNonBatching(..)
145            | PredictionProvider::TeacherMultiRegionNonBatching(..)
146            | PredictionProvider::Repair
147            | PredictionProvider::Baseten(_) => {
148                unreachable!()
149            }
150        };
151        store.set_edit_prediction_model(model);
152
153        // If user specified a non-default Zeta2 version, configure raw endpoint.
154        // ZED_ZETA_MODEL env var is optional.
155        if let PredictionProvider::Zeta2(format) = provider {
156            if format != ZetaFormat::default() {
157                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
158                let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
159                store.set_zeta2_raw_config(Zeta2RawConfig {
160                    model_id,
161                    environment,
162                    format,
163                });
164            }
165        }
166    });
167    step_progress.set_substatus("configuring model");
168    let state = example.state.as_ref().context("state must be set")?;
169    let run_dir = RUN_DIR.join(&example.spec.name);
170
171    let updated_example = Arc::new(Mutex::new(example.clone()));
172    let current_run_ix = Arc::new(AtomicUsize::new(0));
173
174    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
175    let debug_task = cx.background_spawn({
176        let updated_example = updated_example.clone();
177        let current_run_ix = current_run_ix.clone();
178        let run_dir = run_dir.clone();
179        async move {
180            while let Some(event) = debug_rx.next().await {
181                let run_ix = current_run_ix.load(SeqCst);
182                let mut updated_example = updated_example.lock().unwrap();
183
184                let run_dir = if repetition_count > 1 {
185                    run_dir.join(format!("{:03}", run_ix))
186                } else {
187                    run_dir.clone()
188                };
189
190                match event {
191                    DebugEvent::EditPredictionStarted(request) => {
192                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
193
194                        if let Some(prompt) = request.prompt {
195                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
196                            if matches!(provider, PredictionProvider::Zeta2(_)) {
197                                updated_example.prompt.get_or_insert(ExamplePrompt {
198                                    input: prompt,
199                                    expected_output: String::new(),
200                                    rejected_output: None,
201                                    provider,
202                                    prefill: None,
203                                });
204                            }
205                        }
206                    }
207                    DebugEvent::EditPredictionFinished(request) => {
208                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
209
210                        if let Some(output) = request.model_output {
211                            fs::write(run_dir.join("prediction_response.md"), &output)?;
212                            updated_example
213                                .predictions
214                                .last_mut()
215                                .unwrap()
216                                .actual_output = output;
217                        }
218                        if run_ix >= repetition_count {
219                            break;
220                        }
221                    }
222                    _ => {}
223                }
224            }
225            anyhow::Ok(())
226        }
227    });
228
229    for ix in 0..repetition_count {
230        current_run_ix.store(ix, SeqCst);
231        let run_dir = if repetition_count > 1 {
232            run_dir.join(format!("{:03}", ix))
233        } else {
234            run_dir.clone()
235        };
236
237        if repetition_count > 1 {
238            step_progress.set_substatus(format!(
239                "running prediction {}/{}",
240                ix + 1,
241                repetition_count
242            ));
243        } else {
244            step_progress.set_substatus("running prediction");
245        }
246
247        fs::create_dir_all(&run_dir)?;
248        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
249            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
250        }
251        #[cfg(unix)]
252        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
253        #[cfg(windows)]
254        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
255
256        updated_example
257            .lock()
258            .unwrap()
259            .predictions
260            .push(ExamplePrediction {
261                actual_patch: None,
262                actual_output: String::new(),
263                actual_cursor: None,
264                error: None,
265                provider,
266                cumulative_logprob: None,
267                avg_logprob: None,
268            });
269
270        step_progress.set_substatus("requesting prediction");
271        let prediction = ep_store
272            .update(&mut cx, |store, cx| {
273                store.request_prediction(
274                    &state.project,
275                    &state.buffer,
276                    state.cursor_position,
277                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
278                    cx,
279                )
280            })
281            .await?;
282
283        let actual_patch = prediction.and_then(|prediction| {
284            let prediction = prediction.prediction.ok()?;
285            prediction
286                .edit_preview
287                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
288        });
289
290        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
291
292        updated_example
293            .lock()
294            .unwrap()
295            .predictions
296            .last_mut()
297            .unwrap()
298            .actual_patch = actual_patch;
299
300        if ix == repetition_count - 1 {
301            let (info, style) = if has_prediction {
302                ("predicted", InfoStyle::Normal)
303            } else {
304                ("no prediction", InfoStyle::Warning)
305            };
306            step_progress.set_info(info, style);
307        }
308    }
309
310    ep_store.update(&mut cx, |store, _| {
311        store.remove_project(&state.project);
312    });
313    debug_task.await?;
314
315    *example = Arc::into_inner(updated_example)
316        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
317        .into_inner()
318        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
319    Ok(())
320}
321
322async fn predict_teacher(
323    example: &mut Example,
324    backend: TeacherBackend,
325    batched: bool,
326    repetition_count: usize,
327    cache_only: bool,
328    step_progress: &crate::progress::StepProgress,
329) -> anyhow::Result<()> {
330    match backend {
331        TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
332            predict_anthropic(
333                example,
334                backend,
335                batched,
336                repetition_count,
337                cache_only,
338                step_progress,
339            )
340            .await
341        }
342        TeacherBackend::Gpt52 => {
343            predict_openai(
344                example,
345                backend,
346                batched,
347                repetition_count,
348                cache_only,
349                step_progress,
350            )
351            .await
352        }
353    }
354}
355
356async fn predict_anthropic(
357    example: &mut Example,
358    backend: TeacherBackend,
359    batched: bool,
360    repetition_count: usize,
361    cache_only: bool,
362    step_progress: &crate::progress::StepProgress,
363) -> anyhow::Result<()> {
364    let llm_model_name = backend.model_name();
365    let max_tokens = 16384;
366    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
367        let client = if batched {
368            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
369        } else {
370            AnthropicClient::plain()
371        };
372        client.expect("Failed to create Anthropic client")
373    });
374
375    let prompt = example.prompt.as_ref().context("Prompt is required")?;
376
377    for ix in 0..repetition_count {
378        if repetition_count > 1 {
379            step_progress.set_substatus(format!(
380                "running prediction {}/{}",
381                ix + 1,
382                repetition_count
383            ));
384        } else {
385            step_progress.set_substatus("running prediction");
386        }
387
388        let messages = vec![anthropic::Message {
389            role: anthropic::Role::User,
390            content: vec![anthropic::RequestContent::Text {
391                text: prompt.input.clone(),
392                cache_control: None,
393            }],
394        }];
395
396        let seed = if repetition_count > 1 { Some(ix) } else { None };
397        let Some(response) = llm_client
398            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
399            .await?
400        else {
401            // Request stashed for batched processing
402            continue;
403        };
404
405        let actual_output = response
406            .content
407            .into_iter()
408            .filter_map(|content| match content {
409                anthropic::ResponseContent::Text { text } => Some(text),
410                _ => None,
411            })
412            .collect::<Vec<String>>()
413            .join("\n");
414
415        let parser_provider = if batched {
416            example
417                .prompt
418                .as_ref()
419                .map(|prompt| prompt.provider)
420                .unwrap_or(PredictionProvider::Teacher(backend))
421        } else {
422            match example.prompt.as_ref().map(|prompt| prompt.provider) {
423                Some(PredictionProvider::TeacherMultiRegion(_))
424                | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
425                    PredictionProvider::TeacherMultiRegionNonBatching(backend)
426                }
427                _ => PredictionProvider::TeacherNonBatching(backend),
428            }
429        };
430
431        let (actual_patch, actual_cursor) = match parser_provider {
432            PredictionProvider::TeacherMultiRegion(_)
433            | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
434                TeacherMultiRegionPrompt::parse(example, &actual_output)?
435            }
436            _ => TeacherPrompt::parse(example, &actual_output)?,
437        };
438
439        let prediction = ExamplePrediction {
440            actual_patch: Some(actual_patch),
441            actual_output,
442            actual_cursor,
443            error: None,
444            provider: if batched {
445                match example.prompt.as_ref().map(|prompt| prompt.provider) {
446                    Some(PredictionProvider::TeacherMultiRegion(_)) => {
447                        PredictionProvider::TeacherMultiRegion(backend)
448                    }
449                    _ => PredictionProvider::Teacher(backend),
450                }
451            } else {
452                match example.prompt.as_ref().map(|prompt| prompt.provider) {
453                    Some(PredictionProvider::TeacherMultiRegion(_))
454                    | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
455                        PredictionProvider::TeacherMultiRegionNonBatching(backend)
456                    }
457                    _ => PredictionProvider::TeacherNonBatching(backend),
458                }
459            },
460            cumulative_logprob: None,
461            avg_logprob: None,
462        };
463
464        example.predictions.push(prediction);
465    }
466    Ok(())
467}
468
469async fn predict_openai(
470    example: &mut Example,
471    backend: TeacherBackend,
472    batched: bool,
473    repetition_count: usize,
474    cache_only: bool,
475    step_progress: &crate::progress::StepProgress,
476) -> anyhow::Result<()> {
477    let llm_model_name = backend.model_name();
478    let max_tokens = 16384;
479    let llm_client = OPENAI_CLIENT.get_or_init(|| {
480        let client = if batched {
481            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
482        } else {
483            OpenAiClient::plain()
484        };
485        client.expect("Failed to create OpenAI client")
486    });
487
488    let prompt = example.prompt.as_ref().context("Prompt is required")?;
489
490    for ix in 0..repetition_count {
491        if repetition_count > 1 {
492            step_progress.set_substatus(format!(
493                "running prediction {}/{}",
494                ix + 1,
495                repetition_count
496            ));
497        } else {
498            step_progress.set_substatus("running prediction");
499        }
500
501        let messages = vec![open_ai::RequestMessage::User {
502            content: open_ai::MessageContent::Plain(prompt.input.clone()),
503        }];
504
505        let seed = if repetition_count > 1 { Some(ix) } else { None };
506        let Some(response) = llm_client
507            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
508            .await?
509        else {
510            // Request stashed for batched processing
511            continue;
512        };
513
514        let actual_output = response
515            .choices
516            .into_iter()
517            .filter_map(|choice| match choice.message {
518                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
519                    open_ai::MessageContent::Plain(text) => text,
520                    open_ai::MessageContent::Multipart(parts) => parts
521                        .into_iter()
522                        .filter_map(|p| match p {
523                            open_ai::MessagePart::Text { text } => Some(text),
524                            _ => None,
525                        })
526                        .collect::<Vec<_>>()
527                        .join(""),
528                }),
529                _ => None,
530            })
531            .collect::<Vec<String>>()
532            .join("\n");
533
534        let parser_provider = if batched {
535            example
536                .prompt
537                .as_ref()
538                .map(|prompt| prompt.provider)
539                .unwrap_or(PredictionProvider::Teacher(backend))
540        } else {
541            match example.prompt.as_ref().map(|prompt| prompt.provider) {
542                Some(PredictionProvider::TeacherMultiRegion(_))
543                | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
544                    PredictionProvider::TeacherMultiRegionNonBatching(backend)
545                }
546                _ => PredictionProvider::TeacherNonBatching(backend),
547            }
548        };
549
550        let (actual_patch, actual_cursor) = match parser_provider {
551            PredictionProvider::TeacherMultiRegion(_)
552            | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
553                TeacherMultiRegionPrompt::parse(example, &actual_output)?
554            }
555            _ => TeacherPrompt::parse(example, &actual_output)?,
556        };
557
558        let prediction = ExamplePrediction {
559            actual_patch: Some(actual_patch),
560            actual_output,
561            actual_cursor,
562            error: None,
563            provider: if batched {
564                match example.prompt.as_ref().map(|prompt| prompt.provider) {
565                    Some(PredictionProvider::TeacherMultiRegion(_)) => {
566                        PredictionProvider::TeacherMultiRegion(backend)
567                    }
568                    _ => PredictionProvider::Teacher(backend),
569                }
570            } else {
571                match example.prompt.as_ref().map(|prompt| prompt.provider) {
572                    Some(PredictionProvider::TeacherMultiRegion(_))
573                    | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
574                        PredictionProvider::TeacherMultiRegionNonBatching(backend)
575                    }
576                    _ => PredictionProvider::TeacherNonBatching(backend),
577                }
578            },
579            cumulative_logprob: None,
580            avg_logprob: None,
581        };
582
583        example.predictions.push(prediction);
584    }
585    Ok(())
586}
587
588pub async fn predict_baseten(
589    example: &mut Example,
590    format: ZetaFormat,
591    step_progress: &StepProgress,
592) -> anyhow::Result<()> {
593    let model_id =
594        std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
595
596    let api_key =
597        std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
598
599    let prompt = example.prompt.as_ref().context("Prompt is required")?;
600    let prompt_text = prompt.input.clone();
601    let prefill = prompt.prefill.clone().unwrap_or_default();
602
603    step_progress.set_substatus("running prediction via baseten");
604
605    let environment: String = <&'static str>::from(&format).to_lowercase();
606    let url = format!(
607        "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
608    );
609
610    let request_body = RawCompletionRequest {
611        model: model_id,
612        prompt: prompt_text.clone(),
613        max_tokens: Some(2048),
614        temperature: Some(0.),
615        stop: vec![],
616        environment: None,
617    };
618
619    let body_bytes =
620        serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
621
622    let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
623    let request = http_client::Request::builder()
624        .method(Method::POST)
625        .uri(&url)
626        .header("Content-Type", "application/json")
627        .header("Authorization", format!("Api-Key {api_key}"))
628        .body(AsyncBody::from(body_bytes))?;
629
630    let mut response = http_client.send(request).await?;
631    let status = response.status();
632
633    let mut body = String::new();
634    response
635        .body_mut()
636        .read_to_string(&mut body)
637        .await
638        .context("Failed to read Baseten response body")?;
639
640    if !status.is_success() {
641        anyhow::bail!("Baseten API returned {status}: {body}");
642    }
643
644    let completion: RawCompletionResponse =
645        serde_json::from_str(&body).context("Failed to parse Baseten response")?;
646
647    let actual_output = completion
648        .choices
649        .into_iter()
650        .next()
651        .map(|choice| choice.text)
652        .unwrap_or_default();
653
654    let actual_output = format!("{prefill}{actual_output}");
655
656    let (actual_patch, actual_cursor) =
657        parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
658
659    let prediction = ExamplePrediction {
660        actual_patch: Some(actual_patch),
661        actual_output,
662        actual_cursor,
663        error: None,
664        provider: PredictionProvider::Baseten(format),
665        cumulative_logprob: None,
666        avg_logprob: None,
667    };
668
669    example.predictions.push(prediction);
670    Ok(())
671}
672
673pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
674    match provider {
675        Some(PredictionProvider::Teacher(backend))
676        | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
677            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
678                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
679                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
680                        .expect("Failed to create Anthropic client")
681                });
682                llm_client
683                    .sync_batches()
684                    .await
685                    .context("Failed to sync Anthropic batches")?;
686            }
687            TeacherBackend::Gpt52 => {
688                let llm_client = OPENAI_CLIENT.get_or_init(|| {
689                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
690                        .expect("Failed to create OpenAI client")
691                });
692                llm_client
693                    .sync_batches()
694                    .await
695                    .context("Failed to sync OpenAI batches")?;
696            }
697        },
698        _ => (),
699    };
700    Ok(())
701}
702
703pub async fn reprocess_after_batch_wait(
704    examples: &mut [Example],
705    args: &PredictArgs,
706) -> anyhow::Result<()> {
707    let Some(PredictionProvider::Teacher(backend)) = args.provider else {
708        return Ok(());
709    };
710
711    let mut reprocessed = 0;
712    for example in examples.iter_mut() {
713        let has_prediction = example
714            .predictions
715            .iter()
716            .any(|p| p.actual_patch.is_some() || !p.actual_output.is_empty());
717        if has_prediction || example.prompt.is_none() {
718            continue;
719        }
720
721        let example_progress = Progress::global().start_group(&example.spec.name);
722        let step_progress = example_progress.start(Step::Predict);
723        predict_teacher(
724            example,
725            backend,
726            true,
727            args.repetitions,
728            false,
729            &step_progress,
730        )
731        .await?;
732        reprocessed += 1;
733    }
734
735    if reprocessed > 0 {
736        eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
737    }
738
739    Ok(())
740}
741
742pub async fn wait_for_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
743    let poll_interval = std::time::Duration::from_secs(30);
744
745    loop {
746        let pending = pending_batch_count(provider)?;
747        if pending == 0 {
748            break;
749        }
750
751        eprintln!(
752            "Waiting for {} pending batch request(s) to complete... (polling every {}s)",
753            pending,
754            poll_interval.as_secs()
755        );
756        std::thread::sleep(poll_interval);
757
758        sync_batches(provider).await?;
759    }
760
761    Ok(())
762}
763
764fn pending_batch_count(provider: Option<&PredictionProvider>) -> anyhow::Result<usize> {
765    match provider {
766        Some(PredictionProvider::Teacher(backend)) => match backend {
767            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
768                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
769                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
770                        .expect("Failed to create Anthropic client")
771                });
772                llm_client.pending_batch_count()
773            }
774            TeacherBackend::Gpt52 => {
775                let llm_client = OPENAI_CLIENT.get_or_init(|| {
776                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
777                        .expect("Failed to create OpenAI client")
778                });
779                llm_client.pending_batch_count()
780            }
781        },
782        _ => Ok(0),
783    }
784}