1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 parse_output::parse_prediction_output,
10 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
11 progress::{ExampleProgress, InfoStyle, Progress, Step, StepProgress},
12 retrieve_context::run_context_retrieval,
13};
14use anyhow::Context as _;
15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
18use gpui::{AppContext as _, AsyncApp, Task};
19use http_client::{AsyncBody, HttpClient, Method};
20use reqwest_client::ReqwestClient;
21use std::{
22 fs,
23 sync::{
24 Arc, Mutex, OnceLock,
25 atomic::{AtomicUsize, Ordering::SeqCst},
26 },
27};
28use zeta_prompt::ZetaFormat;
29
30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
32
33pub async fn run_prediction(
34 example: &mut Example,
35 args: &PredictArgs,
36 app_state: Arc<EpAppState>,
37 example_progress: &ExampleProgress,
38 mut cx: AsyncApp,
39) -> anyhow::Result<()> {
40 let repetition_count = args.repetitions;
41
42 if let Some(existing_prediction) = example.predictions.first() {
43 let has_prediction = existing_prediction.actual_patch.is_some()
44 || !existing_prediction.actual_output.is_empty();
45 if has_prediction {
46 match args.provider {
47 None => return Ok(()),
48 Some(provider) if existing_prediction.provider == provider => return Ok(()),
49 Some(_) => example.predictions.clear(),
50 }
51 }
52 }
53
54 let Some(provider) = args.provider else {
55 anyhow::bail!(
56 "No existing predictions found. Use --provider to specify which model to use for prediction."
57 );
58 };
59
60 if let PredictionProvider::Teacher(backend)
61 | PredictionProvider::TeacherMultiRegion(backend)
62 | PredictionProvider::TeacherNonBatching(backend)
63 | PredictionProvider::TeacherMultiRegionNonBatching(backend) = provider
64 {
65 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
66 run_format_prompt(
67 example,
68 &FormatPromptArgs { provider },
69 app_state.clone(),
70 example_progress,
71 cx,
72 )
73 .await?;
74
75 let step_progress = example_progress.start(Step::Predict);
76 let batched = matches!(
77 provider,
78 PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
79 );
80 return predict_teacher(
81 example,
82 backend,
83 batched,
84 repetition_count,
85 args.cache_only,
86 &step_progress,
87 )
88 .await;
89 }
90
91 if let PredictionProvider::Baseten(format) = provider {
92 run_format_prompt(
93 example,
94 &FormatPromptArgs {
95 provider: PredictionProvider::Zeta2(format),
96 },
97 app_state.clone(),
98 example_progress,
99 cx,
100 )
101 .await?;
102
103 let step_progress = example_progress.start(Step::Predict);
104 return predict_baseten(example, format, &step_progress).await;
105 }
106
107 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
108 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
109
110 let step_progress = example_progress.start(Step::Predict);
111
112 if matches!(
113 provider,
114 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
115 ) {
116 step_progress.set_substatus("authenticating");
117 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
118 AUTHENTICATED
119 .get_or_init(|| {
120 let client = app_state.client.clone();
121 cx.spawn(async move |cx| {
122 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
123 eprintln!("Authentication failed: {}", e);
124 }
125 })
126 .shared()
127 })
128 .clone()
129 .await;
130 }
131
132 let ep_store = cx
133 .update(|cx| EditPredictionStore::try_global(cx))
134 .context("EditPredictionStore not initialized")?;
135
136 ep_store.update(&mut cx, |store, _cx| {
137 let model = match provider {
138 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
139 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
140 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
141 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
142 PredictionProvider::Teacher(..)
143 | PredictionProvider::TeacherMultiRegion(..)
144 | PredictionProvider::TeacherNonBatching(..)
145 | PredictionProvider::TeacherMultiRegionNonBatching(..)
146 | PredictionProvider::Repair
147 | PredictionProvider::Baseten(_) => {
148 unreachable!()
149 }
150 };
151 store.set_edit_prediction_model(model);
152
153 // If user specified a non-default Zeta2 version, configure raw endpoint.
154 // ZED_ZETA_MODEL env var is optional.
155 if let PredictionProvider::Zeta2(format) = provider {
156 if format != ZetaFormat::default() {
157 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
158 let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
159 store.set_zeta2_raw_config(Zeta2RawConfig {
160 model_id,
161 environment,
162 format,
163 });
164 }
165 }
166 });
167 step_progress.set_substatus("configuring model");
168 let state = example.state.as_ref().context("state must be set")?;
169 let run_dir = RUN_DIR.join(&example.spec.name);
170
171 let updated_example = Arc::new(Mutex::new(example.clone()));
172 let current_run_ix = Arc::new(AtomicUsize::new(0));
173
174 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
175 let debug_task = cx.background_spawn({
176 let updated_example = updated_example.clone();
177 let current_run_ix = current_run_ix.clone();
178 let run_dir = run_dir.clone();
179 async move {
180 while let Some(event) = debug_rx.next().await {
181 let run_ix = current_run_ix.load(SeqCst);
182 let mut updated_example = updated_example.lock().unwrap();
183
184 let run_dir = if repetition_count > 1 {
185 run_dir.join(format!("{:03}", run_ix))
186 } else {
187 run_dir.clone()
188 };
189
190 match event {
191 DebugEvent::EditPredictionStarted(request) => {
192 assert_eq!(updated_example.predictions.len(), run_ix + 1);
193
194 if let Some(prompt) = request.prompt {
195 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
196 if matches!(provider, PredictionProvider::Zeta2(_)) {
197 updated_example.prompt.get_or_insert(ExamplePrompt {
198 input: prompt,
199 expected_output: String::new(),
200 rejected_output: None,
201 provider,
202 prefill: None,
203 });
204 }
205 }
206 }
207 DebugEvent::EditPredictionFinished(request) => {
208 assert_eq!(updated_example.predictions.len(), run_ix + 1);
209
210 if let Some(output) = request.model_output {
211 fs::write(run_dir.join("prediction_response.md"), &output)?;
212 updated_example
213 .predictions
214 .last_mut()
215 .unwrap()
216 .actual_output = output;
217 }
218 if run_ix >= repetition_count {
219 break;
220 }
221 }
222 _ => {}
223 }
224 }
225 anyhow::Ok(())
226 }
227 });
228
229 for ix in 0..repetition_count {
230 current_run_ix.store(ix, SeqCst);
231 let run_dir = if repetition_count > 1 {
232 run_dir.join(format!("{:03}", ix))
233 } else {
234 run_dir.clone()
235 };
236
237 if repetition_count > 1 {
238 step_progress.set_substatus(format!(
239 "running prediction {}/{}",
240 ix + 1,
241 repetition_count
242 ));
243 } else {
244 step_progress.set_substatus("running prediction");
245 }
246
247 fs::create_dir_all(&run_dir)?;
248 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
249 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
250 }
251 #[cfg(unix)]
252 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
253 #[cfg(windows)]
254 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
255
256 updated_example
257 .lock()
258 .unwrap()
259 .predictions
260 .push(ExamplePrediction {
261 actual_patch: None,
262 actual_output: String::new(),
263 actual_cursor: None,
264 error: None,
265 provider,
266 cumulative_logprob: None,
267 avg_logprob: None,
268 });
269
270 step_progress.set_substatus("requesting prediction");
271 let prediction = ep_store
272 .update(&mut cx, |store, cx| {
273 store.request_prediction(
274 &state.project,
275 &state.buffer,
276 state.cursor_position,
277 cloud_llm_client::PredictEditsRequestTrigger::Cli,
278 cx,
279 )
280 })
281 .await?;
282
283 let actual_patch = prediction.and_then(|prediction| {
284 let prediction = prediction.prediction.ok()?;
285 prediction
286 .edit_preview
287 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
288 });
289
290 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
291
292 updated_example
293 .lock()
294 .unwrap()
295 .predictions
296 .last_mut()
297 .unwrap()
298 .actual_patch = actual_patch;
299
300 if ix == repetition_count - 1 {
301 let (info, style) = if has_prediction {
302 ("predicted", InfoStyle::Normal)
303 } else {
304 ("no prediction", InfoStyle::Warning)
305 };
306 step_progress.set_info(info, style);
307 }
308 }
309
310 ep_store.update(&mut cx, |store, _| {
311 store.remove_project(&state.project);
312 });
313 debug_task.await?;
314
315 *example = Arc::into_inner(updated_example)
316 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
317 .into_inner()
318 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
319 Ok(())
320}
321
322async fn predict_teacher(
323 example: &mut Example,
324 backend: TeacherBackend,
325 batched: bool,
326 repetition_count: usize,
327 cache_only: bool,
328 step_progress: &crate::progress::StepProgress,
329) -> anyhow::Result<()> {
330 match backend {
331 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
332 predict_anthropic(
333 example,
334 backend,
335 batched,
336 repetition_count,
337 cache_only,
338 step_progress,
339 )
340 .await
341 }
342 TeacherBackend::Gpt52 => {
343 predict_openai(
344 example,
345 backend,
346 batched,
347 repetition_count,
348 cache_only,
349 step_progress,
350 )
351 .await
352 }
353 }
354}
355
356async fn predict_anthropic(
357 example: &mut Example,
358 backend: TeacherBackend,
359 batched: bool,
360 repetition_count: usize,
361 cache_only: bool,
362 step_progress: &crate::progress::StepProgress,
363) -> anyhow::Result<()> {
364 let llm_model_name = backend.model_name();
365 let max_tokens = 16384;
366 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
367 let client = if batched {
368 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
369 } else {
370 AnthropicClient::plain()
371 };
372 client.expect("Failed to create Anthropic client")
373 });
374
375 let prompt = example.prompt.as_ref().context("Prompt is required")?;
376
377 for ix in 0..repetition_count {
378 if repetition_count > 1 {
379 step_progress.set_substatus(format!(
380 "running prediction {}/{}",
381 ix + 1,
382 repetition_count
383 ));
384 } else {
385 step_progress.set_substatus("running prediction");
386 }
387
388 let messages = vec![anthropic::Message {
389 role: anthropic::Role::User,
390 content: vec![anthropic::RequestContent::Text {
391 text: prompt.input.clone(),
392 cache_control: None,
393 }],
394 }];
395
396 let seed = if repetition_count > 1 { Some(ix) } else { None };
397 let Some(response) = llm_client
398 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
399 .await?
400 else {
401 // Request stashed for batched processing
402 continue;
403 };
404
405 let actual_output = response
406 .content
407 .into_iter()
408 .filter_map(|content| match content {
409 anthropic::ResponseContent::Text { text } => Some(text),
410 _ => None,
411 })
412 .collect::<Vec<String>>()
413 .join("\n");
414
415 let parser_provider = if batched {
416 example
417 .prompt
418 .as_ref()
419 .map(|prompt| prompt.provider)
420 .unwrap_or(PredictionProvider::Teacher(backend))
421 } else {
422 match example.prompt.as_ref().map(|prompt| prompt.provider) {
423 Some(PredictionProvider::TeacherMultiRegion(_))
424 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
425 PredictionProvider::TeacherMultiRegionNonBatching(backend)
426 }
427 _ => PredictionProvider::TeacherNonBatching(backend),
428 }
429 };
430
431 let (actual_patch, actual_cursor) = match parser_provider {
432 PredictionProvider::TeacherMultiRegion(_)
433 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
434 TeacherMultiRegionPrompt::parse(example, &actual_output)?
435 }
436 _ => TeacherPrompt::parse(example, &actual_output)?,
437 };
438
439 let prediction = ExamplePrediction {
440 actual_patch: Some(actual_patch),
441 actual_output,
442 actual_cursor,
443 error: None,
444 provider: if batched {
445 match example.prompt.as_ref().map(|prompt| prompt.provider) {
446 Some(PredictionProvider::TeacherMultiRegion(_)) => {
447 PredictionProvider::TeacherMultiRegion(backend)
448 }
449 _ => PredictionProvider::Teacher(backend),
450 }
451 } else {
452 match example.prompt.as_ref().map(|prompt| prompt.provider) {
453 Some(PredictionProvider::TeacherMultiRegion(_))
454 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
455 PredictionProvider::TeacherMultiRegionNonBatching(backend)
456 }
457 _ => PredictionProvider::TeacherNonBatching(backend),
458 }
459 },
460 cumulative_logprob: None,
461 avg_logprob: None,
462 };
463
464 example.predictions.push(prediction);
465 }
466 Ok(())
467}
468
469async fn predict_openai(
470 example: &mut Example,
471 backend: TeacherBackend,
472 batched: bool,
473 repetition_count: usize,
474 cache_only: bool,
475 step_progress: &crate::progress::StepProgress,
476) -> anyhow::Result<()> {
477 let llm_model_name = backend.model_name();
478 let max_tokens = 16384;
479 let llm_client = OPENAI_CLIENT.get_or_init(|| {
480 let client = if batched {
481 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
482 } else {
483 OpenAiClient::plain()
484 };
485 client.expect("Failed to create OpenAI client")
486 });
487
488 let prompt = example.prompt.as_ref().context("Prompt is required")?;
489
490 for ix in 0..repetition_count {
491 if repetition_count > 1 {
492 step_progress.set_substatus(format!(
493 "running prediction {}/{}",
494 ix + 1,
495 repetition_count
496 ));
497 } else {
498 step_progress.set_substatus("running prediction");
499 }
500
501 let messages = vec![open_ai::RequestMessage::User {
502 content: open_ai::MessageContent::Plain(prompt.input.clone()),
503 }];
504
505 let seed = if repetition_count > 1 { Some(ix) } else { None };
506 let Some(response) = llm_client
507 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
508 .await?
509 else {
510 // Request stashed for batched processing
511 continue;
512 };
513
514 let actual_output = response
515 .choices
516 .into_iter()
517 .filter_map(|choice| match choice.message {
518 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
519 open_ai::MessageContent::Plain(text) => text,
520 open_ai::MessageContent::Multipart(parts) => parts
521 .into_iter()
522 .filter_map(|p| match p {
523 open_ai::MessagePart::Text { text } => Some(text),
524 _ => None,
525 })
526 .collect::<Vec<_>>()
527 .join(""),
528 }),
529 _ => None,
530 })
531 .collect::<Vec<String>>()
532 .join("\n");
533
534 let parser_provider = if batched {
535 example
536 .prompt
537 .as_ref()
538 .map(|prompt| prompt.provider)
539 .unwrap_or(PredictionProvider::Teacher(backend))
540 } else {
541 match example.prompt.as_ref().map(|prompt| prompt.provider) {
542 Some(PredictionProvider::TeacherMultiRegion(_))
543 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
544 PredictionProvider::TeacherMultiRegionNonBatching(backend)
545 }
546 _ => PredictionProvider::TeacherNonBatching(backend),
547 }
548 };
549
550 let (actual_patch, actual_cursor) = match parser_provider {
551 PredictionProvider::TeacherMultiRegion(_)
552 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
553 TeacherMultiRegionPrompt::parse(example, &actual_output)?
554 }
555 _ => TeacherPrompt::parse(example, &actual_output)?,
556 };
557
558 let prediction = ExamplePrediction {
559 actual_patch: Some(actual_patch),
560 actual_output,
561 actual_cursor,
562 error: None,
563 provider: if batched {
564 match example.prompt.as_ref().map(|prompt| prompt.provider) {
565 Some(PredictionProvider::TeacherMultiRegion(_)) => {
566 PredictionProvider::TeacherMultiRegion(backend)
567 }
568 _ => PredictionProvider::Teacher(backend),
569 }
570 } else {
571 match example.prompt.as_ref().map(|prompt| prompt.provider) {
572 Some(PredictionProvider::TeacherMultiRegion(_))
573 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
574 PredictionProvider::TeacherMultiRegionNonBatching(backend)
575 }
576 _ => PredictionProvider::TeacherNonBatching(backend),
577 }
578 },
579 cumulative_logprob: None,
580 avg_logprob: None,
581 };
582
583 example.predictions.push(prediction);
584 }
585 Ok(())
586}
587
588pub async fn predict_baseten(
589 example: &mut Example,
590 format: ZetaFormat,
591 step_progress: &StepProgress,
592) -> anyhow::Result<()> {
593 let model_id =
594 std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
595
596 let api_key =
597 std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
598
599 let prompt = example.prompt.as_ref().context("Prompt is required")?;
600 let prompt_text = prompt.input.clone();
601 let prefill = prompt.prefill.clone().unwrap_or_default();
602
603 step_progress.set_substatus("running prediction via baseten");
604
605 let environment: String = <&'static str>::from(&format).to_lowercase();
606 let url = format!(
607 "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
608 );
609
610 let request_body = RawCompletionRequest {
611 model: model_id,
612 prompt: prompt_text.clone(),
613 max_tokens: Some(2048),
614 temperature: Some(0.),
615 stop: vec![],
616 environment: None,
617 };
618
619 let body_bytes =
620 serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
621
622 let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
623 let request = http_client::Request::builder()
624 .method(Method::POST)
625 .uri(&url)
626 .header("Content-Type", "application/json")
627 .header("Authorization", format!("Api-Key {api_key}"))
628 .body(AsyncBody::from(body_bytes))?;
629
630 let mut response = http_client.send(request).await?;
631 let status = response.status();
632
633 let mut body = String::new();
634 response
635 .body_mut()
636 .read_to_string(&mut body)
637 .await
638 .context("Failed to read Baseten response body")?;
639
640 if !status.is_success() {
641 anyhow::bail!("Baseten API returned {status}: {body}");
642 }
643
644 let completion: RawCompletionResponse =
645 serde_json::from_str(&body).context("Failed to parse Baseten response")?;
646
647 let actual_output = completion
648 .choices
649 .into_iter()
650 .next()
651 .map(|choice| choice.text)
652 .unwrap_or_default();
653
654 let actual_output = format!("{prefill}{actual_output}");
655
656 let (actual_patch, actual_cursor) =
657 parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
658
659 let prediction = ExamplePrediction {
660 actual_patch: Some(actual_patch),
661 actual_output,
662 actual_cursor,
663 error: None,
664 provider: PredictionProvider::Baseten(format),
665 cumulative_logprob: None,
666 avg_logprob: None,
667 };
668
669 example.predictions.push(prediction);
670 Ok(())
671}
672
673pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
674 match provider {
675 Some(PredictionProvider::Teacher(backend))
676 | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
677 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
678 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
679 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
680 .expect("Failed to create Anthropic client")
681 });
682 llm_client
683 .sync_batches()
684 .await
685 .context("Failed to sync Anthropic batches")?;
686 }
687 TeacherBackend::Gpt52 => {
688 let llm_client = OPENAI_CLIENT.get_or_init(|| {
689 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
690 .expect("Failed to create OpenAI client")
691 });
692 llm_client
693 .sync_batches()
694 .await
695 .context("Failed to sync OpenAI batches")?;
696 }
697 },
698 _ => (),
699 };
700 Ok(())
701}
702
703pub async fn reprocess_after_batch_wait(
704 examples: &mut [Example],
705 args: &PredictArgs,
706) -> anyhow::Result<()> {
707 let Some(PredictionProvider::Teacher(backend)) = args.provider else {
708 return Ok(());
709 };
710
711 let mut reprocessed = 0;
712 for example in examples.iter_mut() {
713 let has_prediction = example
714 .predictions
715 .iter()
716 .any(|p| p.actual_patch.is_some() || !p.actual_output.is_empty());
717 if has_prediction || example.prompt.is_none() {
718 continue;
719 }
720
721 let example_progress = Progress::global().start_group(&example.spec.name);
722 let step_progress = example_progress.start(Step::Predict);
723 predict_teacher(
724 example,
725 backend,
726 true,
727 args.repetitions,
728 false,
729 &step_progress,
730 )
731 .await?;
732 reprocessed += 1;
733 }
734
735 if reprocessed > 0 {
736 eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
737 }
738
739 Ok(())
740}
741
742pub async fn wait_for_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
743 let poll_interval = std::time::Duration::from_secs(30);
744
745 loop {
746 let pending = pending_batch_count(provider)?;
747 if pending == 0 {
748 break;
749 }
750
751 eprintln!(
752 "Waiting for {} pending batch request(s) to complete... (polling every {}s)",
753 pending,
754 poll_interval.as_secs()
755 );
756 std::thread::sleep(poll_interval);
757
758 sync_batches(provider).await?;
759 }
760
761 Ok(())
762}
763
764fn pending_batch_count(provider: Option<&PredictionProvider>) -> anyhow::Result<usize> {
765 match provider {
766 Some(PredictionProvider::Teacher(backend)) => match backend {
767 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
768 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
769 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
770 .expect("Failed to create Anthropic client")
771 });
772 llm_client.pending_batch_count()
773 }
774 TeacherBackend::Gpt52 => {
775 let llm_client = OPENAI_CLIENT.get_or_init(|| {
776 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
777 .expect("Failed to create OpenAI client")
778 });
779 llm_client.pending_batch_count()
780 }
781 },
782 _ => Ok(0),
783 }
784}