1use crate::{
2 BatchProvider, FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 llm_client::{LlmClient, model_for_backend},
8 load_project::run_load_project,
9 openai_client::OpenAiClient,
10 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
11 progress::{ExampleProgress, InfoStyle, Step},
12 qa,
13 repair::{build_repair_prompt, needs_repair, parse_repair_response},
14 retrieve_context::run_context_retrieval,
15};
16use anyhow::Context as _;
17use edit_prediction::{DebugEvent, EditPredictionStore};
18use futures::{FutureExt as _, StreamExt as _, future::Shared};
19use gpui::{AppContext as _, AsyncApp, Task};
20use std::{
21 fs,
22 sync::{
23 Arc, Mutex, OnceLock,
24 atomic::{AtomicUsize, Ordering::SeqCst},
25 },
26};
27
28static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
29static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
30
31pub async fn run_prediction(
32 example: &mut Example,
33 args: &PredictArgs,
34 app_state: Arc<EpAppState>,
35 example_progress: &ExampleProgress,
36 mut cx: AsyncApp,
37) -> anyhow::Result<()> {
38 let repetition_count = args.repetitions;
39
40 if let Some(existing_prediction) = example.predictions.first() {
41 let has_prediction = existing_prediction.actual_patch.is_some()
42 || !existing_prediction.actual_output.is_empty();
43 if has_prediction {
44 match args.provider {
45 None => return Ok(()),
46 Some(provider) if existing_prediction.provider == provider => return Ok(()),
47 Some(_) => example.predictions.clear(),
48 }
49 }
50 }
51
52 let Some(provider) = args.provider else {
53 anyhow::bail!(
54 "No existing predictions found. Use --provider to specify which model to use for prediction."
55 );
56 };
57
58 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
59
60 if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
61 provider
62 {
63 let _step_progress = example_progress.start(Step::Predict);
64
65 run_format_prompt(
66 example,
67 &FormatPromptArgs { provider },
68 app_state.clone(),
69 example_progress,
70 cx,
71 )
72 .await?;
73
74 let batched = matches!(provider, PredictionProvider::Teacher(..));
75 return predict_teacher(example, backend, batched, repetition_count).await;
76 }
77
78 if let PredictionProvider::RepairedTeacher(backend) = provider {
79 let _step_progress = example_progress.start(Step::Predict);
80
81 run_format_prompt(
82 example,
83 &FormatPromptArgs { provider },
84 app_state.clone(),
85 example_progress,
86 cx,
87 )
88 .await?;
89
90 return predict_repaired_teacher(example, backend, repetition_count).await;
91 }
92
93 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
94
95 let step_progress = example_progress.start(Step::Predict);
96
97 if matches!(
98 provider,
99 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
100 ) {
101 step_progress.set_substatus("authenticating");
102 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
103 AUTHENTICATED
104 .get_or_init(|| {
105 let client = app_state.client.clone();
106 cx.spawn(async move |cx| {
107 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
108 eprintln!("Authentication failed: {}", e);
109 }
110 })
111 .shared()
112 })
113 .clone()
114 .await;
115 }
116
117 let ep_store = cx
118 .update(|cx| EditPredictionStore::try_global(cx))
119 .context("EditPredictionStore not initialized")?;
120
121 ep_store.update(&mut cx, |store, _cx| {
122 let model = match provider {
123 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
124 PredictionProvider::Zeta2(version) => {
125 edit_prediction::EditPredictionModel::Zeta2 { version }
126 }
127 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
128 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
129 PredictionProvider::Teacher(..)
130 | PredictionProvider::TeacherNonBatching(..)
131 | PredictionProvider::RepairedTeacher(..)
132 | PredictionProvider::Repair => {
133 unreachable!()
134 }
135 };
136 store.set_edit_prediction_model(model);
137 });
138 step_progress.set_substatus("configuring model");
139 let state = example.state.as_ref().context("state must be set")?;
140 let run_dir = RUN_DIR.join(&example.spec.name);
141
142 let updated_example = Arc::new(Mutex::new(example.clone()));
143 let current_run_ix = Arc::new(AtomicUsize::new(0));
144
145 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
146 let debug_task = cx.background_spawn({
147 let updated_example = updated_example.clone();
148 let current_run_ix = current_run_ix.clone();
149 let run_dir = run_dir.clone();
150 async move {
151 while let Some(event) = debug_rx.next().await {
152 let run_ix = current_run_ix.load(SeqCst);
153 let mut updated_example = updated_example.lock().unwrap();
154
155 let run_dir = if repetition_count > 1 {
156 run_dir.join(format!("{:03}", run_ix))
157 } else {
158 run_dir.clone()
159 };
160
161 match event {
162 DebugEvent::EditPredictionStarted(request) => {
163 assert_eq!(updated_example.predictions.len(), run_ix + 1);
164
165 if let Some(prompt) = request.prompt {
166 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
167 if matches!(provider, PredictionProvider::Zeta2(_)) {
168 updated_example.prompt.get_or_insert(ExamplePrompt {
169 input: prompt,
170 expected_output: String::new(),
171 rejected_output: None,
172 provider,
173 });
174 }
175 }
176 }
177 DebugEvent::EditPredictionFinished(request) => {
178 assert_eq!(updated_example.predictions.len(), run_ix + 1);
179
180 if let Some(output) = request.model_output {
181 fs::write(run_dir.join("prediction_response.md"), &output)?;
182 updated_example
183 .predictions
184 .last_mut()
185 .unwrap()
186 .actual_output = output;
187 }
188 if run_ix >= repetition_count {
189 break;
190 }
191 }
192 _ => {}
193 }
194 }
195 anyhow::Ok(())
196 }
197 });
198
199 for ix in 0..repetition_count {
200 current_run_ix.store(ix, SeqCst);
201 let run_dir = if repetition_count > 1 {
202 run_dir.join(format!("{:03}", ix))
203 } else {
204 run_dir.clone()
205 };
206
207 fs::create_dir_all(&run_dir)?;
208 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
209 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
210 }
211 #[cfg(unix)]
212 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
213 #[cfg(windows)]
214 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
215
216 updated_example
217 .lock()
218 .unwrap()
219 .predictions
220 .push(ExamplePrediction {
221 actual_patch: None,
222 actual_output: String::new(),
223 error: None,
224 provider,
225 });
226
227 step_progress.set_substatus("requesting prediction");
228 let prediction = ep_store
229 .update(&mut cx, |store, cx| {
230 store.request_prediction(
231 &state.project,
232 &state.buffer,
233 state.cursor_position,
234 cloud_llm_client::PredictEditsRequestTrigger::Cli,
235 cx,
236 )
237 })
238 .await?;
239
240 let actual_patch = prediction.and_then(|prediction| {
241 let prediction = prediction.prediction.ok()?;
242 prediction
243 .edit_preview
244 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
245 });
246
247 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
248
249 updated_example
250 .lock()
251 .unwrap()
252 .predictions
253 .last_mut()
254 .unwrap()
255 .actual_patch = actual_patch;
256
257 if ix == repetition_count - 1 {
258 let (info, style) = if has_prediction {
259 ("predicted", InfoStyle::Normal)
260 } else {
261 ("no prediction", InfoStyle::Warning)
262 };
263 step_progress.set_info(info, style);
264 }
265 }
266
267 ep_store.update(&mut cx, |store, _| {
268 store.remove_project(&state.project);
269 });
270 debug_task.await?;
271
272 *example = Arc::into_inner(updated_example)
273 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
274 .into_inner()
275 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
276 Ok(())
277}
278
279async fn predict_teacher(
280 example: &mut Example,
281 backend: TeacherBackend,
282 batched: bool,
283 repetition_count: usize,
284) -> anyhow::Result<()> {
285 match backend {
286 TeacherBackend::Sonnet45 => {
287 predict_anthropic(example, backend, batched, repetition_count).await
288 }
289 TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await,
290 }
291}
292
293async fn predict_anthropic(
294 example: &mut Example,
295 backend: TeacherBackend,
296 batched: bool,
297 repetition_count: usize,
298) -> anyhow::Result<()> {
299 let llm_model_name = backend.model_name();
300 let max_tokens = 16384;
301 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
302 let client = if batched {
303 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
304 } else {
305 AnthropicClient::plain()
306 };
307 client.expect("Failed to create Anthropic client")
308 });
309
310 let prompt = example.prompt.as_ref().context("Prompt is required")?;
311
312 for ix in 0..repetition_count {
313 let messages = vec![anthropic::Message {
314 role: anthropic::Role::User,
315 content: vec![anthropic::RequestContent::Text {
316 text: prompt.input.clone(),
317 cache_control: None,
318 }],
319 }];
320
321 let seed = if repetition_count > 1 { Some(ix) } else { None };
322 let Some(response) = llm_client
323 .generate(llm_model_name, max_tokens, messages, seed)
324 .await?
325 else {
326 // Request stashed for batched processing
327 return Ok(());
328 };
329
330 let actual_output = response
331 .content
332 .into_iter()
333 .filter_map(|content| match content {
334 anthropic::ResponseContent::Text { text } => Some(text),
335 _ => None,
336 })
337 .collect::<Vec<String>>()
338 .join("\n");
339
340 let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
341
342 let prediction = ExamplePrediction {
343 actual_patch: Some(actual_patch),
344 actual_output,
345 error: None,
346 provider: if batched {
347 PredictionProvider::Teacher(backend)
348 } else {
349 PredictionProvider::TeacherNonBatching(backend)
350 },
351 };
352
353 example.predictions.push(prediction);
354 }
355 Ok(())
356}
357
358async fn predict_openai(
359 example: &mut Example,
360 backend: TeacherBackend,
361 batched: bool,
362 repetition_count: usize,
363) -> anyhow::Result<()> {
364 let llm_model_name = backend.model_name();
365 let max_tokens = 16384;
366 let llm_client = OPENAI_CLIENT.get_or_init(|| {
367 let client = if batched {
368 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
369 } else {
370 OpenAiClient::plain()
371 };
372 client.expect("Failed to create OpenAI client")
373 });
374
375 let prompt = example.prompt.as_ref().context("Prompt is required")?;
376
377 for ix in 0..repetition_count {
378 let messages = vec![open_ai::RequestMessage::User {
379 content: open_ai::MessageContent::Plain(prompt.input.clone()),
380 }];
381
382 let seed = if repetition_count > 1 { Some(ix) } else { None };
383 let Some(response) = llm_client
384 .generate(llm_model_name, max_tokens, messages, seed)
385 .await?
386 else {
387 // Request stashed for batched processing
388 return Ok(());
389 };
390
391 let actual_output = response
392 .choices
393 .into_iter()
394 .filter_map(|choice| match choice.message {
395 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
396 open_ai::MessageContent::Plain(text) => text,
397 open_ai::MessageContent::Multipart(parts) => parts
398 .into_iter()
399 .filter_map(|p| match p {
400 open_ai::MessagePart::Text { text } => Some(text),
401 _ => None,
402 })
403 .collect::<Vec<_>>()
404 .join(""),
405 }),
406 _ => None,
407 })
408 .collect::<Vec<String>>()
409 .join("\n");
410
411 let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
412
413 let prediction = ExamplePrediction {
414 actual_patch: Some(actual_patch),
415 actual_output,
416 error: None,
417 provider: if batched {
418 PredictionProvider::Teacher(backend)
419 } else {
420 PredictionProvider::TeacherNonBatching(backend)
421 },
422 };
423
424 example.predictions.push(prediction);
425 }
426 Ok(())
427}
428
429/// Default confidence threshold for repair
430const DEFAULT_REPAIR_CONFIDENCE_THRESHOLD: u8 = 3;
431
432/// Predict using teacher model, then run QA evaluation, and optionally repair
433/// if QA indicates issues (reverts_edits=true or low confidence).
434///
435/// This is a non-batched flow that processes each step synchronously.
436async fn predict_repaired_teacher(
437 example: &mut Example,
438 backend: TeacherBackend,
439 repetition_count: usize,
440) -> anyhow::Result<()> {
441 // Step 1: Run teacher prediction (non-batched for immediate results)
442 predict_teacher(example, backend, false, repetition_count).await?;
443
444 // Only proceed with QA/repair for the first prediction
445 let Some(prediction) = example.predictions.first() else {
446 return Ok(());
447 };
448
449 // Skip QA if no actual patch was generated
450 if prediction.actual_patch.is_none() {
451 return Ok(());
452 }
453
454 // Step 2: Run QA evaluation
455 let batch_provider = match backend {
456 TeacherBackend::Sonnet45 => BatchProvider::Anthropic,
457 TeacherBackend::Gpt52 => BatchProvider::Openai,
458 };
459 let qa_client = LlmClient::new(batch_provider, false)?;
460 let qa_model = model_for_backend(batch_provider);
461
462 let qa_result = if let Some(qa_prompt) = qa::build_prompt(example) {
463 match qa_client.generate(qa_model, 1024, &qa_prompt).await? {
464 Some(response_text) => Some(qa::parse_response(&response_text)),
465 None => None,
466 }
467 } else {
468 None
469 };
470
471 // Store QA result
472 example.qa = vec![qa_result.clone()];
473
474 // Step 3: Check if repair is needed and run repair if so
475 if needs_repair(example, DEFAULT_REPAIR_CONFIDENCE_THRESHOLD) {
476 let repair_client = LlmClient::new(batch_provider, false)?;
477
478 if let Some(repair_prompt) = build_repair_prompt(example) {
479 if let Some(response_text) = repair_client
480 .generate(qa_model, 16384, &repair_prompt)
481 .await?
482 {
483 match parse_repair_response(example, &response_text) {
484 Ok(mut repaired_prediction) => {
485 // Mark the prediction as coming from repaired-teacher
486 repaired_prediction.provider = PredictionProvider::RepairedTeacher(backend);
487 example.predictions.push(repaired_prediction);
488 }
489 Err(e) => {
490 // Add error prediction if parsing failed
491 example.predictions.push(ExamplePrediction {
492 actual_patch: None,
493 actual_output: response_text,
494 error: Some(format!("Failed to parse repair response: {}", e)),
495 provider: PredictionProvider::RepairedTeacher(backend),
496 });
497 }
498 }
499 }
500 }
501 }
502
503 Ok(())
504}
505
506pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
507 match provider {
508 Some(PredictionProvider::Teacher(backend)) => match backend {
509 TeacherBackend::Sonnet45 => {
510 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
511 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
512 .expect("Failed to create Anthropic client")
513 });
514 llm_client
515 .sync_batches()
516 .await
517 .context("Failed to sync Anthropic batches")?;
518 }
519 TeacherBackend::Gpt52 => {
520 let llm_client = OPENAI_CLIENT.get_or_init(|| {
521 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
522 .expect("Failed to create OpenAI client")
523 });
524 llm_client
525 .sync_batches()
526 .await
527 .context("Failed to sync OpenAI batches")?;
528 }
529 },
530 _ => (),
531 };
532 Ok(())
533}