predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    parse_output::parse_prediction_output,
 10    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 11    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
 12    retrieve_context::run_context_retrieval,
 13};
 14use anyhow::Context as _;
 15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
 16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
 18use gpui::{AppContext as _, AsyncApp, Task};
 19use http_client::{AsyncBody, HttpClient, Method};
 20use reqwest_client::ReqwestClient;
 21use std::{
 22    fs,
 23    sync::{
 24        Arc, Mutex, OnceLock,
 25        atomic::{AtomicUsize, Ordering::SeqCst},
 26    },
 27};
 28use zeta_prompt::ZetaFormat;
 29
 30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 32
 33pub async fn run_prediction(
 34    example: &mut Example,
 35    args: &PredictArgs,
 36    app_state: Arc<EpAppState>,
 37    example_progress: &ExampleProgress,
 38    mut cx: AsyncApp,
 39) -> anyhow::Result<()> {
 40    let repetition_count = args.repetitions;
 41
 42    if let Some(existing_prediction) = example.predictions.first() {
 43        let has_prediction = existing_prediction.actual_patch.is_some()
 44            || !existing_prediction.actual_output.is_empty();
 45        if has_prediction {
 46            match args.provider {
 47                None => return Ok(()),
 48                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 49                Some(_) => example.predictions.clear(),
 50            }
 51        }
 52    }
 53
 54    let Some(provider) = args.provider else {
 55        anyhow::bail!(
 56            "No existing predictions found. Use --provider to specify which model to use for prediction."
 57        );
 58    };
 59
 60    if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
 61        provider
 62    {
 63        run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 64        run_format_prompt(
 65            example,
 66            &FormatPromptArgs { provider },
 67            app_state.clone(),
 68            example_progress,
 69            cx,
 70        )
 71        .await?;
 72
 73        let step_progress = example_progress.start(Step::Predict);
 74        let batched = matches!(provider, PredictionProvider::Teacher(..));
 75        return predict_teacher(
 76            example,
 77            backend,
 78            batched,
 79            repetition_count,
 80            args.cache_only,
 81            &step_progress,
 82        )
 83        .await;
 84    }
 85
 86    if let PredictionProvider::Baseten(format) = provider {
 87        run_format_prompt(
 88            example,
 89            &FormatPromptArgs {
 90                provider: PredictionProvider::Zeta2(format),
 91            },
 92            app_state.clone(),
 93            example_progress,
 94            cx,
 95        )
 96        .await?;
 97
 98        let step_progress = example_progress.start(Step::Predict);
 99        return predict_baseten(example, format, &step_progress).await;
100    }
101
102    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
103    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
104
105    let step_progress = example_progress.start(Step::Predict);
106
107    if matches!(
108        provider,
109        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
110    ) {
111        step_progress.set_substatus("authenticating");
112        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
113        AUTHENTICATED
114            .get_or_init(|| {
115                let client = app_state.client.clone();
116                cx.spawn(async move |cx| {
117                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
118                        eprintln!("Authentication failed: {}", e);
119                    }
120                })
121                .shared()
122            })
123            .clone()
124            .await;
125    }
126
127    let ep_store = cx
128        .update(|cx| EditPredictionStore::try_global(cx))
129        .context("EditPredictionStore not initialized")?;
130
131    ep_store.update(&mut cx, |store, _cx| {
132        let model = match provider {
133            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
134            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
135            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
136            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
137            PredictionProvider::Teacher(..)
138            | PredictionProvider::TeacherNonBatching(..)
139            | PredictionProvider::Repair
140            | PredictionProvider::Baseten(_) => {
141                unreachable!()
142            }
143        };
144        store.set_edit_prediction_model(model);
145
146        // If user specified a non-default Zeta2 version, configure raw endpoint.
147        // ZED_ZETA_MODEL env var is optional.
148        if let PredictionProvider::Zeta2(format) = provider {
149            if format != ZetaFormat::default() {
150                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
151                let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
152                store.set_zeta2_raw_config(Zeta2RawConfig {
153                    model_id,
154                    environment,
155                    format,
156                });
157            }
158        }
159    });
160    step_progress.set_substatus("configuring model");
161    let state = example.state.as_ref().context("state must be set")?;
162    let run_dir = RUN_DIR.join(&example.spec.name);
163
164    let updated_example = Arc::new(Mutex::new(example.clone()));
165    let current_run_ix = Arc::new(AtomicUsize::new(0));
166
167    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
168    let debug_task = cx.background_spawn({
169        let updated_example = updated_example.clone();
170        let current_run_ix = current_run_ix.clone();
171        let run_dir = run_dir.clone();
172        async move {
173            while let Some(event) = debug_rx.next().await {
174                let run_ix = current_run_ix.load(SeqCst);
175                let mut updated_example = updated_example.lock().unwrap();
176
177                let run_dir = if repetition_count > 1 {
178                    run_dir.join(format!("{:03}", run_ix))
179                } else {
180                    run_dir.clone()
181                };
182
183                match event {
184                    DebugEvent::EditPredictionStarted(request) => {
185                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
186
187                        if let Some(prompt) = request.prompt {
188                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
189                            if matches!(provider, PredictionProvider::Zeta2(_)) {
190                                updated_example.prompt.get_or_insert(ExamplePrompt {
191                                    input: prompt,
192                                    expected_output: String::new(),
193                                    rejected_output: None,
194                                    provider,
195                                    prefill: None,
196                                });
197                            }
198                        }
199                    }
200                    DebugEvent::EditPredictionFinished(request) => {
201                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
202
203                        if let Some(output) = request.model_output {
204                            fs::write(run_dir.join("prediction_response.md"), &output)?;
205                            updated_example
206                                .predictions
207                                .last_mut()
208                                .unwrap()
209                                .actual_output = output;
210                        }
211                        if run_ix >= repetition_count {
212                            break;
213                        }
214                    }
215                    _ => {}
216                }
217            }
218            anyhow::Ok(())
219        }
220    });
221
222    for ix in 0..repetition_count {
223        current_run_ix.store(ix, SeqCst);
224        let run_dir = if repetition_count > 1 {
225            run_dir.join(format!("{:03}", ix))
226        } else {
227            run_dir.clone()
228        };
229
230        if repetition_count > 1 {
231            step_progress.set_substatus(format!(
232                "running prediction {}/{}",
233                ix + 1,
234                repetition_count
235            ));
236        } else {
237            step_progress.set_substatus("running prediction");
238        }
239
240        fs::create_dir_all(&run_dir)?;
241        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
242            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
243        }
244        #[cfg(unix)]
245        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
246        #[cfg(windows)]
247        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
248
249        updated_example
250            .lock()
251            .unwrap()
252            .predictions
253            .push(ExamplePrediction {
254                actual_patch: None,
255                actual_output: String::new(),
256                actual_cursor: None,
257                error: None,
258                provider,
259            });
260
261        step_progress.set_substatus("requesting prediction");
262        let prediction = ep_store
263            .update(&mut cx, |store, cx| {
264                store.request_prediction(
265                    &state.project,
266                    &state.buffer,
267                    state.cursor_position,
268                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
269                    cx,
270                )
271            })
272            .await?;
273
274        let actual_patch = prediction.and_then(|prediction| {
275            let prediction = prediction.prediction.ok()?;
276            prediction
277                .edit_preview
278                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
279        });
280
281        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
282
283        updated_example
284            .lock()
285            .unwrap()
286            .predictions
287            .last_mut()
288            .unwrap()
289            .actual_patch = actual_patch;
290
291        if ix == repetition_count - 1 {
292            let (info, style) = if has_prediction {
293                ("predicted", InfoStyle::Normal)
294            } else {
295                ("no prediction", InfoStyle::Warning)
296            };
297            step_progress.set_info(info, style);
298        }
299    }
300
301    ep_store.update(&mut cx, |store, _| {
302        store.remove_project(&state.project);
303    });
304    debug_task.await?;
305
306    *example = Arc::into_inner(updated_example)
307        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
308        .into_inner()
309        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
310    Ok(())
311}
312
313async fn predict_teacher(
314    example: &mut Example,
315    backend: TeacherBackend,
316    batched: bool,
317    repetition_count: usize,
318    cache_only: bool,
319    step_progress: &crate::progress::StepProgress,
320) -> anyhow::Result<()> {
321    match backend {
322        TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
323            predict_anthropic(
324                example,
325                backend,
326                batched,
327                repetition_count,
328                cache_only,
329                step_progress,
330            )
331            .await
332        }
333        TeacherBackend::Gpt52 => {
334            predict_openai(
335                example,
336                backend,
337                batched,
338                repetition_count,
339                cache_only,
340                step_progress,
341            )
342            .await
343        }
344    }
345}
346
347async fn predict_anthropic(
348    example: &mut Example,
349    backend: TeacherBackend,
350    batched: bool,
351    repetition_count: usize,
352    cache_only: bool,
353    step_progress: &crate::progress::StepProgress,
354) -> anyhow::Result<()> {
355    let llm_model_name = backend.model_name();
356    let max_tokens = 16384;
357    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
358        let client = if batched {
359            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
360        } else {
361            AnthropicClient::plain()
362        };
363        client.expect("Failed to create Anthropic client")
364    });
365
366    let prompt = example.prompt.as_ref().context("Prompt is required")?;
367
368    for ix in 0..repetition_count {
369        if repetition_count > 1 {
370            step_progress.set_substatus(format!(
371                "running prediction {}/{}",
372                ix + 1,
373                repetition_count
374            ));
375        } else {
376            step_progress.set_substatus("running prediction");
377        }
378
379        let messages = vec![anthropic::Message {
380            role: anthropic::Role::User,
381            content: vec![anthropic::RequestContent::Text {
382                text: prompt.input.clone(),
383                cache_control: None,
384            }],
385        }];
386
387        let seed = if repetition_count > 1 { Some(ix) } else { None };
388        let Some(response) = llm_client
389            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
390            .await?
391        else {
392            // Request stashed for batched processing
393            continue;
394        };
395
396        let actual_output = response
397            .content
398            .into_iter()
399            .filter_map(|content| match content {
400                anthropic::ResponseContent::Text { text } => Some(text),
401                _ => None,
402            })
403            .collect::<Vec<String>>()
404            .join("\n");
405
406        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
407
408        let prediction = ExamplePrediction {
409            actual_patch: Some(actual_patch),
410            actual_output,
411            actual_cursor,
412            error: None,
413            provider: if batched {
414                PredictionProvider::Teacher(backend)
415            } else {
416                PredictionProvider::TeacherNonBatching(backend)
417            },
418        };
419
420        example.predictions.push(prediction);
421    }
422    Ok(())
423}
424
425async fn predict_openai(
426    example: &mut Example,
427    backend: TeacherBackend,
428    batched: bool,
429    repetition_count: usize,
430    cache_only: bool,
431    step_progress: &crate::progress::StepProgress,
432) -> anyhow::Result<()> {
433    let llm_model_name = backend.model_name();
434    let max_tokens = 16384;
435    let llm_client = OPENAI_CLIENT.get_or_init(|| {
436        let client = if batched {
437            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
438        } else {
439            OpenAiClient::plain()
440        };
441        client.expect("Failed to create OpenAI client")
442    });
443
444    let prompt = example.prompt.as_ref().context("Prompt is required")?;
445
446    for ix in 0..repetition_count {
447        if repetition_count > 1 {
448            step_progress.set_substatus(format!(
449                "running prediction {}/{}",
450                ix + 1,
451                repetition_count
452            ));
453        } else {
454            step_progress.set_substatus("running prediction");
455        }
456
457        let messages = vec![open_ai::RequestMessage::User {
458            content: open_ai::MessageContent::Plain(prompt.input.clone()),
459        }];
460
461        let seed = if repetition_count > 1 { Some(ix) } else { None };
462        let Some(response) = llm_client
463            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
464            .await?
465        else {
466            // Request stashed for batched processing
467            continue;
468        };
469
470        let actual_output = response
471            .choices
472            .into_iter()
473            .filter_map(|choice| match choice.message {
474                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
475                    open_ai::MessageContent::Plain(text) => text,
476                    open_ai::MessageContent::Multipart(parts) => parts
477                        .into_iter()
478                        .filter_map(|p| match p {
479                            open_ai::MessagePart::Text { text } => Some(text),
480                            _ => None,
481                        })
482                        .collect::<Vec<_>>()
483                        .join(""),
484                }),
485                _ => None,
486            })
487            .collect::<Vec<String>>()
488            .join("\n");
489
490        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
491
492        let prediction = ExamplePrediction {
493            actual_patch: Some(actual_patch),
494            actual_output,
495            actual_cursor,
496            error: None,
497            provider: if batched {
498                PredictionProvider::Teacher(backend)
499            } else {
500                PredictionProvider::TeacherNonBatching(backend)
501            },
502        };
503
504        example.predictions.push(prediction);
505    }
506    Ok(())
507}
508
509pub async fn predict_baseten(
510    example: &mut Example,
511    format: ZetaFormat,
512    step_progress: &StepProgress,
513) -> anyhow::Result<()> {
514    let model_id =
515        std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
516
517    let api_key =
518        std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
519
520    let prompt = example.prompt.as_ref().context("Prompt is required")?;
521    let prompt_text = prompt.input.clone();
522    let prefill = prompt.prefill.clone().unwrap_or_default();
523
524    step_progress.set_substatus("running prediction via baseten");
525
526    let environment: String = <&'static str>::from(&format).to_lowercase();
527    let url = format!(
528        "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
529    );
530
531    let request_body = RawCompletionRequest {
532        model: model_id,
533        prompt: prompt_text.clone(),
534        max_tokens: Some(2048),
535        temperature: Some(0.),
536        stop: vec![],
537        environment: None,
538    };
539
540    let body_bytes =
541        serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
542
543    let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
544    let request = http_client::Request::builder()
545        .method(Method::POST)
546        .uri(&url)
547        .header("Content-Type", "application/json")
548        .header("Authorization", format!("Api-Key {api_key}"))
549        .body(AsyncBody::from(body_bytes))?;
550
551    let mut response = http_client.send(request).await?;
552    let status = response.status();
553
554    let mut body = String::new();
555    response
556        .body_mut()
557        .read_to_string(&mut body)
558        .await
559        .context("Failed to read Baseten response body")?;
560
561    if !status.is_success() {
562        anyhow::bail!("Baseten API returned {status}: {body}");
563    }
564
565    let completion: RawCompletionResponse =
566        serde_json::from_str(&body).context("Failed to parse Baseten response")?;
567
568    let actual_output = completion
569        .choices
570        .into_iter()
571        .next()
572        .map(|choice| choice.text)
573        .unwrap_or_default();
574
575    let actual_output = format!("{prefill}{actual_output}");
576
577    let (actual_patch, actual_cursor) =
578        parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
579
580    let prediction = ExamplePrediction {
581        actual_patch: Some(actual_patch),
582        actual_output,
583        actual_cursor,
584        error: None,
585        provider: PredictionProvider::Baseten(format),
586    };
587
588    example.predictions.push(prediction);
589    Ok(())
590}
591
592pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
593    match provider {
594        Some(PredictionProvider::Teacher(backend)) => match backend {
595            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
596                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
597                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
598                        .expect("Failed to create Anthropic client")
599                });
600                llm_client
601                    .sync_batches()
602                    .await
603                    .context("Failed to sync Anthropic batches")?;
604            }
605            TeacherBackend::Gpt52 => {
606                let llm_client = OPENAI_CLIENT.get_or_init(|| {
607                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
608                        .expect("Failed to create OpenAI client")
609                });
610                llm_client
611                    .sync_batches()
612                    .await
613                    .context("Failed to sync OpenAI batches")?;
614            }
615        },
616        _ => (),
617    };
618    Ok(())
619}