1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 parse_output::parse_prediction_output,
10 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
11 progress::{ExampleProgress, InfoStyle, Progress, Step, StepProgress},
12 retrieve_context::run_context_retrieval,
13};
14use anyhow::Context as _;
15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
18use gpui::{AppContext as _, AsyncApp, Task};
19use http_client::{AsyncBody, HttpClient, Method};
20use reqwest_client::ReqwestClient;
21use std::{
22 fs,
23 sync::{
24 Arc, Mutex, OnceLock,
25 atomic::{AtomicUsize, Ordering::SeqCst},
26 },
27};
28use zeta_prompt::ZetaFormat;
29
30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
32
33pub async fn run_prediction(
34 example: &mut Example,
35 args: &PredictArgs,
36 app_state: Arc<EpAppState>,
37 example_progress: &ExampleProgress,
38 mut cx: AsyncApp,
39) -> anyhow::Result<()> {
40 let repetition_count = args.repetitions;
41
42 if let Some(existing_prediction) = example.predictions.first() {
43 let has_prediction = existing_prediction.actual_patch.is_some()
44 || !existing_prediction.actual_output.is_empty();
45 if has_prediction {
46 match args.provider {
47 None => return Ok(()),
48 Some(provider) if existing_prediction.provider == provider => return Ok(()),
49 Some(_) => example.predictions.clear(),
50 }
51 }
52 }
53
54 let Some(provider) = args.provider else {
55 anyhow::bail!(
56 "No existing predictions found. Use --provider to specify which model to use for prediction."
57 );
58 };
59
60 if matches!(
61 provider,
62 PredictionProvider::TeacherMultiRegion(..)
63 | PredictionProvider::TeacherMultiRegionNonBatching(..)
64 ) {
65 anyhow::bail!("Teacher multi-region providers are not supported for prediction.");
66 }
67
68 if let PredictionProvider::Teacher(backend, _)
69 | PredictionProvider::TeacherNonBatching(backend, _) = provider
70 {
71 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
72 run_format_prompt(
73 example,
74 &FormatPromptArgs { provider },
75 app_state.clone(),
76 example_progress,
77 cx,
78 )
79 .await?;
80
81 let step_progress = example_progress.start(Step::Predict);
82 let batched = matches!(
83 provider,
84 PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
85 );
86 return predict_teacher(
87 example,
88 backend,
89 batched,
90 repetition_count,
91 args.cache_only,
92 &step_progress,
93 )
94 .await;
95 }
96
97 if let PredictionProvider::Baseten(format) = provider {
98 run_format_prompt(
99 example,
100 &FormatPromptArgs {
101 provider: PredictionProvider::Zeta2(format),
102 },
103 app_state.clone(),
104 example_progress,
105 cx,
106 )
107 .await?;
108
109 let step_progress = example_progress.start(Step::Predict);
110 return predict_baseten(example, format, &step_progress).await;
111 }
112
113 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
114 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
115
116 let step_progress = example_progress.start(Step::Predict);
117
118 if matches!(
119 provider,
120 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
121 ) {
122 step_progress.set_substatus("authenticating");
123 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
124 AUTHENTICATED
125 .get_or_init(|| {
126 let client = app_state.client.clone();
127 cx.spawn(async move |cx| {
128 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
129 eprintln!("Authentication failed: {}", e);
130 }
131 })
132 .shared()
133 })
134 .clone()
135 .await;
136 }
137
138 let ep_store = cx
139 .update(|cx| EditPredictionStore::try_global(cx))
140 .context("EditPredictionStore not initialized")?;
141
142 ep_store.update(&mut cx, |store, _cx| {
143 let model = match provider {
144 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
145 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
146 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
147 PredictionProvider::Teacher(..)
148 | PredictionProvider::TeacherMultiRegion(..)
149 | PredictionProvider::TeacherNonBatching(..)
150 | PredictionProvider::TeacherMultiRegionNonBatching(..)
151 | PredictionProvider::Repair
152 | PredictionProvider::Baseten(_) => {
153 unreachable!()
154 }
155 };
156 store.set_edit_prediction_model(model);
157
158 // If user specified a non-default Zeta2 version, configure raw endpoint.
159 // ZED_ZETA_MODEL env var is optional.
160 if let PredictionProvider::Zeta2(format) = provider {
161 if format != ZetaFormat::default() {
162 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
163 let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
164 store.set_zeta2_raw_config(Zeta2RawConfig {
165 model_id,
166 environment,
167 format,
168 });
169 }
170 }
171 });
172 step_progress.set_substatus("configuring model");
173 let state = example.state.as_ref().context("state must be set")?;
174 let run_dir = RUN_DIR.join(&example.spec.name);
175
176 let updated_example = Arc::new(Mutex::new(example.clone()));
177 let current_run_ix = Arc::new(AtomicUsize::new(0));
178
179 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
180 let debug_task = cx.background_spawn({
181 let updated_example = updated_example.clone();
182 let current_run_ix = current_run_ix.clone();
183 let run_dir = run_dir.clone();
184 async move {
185 while let Some(event) = debug_rx.next().await {
186 let run_ix = current_run_ix.load(SeqCst);
187 let mut updated_example = updated_example.lock().unwrap();
188
189 let run_dir = if repetition_count > 1 {
190 run_dir.join(format!("{:03}", run_ix))
191 } else {
192 run_dir.clone()
193 };
194
195 match event {
196 DebugEvent::EditPredictionStarted(request) => {
197 assert_eq!(updated_example.predictions.len(), run_ix + 1);
198
199 if let Some(prompt) = request.prompt {
200 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
201 if matches!(provider, PredictionProvider::Zeta2(_)) {
202 updated_example.prompt.get_or_insert(ExamplePrompt {
203 input: prompt,
204 expected_output: None,
205 rejected_output: None,
206 provider,
207 prefill: None,
208 });
209 }
210 }
211 }
212 DebugEvent::EditPredictionFinished(request) => {
213 assert_eq!(updated_example.predictions.len(), run_ix + 1);
214
215 if let Some(output) = request.model_output {
216 fs::write(run_dir.join("prediction_response.md"), &output)?;
217 updated_example
218 .predictions
219 .last_mut()
220 .unwrap()
221 .actual_output = output;
222 }
223 if run_ix >= repetition_count {
224 break;
225 }
226 }
227 _ => {}
228 }
229 }
230 anyhow::Ok(())
231 }
232 });
233
234 for ix in 0..repetition_count {
235 current_run_ix.store(ix, SeqCst);
236 let run_dir = if repetition_count > 1 {
237 run_dir.join(format!("{:03}", ix))
238 } else {
239 run_dir.clone()
240 };
241
242 if repetition_count > 1 {
243 step_progress.set_substatus(format!(
244 "running prediction {}/{}",
245 ix + 1,
246 repetition_count
247 ));
248 } else {
249 step_progress.set_substatus("running prediction");
250 }
251
252 fs::create_dir_all(&run_dir)?;
253 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
254 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
255 }
256 #[cfg(unix)]
257 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
258 #[cfg(windows)]
259 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
260
261 updated_example
262 .lock()
263 .unwrap()
264 .predictions
265 .push(ExamplePrediction {
266 actual_patch: None,
267 actual_output: String::new(),
268 actual_cursor: None,
269 error: None,
270 provider,
271 cumulative_logprob: None,
272 avg_logprob: None,
273 });
274
275 step_progress.set_substatus("requesting prediction");
276 let prediction = ep_store
277 .update(&mut cx, |store, cx| {
278 store.request_prediction(
279 &state.project,
280 &state.buffer,
281 state.cursor_position,
282 cloud_llm_client::PredictEditsRequestTrigger::Cli,
283 cx,
284 )
285 })
286 .await?;
287
288 let actual_patch = prediction.and_then(|prediction| {
289 let prediction = prediction.prediction.ok()?;
290 prediction
291 .edit_preview
292 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
293 });
294
295 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
296
297 updated_example
298 .lock()
299 .unwrap()
300 .predictions
301 .last_mut()
302 .unwrap()
303 .actual_patch = actual_patch;
304
305 if ix == repetition_count - 1 {
306 let (info, style) = if has_prediction {
307 ("predicted", InfoStyle::Normal)
308 } else {
309 ("no prediction", InfoStyle::Warning)
310 };
311 step_progress.set_info(info, style);
312 }
313 }
314
315 ep_store.update(&mut cx, |store, _| {
316 store.remove_project(&state.project);
317 });
318 debug_task.await?;
319
320 *example = Arc::into_inner(updated_example)
321 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
322 .into_inner()
323 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
324 Ok(())
325}
326
327async fn predict_teacher(
328 example: &mut Example,
329 backend: TeacherBackend,
330 batched: bool,
331 repetition_count: usize,
332 cache_only: bool,
333 step_progress: &crate::progress::StepProgress,
334) -> anyhow::Result<()> {
335 match backend {
336 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
337 predict_anthropic(
338 example,
339 backend,
340 batched,
341 repetition_count,
342 cache_only,
343 step_progress,
344 )
345 .await
346 }
347 TeacherBackend::Gpt52 => {
348 predict_openai(
349 example,
350 backend,
351 batched,
352 repetition_count,
353 cache_only,
354 step_progress,
355 )
356 .await
357 }
358 }
359}
360
361async fn predict_anthropic(
362 example: &mut Example,
363 backend: TeacherBackend,
364 batched: bool,
365 repetition_count: usize,
366 cache_only: bool,
367 step_progress: &crate::progress::StepProgress,
368) -> anyhow::Result<()> {
369 let llm_model_name = backend.model_name();
370 let max_tokens = 16384;
371 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
372 let client = if batched {
373 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
374 } else {
375 AnthropicClient::plain()
376 };
377 client.expect("Failed to create Anthropic client")
378 });
379
380 let prompt = example.prompt.as_ref().context("Prompt is required")?;
381
382 for ix in 0..repetition_count {
383 if repetition_count > 1 {
384 step_progress.set_substatus(format!(
385 "running prediction {}/{}",
386 ix + 1,
387 repetition_count
388 ));
389 } else {
390 step_progress.set_substatus("running prediction");
391 }
392
393 let messages = vec![anthropic::Message {
394 role: anthropic::Role::User,
395 content: vec![anthropic::RequestContent::Text {
396 text: prompt.input.clone(),
397 cache_control: None,
398 }],
399 }];
400
401 let seed = if repetition_count > 1 { Some(ix) } else { None };
402 let Some(response) = llm_client
403 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
404 .await?
405 else {
406 // Request stashed for batched processing
407 continue;
408 };
409
410 let actual_output = response
411 .content
412 .into_iter()
413 .filter_map(|content| match content {
414 anthropic::ResponseContent::Text { text } => Some(text),
415 _ => None,
416 })
417 .collect::<Vec<String>>()
418 .join("\n");
419
420 let parser_provider = if batched {
421 example
422 .prompt
423 .as_ref()
424 .map(|prompt| prompt.provider)
425 .unwrap_or(PredictionProvider::Teacher(backend, ZetaFormat::default()))
426 } else {
427 match example.prompt.as_ref().map(|prompt| prompt.provider) {
428 Some(PredictionProvider::TeacherMultiRegion(_))
429 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
430 PredictionProvider::TeacherMultiRegionNonBatching(backend)
431 }
432 _ => PredictionProvider::TeacherNonBatching(backend, ZetaFormat::default()),
433 }
434 };
435
436 let (actual_patch, actual_cursor) = match parser_provider {
437 PredictionProvider::TeacherMultiRegion(_)
438 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
439 TeacherMultiRegionPrompt::parse(example, &actual_output)?
440 }
441 _ => TeacherPrompt::parse(example, &actual_output)?,
442 };
443
444 let prediction = ExamplePrediction {
445 actual_patch: Some(actual_patch),
446 actual_output,
447 actual_cursor,
448 error: None,
449 provider: if batched {
450 match example.prompt.as_ref().map(|prompt| prompt.provider) {
451 Some(PredictionProvider::TeacherMultiRegion(_)) => {
452 PredictionProvider::TeacherMultiRegion(backend)
453 }
454 _ => PredictionProvider::Teacher(backend, ZetaFormat::default()),
455 }
456 } else {
457 match example.prompt.as_ref().map(|prompt| prompt.provider) {
458 Some(PredictionProvider::TeacherMultiRegion(_))
459 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
460 PredictionProvider::TeacherMultiRegionNonBatching(backend)
461 }
462 _ => PredictionProvider::TeacherNonBatching(backend, ZetaFormat::default()),
463 }
464 },
465 cumulative_logprob: None,
466 avg_logprob: None,
467 };
468
469 example.predictions.push(prediction);
470 }
471 Ok(())
472}
473
474async fn predict_openai(
475 example: &mut Example,
476 backend: TeacherBackend,
477 batched: bool,
478 repetition_count: usize,
479 cache_only: bool,
480 step_progress: &crate::progress::StepProgress,
481) -> anyhow::Result<()> {
482 let llm_model_name = backend.model_name();
483 let max_tokens = 16384;
484 let llm_client = OPENAI_CLIENT.get_or_init(|| {
485 let client = if batched {
486 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
487 } else {
488 OpenAiClient::plain()
489 };
490 client.expect("Failed to create OpenAI client")
491 });
492
493 let prompt = example.prompt.as_ref().context("Prompt is required")?;
494
495 for ix in 0..repetition_count {
496 if repetition_count > 1 {
497 step_progress.set_substatus(format!(
498 "running prediction {}/{}",
499 ix + 1,
500 repetition_count
501 ));
502 } else {
503 step_progress.set_substatus("running prediction");
504 }
505
506 let messages = vec![open_ai::RequestMessage::User {
507 content: open_ai::MessageContent::Plain(prompt.input.clone()),
508 }];
509
510 let seed = if repetition_count > 1 { Some(ix) } else { None };
511 let Some(response) = llm_client
512 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
513 .await?
514 else {
515 // Request stashed for batched processing
516 continue;
517 };
518
519 let actual_output = response
520 .choices
521 .into_iter()
522 .filter_map(|choice| match choice.message {
523 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
524 open_ai::MessageContent::Plain(text) => text,
525 open_ai::MessageContent::Multipart(parts) => parts
526 .into_iter()
527 .filter_map(|p| match p {
528 open_ai::MessagePart::Text { text } => Some(text),
529 _ => None,
530 })
531 .collect::<Vec<_>>()
532 .join(""),
533 }),
534 _ => None,
535 })
536 .collect::<Vec<String>>()
537 .join("\n");
538
539 let parser_provider = if batched {
540 example
541 .prompt
542 .as_ref()
543 .map(|prompt| prompt.provider)
544 .unwrap_or(PredictionProvider::Teacher(backend, ZetaFormat::default()))
545 } else {
546 match example.prompt.as_ref().map(|prompt| prompt.provider) {
547 Some(PredictionProvider::TeacherMultiRegion(_))
548 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
549 PredictionProvider::TeacherMultiRegionNonBatching(backend)
550 }
551 _ => PredictionProvider::TeacherNonBatching(backend, ZetaFormat::default()),
552 }
553 };
554
555 let (actual_patch, actual_cursor) = match parser_provider {
556 PredictionProvider::TeacherMultiRegion(_)
557 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
558 TeacherMultiRegionPrompt::parse(example, &actual_output)?
559 }
560 _ => TeacherPrompt::parse(example, &actual_output)?,
561 };
562
563 let prediction = ExamplePrediction {
564 actual_patch: Some(actual_patch),
565 actual_output,
566 actual_cursor,
567 error: None,
568 provider: if batched {
569 match example.prompt.as_ref().map(|prompt| prompt.provider) {
570 Some(PredictionProvider::TeacherMultiRegion(_)) => {
571 PredictionProvider::TeacherMultiRegion(backend)
572 }
573 _ => PredictionProvider::Teacher(backend, ZetaFormat::default()),
574 }
575 } else {
576 match example.prompt.as_ref().map(|prompt| prompt.provider) {
577 Some(PredictionProvider::TeacherMultiRegion(_))
578 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
579 PredictionProvider::TeacherMultiRegionNonBatching(backend)
580 }
581 _ => PredictionProvider::TeacherNonBatching(backend, ZetaFormat::default()),
582 }
583 },
584 cumulative_logprob: None,
585 avg_logprob: None,
586 };
587
588 example.predictions.push(prediction);
589 }
590 Ok(())
591}
592
593pub async fn predict_baseten(
594 example: &mut Example,
595 format: ZetaFormat,
596 step_progress: &StepProgress,
597) -> anyhow::Result<()> {
598 let model_id =
599 std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
600
601 let api_key =
602 std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
603
604 let prompt = example.prompt.as_ref().context("Prompt is required")?;
605 let prompt_text = prompt.input.clone();
606 let prefill = prompt.prefill.clone().unwrap_or_default();
607
608 step_progress.set_substatus("running prediction via baseten");
609
610 let environment: String = <&'static str>::from(&format).to_lowercase();
611 let url = format!(
612 "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
613 );
614
615 let request_body = RawCompletionRequest {
616 model: model_id,
617 prompt: prompt_text.clone(),
618 max_tokens: Some(2048),
619 temperature: Some(0.),
620 stop: vec![],
621 environment: None,
622 };
623
624 let body_bytes =
625 serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
626
627 let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
628 let request = http_client::Request::builder()
629 .method(Method::POST)
630 .uri(&url)
631 .header("Content-Type", "application/json")
632 .header("Authorization", format!("Api-Key {api_key}"))
633 .body(AsyncBody::from(body_bytes))?;
634
635 let mut response = http_client.send(request).await?;
636 let status = response.status();
637
638 let mut body = String::new();
639 response
640 .body_mut()
641 .read_to_string(&mut body)
642 .await
643 .context("Failed to read Baseten response body")?;
644
645 if !status.is_success() {
646 anyhow::bail!("Baseten API returned {status}: {body}");
647 }
648
649 let completion: RawCompletionResponse =
650 serde_json::from_str(&body).context("Failed to parse Baseten response")?;
651
652 let actual_output = completion
653 .choices
654 .into_iter()
655 .next()
656 .map(|choice| choice.text)
657 .unwrap_or_default();
658
659 let actual_output = format!("{prefill}{actual_output}");
660
661 let (actual_patch, actual_cursor) =
662 parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
663
664 let prediction = ExamplePrediction {
665 actual_patch: Some(actual_patch),
666 actual_output,
667 actual_cursor,
668 error: None,
669 provider: PredictionProvider::Baseten(format),
670 cumulative_logprob: None,
671 avg_logprob: None,
672 };
673
674 example.predictions.push(prediction);
675 Ok(())
676}
677
678pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
679 match provider {
680 Some(PredictionProvider::Teacher(backend, _))
681 | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
682 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
683 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
684 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
685 .expect("Failed to create Anthropic client")
686 });
687 llm_client
688 .sync_batches()
689 .await
690 .context("Failed to sync Anthropic batches")?;
691 }
692 TeacherBackend::Gpt52 => {
693 let llm_client = OPENAI_CLIENT.get_or_init(|| {
694 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
695 .expect("Failed to create OpenAI client")
696 });
697 llm_client
698 .sync_batches()
699 .await
700 .context("Failed to sync OpenAI batches")?;
701 }
702 },
703 _ => (),
704 };
705 Ok(())
706}
707
708pub async fn reprocess_after_batch_wait(
709 examples: &mut [Example],
710 args: &PredictArgs,
711) -> anyhow::Result<()> {
712 let Some(PredictionProvider::Teacher(backend, _)) = args.provider else {
713 return Ok(());
714 };
715
716 let mut reprocessed = 0;
717 for example in examples.iter_mut() {
718 let has_prediction = example
719 .predictions
720 .iter()
721 .any(|p| p.actual_patch.is_some() || !p.actual_output.is_empty());
722 if has_prediction || example.prompt.is_none() {
723 continue;
724 }
725
726 let example_progress = Progress::global().start_group(&example.spec.name);
727 let step_progress = example_progress.start(Step::Predict);
728 predict_teacher(
729 example,
730 backend,
731 true,
732 args.repetitions,
733 false,
734 &step_progress,
735 )
736 .await?;
737 reprocessed += 1;
738 }
739
740 if reprocessed > 0 {
741 eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
742 }
743
744 Ok(())
745}
746
747pub async fn wait_for_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
748 let poll_interval = std::time::Duration::from_secs(30);
749
750 loop {
751 let pending = pending_batch_count(provider)?;
752 if pending == 0 {
753 break;
754 }
755
756 eprintln!(
757 "Waiting for {} pending batch request(s) to complete... (polling every {}s)",
758 pending,
759 poll_interval.as_secs()
760 );
761 std::thread::sleep(poll_interval);
762
763 sync_batches(provider).await?;
764 }
765
766 Ok(())
767}
768
769fn pending_batch_count(provider: Option<&PredictionProvider>) -> anyhow::Result<usize> {
770 match provider {
771 Some(PredictionProvider::Teacher(backend, _)) => match backend {
772 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
773 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
774 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
775 .expect("Failed to create Anthropic client")
776 });
777 llm_client.pending_batch_count()
778 }
779 TeacherBackend::Gpt52 => {
780 let llm_client = OPENAI_CLIENT.get_or_init(|| {
781 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
782 .expect("Failed to create OpenAI client")
783 });
784 llm_client.pending_batch_count()
785 }
786 },
787 _ => Ok(0),
788 }
789}