predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 10    progress::{ExampleProgress, InfoStyle, Step},
 11    retrieve_context::run_context_retrieval,
 12};
 13use anyhow::Context as _;
 14use edit_prediction::{DebugEvent, EditPredictionStore};
 15use futures::{FutureExt as _, StreamExt as _, future::Shared};
 16use gpui::{AppContext as _, AsyncApp, Task};
 17use std::{
 18    fs,
 19    sync::{
 20        Arc, Mutex, OnceLock,
 21        atomic::{AtomicUsize, Ordering::SeqCst},
 22    },
 23};
 24
 25static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 26static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 27
 28pub async fn run_prediction(
 29    example: &mut Example,
 30    args: &PredictArgs,
 31    app_state: Arc<EpAppState>,
 32    example_progress: &ExampleProgress,
 33    mut cx: AsyncApp,
 34) -> anyhow::Result<()> {
 35    let repetition_count = args.repetitions;
 36
 37    if let Some(existing_prediction) = example.predictions.first() {
 38        let has_prediction = existing_prediction.actual_patch.is_some()
 39            || !existing_prediction.actual_output.is_empty();
 40        if has_prediction {
 41            match args.provider {
 42                None => return Ok(()),
 43                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 44                Some(_) => example.predictions.clear(),
 45            }
 46        }
 47    }
 48
 49    let Some(provider) = args.provider else {
 50        anyhow::bail!(
 51            "No existing predictions found. Use --provider to specify which model to use for prediction."
 52        );
 53    };
 54
 55    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 56
 57    if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
 58        provider
 59    {
 60        let _step_progress = example_progress.start(Step::Predict);
 61
 62        run_format_prompt(
 63            example,
 64            &FormatPromptArgs { provider },
 65            app_state.clone(),
 66            example_progress,
 67            cx,
 68        )
 69        .await?;
 70
 71        let batched = matches!(provider, PredictionProvider::Teacher(..));
 72        return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await;
 73    }
 74
 75    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 76
 77    let step_progress = example_progress.start(Step::Predict);
 78
 79    if matches!(
 80        provider,
 81        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
 82    ) {
 83        step_progress.set_substatus("authenticating");
 84        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 85        AUTHENTICATED
 86            .get_or_init(|| {
 87                let client = app_state.client.clone();
 88                cx.spawn(async move |cx| {
 89                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 90                        eprintln!("Authentication failed: {}", e);
 91                    }
 92                })
 93                .shared()
 94            })
 95            .clone()
 96            .await;
 97    }
 98
 99    let ep_store = cx
