1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 parse_output::parse_prediction_output,
10 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
11 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
12 retrieve_context::run_context_retrieval,
13};
14use anyhow::Context as _;
15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
18use gpui::{AppContext as _, AsyncApp, Task};
19use http_client::{AsyncBody, HttpClient, Method};
20use reqwest_client::ReqwestClient;
21use std::{
22 fs,
23 sync::{
24 Arc, Mutex, OnceLock,
25 atomic::{AtomicUsize, Ordering::SeqCst},
26 },
27};
28use zeta_prompt::ZetaFormat;
29
30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
32
33pub async fn run_prediction(
34 example: &mut Example,
35 args: &PredictArgs,
36 app_state: Arc<EpAppState>,
37 example_progress: &ExampleProgress,
38 mut cx: AsyncApp,
39) -> anyhow::Result<()> {
40 let repetition_count = args.repetitions;
41
42 if let Some(existing_prediction) = example.predictions.first() {
43 let has_prediction = existing_prediction.actual_patch.is_some()
44 || !existing_prediction.actual_output.is_empty();
45 if has_prediction {
46 match args.provider {
47 None => return Ok(()),
48 Some(provider) if existing_prediction.provider == provider => return Ok(()),
49 Some(_) => example.predictions.clear(),
50 }
51 }
52 }
53
54 let Some(provider) = args.provider else {
55 anyhow::bail!(
56 "No existing predictions found. Use --provider to specify which model to use for prediction."
57 );
58 };
59
60 if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
61 provider
62 {
63 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
64 run_format_prompt(
65 example,
66 &FormatPromptArgs { provider },
67 app_state.clone(),
68 example_progress,
69 cx,
70 )
71 .await?;
72
73 let step_progress = example_progress.start(Step::Predict);
74 let batched = matches!(provider, PredictionProvider::Teacher(..));
75 return predict_teacher(
76 example,
77 backend,
78 batched,
79 repetition_count,
80 args.cache_only,
81 &step_progress,
82 )
83 .await;
84 }
85
86 if let PredictionProvider::Baseten(format) = provider {
87 run_format_prompt(
88 example,
89 &FormatPromptArgs {
90 provider: PredictionProvider::Zeta2(format),
91 },
92 app_state.clone(),
93 example_progress,
94 cx,
95 )
96 .await?;
97
98 let step_progress = example_progress.start(Step::Predict);
99 return predict_baseten(example, format, &step_progress).await;
100 }
101
102 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
103 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
104
105 let step_progress = example_progress.start(Step::Predict);
106
107 if matches!(
108 provider,
109 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
110 ) {
111 step_progress.set_substatus("authenticating");
112 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
113 AUTHENTICATED
114 .get_or_init(|| {
115 let client = app_state.client.clone();
116 cx.spawn(async move |cx| {
117 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
118 eprintln!("Authentication failed: {}", e);
119 }
120 })
121 .shared()
122 })
123 .clone()
124 .await;
125 }
126
127 let ep_store = cx
128 .update(|cx| EditPredictionStore::try_global(cx))
129 .context("EditPredictionStore not initialized")?;
130
131 ep_store.update(&mut cx, |store, _cx| {
132 let model = match provider {
133 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
134 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
135 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
136 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
137 PredictionProvider::Teacher(..)
138 | PredictionProvider::TeacherNonBatching(..)
139 | PredictionProvider::Repair
140 | PredictionProvider::Baseten(_) => {
141 unreachable!()
142 }
143 };
144 store.set_edit_prediction_model(model);
145
146 // If user specified a non-default Zeta2 version, configure raw endpoint.
147 // ZED_ZETA_MODEL env var is optional.
148 if let PredictionProvider::Zeta2(format) = provider {
149 if format != ZetaFormat::default() {
150 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
151 let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
152 store.set_zeta2_raw_config(Zeta2RawConfig {
153 model_id,
154 environment,
155 format,
156 });
157 }
158 }
159 });
160 step_progress.set_substatus("configuring model");
161 let state = example.state.as_ref().context("state must be set")?;
162 let run_dir = RUN_DIR.join(&example.spec.name);
163
164 let updated_example = Arc::new(Mutex::new(example.clone()));
165 let current_run_ix = Arc::new(AtomicUsize::new(0));
166
167 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
168 let debug_task = cx.background_spawn({
169 let updated_example = updated_example.clone();
170 let current_run_ix = current_run_ix.clone();
171 let run_dir = run_dir.clone();
172 async move {
173 while let Some(event) = debug_rx.next().await {
174 let run_ix = current_run_ix.load(SeqCst);
175 let mut updated_example = updated_example.lock().unwrap();
176
177 let run_dir = if repetition_count > 1 {
178 run_dir.join(format!("{:03}", run_ix))
179 } else {
180 run_dir.clone()
181 };
182
183 match event {
184 DebugEvent::EditPredictionStarted(request) => {
185 assert_eq!(updated_example.predictions.len(), run_ix + 1);
186
187 if let Some(prompt) = request.prompt {
188 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
189 if matches!(provider, PredictionProvider::Zeta2(_)) {
190 updated_example.prompt.get_or_insert(ExamplePrompt {
191 input: prompt,
192 expected_output: String::new(),
193 rejected_output: None,
194 provider,
195 prefill: None,
196 });
197 }
198 }
199 }
200 DebugEvent::EditPredictionFinished(request) => {
201 assert_eq!(updated_example.predictions.len(), run_ix + 1);
202
203 if let Some(output) = request.model_output {
204 fs::write(run_dir.join("prediction_response.md"), &output)?;
205 updated_example
206 .predictions
207 .last_mut()
208 .unwrap()
209 .actual_output = output;
210 }
211 if run_ix >= repetition_count {
212 break;
213 }
214 }
215 _ => {}
216 }
217 }
218 anyhow::Ok(())
219 }
220 });
221
222 for ix in 0..repetition_count {
223 current_run_ix.store(ix, SeqCst);
224 let run_dir = if repetition_count > 1 {
225 run_dir.join(format!("{:03}", ix))
226 } else {
227 run_dir.clone()
228 };
229
230 if repetition_count > 1 {
231 step_progress.set_substatus(format!(
232 "running prediction {}/{}",
233 ix + 1,
234 repetition_count
235 ));
236 } else {
237 step_progress.set_substatus("running prediction");
238 }
239
240 fs::create_dir_all(&run_dir)?;
241 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
242 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
243 }
244 #[cfg(unix)]
245 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
246 #[cfg(windows)]
247 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
248
249 updated_example
250 .lock()
251 .unwrap()
252 .predictions
253 .push(ExamplePrediction {
254 actual_patch: None,
255 actual_output: String::new(),
256 actual_cursor: None,
257 error: None,
258 provider,
259 });
260
261 step_progress.set_substatus("requesting prediction");
262 let prediction = ep_store
263 .update(&mut cx, |store, cx| {
264 store.request_prediction(
265 &state.project,
266 &state.buffer,
267 state.cursor_position,
268 cloud_llm_client::PredictEditsRequestTrigger::Cli,
269 cx,
270 )
271 })
272 .await?;
273
274 let actual_patch = prediction.and_then(|prediction| {
275 let prediction = prediction.prediction.ok()?;
276 prediction
277 .edit_preview
278 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
279 });
280
281 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
282
283 updated_example
284 .lock()
285 .unwrap()
286 .predictions
287 .last_mut()
288 .unwrap()
289 .actual_patch = actual_patch;
290
291 if ix == repetition_count - 1 {
292 let (info, style) = if has_prediction {
293 ("predicted", InfoStyle::Normal)
294 } else {
295 ("no prediction", InfoStyle::Warning)
296 };
297 step_progress.set_info(info, style);
298 }
299 }
300
301 ep_store.update(&mut cx, |store, _| {
302 store.remove_project(&state.project);
303 });
304 debug_task.await?;
305
306 *example = Arc::into_inner(updated_example)
307 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
308 .into_inner()
309 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
310 Ok(())
311}
312
313async fn predict_teacher(
314 example: &mut Example,
315 backend: TeacherBackend,
316 batched: bool,
317 repetition_count: usize,
318 cache_only: bool,
319 step_progress: &crate::progress::StepProgress,
320) -> anyhow::Result<()> {
321 match backend {
322 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
323 predict_anthropic(
324 example,
325 backend,
326 batched,
327 repetition_count,
328 cache_only,
329 step_progress,
330 )
331 .await
332 }
333 TeacherBackend::Gpt52 => {
334 predict_openai(
335 example,
336 backend,
337 batched,
338 repetition_count,
339 cache_only,
340 step_progress,
341 )
342 .await
343 }
344 }
345}
346
347async fn predict_anthropic(
348 example: &mut Example,
349 backend: TeacherBackend,
350 batched: bool,
351 repetition_count: usize,
352 cache_only: bool,
353 step_progress: &crate::progress::StepProgress,
354) -> anyhow::Result<()> {
355 let llm_model_name = backend.model_name();
356 let max_tokens = 16384;
357 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
358 let client = if batched {
359 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
360 } else {
361 AnthropicClient::plain()
362 };
363 client.expect("Failed to create Anthropic client")
364 });
365
366 let prompt = example.prompt.as_ref().context("Prompt is required")?;
367
368 for ix in 0..repetition_count {
369 if repetition_count > 1 {
370 step_progress.set_substatus(format!(
371 "running prediction {}/{}",
372 ix + 1,
373 repetition_count
374 ));
375 } else {
376 step_progress.set_substatus("running prediction");
377 }
378
379 let messages = vec![anthropic::Message {
380 role: anthropic::Role::User,
381 content: vec![anthropic::RequestContent::Text {
382 text: prompt.input.clone(),
383 cache_control: None,
384 }],
385 }];
386
387 let seed = if repetition_count > 1 { Some(ix) } else { None };
388 let Some(response) = llm_client
389 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
390 .await?
391 else {
392 // Request stashed for batched processing
393 continue;
394 };
395
396 let actual_output = response
397 .content
398 .into_iter()
399 .filter_map(|content| match content {
400 anthropic::ResponseContent::Text { text } => Some(text),
401 _ => None,
402 })
403 .collect::<Vec<String>>()
404 .join("\n");
405
406 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
407
408 let prediction = ExamplePrediction {
409 actual_patch: Some(actual_patch),
410 actual_output,
411 actual_cursor,
412 error: None,
413 provider: if batched {
414 PredictionProvider::Teacher(backend)
415 } else {
416 PredictionProvider::TeacherNonBatching(backend)
417 },
418 };
419
420 example.predictions.push(prediction);
421 }
422 Ok(())
423}
424
425async fn predict_openai(
426 example: &mut Example,
427 backend: TeacherBackend,
428 batched: bool,
429 repetition_count: usize,
430 cache_only: bool,
431 step_progress: &crate::progress::StepProgress,
432) -> anyhow::Result<()> {
433 let llm_model_name = backend.model_name();
434 let max_tokens = 16384;
435 let llm_client = OPENAI_CLIENT.get_or_init(|| {
436 let client = if batched {
437 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
438 } else {
439 OpenAiClient::plain()
440 };
441 client.expect("Failed to create OpenAI client")
442 });
443
444 let prompt = example.prompt.as_ref().context("Prompt is required")?;
445
446 for ix in 0..repetition_count {
447 if repetition_count > 1 {
448 step_progress.set_substatus(format!(
449 "running prediction {}/{}",
450 ix + 1,
451 repetition_count
452 ));
453 } else {
454 step_progress.set_substatus("running prediction");
455 }
456
457 let messages = vec![open_ai::RequestMessage::User {
458 content: open_ai::MessageContent::Plain(prompt.input.clone()),
459 }];
460
461 let seed = if repetition_count > 1 { Some(ix) } else { None };
462 let Some(response) = llm_client
463 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
464 .await?
465 else {
466 // Request stashed for batched processing
467 continue;
468 };
469
470 let actual_output = response
471 .choices
472 .into_iter()
473 .filter_map(|choice| match choice.message {
474 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
475 open_ai::MessageContent::Plain(text) => text,
476 open_ai::MessageContent::Multipart(parts) => parts
477 .into_iter()
478 .filter_map(|p| match p {
479 open_ai::MessagePart::Text { text } => Some(text),
480 _ => None,
481 })
482 .collect::<Vec<_>>()
483 .join(""),
484 }),
485 _ => None,
486 })
487 .collect::<Vec<String>>()
488 .join("\n");
489
490 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
491
492 let prediction = ExamplePrediction {
493 actual_patch: Some(actual_patch),
494 actual_output,
495 actual_cursor,
496 error: None,
497 provider: if batched {
498 PredictionProvider::Teacher(backend)
499 } else {
500 PredictionProvider::TeacherNonBatching(backend)
501 },
502 };
503
504 example.predictions.push(prediction);
505 }
506 Ok(())
507}
508
509pub async fn predict_baseten(
510 example: &mut Example,
511 format: ZetaFormat,
512 step_progress: &StepProgress,
513) -> anyhow::Result<()> {
514 let model_id =
515 std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
516
517 let api_key =
518 std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
519
520 let prompt = example.prompt.as_ref().context("Prompt is required")?;
521 let prompt_text = prompt.input.clone();
522 let prefill = prompt.prefill.clone().unwrap_or_default();
523
524 step_progress.set_substatus("running prediction via baseten");
525
526 let environment: String = <&'static str>::from(&format).to_lowercase();
527 let url = format!(
528 "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
529 );
530
531 let request_body = RawCompletionRequest {
532 model: model_id,
533 prompt: prompt_text.clone(),
534 max_tokens: Some(2048),
535 temperature: Some(0.),
536 stop: vec![],
537 environment: None,
538 };
539
540 let body_bytes =
541 serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
542
543 let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
544 let request = http_client::Request::builder()
545 .method(Method::POST)
546 .uri(&url)
547 .header("Content-Type", "application/json")
548 .header("Authorization", format!("Api-Key {api_key}"))
549 .body(AsyncBody::from(body_bytes))?;
550
551 let mut response = http_client.send(request).await?;
552 let status = response.status();
553
554 let mut body = String::new();
555 response
556 .body_mut()
557 .read_to_string(&mut body)
558 .await
559 .context("Failed to read Baseten response body")?;
560
561 if !status.is_success() {
562 anyhow::bail!("Baseten API returned {status}: {body}");
563 }
564
565 let completion: RawCompletionResponse =
566 serde_json::from_str(&body).context("Failed to parse Baseten response")?;
567
568 let actual_output = completion
569 .choices
570 .into_iter()
571 .next()
572 .map(|choice| choice.text)
573 .unwrap_or_default();
574
575 let actual_output = format!("{prefill}{actual_output}");
576
577 let (actual_patch, actual_cursor) =
578 parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
579
580 let prediction = ExamplePrediction {
581 actual_patch: Some(actual_patch),
582 actual_output,
583 actual_cursor,
584 error: None,
585 provider: PredictionProvider::Baseten(format),
586 };
587
588 example.predictions.push(prediction);
589 Ok(())
590}
591
592pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
593 match provider {
594 Some(PredictionProvider::Teacher(backend)) => match backend {
595 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
596 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
597 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
598 .expect("Failed to create Anthropic client")
599 });
600 llm_client
601 .sync_batches()
602 .await
603 .context("Failed to sync Anthropic batches")?;
604 }
605 TeacherBackend::Gpt52 => {
606 let llm_client = OPENAI_CLIENT.get_or_init(|| {
607 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
608 .expect("Failed to create OpenAI client")
609 });
610 llm_client
611 .sync_batches()
612 .await
613 .context("Failed to sync OpenAI batches")?;
614 }
615 },
616 _ => (),
617 };
618 Ok(())
619}