1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 parse_output::parse_prediction_output,
10 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
11 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
12 retrieve_context::run_context_retrieval,
13};
14use anyhow::Context as _;
15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
18use gpui::{AppContext as _, AsyncApp, Task};
19use http_client::{AsyncBody, HttpClient, Method};
20use reqwest_client::ReqwestClient;
21use std::{
22 fs,
23 sync::{
24 Arc, Mutex, OnceLock,
25 atomic::{AtomicUsize, Ordering::SeqCst},
26 },
27};
28use zeta_prompt::ZetaFormat;
29
30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
32
33pub async fn run_prediction(
34 example: &mut Example,
35 args: &PredictArgs,
36 app_state: Arc<EpAppState>,
37 example_progress: &ExampleProgress,
38 mut cx: AsyncApp,
39) -> anyhow::Result<()> {
40 let repetition_count = args.repetitions;
41
42 if let Some(existing_prediction) = example.predictions.first() {
43 let has_prediction = existing_prediction.actual_patch.is_some()
44 || !existing_prediction.actual_output.is_empty();
45 if has_prediction {
46 match args.provider {
47 None => return Ok(()),
48 Some(provider) if existing_prediction.provider == provider => return Ok(()),
49 Some(_) => example.predictions.clear(),
50 }
51 }
52 }
53
54 let Some(provider) = args.provider else {
55 anyhow::bail!(
56 "No existing predictions found. Use --provider to specify which model to use for prediction."
57 );
58 };
59
60 if let PredictionProvider::Teacher(backend)
61 | PredictionProvider::TeacherMultiRegion(backend)
62 | PredictionProvider::TeacherNonBatching(backend)
63 | PredictionProvider::TeacherMultiRegionNonBatching(backend) = provider
64 {
65 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
66 run_format_prompt(
67 example,
68 &FormatPromptArgs { provider },
69 app_state.clone(),
70 example_progress,
71 cx,
72 )
73 .await?;
74
75 let step_progress = example_progress.start(Step::Predict);
76 let batched = matches!(
77 provider,
78 PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
79 );
80 return predict_teacher(
81 example,
82 backend,
83 batched,
84 repetition_count,
85 args.cache_only,
86 &step_progress,
87 )
88 .await;
89 }
90
91 if let PredictionProvider::Baseten(format) = provider {
92 run_format_prompt(
93 example,
94 &FormatPromptArgs {
95 provider: PredictionProvider::Zeta2(format),
96 },
97 app_state.clone(),
98 example_progress,
99 cx,
100 )
101 .await?;
102
103 let step_progress = example_progress.start(Step::Predict);
104 return predict_baseten(example, format, &step_progress).await;
105 }
106
107 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
108 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
109
110 let step_progress = example_progress.start(Step::Predict);
111
112 if matches!(
113 provider,
114 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
115 ) {
116 step_progress.set_substatus("authenticating");
117 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
118 AUTHENTICATED
119 .get_or_init(|| {
120 let client = app_state.client.clone();
121 cx.spawn(async move |cx| {
122 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
123 eprintln!("Authentication failed: {}", e);
124 }
125 })
126 .shared()
127 })
128 .clone()
129 .await;
130 }
131
132 let ep_store = cx
133 .update(|cx| EditPredictionStore::try_global(cx))
134 .context("EditPredictionStore not initialized")?;
135
136 ep_store.update(&mut cx, |store, _cx| {
137 let model = match provider {
138 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
139 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
140 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
141 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
142 PredictionProvider::Teacher(..)
143 | PredictionProvider::TeacherMultiRegion(..)
144 | PredictionProvider::TeacherNonBatching(..)
145 | PredictionProvider::TeacherMultiRegionNonBatching(..)
146 | PredictionProvider::Repair
147 | PredictionProvider::Baseten(_) => {
148 unreachable!()
149 }
150 };
151 store.set_edit_prediction_model(model);
152
153 // If user specified a non-default Zeta2 version, configure raw endpoint.
154 // ZED_ZETA_MODEL env var is optional.
155 if let PredictionProvider::Zeta2(format) = provider {
156 if format != ZetaFormat::default() {
157 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
158 let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
159 store.set_zeta2_raw_config(Zeta2RawConfig {
160 model_id,
161 environment,
162 format,
163 });
164 }
165 }
166 });
167 step_progress.set_substatus("configuring model");
168 let state = example.state.as_ref().context("state must be set")?;
169 let run_dir = RUN_DIR.join(&example.spec.name);
170
171 let updated_example = Arc::new(Mutex::new(example.clone()));
172 let current_run_ix = Arc::new(AtomicUsize::new(0));
173
174 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
175 let debug_task = cx.background_spawn({
176 let updated_example = updated_example.clone();
177 let current_run_ix = current_run_ix.clone();
178 let run_dir = run_dir.clone();
179 async move {
180 while let Some(event) = debug_rx.next().await {
181 let run_ix = current_run_ix.load(SeqCst);
182 let mut updated_example = updated_example.lock().unwrap();
183
184 let run_dir = if repetition_count > 1 {
185 run_dir.join(format!("{:03}", run_ix))
186 } else {
187 run_dir.clone()
188 };
189
190 match event {
191 DebugEvent::EditPredictionStarted(request) => {
192 assert_eq!(updated_example.predictions.len(), run_ix + 1);
193
194 if let Some(prompt) = request.prompt {
195 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
196 if matches!(provider, PredictionProvider::Zeta2(_)) {
197 updated_example.prompt.get_or_insert(ExamplePrompt {
198 input: prompt,
199 expected_output: String::new(),
200 rejected_output: None,
201 provider,
202 prefill: None,
203 });
204 }
205 }
206 }
207 DebugEvent::EditPredictionFinished(request) => {
208 assert_eq!(updated_example.predictions.len(), run_ix + 1);
209
210 if let Some(output) = request.model_output {
211 fs::write(run_dir.join("prediction_response.md"), &output)?;
212 updated_example
213 .predictions
214 .last_mut()
215 .unwrap()
216 .actual_output = output;
217 }
218 if run_ix >= repetition_count {
219 break;
220 }
221 }
222 _ => {}
223 }
224 }
225 anyhow::Ok(())
226 }
227 });
228
229 for ix in 0..repetition_count {
230 current_run_ix.store(ix, SeqCst);
231 let run_dir = if repetition_count > 1 {
232 run_dir.join(format!("{:03}", ix))
233 } else {
234 run_dir.clone()
235 };
236
237 if repetition_count > 1 {
238 step_progress.set_substatus(format!(
239 "running prediction {}/{}",
240 ix + 1,
241 repetition_count
242 ));
243 } else {
244 step_progress.set_substatus("running prediction");
245 }
246
247 fs::create_dir_all(&run_dir)?;
248 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
249 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
250 }
251 #[cfg(unix)]
252 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
253 #[cfg(windows)]
254 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
255
256 updated_example
257 .lock()
258 .unwrap()
259 .predictions
260 .push(ExamplePrediction {
261 actual_patch: None,
262 actual_output: String::new(),
263 actual_cursor: None,
264 error: None,
265 provider,
266 });
267
268 step_progress.set_substatus("requesting prediction");
269 let prediction = ep_store
270 .update(&mut cx, |store, cx| {
271 store.request_prediction(
272 &state.project,
273 &state.buffer,
274 state.cursor_position,
275 cloud_llm_client::PredictEditsRequestTrigger::Cli,
276 cx,
277 )
278 })
279 .await?;
280
281 let actual_patch = prediction.and_then(|prediction| {
282 let prediction = prediction.prediction.ok()?;
283 prediction
284 .edit_preview
285 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
286 });
287
288 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
289
290 updated_example
291 .lock()
292 .unwrap()
293 .predictions
294 .last_mut()
295 .unwrap()
296 .actual_patch = actual_patch;
297
298 if ix == repetition_count - 1 {
299 let (info, style) = if has_prediction {
300 ("predicted", InfoStyle::Normal)
301 } else {
302 ("no prediction", InfoStyle::Warning)
303 };
304 step_progress.set_info(info, style);
305 }
306 }
307
308 ep_store.update(&mut cx, |store, _| {
309 store.remove_project(&state.project);
310 });
311 debug_task.await?;
312
313 *example = Arc::into_inner(updated_example)
314 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
315 .into_inner()
316 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
317 Ok(())
318}
319
320async fn predict_teacher(
321 example: &mut Example,
322 backend: TeacherBackend,
323 batched: bool,
324 repetition_count: usize,
325 cache_only: bool,
326 step_progress: &crate::progress::StepProgress,
327) -> anyhow::Result<()> {
328 match backend {
329 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
330 predict_anthropic(
331 example,
332 backend,
333 batched,
334 repetition_count,
335 cache_only,
336 step_progress,
337 )
338 .await
339 }
340 TeacherBackend::Gpt52 => {
341 predict_openai(
342 example,
343 backend,
344 batched,
345 repetition_count,
346 cache_only,
347 step_progress,
348 )
349 .await
350 }
351 }
352}
353
354async fn predict_anthropic(
355 example: &mut Example,
356 backend: TeacherBackend,
357 batched: bool,
358 repetition_count: usize,
359 cache_only: bool,
360 step_progress: &crate::progress::StepProgress,
361) -> anyhow::Result<()> {
362 let llm_model_name = backend.model_name();
363 let max_tokens = 16384;
364 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
365 let client = if batched {
366 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
367 } else {
368 AnthropicClient::plain()
369 };
370 client.expect("Failed to create Anthropic client")
371 });
372
373 let prompt = example.prompt.as_ref().context("Prompt is required")?;
374
375 for ix in 0..repetition_count {
376 if repetition_count > 1 {
377 step_progress.set_substatus(format!(
378 "running prediction {}/{}",
379 ix + 1,
380 repetition_count
381 ));
382 } else {
383 step_progress.set_substatus("running prediction");
384 }
385
386 let messages = vec![anthropic::Message {
387 role: anthropic::Role::User,
388 content: vec![anthropic::RequestContent::Text {
389 text: prompt.input.clone(),
390 cache_control: None,
391 }],
392 }];
393
394 let seed = if repetition_count > 1 { Some(ix) } else { None };
395 let Some(response) = llm_client
396 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
397 .await?
398 else {
399 // Request stashed for batched processing
400 continue;
401 };
402
403 let actual_output = response
404 .content
405 .into_iter()
406 .filter_map(|content| match content {
407 anthropic::ResponseContent::Text { text } => Some(text),
408 _ => None,
409 })
410 .collect::<Vec<String>>()
411 .join("\n");
412
413 let parser_provider = if batched {
414 example
415 .prompt
416 .as_ref()
417 .map(|prompt| prompt.provider)
418 .unwrap_or(PredictionProvider::Teacher(backend))
419 } else {
420 match example.prompt.as_ref().map(|prompt| prompt.provider) {
421 Some(PredictionProvider::TeacherMultiRegion(_))
422 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
423 PredictionProvider::TeacherMultiRegionNonBatching(backend)
424 }
425 _ => PredictionProvider::TeacherNonBatching(backend),
426 }
427 };
428
429 let (actual_patch, actual_cursor) = match parser_provider {
430 PredictionProvider::TeacherMultiRegion(_)
431 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
432 TeacherMultiRegionPrompt::parse(example, &actual_output)?
433 }
434 _ => TeacherPrompt::parse(example, &actual_output)?,
435 };
436
437 let prediction = ExamplePrediction {
438 actual_patch: Some(actual_patch),
439 actual_output,
440 actual_cursor,
441 error: None,
442 provider: if batched {
443 match example.prompt.as_ref().map(|prompt| prompt.provider) {
444 Some(PredictionProvider::TeacherMultiRegion(_)) => {
445 PredictionProvider::TeacherMultiRegion(backend)
446 }
447 _ => PredictionProvider::Teacher(backend),
448 }
449 } else {
450 match example.prompt.as_ref().map(|prompt| prompt.provider) {
451 Some(PredictionProvider::TeacherMultiRegion(_))
452 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
453 PredictionProvider::TeacherMultiRegionNonBatching(backend)
454 }
455 _ => PredictionProvider::TeacherNonBatching(backend),
456 }
457 },
458 };
459
460 example.predictions.push(prediction);
461 }
462 Ok(())
463}
464
465async fn predict_openai(
466 example: &mut Example,
467 backend: TeacherBackend,
468 batched: bool,
469 repetition_count: usize,
470 cache_only: bool,
471 step_progress: &crate::progress::StepProgress,
472) -> anyhow::Result<()> {
473 let llm_model_name = backend.model_name();
474 let max_tokens = 16384;
475 let llm_client = OPENAI_CLIENT.get_or_init(|| {
476 let client = if batched {
477 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
478 } else {
479 OpenAiClient::plain()
480 };
481 client.expect("Failed to create OpenAI client")
482 });
483
484 let prompt = example.prompt.as_ref().context("Prompt is required")?;
485
486 for ix in 0..repetition_count {
487 if repetition_count > 1 {
488 step_progress.set_substatus(format!(
489 "running prediction {}/{}",
490 ix + 1,
491 repetition_count
492 ));
493 } else {
494 step_progress.set_substatus("running prediction");
495 }
496
497 let messages = vec![open_ai::RequestMessage::User {
498 content: open_ai::MessageContent::Plain(prompt.input.clone()),
499 }];
500
501 let seed = if repetition_count > 1 { Some(ix) } else { None };
502 let Some(response) = llm_client
503 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
504 .await?
505 else {
506 // Request stashed for batched processing
507 continue;
508 };
509
510 let actual_output = response
511 .choices
512 .into_iter()
513 .filter_map(|choice| match choice.message {
514 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
515 open_ai::MessageContent::Plain(text) => text,
516 open_ai::MessageContent::Multipart(parts) => parts
517 .into_iter()
518 .filter_map(|p| match p {
519 open_ai::MessagePart::Text { text } => Some(text),
520 _ => None,
521 })
522 .collect::<Vec<_>>()
523 .join(""),
524 }),
525 _ => None,
526 })
527 .collect::<Vec<String>>()
528 .join("\n");
529
530 let parser_provider = if batched {
531 example
532 .prompt
533 .as_ref()
534 .map(|prompt| prompt.provider)
535 .unwrap_or(PredictionProvider::Teacher(backend))
536 } else {
537 match example.prompt.as_ref().map(|prompt| prompt.provider) {
538 Some(PredictionProvider::TeacherMultiRegion(_))
539 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
540 PredictionProvider::TeacherMultiRegionNonBatching(backend)
541 }
542 _ => PredictionProvider::TeacherNonBatching(backend),
543 }
544 };
545
546 let (actual_patch, actual_cursor) = match parser_provider {
547 PredictionProvider::TeacherMultiRegion(_)
548 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
549 TeacherMultiRegionPrompt::parse(example, &actual_output)?
550 }
551 _ => TeacherPrompt::parse(example, &actual_output)?,
552 };
553
554 let prediction = ExamplePrediction {
555 actual_patch: Some(actual_patch),
556 actual_output,
557 actual_cursor,
558 error: None,
559 provider: if batched {
560 match example.prompt.as_ref().map(|prompt| prompt.provider) {
561 Some(PredictionProvider::TeacherMultiRegion(_)) => {
562 PredictionProvider::TeacherMultiRegion(backend)
563 }
564 _ => PredictionProvider::Teacher(backend),
565 }
566 } else {
567 match example.prompt.as_ref().map(|prompt| prompt.provider) {
568 Some(PredictionProvider::TeacherMultiRegion(_))
569 | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
570 PredictionProvider::TeacherMultiRegionNonBatching(backend)
571 }
572 _ => PredictionProvider::TeacherNonBatching(backend),
573 }
574 },
575 };
576
577 example.predictions.push(prediction);
578 }
579 Ok(())
580}
581
582pub async fn predict_baseten(
583 example: &mut Example,
584 format: ZetaFormat,
585 step_progress: &StepProgress,
586) -> anyhow::Result<()> {
587 let model_id =
588 std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
589
590 let api_key =
591 std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
592
593 let prompt = example.prompt.as_ref().context("Prompt is required")?;
594 let prompt_text = prompt.input.clone();
595 let prefill = prompt.prefill.clone().unwrap_or_default();
596
597 step_progress.set_substatus("running prediction via baseten");
598
599 let environment: String = <&'static str>::from(&format).to_lowercase();
600 let url = format!(
601 "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
602 );
603
604 let request_body = RawCompletionRequest {
605 model: model_id,
606 prompt: prompt_text.clone(),
607 max_tokens: Some(2048),
608 temperature: Some(0.),
609 stop: vec![],
610 environment: None,
611 };
612
613 let body_bytes =
614 serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
615
616 let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
617 let request = http_client::Request::builder()
618 .method(Method::POST)
619 .uri(&url)
620 .header("Content-Type", "application/json")
621 .header("Authorization", format!("Api-Key {api_key}"))
622 .body(AsyncBody::from(body_bytes))?;
623
624 let mut response = http_client.send(request).await?;
625 let status = response.status();
626
627 let mut body = String::new();
628 response
629 .body_mut()
630 .read_to_string(&mut body)
631 .await
632 .context("Failed to read Baseten response body")?;
633
634 if !status.is_success() {
635 anyhow::bail!("Baseten API returned {status}: {body}");
636 }
637
638 let completion: RawCompletionResponse =
639 serde_json::from_str(&body).context("Failed to parse Baseten response")?;
640
641 let actual_output = completion
642 .choices
643 .into_iter()
644 .next()
645 .map(|choice| choice.text)
646 .unwrap_or_default();
647
648 let actual_output = format!("{prefill}{actual_output}");
649
650 let (actual_patch, actual_cursor) =
651 parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
652
653 let prediction = ExamplePrediction {
654 actual_patch: Some(actual_patch),
655 actual_output,
656 actual_cursor,
657 error: None,
658 provider: PredictionProvider::Baseten(format),
659 };
660
661 example.predictions.push(prediction);
662 Ok(())
663}
664
665pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
666 match provider {
667 Some(PredictionProvider::Teacher(backend))
668 | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
669 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
670 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
671 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
672 .expect("Failed to create Anthropic client")
673 });
674 llm_client
675 .sync_batches()
676 .await
677 .context("Failed to sync Anthropic batches")?;
678 }
679 TeacherBackend::Gpt52 => {
680 let llm_client = OPENAI_CLIENT.get_or_init(|| {
681 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
682 .expect("Failed to create OpenAI client")
683 });
684 llm_client
685 .sync_batches()
686 .await
687 .context("Failed to sync OpenAI batches")?;
688 }
689 },
690 _ => (),
691 };
692 Ok(())
693}