predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    parse_output::parse_prediction_output,
 10    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 11    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
 12    retrieve_context::run_context_retrieval,
 13};
 14use anyhow::Context as _;
 15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
 16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
 18use gpui::{AppContext as _, AsyncApp, Task};
 19use http_client::{AsyncBody, HttpClient, Method};
 20use reqwest_client::ReqwestClient;
 21use std::{
 22    fs,
 23    sync::{
 24        Arc, Mutex, OnceLock,
 25        atomic::{AtomicUsize, Ordering::SeqCst},
 26    },
 27};
 28use zeta_prompt::ZetaFormat;
 29
 30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 32
 33pub async fn run_prediction(
 34    example: &mut Example,
 35    args: &PredictArgs,
 36    app_state: Arc<EpAppState>,
 37    example_progress: &ExampleProgress,
 38    mut cx: AsyncApp,
 39) -> anyhow::Result<()> {
 40    let repetition_count = args.repetitions;
 41
 42    if let Some(existing_prediction) = example.predictions.first() {
 43        let has_prediction = existing_prediction.actual_patch.is_some()
 44            || !existing_prediction.actual_output.is_empty();
 45        if has_prediction {
 46            match args.provider {
 47                None => return Ok(()),
 48                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 49                Some(_) => example.predictions.clear(),
 50            }
 51        }
 52    }
 53
 54    let Some(provider) = args.provider else {
 55        anyhow::bail!(
 56            "No existing predictions found. Use --provider to specify which model to use for prediction."
 57        );
 58    };
 59
 60    if let PredictionProvider::Teacher(backend)
 61    | PredictionProvider::TeacherMultiRegion(backend)
 62    | PredictionProvider::TeacherNonBatching(backend)
 63    | PredictionProvider::TeacherMultiRegionNonBatching(backend) = provider
 64    {
 65        run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 66        run_format_prompt(
 67            example,
 68            &FormatPromptArgs { provider },
 69            app_state.clone(),
 70            example_progress,
 71            cx,
 72        )
 73        .await?;
 74
 75        let step_progress = example_progress.start(Step::Predict);
 76        let batched = matches!(
 77            provider,
 78            PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
 79        );
 80        return predict_teacher(
 81            example,
 82            backend,
 83            batched,
 84            repetition_count,
 85            args.cache_only,
 86            &step_progress,
 87        )
 88        .await;
 89    }
 90
 91    if let PredictionProvider::Baseten(format) = provider {
 92        run_format_prompt(
 93            example,
 94            &FormatPromptArgs {
 95                provider: PredictionProvider::Zeta2(format),
 96            },
 97            app_state.clone(),
 98            example_progress,
 99            cx,
100        )
101        .await?;
102
103        let step_progress = example_progress.start(Step::Predict);
104        return predict_baseten(example, format, &step_progress).await;
105    }
106
107    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
108    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
109
110    let step_progress = example_progress.start(Step::Predict);
111
112    if matches!(
113        provider,
114        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
115    ) {
116        step_progress.set_substatus("authenticating");
117        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
118        AUTHENTICATED
119            .get_or_init(|| {
120                let client = app_state.client.clone();
121                cx.spawn(async move |cx| {
122                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
123                        eprintln!("Authentication failed: {}", e);
124                    }
125                })
126                .shared()
127            })
128            .clone()
129            .await;
130    }
131
132    let ep_store = cx
133        .update(|cx| EditPredictionStore::try_global(cx))
134        .context("EditPredictionStore not initialized")?;
135
136    ep_store.update(&mut cx, |store, _cx| {
137        let model = match provider {
138            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
139            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
140            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
141            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
142            PredictionProvider::Teacher(..)
143            | PredictionProvider::TeacherMultiRegion(..)
144            | PredictionProvider::TeacherNonBatching(..)
145            | PredictionProvider::TeacherMultiRegionNonBatching(..)
146            | PredictionProvider::Repair
147            | PredictionProvider::Baseten(_) => {
148                unreachable!()
149            }
150        };
151        store.set_edit_prediction_model(model);
152
153        // If user specified a non-default Zeta2 version, configure raw endpoint.
154        // ZED_ZETA_MODEL env var is optional.
155        if let PredictionProvider::Zeta2(format) = provider {
156            if format != ZetaFormat::default() {
157                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
158                let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
159                store.set_zeta2_raw_config(Zeta2RawConfig {
160                    model_id,
161                    environment,
162                    format,
163                });
164            }
165        }
166    });
167    step_progress.set_substatus("configuring model");
168    let state = example.state.as_ref().context("state must be set")?;
169    let run_dir = RUN_DIR.join(&example.spec.name);
170
171    let updated_example = Arc::new(Mutex::new(example.clone()));
172    let current_run_ix = Arc::new(AtomicUsize::new(0));
173
174    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
175    let debug_task = cx.background_spawn({
176        let updated_example = updated_example.clone();
177        let current_run_ix = current_run_ix.clone();
178        let run_dir = run_dir.clone();
179        async move {
180            while let Some(event) = debug_rx.next().await {
181                let run_ix = current_run_ix.load(SeqCst);
182                let mut updated_example = updated_example.lock().unwrap();
183
184                let run_dir = if repetition_count > 1 {
185                    run_dir.join(format!("{:03}", run_ix))
186                } else {
187                    run_dir.clone()
188                };
189
190                match event {
191                    DebugEvent::EditPredictionStarted(request) => {
192                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
193
194                        if let Some(prompt) = request.prompt {
195                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
196                            if matches!(provider, PredictionProvider::Zeta2(_)) {
197                                updated_example.prompt.get_or_insert(ExamplePrompt {
198                                    input: prompt,
199                                    expected_output: String::new(),
200                                    rejected_output: None,
201                                    provider,
202                                    prefill: None,
203                                });
204                            }
205                        }
206                    }
207                    DebugEvent::EditPredictionFinished(request) => {
208                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
209
210                        if let Some(output) = request.model_output {
211                            fs::write(run_dir.join("prediction_response.md"), &output)?;
212                            updated_example
213                                .predictions
214                                .last_mut()
215                                .unwrap()
216                                .actual_output = output;
217                        }
218                        if run_ix >= repetition_count {
219                            break;
220                        }
221                    }
222                    _ => {}
223                }
224            }
225            anyhow::Ok(())
226        }
227    });
228
229    for ix in 0..repetition_count {
230        current_run_ix.store(ix, SeqCst);
231        let run_dir = if repetition_count > 1 {
232            run_dir.join(format!("{:03}", ix))
233        } else {
234            run_dir.clone()
235        };
236
237        if repetition_count > 1 {
238            step_progress.set_substatus(format!(
239                "running prediction {}/{}",
240                ix + 1,
241                repetition_count
242            ));
243        } else {
244            step_progress.set_substatus("running prediction");
245        }
246
247        fs::create_dir_all(&run_dir)?;
248        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
249            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
250        }
251        #[cfg(unix)]
252        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
253        #[cfg(windows)]
254        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
255
256        updated_example
257            .lock()
258            .unwrap()
259            .predictions
260            .push(ExamplePrediction {
261                actual_patch: None,
262                actual_output: String::new(),
263                actual_cursor: None,
264                error: None,
265                provider,
266            });
267
268        step_progress.set_substatus("requesting prediction");
269        let prediction = ep_store
270            .update(&mut cx, |store, cx| {
271                store.request_prediction(
272                    &state.project,
273                    &state.buffer,
274                    state.cursor_position,
275                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
276                    cx,
277                )
278            })
279            .await?;
280
281        let actual_patch = prediction.and_then(|prediction| {
282            let prediction = prediction.prediction.ok()?;
283            prediction
284                .edit_preview
285                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
286        });
287
288        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
289
290        updated_example
291            .lock()
292            .unwrap()
293            .predictions
294            .last_mut()
295            .unwrap()
296            .actual_patch = actual_patch;
297
298        if ix == repetition_count - 1 {
299            let (info, style) = if has_prediction {
300                ("predicted", InfoStyle::Normal)
301            } else {
302                ("no prediction", InfoStyle::Warning)
303            };
304            step_progress.set_info(info, style);
305        }
306    }
307
308    ep_store.update(&mut cx, |store, _| {
309        store.remove_project(&state.project);
310    });
311    debug_task.await?;
312
313    *example = Arc::into_inner(updated_example)
314        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
315        .into_inner()
316        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
317    Ok(())
318}
319
320async fn predict_teacher(
321    example: &mut Example,
322    backend: TeacherBackend,
323    batched: bool,
324    repetition_count: usize,
325    cache_only: bool,
326    step_progress: &crate::progress::StepProgress,
327) -> anyhow::Result<()> {
328    match backend {
329        TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
330            predict_anthropic(
331                example,
332                backend,
333                batched,
334                repetition_count,
335                cache_only,
336                step_progress,
337            )
338            .await
339        }
340        TeacherBackend::Gpt52 => {
341            predict_openai(
342                example,
343                backend,
344                batched,
345                repetition_count,
346                cache_only,
347                step_progress,
348            )
349            .await
350        }
351    }
352}
353
354async fn predict_anthropic(
355    example: &mut Example,
356    backend: TeacherBackend,
357    batched: bool,
358    repetition_count: usize,
359    cache_only: bool,
360    step_progress: &crate::progress::StepProgress,
361) -> anyhow::Result<()> {
362    let llm_model_name = backend.model_name();
363    let max_tokens = 16384;
364    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
365        let client = if batched {
366            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
367        } else {
368            AnthropicClient::plain()
369        };
370        client.expect("Failed to create Anthropic client")
371    });
372
373    let prompt = example.prompt.as_ref().context("Prompt is required")?;
374
375    for ix in 0..repetition_count {
376        if repetition_count > 1 {
377            step_progress.set_substatus(format!(
378                "running prediction {}/{}",
379                ix + 1,
380                repetition_count
381            ));
382        } else {
383            step_progress.set_substatus("running prediction");
384        }
385
386        let messages = vec![anthropic::Message {
387            role: anthropic::Role::User,
388            content: vec![anthropic::RequestContent::Text {
389                text: prompt.input.clone(),
390                cache_control: None,
391            }],
392        }];
393
394        let seed = if repetition_count > 1 { Some(ix) } else { None };
395        let Some(response) = llm_client
396            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
397            .await?
398        else {
399            // Request stashed for batched processing
400            continue;
401        };
402
403        let actual_output = response
404            .content
405            .into_iter()
406            .filter_map(|content| match content {
407                anthropic::ResponseContent::Text { text } => Some(text),
408                _ => None,
409            })
410            .collect::<Vec<String>>()
411            .join("\n");
412
413        let parser_provider = if batched {
414            example
415                .prompt
416                .as_ref()
417                .map(|prompt| prompt.provider)
418                .unwrap_or(PredictionProvider::Teacher(backend))
419        } else {
420            match example.prompt.as_ref().map(|prompt| prompt.provider) {
421                Some(PredictionProvider::TeacherMultiRegion(_))
422                | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
423                    PredictionProvider::TeacherMultiRegionNonBatching(backend)
424                }
425                _ => PredictionProvider::TeacherNonBatching(backend),
426            }
427        };
428
429        let (actual_patch, actual_cursor) = match parser_provider {
430            PredictionProvider::TeacherMultiRegion(_)
431            | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
432                TeacherMultiRegionPrompt::parse(example, &actual_output)?
433            }
434            _ => TeacherPrompt::parse(example, &actual_output)?,
435        };
436
437        let prediction = ExamplePrediction {
438            actual_patch: Some(actual_patch),
439            actual_output,
440            actual_cursor,
441            error: None,
442            provider: if batched {
443                match example.prompt.as_ref().map(|prompt| prompt.provider) {
444                    Some(PredictionProvider::TeacherMultiRegion(_)) => {
445                        PredictionProvider::TeacherMultiRegion(backend)
446                    }
447                    _ => PredictionProvider::Teacher(backend),
448                }
449            } else {
450                match example.prompt.as_ref().map(|prompt| prompt.provider) {
451                    Some(PredictionProvider::TeacherMultiRegion(_))
452                    | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
453                        PredictionProvider::TeacherMultiRegionNonBatching(backend)
454                    }
455                    _ => PredictionProvider::TeacherNonBatching(backend),
456                }
457            },
458        };
459
460        example.predictions.push(prediction);
461    }
462    Ok(())
463}
464
465async fn predict_openai(
466    example: &mut Example,
467    backend: TeacherBackend,
468    batched: bool,
469    repetition_count: usize,
470    cache_only: bool,
471    step_progress: &crate::progress::StepProgress,
472) -> anyhow::Result<()> {
473    let llm_model_name = backend.model_name();
474    let max_tokens = 16384;
475    let llm_client = OPENAI_CLIENT.get_or_init(|| {
476        let client = if batched {
477            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
478        } else {
479            OpenAiClient::plain()
480        };
481        client.expect("Failed to create OpenAI client")
482    });
483
484    let prompt = example.prompt.as_ref().context("Prompt is required")?;
485
486    for ix in 0..repetition_count {
487        if repetition_count > 1 {
488            step_progress.set_substatus(format!(
489                "running prediction {}/{}",
490                ix + 1,
491                repetition_count
492            ));
493        } else {
494            step_progress.set_substatus("running prediction");
495        }
496
497        let messages = vec![open_ai::RequestMessage::User {
498            content: open_ai::MessageContent::Plain(prompt.input.clone()),
499        }];
500
501        let seed = if repetition_count > 1 { Some(ix) } else { None };
502        let Some(response) = llm_client
503            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
504            .await?
505        else {
506            // Request stashed for batched processing
507            continue;
508        };
509
510        let actual_output = response
511            .choices
512            .into_iter()
513            .filter_map(|choice| match choice.message {
514                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
515                    open_ai::MessageContent::Plain(text) => text,
516                    open_ai::MessageContent::Multipart(parts) => parts
517                        .into_iter()
518                        .filter_map(|p| match p {
519                            open_ai::MessagePart::Text { text } => Some(text),
520                            _ => None,
521                        })
522                        .collect::<Vec<_>>()
523                        .join(""),
524                }),
525                _ => None,
526            })
527            .collect::<Vec<String>>()
528            .join("\n");
529
530        let parser_provider = if batched {
531            example
532                .prompt
533                .as_ref()
534                .map(|prompt| prompt.provider)
535                .unwrap_or(PredictionProvider::Teacher(backend))
536        } else {
537            match example.prompt.as_ref().map(|prompt| prompt.provider) {
538                Some(PredictionProvider::TeacherMultiRegion(_))
539                | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
540                    PredictionProvider::TeacherMultiRegionNonBatching(backend)
541                }
542                _ => PredictionProvider::TeacherNonBatching(backend),
543            }
544        };
545
546        let (actual_patch, actual_cursor) = match parser_provider {
547            PredictionProvider::TeacherMultiRegion(_)
548            | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
549                TeacherMultiRegionPrompt::parse(example, &actual_output)?
550            }
551            _ => TeacherPrompt::parse(example, &actual_output)?,
552        };
553
554        let prediction = ExamplePrediction {
555            actual_patch: Some(actual_patch),
556            actual_output,
557            actual_cursor,
558            error: None,
559            provider: if batched {
560                match example.prompt.as_ref().map(|prompt| prompt.provider) {
561                    Some(PredictionProvider::TeacherMultiRegion(_)) => {
562                        PredictionProvider::TeacherMultiRegion(backend)
563                    }
564                    _ => PredictionProvider::Teacher(backend),
565                }
566            } else {
567                match example.prompt.as_ref().map(|prompt| prompt.provider) {
568                    Some(PredictionProvider::TeacherMultiRegion(_))
569                    | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
570                        PredictionProvider::TeacherMultiRegionNonBatching(backend)
571                    }
572                    _ => PredictionProvider::TeacherNonBatching(backend),
573                }
574            },
575        };
576
577        example.predictions.push(prediction);
578    }
579    Ok(())
580}
581
582pub async fn predict_baseten(
583    example: &mut Example,
584    format: ZetaFormat,
585    step_progress: &StepProgress,
586) -> anyhow::Result<()> {
587    let model_id =
588        std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
589
590    let api_key =
591        std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
592
593    let prompt = example.prompt.as_ref().context("Prompt is required")?;
594    let prompt_text = prompt.input.clone();
595    let prefill = prompt.prefill.clone().unwrap_or_default();
596
597    step_progress.set_substatus("running prediction via baseten");
598
599    let environment: String = <&'static str>::from(&format).to_lowercase();
600    let url = format!(
601        "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
602    );
603
604    let request_body = RawCompletionRequest {
605        model: model_id,
606        prompt: prompt_text.clone(),
607        max_tokens: Some(2048),
608        temperature: Some(0.),
609        stop: vec![],
610        environment: None,
611    };
612
613    let body_bytes =
614        serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
615
616    let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
617    let request = http_client::Request::builder()
618        .method(Method::POST)
619        .uri(&url)
620        .header("Content-Type", "application/json")
621        .header("Authorization", format!("Api-Key {api_key}"))
622        .body(AsyncBody::from(body_bytes))?;
623
624    let mut response = http_client.send(request).await?;
625    let status = response.status();
626
627    let mut body = String::new();
628    response
629        .body_mut()
630        .read_to_string(&mut body)
631        .await
632        .context("Failed to read Baseten response body")?;
633
634    if !status.is_success() {
635        anyhow::bail!("Baseten API returned {status}: {body}");
636    }
637
638    let completion: RawCompletionResponse =
639        serde_json::from_str(&body).context("Failed to parse Baseten response")?;
640
641    let actual_output = completion
642        .choices
643        .into_iter()
644        .next()
645        .map(|choice| choice.text)
646        .unwrap_or_default();
647
648    let actual_output = format!("{prefill}{actual_output}");
649
650    let (actual_patch, actual_cursor) =
651        parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
652
653    let prediction = ExamplePrediction {
654        actual_patch: Some(actual_patch),
655        actual_output,
656        actual_cursor,
657        error: None,
658        provider: PredictionProvider::Baseten(format),
659    };
660
661    example.predictions.push(prediction);
662    Ok(())
663}
664
665pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
666    match provider {
667        Some(PredictionProvider::Teacher(backend))
668        | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
669            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
670                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
671                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
672                        .expect("Failed to create Anthropic client")
673                });
674                llm_client
675                    .sync_batches()
676                    .await
677                    .context("Failed to sync Anthropic batches")?;
678            }
679            TeacherBackend::Gpt52 => {
680                let llm_client = OPENAI_CLIENT.get_or_init(|| {
681                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
682                        .expect("Failed to create OpenAI client")
683                });
684                llm_client
685                    .sync_batches()
686                    .await
687                    .context("Failed to sync OpenAI batches")?;
688            }
689        },
690        _ => (),
691    };
692    Ok(())
693}