predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    parse_output::parse_prediction_output,
 10    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 11    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
 12    retrieve_context::run_context_retrieval,
 13};
 14use anyhow::Context as _;
 15use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
 16use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 17use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
 18use gpui::{AppContext as _, AsyncApp, Task};
 19use http_client::{AsyncBody, HttpClient, Method};
 20use reqwest_client::ReqwestClient;
 21use std::{
 22    fs,
 23    sync::{
 24        Arc, Mutex, OnceLock,
 25        atomic::{AtomicUsize, Ordering::SeqCst},
 26    },
 27};
 28use zeta_prompt::ZetaFormat;
 29
 30static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 31static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 32
 33pub async fn run_prediction(
 34    example: &mut Example,
 35    args: &PredictArgs,
 36    app_state: Arc<EpAppState>,
 37    example_progress: &ExampleProgress,
 38    mut cx: AsyncApp,
 39) -> anyhow::Result<()> {
 40    let repetition_count = args.repetitions;
 41
 42    if let Some(existing_prediction) = example.predictions.first() {
 43        let has_prediction = existing_prediction.actual_patch.is_some()
 44            || !existing_prediction.actual_output.is_empty();
 45        if has_prediction {
 46            match args.provider {
 47                None => return Ok(()),
 48                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 49                Some(_) => example.predictions.clear(),
 50            }
 51        }
 52    }
 53
 54    let Some(provider) = args.provider else {
 55        anyhow::bail!(
 56            "No existing predictions found. Use --provider to specify which model to use for prediction."
 57        );
 58    };
 59
 60    if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
 61        provider
 62    {
 63        run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 64        run_format_prompt(
 65            example,
 66            &FormatPromptArgs { provider },
 67            app_state.clone(),
 68            example_progress,
 69            cx,
 70        )
 71        .await?;
 72
 73        let step_progress = example_progress.start(Step::Predict);
 74        let batched = matches!(provider, PredictionProvider::Teacher(..));
 75        return predict_teacher(
 76            example,
 77            backend,
 78            batched,
 79            repetition_count,
 80            args.cache_only,
 81            &step_progress,
 82        )
 83        .await;
 84    }
 85
 86    if let PredictionProvider::Baseten(format) = provider {
 87        run_format_prompt(
 88            example,
 89            &FormatPromptArgs {
 90                provider: PredictionProvider::Zeta2(format),
 91            },
 92            app_state.clone(),
 93            example_progress,
 94            cx,
 95        )
 96        .await?;
 97
 98        let step_progress = example_progress.start(Step::Predict);
 99        return predict_baseten(example, format, &step_progress).await;
100    }
101
102    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
103    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
104
105    let step_progress = example_progress.start(Step::Predict);
106
107    if matches!(
108        provider,
109        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
110    ) {
111        step_progress.set_substatus("authenticating");
112        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
113        AUTHENTICATED
114            .get_or_init(|| {
115                let client = app_state.client.clone();
116                cx.spawn(async move |cx| {
117                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
118                        eprintln!("Authentication failed: {}", e);
119                    }
120                })
121                .shared()
122            })
123            .clone()
124            .await;
125    }
126
127    let ep_store = cx
128        .update(|cx| EditPredictionStore::try_global(cx))
129        .context("EditPredictionStore not initialized")?;
130
131    ep_store.update(&mut cx, |store, _cx| {
132        let model = match provider {
133            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta,
134            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta,
135            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
136            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
137            PredictionProvider::Teacher(..)
138            | PredictionProvider::TeacherNonBatching(..)
139            | PredictionProvider::Repair
140            | PredictionProvider::Baseten(_) => {
141                unreachable!()
142            }
143        };
144        store.set_edit_prediction_model(model);
145
146        // If user specified a non-default Zeta2 version, configure raw endpoint.
147        // ZED_ZETA_MODEL env var is optional.
148        if let PredictionProvider::Zeta2(format) = provider {
149            if format != ZetaFormat::default() {
150                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
151                store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
152            }
153        }
154    });
155    step_progress.set_substatus("configuring model");
156    let state = example.state.as_ref().context("state must be set")?;
157    let run_dir = RUN_DIR.join(&example.spec.name);
158
159    let updated_example = Arc::new(Mutex::new(example.clone()));
160    let current_run_ix = Arc::new(AtomicUsize::new(0));
161
162    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
163    let debug_task = cx.background_spawn({
164        let updated_example = updated_example.clone();
165        let current_run_ix = current_run_ix.clone();
166        let run_dir = run_dir.clone();
167        async move {
168            while let Some(event) = debug_rx.next().await {
169                let run_ix = current_run_ix.load(SeqCst);
170                let mut updated_example = updated_example.lock().unwrap();
171
172                let run_dir = if repetition_count > 1 {
173                    run_dir.join(format!("{:03}", run_ix))
174                } else {
175                    run_dir.clone()
176                };
177
178                match event {
179                    DebugEvent::EditPredictionStarted(request) => {
180                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
181
182                        if let Some(prompt) = request.prompt {
183                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
184                            if matches!(provider, PredictionProvider::Zeta2(_)) {
185                                updated_example.prompt.get_or_insert(ExamplePrompt {
186                                    input: prompt,
187                                    expected_output: String::new(),
188                                    rejected_output: None,
189                                    provider,
190                                    prefill: None,
191                                });
192                            }
193                        }
194                    }
195                    DebugEvent::EditPredictionFinished(request) => {
196                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
197
198                        if let Some(output) = request.model_output {
199                            fs::write(run_dir.join("prediction_response.md"), &output)?;
200                            updated_example
201                                .predictions
202                                .last_mut()
203                                .unwrap()
204                                .actual_output = output;
205                        }
206                        if run_ix >= repetition_count {
207                            break;
208                        }
209                    }
210                    _ => {}
211                }
212            }
213            anyhow::Ok(())
214        }
215    });
216
217    for ix in 0..repetition_count {
218        current_run_ix.store(ix, SeqCst);
219        let run_dir = if repetition_count > 1 {
220            run_dir.join(format!("{:03}", ix))
221        } else {
222            run_dir.clone()
223        };
224
225        if repetition_count > 1 {
226            step_progress.set_substatus(format!(
227                "running prediction {}/{}",
228                ix + 1,
229                repetition_count
230            ));
231        } else {
232            step_progress.set_substatus("running prediction");
233        }
234
235        fs::create_dir_all(&run_dir)?;
236        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
237            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
238        }
239        #[cfg(unix)]
240        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
241        #[cfg(windows)]
242        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
243
244        updated_example
245            .lock()
246            .unwrap()
247            .predictions
248            .push(ExamplePrediction {
249                actual_patch: None,
250                actual_output: String::new(),
251                actual_cursor: None,
252                error: None,
253                provider,
254            });
255
256        step_progress.set_substatus("requesting prediction");
257        let prediction = ep_store
258            .update(&mut cx, |store, cx| {
259                store.request_prediction(
260                    &state.project,
261                    &state.buffer,
262                    state.cursor_position,
263                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
264                    cx,
265                )
266            })
267            .await?;
268
269        let actual_patch = prediction.and_then(|prediction| {
270            let prediction = prediction.prediction.ok()?;
271            prediction
272                .edit_preview
273                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
274        });
275
276        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
277
278        updated_example
279            .lock()
280            .unwrap()
281            .predictions
282            .last_mut()
283            .unwrap()
284            .actual_patch = actual_patch;
285
286        if ix == repetition_count - 1 {
287            let (info, style) = if has_prediction {
288                ("predicted", InfoStyle::Normal)
289            } else {
290                ("no prediction", InfoStyle::Warning)
291            };
292            step_progress.set_info(info, style);
293        }
294    }
295
296    ep_store.update(&mut cx, |store, _| {
297        store.remove_project(&state.project);
298    });
299    debug_task.await?;
300
301    *example = Arc::into_inner(updated_example)
302        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
303        .into_inner()
304        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
305    Ok(())
306}
307
308async fn predict_teacher(
309    example: &mut Example,
310    backend: TeacherBackend,
311    batched: bool,
312    repetition_count: usize,
313    cache_only: bool,
314    step_progress: &crate::progress::StepProgress,
315) -> anyhow::Result<()> {
316    match backend {
317        TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
318            predict_anthropic(
319                example,
320                backend,
321                batched,
322                repetition_count,
323                cache_only,
324                step_progress,
325            )
326            .await
327        }
328        TeacherBackend::Gpt52 => {
329            predict_openai(
330                example,
331                backend,
332                batched,
333                repetition_count,
334                cache_only,
335                step_progress,
336            )
337            .await
338        }
339    }
340}
341
342async fn predict_anthropic(
343    example: &mut Example,
344    backend: TeacherBackend,
345    batched: bool,
346    repetition_count: usize,
347    cache_only: bool,
348    step_progress: &crate::progress::StepProgress,
349) -> anyhow::Result<()> {
350    let llm_model_name = backend.model_name();
351    let max_tokens = 16384;
352    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
353        let client = if batched {
354            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
355        } else {
356            AnthropicClient::plain()
357        };
358        client.expect("Failed to create Anthropic client")
359    });
360
361    let prompt = example.prompt.as_ref().context("Prompt is required")?;
362
363    for ix in 0..repetition_count {
364        if repetition_count > 1 {
365            step_progress.set_substatus(format!(
366                "running prediction {}/{}",
367                ix + 1,
368                repetition_count
369            ));
370        } else {
371            step_progress.set_substatus("running prediction");
372        }
373
374        let messages = vec![anthropic::Message {
375            role: anthropic::Role::User,
376            content: vec![anthropic::RequestContent::Text {
377                text: prompt.input.clone(),
378                cache_control: None,
379            }],
380        }];
381
382        let seed = if repetition_count > 1 { Some(ix) } else { None };
383        let Some(response) = llm_client
384            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
385            .await?
386        else {
387            // Request stashed for batched processing
388            return Ok(());
389        };
390
391        let actual_output = response
392            .content
393            .into_iter()
394            .filter_map(|content| match content {
395                anthropic::ResponseContent::Text { text } => Some(text),
396                _ => None,
397            })
398            .collect::<Vec<String>>()
399            .join("\n");
400
401        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
402
403        let prediction = ExamplePrediction {
404            actual_patch: Some(actual_patch),
405            actual_output,
406            actual_cursor,
407            error: None,
408            provider: if batched {
409                PredictionProvider::Teacher(backend)
410            } else {
411                PredictionProvider::TeacherNonBatching(backend)
412            },
413        };
414
415        example.predictions.push(prediction);
416    }
417    Ok(())
418}
419
420async fn predict_openai(
421    example: &mut Example,
422    backend: TeacherBackend,
423    batched: bool,
424    repetition_count: usize,
425    cache_only: bool,
426    step_progress: &crate::progress::StepProgress,
427) -> anyhow::Result<()> {
428    let llm_model_name = backend.model_name();
429    let max_tokens = 16384;
430    let llm_client = OPENAI_CLIENT.get_or_init(|| {
431        let client = if batched {
432            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
433        } else {
434            OpenAiClient::plain()
435        };
436        client.expect("Failed to create OpenAI client")
437    });
438
439    let prompt = example.prompt.as_ref().context("Prompt is required")?;
440
441    for ix in 0..repetition_count {
442        if repetition_count > 1 {
443            step_progress.set_substatus(format!(
444                "running prediction {}/{}",
445                ix + 1,
446                repetition_count
447            ));
448        } else {
449            step_progress.set_substatus("running prediction");
450        }
451
452        let messages = vec![open_ai::RequestMessage::User {
453            content: open_ai::MessageContent::Plain(prompt.input.clone()),
454        }];
455
456        let seed = if repetition_count > 1 { Some(ix) } else { None };
457        let Some(response) = llm_client
458            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
459            .await?
460        else {
461            // Request stashed for batched processing
462            return Ok(());
463        };
464
465        let actual_output = response
466            .choices
467            .into_iter()
468            .filter_map(|choice| match choice.message {
469                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
470                    open_ai::MessageContent::Plain(text) => text,
471                    open_ai::MessageContent::Multipart(parts) => parts
472                        .into_iter()
473                        .filter_map(|p| match p {
474                            open_ai::MessagePart::Text { text } => Some(text),
475                            _ => None,
476                        })
477                        .collect::<Vec<_>>()
478                        .join(""),
479                }),
480                _ => None,
481            })
482            .collect::<Vec<String>>()
483            .join("\n");
484
485        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
486
487        let prediction = ExamplePrediction {
488            actual_patch: Some(actual_patch),
489            actual_output,
490            actual_cursor,
491            error: None,
492            provider: if batched {
493                PredictionProvider::Teacher(backend)
494            } else {
495                PredictionProvider::TeacherNonBatching(backend)
496            },
497        };
498
499        example.predictions.push(prediction);
500    }
501    Ok(())
502}
503
504pub async fn predict_baseten(
505    example: &mut Example,
506    format: ZetaFormat,
507    step_progress: &StepProgress,
508) -> anyhow::Result<()> {
509    let model_id =
510        std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
511
512    let api_key =
513        std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
514
515    let prompt = example.prompt.as_ref().context("Prompt is required")?;
516    let prompt_text = prompt.input.clone();
517    let prefill = prompt.prefill.clone().unwrap_or_default();
518
519    step_progress.set_substatus("running prediction via baseten");
520
521    let environment: String = <&'static str>::from(&format).to_lowercase();
522    let url = format!(
523        "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
524    );
525
526    let request_body = RawCompletionRequest {
527        model: model_id,
528        prompt: prompt_text.clone(),
529        max_tokens: Some(2048),
530        temperature: Some(0.),
531        stop: vec![],
532        environment: None,
533    };
534
535    let body_bytes =
536        serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
537
538    let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
539    let request = http_client::Request::builder()
540        .method(Method::POST)
541        .uri(&url)
542        .header("Content-Type", "application/json")
543        .header("Authorization", format!("Api-Key {api_key}"))
544        .body(AsyncBody::from(body_bytes))?;
545
546    let mut response = http_client.send(request).await?;
547    let status = response.status();
548
549    let mut body = String::new();
550    response
551        .body_mut()
552        .read_to_string(&mut body)
553        .await
554        .context("Failed to read Baseten response body")?;
555
556    if !status.is_success() {
557        anyhow::bail!("Baseten API returned {status}: {body}");
558    }
559
560    let completion: RawCompletionResponse =
561        serde_json::from_str(&body).context("Failed to parse Baseten response")?;
562
563    let actual_output = completion
564        .choices
565        .into_iter()
566        .next()
567        .map(|choice| choice.text)
568        .unwrap_or_default();
569
570    let actual_output = format!("{prefill}{actual_output}");
571
572    let (actual_patch, actual_cursor) =
573        parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
574
575    let prediction = ExamplePrediction {
576        actual_patch: Some(actual_patch),
577        actual_output,
578        actual_cursor,
579        error: None,
580        provider: PredictionProvider::Baseten(format),
581    };
582
583    example.predictions.push(prediction);
584    Ok(())
585}
586
587pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
588    match provider {
589        Some(PredictionProvider::Teacher(backend)) => match backend {
590            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
591                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
592                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
593                        .expect("Failed to create Anthropic client")
594                });
595                llm_client
596                    .sync_batches()
597                    .await
598                    .context("Failed to sync Anthropic batches")?;
599            }
600            TeacherBackend::Gpt52 => {
601                let llm_client = OPENAI_CLIENT.get_or_init(|| {
602                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
603                        .expect("Failed to create OpenAI client")
604                });
605                llm_client
606                    .sync_batches()
607                    .await
608                    .context("Failed to sync OpenAI batches")?;
609            }
610        },
611        _ => (),
612    };
613    Ok(())
614}