1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 parse_output::parse_prediction_output,
10 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
11 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
12 retrieve_context::run_context_retrieval,
13};
14use anyhow::Context as _;
15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
18use gpui::{AppContext as _, AsyncApp, Task};
19use http_client::{AsyncBody, HttpClient, Method};
20use reqwest_client::ReqwestClient;
21use std::{
22 fs,
23 sync::{
24 Arc, Mutex, OnceLock,
25 atomic::{AtomicUsize, Ordering::SeqCst},
26 },
27};
28use zeta_prompt::ZetaFormat;
29
30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
32
33pub async fn run_prediction(
34 example: &mut Example,
35 args: &PredictArgs,
36 app_state: Arc<EpAppState>,
37 example_progress: &ExampleProgress,
38 mut cx: AsyncApp,
39) -> anyhow::Result<()> {
40 let repetition_count = args.repetitions;
41
42 if let Some(existing_prediction) = example.predictions.first() {
43 let has_prediction = existing_prediction.actual_patch.is_some()
44 || !existing_prediction.actual_output.is_empty();
45 if has_prediction {
46 match args.provider {
47 None => return Ok(()),
48 Some(provider) if existing_prediction.provider == provider => return Ok(()),
49 Some(_) => example.predictions.clear(),
50 }
51 }
52 }
53
54 let Some(provider) = args.provider else {
55 anyhow::bail!(
56 "No existing predictions found. Use --provider to specify which model to use for prediction."
57 );
58 };
59
60 if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
61 provider
62 {
63 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
64 run_format_prompt(
65 example,
66 &FormatPromptArgs { provider },
67 app_state.clone(),
68 example_progress,
69 cx,
70 )
71 .await?;
72
73 let step_progress = example_progress.start(Step::Predict);
74 let batched = matches!(provider, PredictionProvider::Teacher(..));
75 return predict_teacher(
76 example,
77 backend,
78 batched,
79 repetition_count,
80 args.cache_only,
81 &step_progress,
82 )
83 .await;
84 }
85
86 if let PredictionProvider::Baseten(format) = provider {
87 run_format_prompt(
88 example,
89 &FormatPromptArgs {
90 provider: PredictionProvider::Zeta2(format),
91 },
92 app_state.clone(),
93 example_progress,
94 cx,
95 )
96 .await?;
97
98 let step_progress = example_progress.start(Step::Predict);
99 return predict_baseten(example, format, &step_progress).await;
100 }
101
102 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
103 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
104
105 let step_progress = example_progress.start(Step::Predict);
106
107 if matches!(
108 provider,
109 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
110 ) {
111 step_progress.set_substatus("authenticating");
112 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
113 AUTHENTICATED
114 .get_or_init(|| {
115 let client = app_state.client.clone();
116 cx.spawn(async move |cx| {
117 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
118 eprintln!("Authentication failed: {}", e);
119 }
120 })
121 .shared()
122 })
123 .clone()
124 .await;
125 }
126
127 let ep_store = cx
128 .update(|cx| EditPredictionStore::try_global(cx))
129 .context("EditPredictionStore not initialized")?;
130
131 ep_store.update(&mut cx, |store, _cx| {
132 let model = match provider {
133 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
134 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
135 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
136 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
137 PredictionProvider::Teacher(..)
138 | PredictionProvider::TeacherNonBatching(..)
139 | PredictionProvider::Repair
140 | PredictionProvider::Baseten(_) => {
141 unreachable!()
142 }
143 };
144 store.set_edit_prediction_model(model);
145
146 // If user specified a non-default Zeta2 version, configure raw endpoint.
147 // ZED_ZETA_MODEL env var is optional.
148 if let PredictionProvider::Zeta2(format) = provider {
149 if format != ZetaFormat::default() {
150 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
151 store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
152 }
153 }
154 });
155 step_progress.set_substatus("configuring model");
156 let state = example.state.as_ref().context("state must be set")?;
157 let run_dir = RUN_DIR.join(&example.spec.name);
158
159 let updated_example = Arc::new(Mutex::new(example.clone()));
160 let current_run_ix = Arc::new(AtomicUsize::new(0));
161
162 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
163 let debug_task = cx.background_spawn({
164 let updated_example = updated_example.clone();
165 let current_run_ix = current_run_ix.clone();
166 let run_dir = run_dir.clone();
167 async move {
168 while let Some(event) = debug_rx.next().await {
169 let run_ix = current_run_ix.load(SeqCst);
170 let mut updated_example = updated_example.lock().unwrap();
171
172 let run_dir = if repetition_count > 1 {
173 run_dir.join(format!("{:03}", run_ix))
174 } else {
175 run_dir.clone()
176 };
177
178 match event {
179 DebugEvent::EditPredictionStarted(request) => {
180 assert_eq!(updated_example.predictions.len(), run_ix + 1);
181
182 if let Some(prompt) = request.prompt {
183 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
184 if matches!(provider, PredictionProvider::Zeta2(_)) {
185 updated_example.prompt.get_or_insert(ExamplePrompt {
186 input: prompt,
187 expected_output: String::new(),
188 rejected_output: None,
189 provider,
190 prefill: None,
191 });
192 }
193 }
194 }
195 DebugEvent::EditPredictionFinished(request) => {
196 assert_eq!(updated_example.predictions.len(), run_ix + 1);
197
198 if let Some(output) = request.model_output {
199 fs::write(run_dir.join("prediction_response.md"), &output)?;
200 updated_example
201 .predictions
202 .last_mut()
203 .unwrap()
204 .actual_output = output;
205 }
206 if run_ix >= repetition_count {
207 break;
208 }
209 }
210 _ => {}
211 }
212 }
213 anyhow::Ok(())
214 }
215 });
216
217 for ix in 0..repetition_count {
218 current_run_ix.store(ix, SeqCst);
219 let run_dir = if repetition_count > 1 {
220 run_dir.join(format!("{:03}", ix))
221 } else {
222 run_dir.clone()
223 };
224
225 if repetition_count > 1 {
226 step_progress.set_substatus(format!(
227 "running prediction {}/{}",
228 ix + 1,
229 repetition_count
230 ));
231 } else {
232 step_progress.set_substatus("running prediction");
233 }
234
235 fs::create_dir_all(&run_dir)?;
236 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
237 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
238 }
239 #[cfg(unix)]
240 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
241 #[cfg(windows)]
242 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
243
244 updated_example
245 .lock()
246 .unwrap()
247 .predictions
248 .push(ExamplePrediction {
249 actual_patch: None,
250 actual_output: String::new(),
251 actual_cursor: None,
252 error: None,
253 provider,
254 });
255
256 step_progress.set_substatus("requesting prediction");
257 let prediction = ep_store
258 .update(&mut cx, |store, cx| {
259 store.request_prediction(
260 &state.project,
261 &state.buffer,
262 state.cursor_position,
263 cloud_llm_client::PredictEditsRequestTrigger::Cli,
264 cx,
265 )
266 })
267 .await?;
268
269 let actual_patch = prediction.and_then(|prediction| {
270 let prediction = prediction.prediction.ok()?;
271 prediction
272 .edit_preview
273 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
274 });
275
276 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
277
278 updated_example
279 .lock()
280 .unwrap()
281 .predictions
282 .last_mut()
283 .unwrap()
284 .actual_patch = actual_patch;
285
286 if ix == repetition_count - 1 {
287 let (info, style) = if has_prediction {
288 ("predicted", InfoStyle::Normal)
289 } else {
290 ("no prediction", InfoStyle::Warning)
291 };
292 step_progress.set_info(info, style);
293 }
294 }
295
296 ep_store.update(&mut cx, |store, _| {
297 store.remove_project(&state.project);
298 });
299 debug_task.await?;
300
301 *example = Arc::into_inner(updated_example)
302 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
303 .into_inner()
304 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
305 Ok(())
306}
307
308async fn predict_teacher(
309 example: &mut Example,
310 backend: TeacherBackend,
311 batched: bool,
312 repetition_count: usize,
313 cache_only: bool,
314 step_progress: &crate::progress::StepProgress,
315) -> anyhow::Result<()> {
316 match backend {
317 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
318 predict_anthropic(
319 example,
320 backend,
321 batched,
322 repetition_count,
323 cache_only,
324 step_progress,
325 )
326 .await
327 }
328 TeacherBackend::Gpt52 => {
329 predict_openai(
330 example,
331 backend,
332 batched,
333 repetition_count,
334 cache_only,
335 step_progress,
336 )
337 .await
338 }
339 }
340}
341
342async fn predict_anthropic(
343 example: &mut Example,
344 backend: TeacherBackend,
345 batched: bool,
346 repetition_count: usize,
347 cache_only: bool,
348 step_progress: &crate::progress::StepProgress,
349) -> anyhow::Result<()> {
350 let llm_model_name = backend.model_name();
351 let max_tokens = 16384;
352 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
353 let client = if batched {
354 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
355 } else {
356 AnthropicClient::plain()
357 };
358 client.expect("Failed to create Anthropic client")
359 });
360
361 let prompt = example.prompt.as_ref().context("Prompt is required")?;
362
363 for ix in 0..repetition_count {
364 if repetition_count > 1 {
365 step_progress.set_substatus(format!(
366 "running prediction {}/{}",
367 ix + 1,
368 repetition_count
369 ));
370 } else {
371 step_progress.set_substatus("running prediction");
372 }
373
374 let messages = vec![anthropic::Message {
375 role: anthropic::Role::User,
376 content: vec![anthropic::RequestContent::Text {
377 text: prompt.input.clone(),
378 cache_control: None,
379 }],
380 }];
381
382 let seed = if repetition_count > 1 { Some(ix) } else { None };
383 let Some(response) = llm_client
384 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
385 .await?
386 else {
387 // Request stashed for batched processing
388 return Ok(());
389 };
390
391 let actual_output = response
392 .content
393 .into_iter()
394 .filter_map(|content| match content {
395 anthropic::ResponseContent::Text { text } => Some(text),
396 _ => None,
397 })
398 .collect::<Vec<String>>()
399 .join("\n");
400
401 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
402
403 let prediction = ExamplePrediction {
404 actual_patch: Some(actual_patch),
405 actual_output,
406 actual_cursor,
407 error: None,
408 provider: if batched {
409 PredictionProvider::Teacher(backend)
410 } else {
411 PredictionProvider::TeacherNonBatching(backend)
412 },
413 };
414
415 example.predictions.push(prediction);
416 }
417 Ok(())
418}
419
420async fn predict_openai(
421 example: &mut Example,
422 backend: TeacherBackend,
423 batched: bool,
424 repetition_count: usize,
425 cache_only: bool,
426 step_progress: &crate::progress::StepProgress,
427) -> anyhow::Result<()> {
428 let llm_model_name = backend.model_name();
429 let max_tokens = 16384;
430 let llm_client = OPENAI_CLIENT.get_or_init(|| {
431 let client = if batched {
432 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
433 } else {
434 OpenAiClient::plain()
435 };
436 client.expect("Failed to create OpenAI client")
437 });
438
439 let prompt = example.prompt.as_ref().context("Prompt is required")?;
440
441 for ix in 0..repetition_count {
442 if repetition_count > 1 {
443 step_progress.set_substatus(format!(
444 "running prediction {}/{}",
445 ix + 1,
446 repetition_count
447 ));
448 } else {
449 step_progress.set_substatus("running prediction");
450 }
451
452 let messages = vec![open_ai::RequestMessage::User {
453 content: open_ai::MessageContent::Plain(prompt.input.clone()),
454 }];
455
456 let seed = if repetition_count > 1 { Some(ix) } else { None };
457 let Some(response) = llm_client
458 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
459 .await?
460 else {
461 // Request stashed for batched processing
462 return Ok(());
463 };
464
465 let actual_output = response
466 .choices
467 .into_iter()
468 .filter_map(|choice| match choice.message {
469 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
470 open_ai::MessageContent::Plain(text) => text,
471 open_ai::MessageContent::Multipart(parts) => parts
472 .into_iter()
473 .filter_map(|p| match p {
474 open_ai::MessagePart::Text { text } => Some(text),
475 _ => None,
476 })
477 .collect::<Vec<_>>()
478 .join(""),
479 }),
480 _ => None,
481 })
482 .collect::<Vec<String>>()
483 .join("\n");
484
485 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
486
487 let prediction = ExamplePrediction {
488 actual_patch: Some(actual_patch),
489 actual_output,
490 actual_cursor,
491 error: None,
492 provider: if batched {
493 PredictionProvider::Teacher(backend)
494 } else {
495 PredictionProvider::TeacherNonBatching(backend)
496 },
497 };
498
499 example.predictions.push(prediction);
500 }
501 Ok(())
502}
503
504pub async fn predict_baseten(
505 example: &mut Example,
506 format: ZetaFormat,
507 step_progress: &StepProgress,
508) -> anyhow::Result<()> {
509 let model_id =
510 std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
511
512 let api_key =
513 std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
514
515 let prompt = example.prompt.as_ref().context("Prompt is required")?;
516 let prompt_text = prompt.input.clone();
517 let prefill = prompt.prefill.clone().unwrap_or_default();
518
519 step_progress.set_substatus("running prediction via baseten");
520
521 let environment: String = <&'static str>::from(&format).to_lowercase();
522 let url = format!(
523 "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
524 );
525
526 let request_body = RawCompletionRequest {
527 model: model_id,
528 prompt: prompt_text.clone(),
529 max_tokens: Some(2048),
530 temperature: Some(0.),
531 stop: vec![],
532 environment: None,
533 };
534
535 let body_bytes =
536 serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
537
538 let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
539 let request = http_client::Request::builder()
540 .method(Method::POST)
541 .uri(&url)
542 .header("Content-Type", "application/json")
543 .header("Authorization", format!("Api-Key {api_key}"))
544 .body(AsyncBody::from(body_bytes))?;
545
546 let mut response = http_client.send(request).await?;
547 let status = response.status();
548
549 let mut body = String::new();
550 response
551 .body_mut()
552 .read_to_string(&mut body)
553 .await
554 .context("Failed to read Baseten response body")?;
555
556 if !status.is_success() {
557 anyhow::bail!("Baseten API returned {status}: {body}");
558 }
559
560 let completion: RawCompletionResponse =
561 serde_json::from_str(&body).context("Failed to parse Baseten response")?;
562
563 let actual_output = completion
564 .choices
565 .into_iter()
566 .next()
567 .map(|choice| choice.text)
568 .unwrap_or_default();
569
570 let actual_output = format!("{prefill}{actual_output}");
571
572 let (actual_patch, actual_cursor) =
573 parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
574
575 let prediction = ExamplePrediction {
576 actual_patch: Some(actual_patch),
577 actual_output,
578 actual_cursor,
579 error: None,
580 provider: PredictionProvider::Baseten(format),
581 };
582
583 example.predictions.push(prediction);
584 Ok(())
585}
586
587pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
588 match provider {
589 Some(PredictionProvider::Teacher(backend)) => match backend {
590 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
591 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
592 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
593 .expect("Failed to create Anthropic client")
594 });
595 llm_client
596 .sync_batches()
597 .await
598 .context("Failed to sync Anthropic batches")?;
599 }
600 TeacherBackend::Gpt52 => {
601 let llm_client = OPENAI_CLIENT.get_or_init(|| {
602 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
603 .expect("Failed to create OpenAI client")
604 });
605 llm_client
606 .sync_batches()
607 .await
608 .context("Failed to sync OpenAI batches")?;
609 }
610 },
611 _ => (),
612 };
613 Ok(())
614}