predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 10    progress::{ExampleProgress, InfoStyle, Step},
 11    retrieve_context::run_context_retrieval,
 12};
 13use anyhow::Context as _;
 14use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 15use futures::{FutureExt as _, StreamExt as _, future::Shared};
 16use gpui::{AppContext as _, AsyncApp, Task};
 17use std::{
 18    fs,
 19    sync::{
 20        Arc, Mutex, OnceLock,
 21        atomic::{AtomicUsize, Ordering::SeqCst},
 22    },
 23};
 24use zeta_prompt::ZetaFormat;
 25
 26static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 27static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 28
 29pub async fn run_prediction(
 30    example: &mut Example,
 31    args: &PredictArgs,
 32    app_state: Arc<EpAppState>,
 33    example_progress: &ExampleProgress,
 34    mut cx: AsyncApp,
 35) -> anyhow::Result<()> {
 36    let repetition_count = args.repetitions;
 37
 38    if let Some(existing_prediction) = example.predictions.first() {
 39        let has_prediction = existing_prediction.actual_patch.is_some()
 40            || !existing_prediction.actual_output.is_empty();
 41        if has_prediction {
 42            match args.provider {
 43                None => return Ok(()),
 44                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 45                Some(_) => example.predictions.clear(),
 46            }
 47        }
 48    }
 49
 50    let Some(provider) = args.provider else {
 51        anyhow::bail!(
 52            "No existing predictions found. Use --provider to specify which model to use for prediction."
 53        );
 54    };
 55
 56    if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
 57        provider
 58    {
 59        run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 60        run_format_prompt(
 61            example,
 62            &FormatPromptArgs { provider },
 63            app_state.clone(),
 64            example_progress,
 65            cx,
 66        )
 67        .await?;
 68
 69        let step_progress = example_progress.start(Step::Predict);
 70        let batched = matches!(provider, PredictionProvider::Teacher(..));
 71        return predict_teacher(
 72            example,
 73            backend,
 74            batched,
 75            repetition_count,
 76            args.cache_only,
 77            &step_progress,
 78        )
 79        .await;
 80    }
 81
 82    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 83    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 84
 85    let step_progress = example_progress.start(Step::Predict);
 86
 87    if matches!(
 88        provider,
 89        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
 90    ) {
 91        step_progress.set_substatus("authenticating");
 92        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 93        AUTHENTICATED
 94            .get_or_init(|| {
 95                let client = app_state.client.clone();
 96                cx.spawn(async move |cx| {
 97                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 98                        eprintln!("Authentication failed: {}", e);
 99                    }
100                })
101                .shared()
102            })
103            .clone()
104            .await;
105    }
106
107    let ep_store = cx
108        .update(|cx| EditPredictionStore::try_global(cx))
109        .context("EditPredictionStore not initialized")?;
110
111    ep_store.update(&mut cx, |store, _cx| {
112        let model = match provider {
113            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
114            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2,
115            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
116            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
117            PredictionProvider::Teacher(..)
118            | PredictionProvider::TeacherNonBatching(..)
119            | PredictionProvider::Repair => {
120                unreachable!()
121            }
122        };
123        store.set_edit_prediction_model(model);
124
125        // If user specified a non-default Zeta2 version, configure raw endpoint.
126        // ZED_ZETA_MODEL env var is optional.
127        if let PredictionProvider::Zeta2(format) = provider {
128            if format != ZetaFormat::default() {
129                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
130                store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
131            }
132        }
133    });
134    step_progress.set_substatus("configuring model");
135    let state = example.state.as_ref().context("state must be set")?;
136    let run_dir = RUN_DIR.join(&example.spec.name);
137
138    let updated_example = Arc::new(Mutex::new(example.clone()));
139    let current_run_ix = Arc::new(AtomicUsize::new(0));
140
141    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
142    let debug_task = cx.background_spawn({
143        let updated_example = updated_example.clone();
144        let current_run_ix = current_run_ix.clone();
145        let run_dir = run_dir.clone();
146        async move {
147            while let Some(event) = debug_rx.next().await {
148                let run_ix = current_run_ix.load(SeqCst);
149                let mut updated_example = updated_example.lock().unwrap();
150
151                let run_dir = if repetition_count > 1 {
152                    run_dir.join(format!("{:03}", run_ix))
153                } else {
154                    run_dir.clone()
155                };
156
157                match event {
158                    DebugEvent::EditPredictionStarted(request) => {
159                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
160
161                        if let Some(prompt) = request.prompt {
162                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
163                            if matches!(provider, PredictionProvider::Zeta2(_)) {
164                                updated_example.prompt.get_or_insert(ExamplePrompt {
165                                    input: prompt,
166                                    expected_output: String::new(),
167                                    rejected_output: None,
168                                    provider,
169                                    prefill: None,
170                                });
171                            }
172                        }
173                    }
174                    DebugEvent::EditPredictionFinished(request) => {
175                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
176
177                        if let Some(output) = request.model_output {
178                            fs::write(run_dir.join("prediction_response.md"), &output)?;
179                            updated_example
180                                .predictions
181                                .last_mut()
182                                .unwrap()
183                                .actual_output = output;
184                        }
185                        if run_ix >= repetition_count {
186                            break;
187                        }
188                    }
189                    _ => {}
190                }
191            }
192            anyhow::Ok(())
193        }
194    });
195
196    for ix in 0..repetition_count {
197        current_run_ix.store(ix, SeqCst);
198        let run_dir = if repetition_count > 1 {
199            run_dir.join(format!("{:03}", ix))
200        } else {
201            run_dir.clone()
202        };
203
204        if repetition_count > 1 {
205            step_progress.set_substatus(format!(
206                "running prediction {}/{}",
207                ix + 1,
208                repetition_count
209            ));
210        } else {
211            step_progress.set_substatus("running prediction");
212        }
213
214        fs::create_dir_all(&run_dir)?;
215        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
216            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
217        }
218        #[cfg(unix)]
219        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
220        #[cfg(windows)]
221        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
222
223        updated_example
224            .lock()
225            .unwrap()
226            .predictions
227            .push(ExamplePrediction {
228                actual_patch: None,
229                actual_output: String::new(),
230                actual_cursor: None,
231                error: None,
232                provider,
233            });
234
235        step_progress.set_substatus("requesting prediction");
236        let prediction = ep_store
237            .update(&mut cx, |store, cx| {
238                store.request_prediction(
239                    &state.project,
240                    &state.buffer,
241                    state.cursor_position,
242                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
243                    cx,
244                )
245            })
246            .await?;
247
248        let actual_patch = prediction.and_then(|prediction| {
249            let prediction = prediction.prediction.ok()?;
250            prediction
251                .edit_preview
252                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
253        });
254
255        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
256
257        updated_example
258            .lock()
259            .unwrap()
260            .predictions
261            .last_mut()
262            .unwrap()
263            .actual_patch = actual_patch;
264
265        if ix == repetition_count - 1 {
266            let (info, style) = if has_prediction {
267                ("predicted", InfoStyle::Normal)
268            } else {
269                ("no prediction", InfoStyle::Warning)
270            };
271            step_progress.set_info(info, style);
272        }
273    }
274
275    ep_store.update(&mut cx, |store, _| {
276        store.remove_project(&state.project);
277    });
278    debug_task.await?;
279
280    *example = Arc::into_inner(updated_example)
281        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
282        .into_inner()
283        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
284    Ok(())
285}
286
287async fn predict_teacher(
288    example: &mut Example,
289    backend: TeacherBackend,
290    batched: bool,
291    repetition_count: usize,
292    cache_only: bool,
293    step_progress: &crate::progress::StepProgress,
294) -> anyhow::Result<()> {
295    match backend {
296        TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
297            predict_anthropic(
298                example,
299                backend,
300                batched,
301                repetition_count,
302                cache_only,
303                step_progress,
304            )
305            .await
306        }
307        TeacherBackend::Gpt52 => {
308            predict_openai(
309                example,
310                backend,
311                batched,
312                repetition_count,
313                cache_only,
314                step_progress,
315            )
316            .await
317        }
318    }
319}
320
321async fn predict_anthropic(
322    example: &mut Example,
323    backend: TeacherBackend,
324    batched: bool,
325    repetition_count: usize,
326    cache_only: bool,
327    step_progress: &crate::progress::StepProgress,
328) -> anyhow::Result<()> {
329    let llm_model_name = backend.model_name();
330    let max_tokens = 16384;
331    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
332        let client = if batched {
333            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
334        } else {
335            AnthropicClient::plain()
336        };
337        client.expect("Failed to create Anthropic client")
338    });
339
340    let prompt = example.prompt.as_ref().context("Prompt is required")?;
341
342    for ix in 0..repetition_count {
343        if repetition_count > 1 {
344            step_progress.set_substatus(format!(
345                "running prediction {}/{}",
346                ix + 1,
347                repetition_count
348            ));
349        } else {
350            step_progress.set_substatus("running prediction");
351        }
352
353        let messages = vec![anthropic::Message {
354            role: anthropic::Role::User,
355            content: vec![anthropic::RequestContent::Text {
356                text: prompt.input.clone(),
357                cache_control: None,
358            }],
359        }];
360
361        let seed = if repetition_count > 1 { Some(ix) } else { None };
362        let Some(response) = llm_client
363            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
364            .await?
365        else {
366            // Request stashed for batched processing
367            return Ok(());
368        };
369
370        let actual_output = response
371            .content
372            .into_iter()
373            .filter_map(|content| match content {
374                anthropic::ResponseContent::Text { text } => Some(text),
375                _ => None,
376            })
377            .collect::<Vec<String>>()
378            .join("\n");
379
380        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
381
382        let prediction = ExamplePrediction {
383            actual_patch: Some(actual_patch),
384            actual_output,
385            actual_cursor,
386            error: None,
387            provider: if batched {
388                PredictionProvider::Teacher(backend)
389            } else {
390                PredictionProvider::TeacherNonBatching(backend)
391            },
392        };
393
394        example.predictions.push(prediction);
395    }
396    Ok(())
397}
398
399async fn predict_openai(
400    example: &mut Example,
401    backend: TeacherBackend,
402    batched: bool,
403    repetition_count: usize,
404    cache_only: bool,
405    step_progress: &crate::progress::StepProgress,
406) -> anyhow::Result<()> {
407    let llm_model_name = backend.model_name();
408    let max_tokens = 16384;
409    let llm_client = OPENAI_CLIENT.get_or_init(|| {
410        let client = if batched {
411            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
412        } else {
413            OpenAiClient::plain()
414        };
415        client.expect("Failed to create OpenAI client")
416    });
417
418    let prompt = example.prompt.as_ref().context("Prompt is required")?;
419
420    for ix in 0..repetition_count {
421        if repetition_count > 1 {
422            step_progress.set_substatus(format!(
423                "running prediction {}/{}",
424                ix + 1,
425                repetition_count
426            ));
427        } else {
428            step_progress.set_substatus("running prediction");
429        }
430
431        let messages = vec![open_ai::RequestMessage::User {
432            content: open_ai::MessageContent::Plain(prompt.input.clone()),
433        }];
434
435        let seed = if repetition_count > 1 { Some(ix) } else { None };
436        let Some(response) = llm_client
437            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
438            .await?
439        else {
440            // Request stashed for batched processing
441            return Ok(());
442        };
443
444        let actual_output = response
445            .choices
446            .into_iter()
447            .filter_map(|choice| match choice.message {
448                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
449                    open_ai::MessageContent::Plain(text) => text,
450                    open_ai::MessageContent::Multipart(parts) => parts
451                        .into_iter()
452                        .filter_map(|p| match p {
453                            open_ai::MessagePart::Text { text } => Some(text),
454                            _ => None,
455                        })
456                        .collect::<Vec<_>>()
457                        .join(""),
458                }),
459                _ => None,
460            })
461            .collect::<Vec<String>>()
462            .join("\n");
463
464        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
465
466        let prediction = ExamplePrediction {
467            actual_patch: Some(actual_patch),
468            actual_output,
469            actual_cursor,
470            error: None,
471            provider: if batched {
472                PredictionProvider::Teacher(backend)
473            } else {
474                PredictionProvider::TeacherNonBatching(backend)
475            },
476        };
477
478        example.predictions.push(prediction);
479    }
480    Ok(())
481}
482
483pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
484    match provider {
485        Some(PredictionProvider::Teacher(backend)) => match backend {
486            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
487                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
488                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
489                        .expect("Failed to create Anthropic client")
490                });
491                llm_client
492                    .sync_batches()
493                    .await
494                    .context("Failed to sync Anthropic batches")?;
495            }
496            TeacherBackend::Gpt52 => {
497                let llm_client = OPENAI_CLIENT.get_or_init(|| {
498                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
499                        .expect("Failed to create OpenAI client")
500                });
501                llm_client
502                    .sync_batches()
503                    .await
504                    .context("Failed to sync OpenAI batches")?;
505            }
506        },
507        _ => (),
508    };
509    Ok(())
510}