predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 10    progress::{ExampleProgress, InfoStyle, Step},
 11    retrieve_context::run_context_retrieval,
 12};
 13use anyhow::Context as _;
 14use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 15use futures::{FutureExt as _, StreamExt as _, future::Shared};
 16use gpui::{AppContext as _, AsyncApp, Task};
 17use std::{
 18    fs,
 19    sync::{
 20        Arc, Mutex, OnceLock,
 21        atomic::{AtomicUsize, Ordering::SeqCst},
 22    },
 23};
 24use zeta_prompt::ZetaFormat;
 25
 26static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 27static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 28
 29pub async fn run_prediction(
 30    example: &mut Example,
 31    args: &PredictArgs,
 32    app_state: Arc<EpAppState>,
 33    example_progress: &ExampleProgress,
 34    mut cx: AsyncApp,
 35) -> anyhow::Result<()> {
 36    let repetition_count = args.repetitions;
 37
 38    if let Some(existing_prediction) = example.predictions.first() {
 39        let has_prediction = existing_prediction.actual_patch.is_some()
 40            || !existing_prediction.actual_output.is_empty();
 41        if has_prediction {
 42            match args.provider {
 43                None => return Ok(()),
 44                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 45                Some(_) => example.predictions.clear(),
 46            }
 47        }
 48    }
 49
 50    let Some(provider) = args.provider else {
 51        anyhow::bail!(
 52            "No existing predictions found. Use --provider to specify which model to use for prediction."
 53        );
 54    };
 55
 56    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 57
 58    if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
 59        provider
 60    {
 61        let _step_progress = example_progress.start(Step::Predict);
 62
 63        run_format_prompt(
 64            example,
 65            &FormatPromptArgs { provider },
 66            app_state.clone(),
 67            example_progress,
 68            cx,
 69        )
 70        .await?;
 71
 72        let batched = matches!(provider, PredictionProvider::Teacher(..));
 73        return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await;
 74    }
 75
 76    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 77
 78    let step_progress = example_progress.start(Step::Predict);
 79
 80    if matches!(
 81        provider,
 82        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
 83    ) {
 84        step_progress.set_substatus("authenticating");
 85        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 86        AUTHENTICATED
 87            .get_or_init(|| {
 88                let client = app_state.client.clone();
 89                cx.spawn(async move |cx| {
 90                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 91                        eprintln!("Authentication failed: {}", e);
 92                    }
 93                })
 94                .shared()
 95            })
 96            .clone()
 97            .await;
 98    }
 99
