1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 parse_output::parse_prediction_output,
10 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
11 progress::{ExampleProgress, InfoStyle, Progress, Step, StepProgress},
12 retrieve_context::run_context_retrieval,
13};
14use anyhow::Context as _;
15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
18use gpui::{AppContext as _, AsyncApp, Task};
19use http_client::{AsyncBody, HttpClient, Method};
20use reqwest_client::ReqwestClient;
21use std::{
22 fs,
23 sync::{
24 Arc, Mutex, OnceLock,
25 atomic::{AtomicUsize, Ordering::SeqCst},
26 },
27};
28use zeta_prompt::ZetaFormat;
29
30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
32
33pub async fn run_prediction(
34 example: &mut Example,
35 args: &PredictArgs,
36 app_state: Arc<EpAppState>,
37 example_progress: &ExampleProgress,
38 mut cx: AsyncApp,
39) -> anyhow::Result<()> {
40 let repetition_count = args.repetitions;
41
42 if let Some(existing_prediction) = example.predictions.first() {
43 let has_prediction = existing_prediction.actual_patch.is_some()
44 || !existing_prediction.actual_output.is_empty();
45 if has_prediction {
46 match args.provider {
47 None => return Ok(()),
48 Some(provider) if existing_prediction.provider == provider => return Ok(()),
49 Some(_) => example.predictions.clear(),
50 }
51 }
52 }
53
54 let Some(provider) = args.provider else {
55 anyhow::bail!(
56 "No existing predictions found. Use --provider to specify which model to use for prediction."
57 );
58 };
59
60 if let PredictionProvider::Teacher(backend)
61 | PredictionProvider::TeacherMultiRegion(backend)
62 | PredictionProvider::TeacherNonBatching(backend)
63 | PredictionProvider::TeacherMultiRegionNonBatching(backend) = provider
64 {
65 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
66 run_format_prompt(
67 example,
68 &FormatPromptArgs { provider },
69 app_state.clone(),
70 example_progress,
71 cx,
72 )
73 .await?;
74
75 let step_progress = example_progress.start(Step::Predict);
76 let batched = matches!(
77 provider,
78 PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
79 );
80 return predict_teacher(
81 example,
82 backend,
83 batched,
84 repetition_count,
85 args.cache_only,
86 &step_progress,
87 )
88 .await;
89 }
90
91 if let PredictionProvider::Baseten(format) = provider {
92 run_format_prompt(
93 example,
94 &FormatPromptArgs {
95 provider: PredictionProvider::Zeta2(format),
96 },
97 app_state.clone(),
98 example_progress,
99 cx,
100 )
101 .await?;
102
103 let step_progress = example_progress.start(Step::Predict);
104 return predict_baseten(example, format, &step_progress).await;
105 }
106
107 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
108 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
109
110 let step_progress = example_progress.start(Step::Predict);
111
112 if matches!(
113 provider,
114 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
115 ) {
116 step_progress.set_substatus("authenticating");
117 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
118 AUTHENTICATED
119 .get_or_init(|| {
120 let client = app_state.client.clone();
121 cx.spawn(async move |cx| {
122 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
123 eprintln!("Authentication failed: {}", e);
124 }
125 })
126 .shared()
127 })
128 .clone()
129 .await;
130 }
131
132 let ep_store = cx
133 .update(|cx| EditPredictionStore::try_global(cx))
134 .context("EditPredictionStore not initialized")?;
135
136 ep_store.update(&mut cx, |store, _cx| {
137 let model = match provider {
138 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
139 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
140 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
141 PredictionProvider::Teacher(..)
142 | PredictionProvider::TeacherMultiRegion(..)
143 | PredictionProvider::TeacherNonBatching(..)
144 | PredictionProvider::TeacherMultiRegionNonBatching(..)
145 | PredictionProvider::Repair
146 | PredictionProvider::Baseten(_) => {
147 unreachable!()
148 }
149 };
150 store.set_edit_prediction_model(model);
151
152 // If user specified a non-default Zeta2 version, configure raw endpoint.
153 // ZED_ZETA_MODEL env var is optional.
154 if let PredictionProvider::Zeta2(format) = provider {
155 if format != ZetaFormat::default() {
156 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
157 let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
158 store.set_zeta2_raw_config(Zeta2RawConfig {
159 model_id,
160 environment,
161 format,
162 });
163 }
164 }
165 });
166 step_progress.set_substatus("configuring model");
167 let state = example.state.as_ref().context("state must be set")?;
168 let run_dir = RUN_DIR.join(&example.spec.name);
169
170 let updated_example = Arc::new(Mutex::new(example.clone()));
171 let current_run_ix = Arc::new(AtomicUsize::new(0));
172
173 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
174 let debug_task = cx.background_spawn({
175 let updated_example = updated_example.clone();
176 let current_run_ix = current_run_ix.clone();
177 let run_dir = run_dir.clone();
178 async move {
179 while let Some(event) = debug_rx.next().await {
180 let run_ix = current_run_ix.load(SeqCst);
181 let mut updated_example = updated_example.lock().unwrap();
182
183 let run_dir = if repetition_count > 1 {
184 run_dir.join(format!("{:03}", run_ix))
185 } else {
186 run_dir.clone()
187 };
188
189 match event {
190 DebugEvent::EditPredictionStarted(request) => {
191 assert_eq!(updated_example.predictions.len(), run_ix + 1);
192
193 if let Some(prompt) = request.prompt {
194 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
195 if matches!(provider, PredictionProvider::Zeta2(_)) {
196 updated_example.prompt.get_or_insert(ExamplePrompt {
197 input: prompt,
198 expected_output: None,
199 rejected_output: None,
200 provider,
201 prefill: None,
202 });
203 }
204 }
205 }
206 DebugEvent::EditPredictionFinished(request) => {
207 assert_eq!(updated_example.predictions.len(), run_ix + 1);
208
209 if let Some(output) = request.model_output {
210 fs::write(run_dir.join("prediction_response.md"), &output)?;
211 updated_example
212 .predictions
213 .last_mut()
214 .unwrap()
215 .actual_output = output;
216 }
217 if run_ix >= repetition_count {
218 break;
219 }
220 }
221 _ => {}
222 }
223 }
224 anyhow::Ok(())
225 }
226 });
227
228 for ix in 0..repetition_count {
229 current_run_ix.store(ix, SeqCst);
230 let run_dir = if repetition_count > 1 {
231 run_dir.join(format!("{:03}", ix))
232 } else {
233 run_dir.clone()
234 };
235
236 if repetition_count > 1 {
237 step_progress.set_substatus(format!(
238 "running prediction {}/{}",
239 ix + 1,
240 repetition_count
241 ));
242 } else {
243 step_progress.set_substatus("running prediction");
244 }
245
246 fs::create_dir_all(&run_dir)?;
247 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
248 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
249 }
250 #[cfg(unix)]
251 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
252 #[cfg(windows)]
253 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
254
255 updated_example
256 .lock()
257 .unwrap()
258 .predictions
259 .push(ExamplePrediction {
260 actual_patch: None,
261 actual_output: String::new(),
262 actual_cursor: None,
263 error: None,
264 provider,
265 cumulative_logprob: None,
266 avg_logprob: None,
267 });
268
269 step_progress.set_substatus("requesting prediction");
270 let prediction = ep_store
271 .update(&mut cx, |store, cx| {
272 store.request_prediction(
273 &state.project,
274 &state.buffer,
275 state.cursor_position,
276 cloud_llm_client::PredictEditsRequestTrigger::Cli,
277 cx,
278 )
279 })
280 .await?;
281
282 let actual_patch = prediction.and_then(|prediction| {
283 let prediction = prediction.prediction.ok()?;
284 prediction
285 .edit_preview
286 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
287 });
288
289 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
290
291 updated_example
292 .lock()
293 .unwrap()
294 .predictions
295 .last_mut()
296 .unwrap()
297 .actual_patch = actual_patch;
298
299 if ix == repetition_count - 1 {
300 let (info, style) = if has_prediction {
301 ("predicted", InfoStyle::Normal)
302 } else {
303 ("no prediction", InfoStyle::Warning)
304 };
305 step_progress.set_info(info, style);
306 }
307 }
308
309 ep_store.update(&mut cx, |store, _| {
310 store.remove_project(&state.project);
311 });
312 debug_task.await?;
313
314 *example = Arc::into_inner(updated_example)
315 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
316 .into_inner()
317 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
318 Ok(())
319}
320
321async fn predict_teacher(
322 example: &mut Example,
323 backend: TeacherBackend,
324 batched: bool,
325 repetition_count: usize,
326 cache_only: bool,
327 step_progress: &crate::progress::StepProgress,
328) -> anyhow::Result<()> {
329 match backend {
330 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
331 predict_anthropic(
332 example,
333 backend,
334 batched,
335 repetition_count,
336 cache_only,
337 step_progress,
338 )
339 .await
340 }
341 TeacherBackend::Gpt52 => {
342 predict_openai(
343 example,
344 backend,
345 batched,
346 repetition_count,
347 cache_only,
348 step_progress,
349 )
350 .await
351 }
352 }
353}
354
355async fn predict_anthropic(
356 example: &mut Example,
357 backend: TeacherBackend,
358 batched: bool,
359 repetition_count: usize,
360 cache_only: bool,
361 step_progress: &crate::progress::StepProgress,
362) -> anyhow::Result<()> {
363 let llm_model_name = backend.model_name();
364 let max_tokens = 16384;
365 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
366 let client = if batched {
367 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
368 } else {
369 AnthropicClient::plain()
370 };
371 client.expect("Failed to create Anthropic client")
372 });
373
374 let prompt = example.prompt.as_ref().context("Prompt is required")?;
375
376 for ix in 0..repetition_count {
377 if repetition_count > 1 {
378 step_progress.set_substatus(format!(
379 "running prediction {}/{}",
380 ix + 1,
381 repetition_count
382 ));
383 } else {
384 step_progress.set_substatus("running prediction");
385 }
386
387 let messages = vec![anthropic::Message {
388 role: anthropic::Role::User,
389 content: vec![anthropic::RequestContent::Text {
390 text: prompt.input.clone(),
391 cache_control: None,
392 }],
393 }];
394
395 let seed = if repetition_count > 1 { Some(ix) } else { None };
396 let Some(response) = llm_client
397 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
398 .await?
399 else {
400 // Request stashed for batched processing
401 continue;
402 };
403
404 let actual_output = response
405 .content
406 .into_iter()
407 .filter_map(|content| match content {
408 anthropic::ResponseContent::Text { text } => Some(text),
409 _ => None,
410 })
411 .collect::<Vec<String>>()
412 .join("\n");
413
414 let parser_provider = if batched {
415 example
416 .prompt
417 .as_ref()
418 .map(|prompt| prompt.provider)
419 .unwrap_or(PredictionProvider::Teacher(backend))
420 } else {
421 match example.prompt.as_ref().map(|prompt| prompt.provider) {
422 Some(PredictionProvider::TeacherMultiRegion(_))
423 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
424 PredictionProvider::TeacherMultiRegionNonBatching(backend)
425 }
426 _ => PredictionProvider::TeacherNonBatching(backend),
427 }
428 };
429
430 let (actual_patch, actual_cursor) = match parser_provider {
431 PredictionProvider::TeacherMultiRegion(_)
432 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
433 TeacherMultiRegionPrompt::parse(example, &actual_output)?
434 }
435 _ => TeacherPrompt::parse(example, &actual_output)?,
436 };
437
438 let prediction = ExamplePrediction {
439 actual_patch: Some(actual_patch),
440 actual_output,
441 actual_cursor,
442 error: None,
443 provider: if batched {
444 match example.prompt.as_ref().map(|prompt| prompt.provider) {
445 Some(PredictionProvider::TeacherMultiRegion(_)) => {
446 PredictionProvider::TeacherMultiRegion(backend)
447 }
448 _ => PredictionProvider::Teacher(backend),
449 }
450 } else {
451 match example.prompt.as_ref().map(|prompt| prompt.provider) {
452 Some(PredictionProvider::TeacherMultiRegion(_))
453 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
454 PredictionProvider::TeacherMultiRegionNonBatching(backend)
455 }
456 _ => PredictionProvider::TeacherNonBatching(backend),
457 }
458 },
459 cumulative_logprob: None,
460 avg_logprob: None,
461 };
462
463 example.predictions.push(prediction);
464 }
465 Ok(())
466}
467
468async fn predict_openai(
469 example: &mut Example,
470 backend: TeacherBackend,
471 batched: bool,
472 repetition_count: usize,
473 cache_only: bool,
474 step_progress: &crate::progress::StepProgress,
475) -> anyhow::Result<()> {
476 let llm_model_name = backend.model_name();
477 let max_tokens = 16384;
478 let llm_client = OPENAI_CLIENT.get_or_init(|| {
479 let client = if batched {
480 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
481 } else {
482 OpenAiClient::plain()
483 };
484 client.expect("Failed to create OpenAI client")
485 });
486
487 let prompt = example.prompt.as_ref().context("Prompt is required")?;
488
489 for ix in 0..repetition_count {
490 if repetition_count > 1 {
491 step_progress.set_substatus(format!(
492 "running prediction {}/{}",
493 ix + 1,
494 repetition_count
495 ));
496 } else {
497 step_progress.set_substatus("running prediction");
498 }
499
500 let messages = vec![open_ai::RequestMessage::User {
501 content: open_ai::MessageContent::Plain(prompt.input.clone()),
502 }];
503
504 let seed = if repetition_count > 1 { Some(ix) } else { None };
505 let Some(response) = llm_client
506 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
507 .await?
508 else {
509 // Request stashed for batched processing
510 continue;
511 };
512
513 let actual_output = response
514 .choices
515 .into_iter()
516 .filter_map(|choice| match choice.message {
517 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
518 open_ai::MessageContent::Plain(text) => text,
519 open_ai::MessageContent::Multipart(parts) => parts
520 .into_iter()
521 .filter_map(|p| match p {
522 open_ai::MessagePart::Text { text } => Some(text),
523 _ => None,
524 })
525 .collect::<Vec<_>>()
526 .join(""),
527 }),
528 _ => None,
529 })
530 .collect::<Vec<String>>()
531 .join("\n");
532
533 let parser_provider = if batched {
534 example
535 .prompt
536 .as_ref()
537 .map(|prompt| prompt.provider)
538 .unwrap_or(PredictionProvider::Teacher(backend))
539 } else {
540 match example.prompt.as_ref().map(|prompt| prompt.provider) {
541 Some(PredictionProvider::TeacherMultiRegion(_))
542 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
543 PredictionProvider::TeacherMultiRegionNonBatching(backend)
544 }
545 _ => PredictionProvider::TeacherNonBatching(backend),
546 }
547 };
548
549 let (actual_patch, actual_cursor) = match parser_provider {
550 PredictionProvider::TeacherMultiRegion(_)
551 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
552 TeacherMultiRegionPrompt::parse(example, &actual_output)?
553 }
554 _ => TeacherPrompt::parse(example, &actual_output)?,
555 };
556
557 let prediction = ExamplePrediction {
558 actual_patch: Some(actual_patch),
559 actual_output,
560 actual_cursor,
561 error: None,
562 provider: if batched {
563 match example.prompt.as_ref().map(|prompt| prompt.provider) {
564 Some(PredictionProvider::TeacherMultiRegion(_)) => {
565 PredictionProvider::TeacherMultiRegion(backend)
566 }
567 _ => PredictionProvider::Teacher(backend),
568 }
569 } else {
570 match example.prompt.as_ref().map(|prompt| prompt.provider) {
571 Some(PredictionProvider::TeacherMultiRegion(_))
572 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
573 PredictionProvider::TeacherMultiRegionNonBatching(backend)
574 }
575 _ => PredictionProvider::TeacherNonBatching(backend),
576 }
577 },
578 cumulative_logprob: None,
579 avg_logprob: None,
580 };
581
582 example.predictions.push(prediction);
583 }
584 Ok(())
585}
586
587pub async fn predict_baseten(
588 example: &mut Example,
589 format: ZetaFormat,
590 step_progress: &StepProgress,
591) -> anyhow::Result<()> {
592 let model_id =
593 std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
594
595 let api_key =
596 std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
597
598 let prompt = example.prompt.as_ref().context("Prompt is required")?;
599 let prompt_text = prompt.input.clone();
600 let prefill = prompt.prefill.clone().unwrap_or_default();
601
602 step_progress.set_substatus("running prediction via baseten");
603
604 let environment: String = <&'static str>::from(&format).to_lowercase();
605 let url = format!(
606 "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
607 );
608
609 let request_body = RawCompletionRequest {
610 model: model_id,
611 prompt: prompt_text.clone(),
612 max_tokens: Some(2048),
613 temperature: Some(0.),
614 stop: vec![],
615 environment: None,
616 };
617
618 let body_bytes =
619 serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
620
621 let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
622 let request = http_client::Request::builder()
623 .method(Method::POST)
624 .uri(&url)
625 .header("Content-Type", "application/json")
626 .header("Authorization", format!("Api-Key {api_key}"))
627 .body(AsyncBody::from(body_bytes))?;
628
629 let mut response = http_client.send(request).await?;
630 let status = response.status();
631
632 let mut body = String::new();
633 response
634 .body_mut()
635 .read_to_string(&mut body)
636 .await
637 .context("Failed to read Baseten response body")?;
638
639 if !status.is_success() {
640 anyhow::bail!("Baseten API returned {status}: {body}");
641 }
642
643 let completion: RawCompletionResponse =
644 serde_json::from_str(&body).context("Failed to parse Baseten response")?;
645
646 let actual_output = completion
647 .choices
648 .into_iter()
649 .next()
650 .map(|choice| choice.text)
651 .unwrap_or_default();
652
653 let actual_output = format!("{prefill}{actual_output}");
654
655 let (actual_patch, actual_cursor) =
656 parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
657
658 let prediction = ExamplePrediction {
659 actual_patch: Some(actual_patch),
660 actual_output,
661 actual_cursor,
662 error: None,
663 provider: PredictionProvider::Baseten(format),
664 cumulative_logprob: None,
665 avg_logprob: None,
666 };
667
668 example.predictions.push(prediction);
669 Ok(())
670}
671
672pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
673 match provider {
674 Some(PredictionProvider::Teacher(backend))
675 | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
676 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
677 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
678 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
679 .expect("Failed to create Anthropic client")
680 });
681 llm_client
682 .sync_batches()
683 .await
684 .context("Failed to sync Anthropic batches")?;
685 }
686 TeacherBackend::Gpt52 => {
687 let llm_client = OPENAI_CLIENT.get_or_init(|| {
688 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
689 .expect("Failed to create OpenAI client")
690 });
691 llm_client
692 .sync_batches()
693 .await
694 .context("Failed to sync OpenAI batches")?;
695 }
696 },
697 _ => (),
698 };
699 Ok(())
700}
701
702pub async fn reprocess_after_batch_wait(
703 examples: &mut [Example],
704 args: &PredictArgs,
705) -> anyhow::Result<()> {
706 let Some(PredictionProvider::Teacher(backend)) = args.provider else {
707 return Ok(());
708 };
709
710 let mut reprocessed = 0;
711 for example in examples.iter_mut() {
712 let has_prediction = example
713 .predictions
714 .iter()
715 .any(|p| p.actual_patch.is_some() || !p.actual_output.is_empty());
716 if has_prediction || example.prompt.is_none() {
717 continue;
718 }
719
720 let example_progress = Progress::global().start_group(&example.spec.name);
721 let step_progress = example_progress.start(Step::Predict);
722 predict_teacher(
723 example,
724 backend,
725 true,
726 args.repetitions,
727 false,
728 &step_progress,
729 )
730 .await?;
731 reprocessed += 1;
732 }
733
734 if reprocessed > 0 {
735 eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
736 }
737
738 Ok(())
739}
740
741pub async fn wait_for_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
742 let poll_interval = std::time::Duration::from_secs(30);
743
744 loop {
745 let pending = pending_batch_count(provider)?;
746 if pending == 0 {
747 break;
748 }
749
750 eprintln!(
751 "Waiting for {} pending batch request(s) to complete... (polling every {}s)",
752 pending,
753 poll_interval.as_secs()
754 );
755 std::thread::sleep(poll_interval);
756
757 sync_batches(provider).await?;
758 }
759
760 Ok(())
761}
762
763fn pending_batch_count(provider: Option<&PredictionProvider>) -> anyhow::Result<usize> {
764 match provider {
765 Some(PredictionProvider::Teacher(backend)) => match backend {
766 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
767 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
768 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
769 .expect("Failed to create Anthropic client")
770 });
771 llm_client.pending_batch_count()
772 }
773 TeacherBackend::Gpt52 => {
774 let llm_client = OPENAI_CLIENT.get_or_init(|| {
775 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
776 .expect("Failed to create OpenAI client")
777 });
778 llm_client.pending_batch_count()
779 }
780 },
781 _ => Ok(0),
782 }
783}