100        .update(|cx| EditPredictionStore::try_global(cx))
101        .context("EditPredictionStore not initialized")?;
102
103    ep_store.update(&mut cx, |store, _cx| {
104        let model = match provider {
105            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
106            PredictionProvider::Zeta2(version) => {
107                edit_prediction::EditPredictionModel::Zeta2 { version }
108            }
109            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
110            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
111            PredictionProvider::Teacher(..)
112            | PredictionProvider::TeacherNonBatching(..)
113            | PredictionProvider::Repair => {
114                unreachable!()
115            }
116        };
117        store.set_edit_prediction_model(model);
118    });
119    step_progress.set_substatus("configuring model");
120    let state = example.state.as_ref().context("state must be set")?;
121    let run_dir = RUN_DIR.join(&example.spec.name);
122
123    let updated_example = Arc::new(Mutex::new(example.clone()));
124    let current_run_ix = Arc::new(AtomicUsize::new(0));
125
126    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
127    let debug_task = cx.background_spawn({
128        let updated_example = updated_example.clone();
129        let current_run_ix = current_run_ix.clone();
130        let run_dir = run_dir.clone();
131        async move {
132            while let Some(event) = debug_rx.next().await {
133                let run_ix = current_run_ix.load(SeqCst);
134                let mut updated_example = updated_example.lock().unwrap();
135
136                let run_dir = if repetition_count > 1 {
137                    run_dir.join(format!("{:03}", run_ix))
138                } else {
139                    run_dir.clone()
140                };
141
142                match event {
143                    DebugEvent::EditPredictionStarted(request) => {
144                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
145
146                        if let Some(prompt) = request.prompt {
147                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
148                            if matches!(provider, PredictionProvider::Zeta2(_)) {
149                                updated_example.prompt.get_or_insert(ExamplePrompt {
150                                    input: prompt,
151                                    expected_output: String::new(),
152                                    rejected_output: None,
153                                    provider,
154                                });
155                            }
156                        }
157                    }
158                    DebugEvent::EditPredictionFinished(request) => {
159                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
160
161                        if let Some(output) = request.model_output {
162                            fs::write(run_dir.join("prediction_response.md"), &output)?;
163                            updated_example
164                                .predictions
165                                .last_mut()
166                                .unwrap()
167                                .actual_output = output;
168                        }
169                        if run_ix >= repetition_count {
170                            break;
171                        }
172                    }
173                    _ => {}
174                }
175            }
176            anyhow::Ok(())
177        }
178    });
179
180    for ix in 0..repetition_count {
181        current_run_ix.store(ix, SeqCst);
182        let run_dir = if repetition_count > 1 {
183            run_dir.join(format!("{:03}", ix))
184        } else {
185            run_dir.clone()
186        };
187
188        fs::create_dir_all(&run_dir)?;
189        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
190            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
191        }
192        #[cfg(unix)]
193        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
194        #[cfg(windows)]
195        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
196
197        updated_example
198            .lock()
199            .unwrap()
200            .predictions
201            .push(ExamplePrediction {
202                actual_patch: None,
203                actual_output: String::new(),
204                actual_cursor_offset: None,
205                error: None,
206                provider,
207            });
208
209        step_progress.set_substatus("requesting prediction");
210        let prediction = ep_store
211            .update(&mut cx, |store, cx| {
212                store.request_prediction(
213                    &state.project,
214                    &state.buffer,
215                    state.cursor_position,
216                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
217                    cx,
218                )
219            })
220            .await?;
221
222        let actual_patch = prediction.and_then(|prediction| {
223            let prediction = prediction.prediction.ok()?;
224            prediction
225                .edit_preview
226                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
227        });
228
229        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
230
231        updated_example
232            .lock()
233            .unwrap()
234            .predictions
235            .last_mut()
236            .unwrap()
237            .actual_patch = actual_patch;
238
239        if ix == repetition_count - 1 {
240            let (info, style) = if has_prediction {
241                ("predicted", InfoStyle::Normal)
242            } else {
243                ("no prediction", InfoStyle::Warning)
244            };
245            step_progress.set_info(info, style);
246        }
247    }
248
249    ep_store.update(&mut cx, |store, _| {
250        store.remove_project(&state.project);
251    });
252    debug_task.await?;
253
254    *example = Arc::into_inner(updated_example)
255        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
256        .into_inner()
257        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
258    Ok(())
259}
260
261async fn predict_teacher(
262    example: &mut Example,
263    backend: TeacherBackend,
264    batched: bool,
265    repetition_count: usize,
266    cache_only: bool,
267) -> anyhow::Result<()> {
268    match backend {
269        TeacherBackend::Sonnet45 => {
270            predict_anthropic(example, backend, batched, repetition_count, cache_only).await
271        }
272        TeacherBackend::Gpt52 => {
273            predict_openai(example, backend, batched, repetition_count, cache_only).await
274        }
275    }
276}
277
278async fn predict_anthropic(
279    example: &mut Example,
280    backend: TeacherBackend,
281    batched: bool,
282    repetition_count: usize,
283    cache_only: bool,
284) -> anyhow::Result<()> {
285    let llm_model_name = backend.model_name();
286    let max_tokens = 16384;
287    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
288        let client = if batched {
289            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
290        } else {
291            AnthropicClient::plain()
292        };
293        client.expect("Failed to create Anthropic client")
294    });
295
296    let prompt = example.prompt.as_ref().context("Prompt is required")?;
297
298    for ix in 0..repetition_count {
299        let messages = vec![anthropic::Message {
300            role: anthropic::Role::User,
301            content: vec![anthropic::RequestContent::Text {
302                text: prompt.input.clone(),
303                cache_control: None,
304            }],
305        }];
306
307        let seed = if repetition_count > 1 { Some(ix) } else { None };
308        let Some(response) = llm_client
309            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
310            .await?
311        else {
312            // Request stashed for batched processing
313            return Ok(());
314        };
315
316        let actual_output = response
317            .content
318            .into_iter()
319            .filter_map(|content| match content {
320                anthropic::ResponseContent::Text { text } => Some(text),
321                _ => None,
322            })
323            .collect::<Vec<String>>()
324            .join("\n");
325
326        let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, &actual_output)?;
327
328        let prediction = ExamplePrediction {
329            actual_patch: Some(actual_patch),
330            actual_output,
331            actual_cursor_offset,
332            error: None,
333            provider: if batched {
334                PredictionProvider::Teacher(backend)
335            } else {
336                PredictionProvider::TeacherNonBatching(backend)
337            },
338        };
339
340        example.predictions.push(prediction);
341    }
342    Ok(())
343}
344
345async fn predict_openai(
346    example: &mut Example,
347    backend: TeacherBackend,
348    batched: bool,
349    repetition_count: usize,
350    cache_only: bool,
351) -> anyhow::Result<()> {
352    let llm_model_name = backend.model_name();
353    let max_tokens = 16384;
354    let llm_client = OPENAI_CLIENT.get_or_init(|| {
355        let client = if batched {
356            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
357        } else {
358            OpenAiClient::plain()
359        };
360        client.expect("Failed to create OpenAI client")
361    });
362
363    let prompt = example.prompt.as_ref().context("Prompt is required")?;
364
365    for ix in 0..repetition_count {
366        let messages = vec![open_ai::RequestMessage::User {
367            content: open_ai::MessageContent::Plain(prompt.input.clone()),
368        }];
369
370        let seed = if repetition_count > 1 { Some(ix) } else { None };
371        let Some(response) = llm_client
372            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
373            .await?
374        else {
375            // Request stashed for batched processing
376            return Ok(());
377        };
378
379        let actual_output = response
380            .choices
381            .into_iter()
382            .filter_map(|choice| match choice.message {
383                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
384                    open_ai::MessageContent::Plain(text) => text,
385                    open_ai::MessageContent::Multipart(parts) => parts
386                        .into_iter()
387                        .filter_map(|p| match p {
388                            open_ai::MessagePart::Text { text } => Some(text),
389                            _ => None,
390                        })
391                        .collect::<Vec<_>>()
392                        .join(""),
393                }),
394                _ => None,
395            })
396            .collect::<Vec<String>>()
397            .join("\n");
398
399        let (actual_patch, actual_cursor_offset) = TeacherPrompt::parse(example, &actual_output)?;
400
401        let prediction = ExamplePrediction {
402            actual_patch: Some(actual_patch),
403            actual_output,
404            actual_cursor_offset,
405            error: None,
406            provider: if batched {
407                PredictionProvider::Teacher(backend)
408            } else {
409                PredictionProvider::TeacherNonBatching(backend)
410            },
411        };
412
413        example.predictions.push(prediction);
414    }
415    Ok(())
416}
417
418pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
419    match provider {
420        Some(PredictionProvider::Teacher(backend)) => match backend {
421            TeacherBackend::Sonnet45 => {
422                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
423                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
424                        .expect("Failed to create Anthropic client")
425                });
426                llm_client
427                    .sync_batches()
428                    .await
429                    .context("Failed to sync Anthropic batches")?;
430            }
431            TeacherBackend::Gpt52 => {
432                let llm_client = OPENAI_CLIENT.get_or_init(|| {
433                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
434                        .expect("Failed to create OpenAI client")
435                });
436                llm_client
437                    .sync_batches()
438                    .await
439                    .context("Failed to sync OpenAI batches")?;
440            }
441        },
442        _ => (),
443    };
444    Ok(())
445}