100    let ep_store = cx
101        .update(|cx| EditPredictionStore::try_global(cx))
102        .context("EditPredictionStore not initialized")?;
103
104    ep_store.update(&mut cx, |store, _cx| {
105        let model = match provider {
106            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
107            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2,
108            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
109            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
110            PredictionProvider::Teacher(..)
111            | PredictionProvider::TeacherNonBatching(..)
112            | PredictionProvider::Repair => {
113                unreachable!()
114            }
115        };
116        store.set_edit_prediction_model(model);
117
118        // If user specified a non-default Zeta2 version, configure raw endpoint.
119        // ZED_ZETA_MODEL env var is optional.
120        if let PredictionProvider::Zeta2(format) = provider {
121            if format != ZetaFormat::default() {
122                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
123                store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
124            }
125        }
126    });
127    step_progress.set_substatus("configuring model");
128    let state = example.state.as_ref().context("state must be set")?;
129    let run_dir = RUN_DIR.join(&example.spec.name);
130
131    let updated_example = Arc::new(Mutex::new(example.clone()));
132    let current_run_ix = Arc::new(AtomicUsize::new(0));
133
134    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
135    let debug_task = cx.background_spawn({
136        let updated_example = updated_example.clone();
137        let current_run_ix = current_run_ix.clone();
138        let run_dir = run_dir.clone();
139        async move {
140            while let Some(event) = debug_rx.next().await {
141                let run_ix = current_run_ix.load(SeqCst);
142                let mut updated_example = updated_example.lock().unwrap();
143
144                let run_dir = if repetition_count > 1 {
145                    run_dir.join(format!("{:03}", run_ix))
146                } else {
147                    run_dir.clone()
148                };
149
150                match event {
151                    DebugEvent::EditPredictionStarted(request) => {
152                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
153
154                        if let Some(prompt) = request.prompt {
155                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
156                            if matches!(provider, PredictionProvider::Zeta2(_)) {
157                                updated_example.prompt.get_or_insert(ExamplePrompt {
158                                    input: prompt,
159                                    expected_output: String::new(),
160                                    rejected_output: None,
161                                    provider,
162                                    prefill: None,
163                                });
164                            }
165                        }
166                    }
167                    DebugEvent::EditPredictionFinished(request) => {
168                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
169
170                        if let Some(output) = request.model_output {
171                            fs::write(run_dir.join("prediction_response.md"), &output)?;
172                            updated_example
173                                .predictions
174                                .last_mut()
175                                .unwrap()
176                                .actual_output = output;
177                        }
178                        if run_ix >= repetition_count {
179                            break;
180                        }
181                    }
182                    _ => {}
183                }
184            }
185            anyhow::Ok(())
186        }
187    });
188
189    for ix in 0..repetition_count {
190        current_run_ix.store(ix, SeqCst);
191        let run_dir = if repetition_count > 1 {
192            run_dir.join(format!("{:03}", ix))
193        } else {
194            run_dir.clone()
195        };
196
197        fs::create_dir_all(&run_dir)?;
198        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
199            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
200        }
201        #[cfg(unix)]
202        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
203        #[cfg(windows)]
204        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
205
206        updated_example
207            .lock()
208            .unwrap()
209            .predictions
210            .push(ExamplePrediction {
211                actual_patch: None,
212                actual_output: String::new(),
213                actual_cursor: None,
214                error: None,
215                provider,
216            });
217
218        step_progress.set_substatus("requesting prediction");
219        let prediction = ep_store
220            .update(&mut cx, |store, cx| {
221                store.request_prediction(
222                    &state.project,
223                    &state.buffer,
224                    state.cursor_position,
225                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
226                    cx,
227                )
228            })
229            .await?;
230
231        let actual_patch = prediction.and_then(|prediction| {
232            let prediction = prediction.prediction.ok()?;
233            prediction
234                .edit_preview
235                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
236        });
237
238        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
239
240        updated_example
241            .lock()
242            .unwrap()
243            .predictions
244            .last_mut()
245            .unwrap()
246            .actual_patch = actual_patch;
247
248        if ix == repetition_count - 1 {
249            let (info, style) = if has_prediction {
250                ("predicted", InfoStyle::Normal)
251            } else {
252                ("no prediction", InfoStyle::Warning)
253            };
254            step_progress.set_info(info, style);
255        }
256    }
257
258    ep_store.update(&mut cx, |store, _| {
259        store.remove_project(&state.project);
260    });
261    debug_task.await?;
262
263    *example = Arc::into_inner(updated_example)
264        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
265        .into_inner()
266        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
267    Ok(())
268}
269
270async fn predict_teacher(
271    example: &mut Example,
272    backend: TeacherBackend,
273    batched: bool,
274    repetition_count: usize,
275    cache_only: bool,
276) -> anyhow::Result<()> {
277    match backend {
278        TeacherBackend::Sonnet45 => {
279            predict_anthropic(example, backend, batched, repetition_count, cache_only).await
280        }
281        TeacherBackend::Gpt52 => {
282            predict_openai(example, backend, batched, repetition_count, cache_only).await
283        }
284    }
285}
286
287async fn predict_anthropic(
288    example: &mut Example,
289    backend: TeacherBackend,
290    batched: bool,
291    repetition_count: usize,
292    cache_only: bool,
293) -> anyhow::Result<()> {
294    let llm_model_name = backend.model_name();
295    let max_tokens = 16384;
296    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
297        let client = if batched {
298            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
299        } else {
300            AnthropicClient::plain()
301        };
302        client.expect("Failed to create Anthropic client")
303    });
304
305    let prompt = example.prompt.as_ref().context("Prompt is required")?;
306
307    for ix in 0..repetition_count {
308        let messages = vec![anthropic::Message {
309            role: anthropic::Role::User,
310            content: vec![anthropic::RequestContent::Text {
311                text: prompt.input.clone(),
312                cache_control: None,
313            }],
314        }];
315
316        let seed = if repetition_count > 1 { Some(ix) } else { None };
317        let Some(response) = llm_client
318            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
319            .await?
320        else {
321            // Request stashed for batched processing
322            return Ok(());
323        };
324
325        let actual_output = response
326            .content
327            .into_iter()
328            .filter_map(|content| match content {
329                anthropic::ResponseContent::Text { text } => Some(text),
330                _ => None,
331            })
332            .collect::<Vec<String>>()
333            .join("\n");
334
335        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
336
337        let prediction = ExamplePrediction {
338            actual_patch: Some(actual_patch),
339            actual_output,
340            actual_cursor,
341            error: None,
342            provider: if batched {
343                PredictionProvider::Teacher(backend)
344            } else {
345                PredictionProvider::TeacherNonBatching(backend)
346            },
347        };
348
349        example.predictions.push(prediction);
350    }
351    Ok(())
352}
353
354async fn predict_openai(
355    example: &mut Example,
356    backend: TeacherBackend,
357    batched: bool,
358    repetition_count: usize,
359    cache_only: bool,
360) -> anyhow::Result<()> {
361    let llm_model_name = backend.model_name();
362    let max_tokens = 16384;
363    let llm_client = OPENAI_CLIENT.get_or_init(|| {
364        let client = if batched {
365            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
366        } else {
367            OpenAiClient::plain()
368        };
369        client.expect("Failed to create OpenAI client")
370    });
371
372    let prompt = example.prompt.as_ref().context("Prompt is required")?;
373
374    for ix in 0..repetition_count {
375        let messages = vec![open_ai::RequestMessage::User {
376            content: open_ai::MessageContent::Plain(prompt.input.clone()),
377        }];
378
379        let seed = if repetition_count > 1 { Some(ix) } else { None };
380        let Some(response) = llm_client
381            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
382            .await?
383        else {
384            // Request stashed for batched processing
385            return Ok(());
386        };
387
388        let actual_output = response
389            .choices
390            .into_iter()
391            .filter_map(|choice| match choice.message {
392                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
393                    open_ai::MessageContent::Plain(text) => text,
394                    open_ai::MessageContent::Multipart(parts) => parts
395                        .into_iter()
396                        .filter_map(|p| match p {
397                            open_ai::MessagePart::Text { text } => Some(text),
398                            _ => None,
399                        })
400                        .collect::<Vec<_>>()
401                        .join(""),
402                }),
403                _ => None,
404            })
405            .collect::<Vec<String>>()
406            .join("\n");
407
408        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
409
410        let prediction = ExamplePrediction {
411            actual_patch: Some(actual_patch),
412            actual_output,
413            actual_cursor,
414            error: None,
415            provider: if batched {
416                PredictionProvider::Teacher(backend)
417            } else {
418                PredictionProvider::TeacherNonBatching(backend)
419            },
420        };
421
422        example.predictions.push(prediction);
423    }
424    Ok(())
425}
426
427pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
428    match provider {
429        Some(PredictionProvider::Teacher(backend)) => match backend {
430            TeacherBackend::Sonnet45 => {
431                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
432                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
433                        .expect("Failed to create Anthropic client")
434                });
435                llm_client
436                    .sync_batches()
437                    .await
438                    .context("Failed to sync Anthropic batches")?;
439            }
440            TeacherBackend::Gpt52 => {
441                let llm_client = OPENAI_CLIENT.get_or_init(|| {
442                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
443                        .expect("Failed to create OpenAI client")
444                });
445                llm_client
446                    .sync_batches()
447                    .await
448                    .context("Failed to sync OpenAI batches")?;
449            }
450        },
451        _ => (),
452    };
453    Ok(())
454}