ep: Add --cache-only option to avoid sending requests (#48011)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/anthropic_client.rs |  8 ++++++--
crates/edit_prediction_cli/src/main.rs             |  3 +++
crates/edit_prediction_cli/src/openai_client.rs    |  8 ++++++--
crates/edit_prediction_cli/src/predict.rs          | 15 ++++++++++-----
crates/edit_prediction_cli/src/qa.rs               |  8 ++++++--
crates/edit_prediction_cli/src/repair.rs           |  8 ++++++--
6 files changed, 37 insertions(+), 13 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/anthropic_client.rs 🔗

@@ -267,13 +267,16 @@ impl BatchingLlmClient {
         max_tokens: u64,
         messages: Vec<Message>,
         seed: Option<usize>,
+        cache_only: bool,
     ) -> Result<Option<AnthropicResponse>> {
         let response = self.lookup(model, max_tokens, &messages, seed)?;
         if let Some(response) = response {
             return Ok(Some(response));
         }
 
-        self.mark_for_batch(model, max_tokens, &messages, seed)?;
+        if !cache_only {
+            self.mark_for_batch(model, max_tokens, &messages, seed)?;
+        }
 
         Ok(None)
     }
@@ -672,6 +675,7 @@ impl AnthropicClient {
         max_tokens: u64,
         messages: Vec<Message>,
         seed: Option<usize>,
+        cache_only: bool,
     ) -> Result<Option<AnthropicResponse>> {
         match self {
             AnthropicClient::Plain(plain_llm_client) => plain_llm_client
@@ -680,7 +684,7 @@ impl AnthropicClient {
                 .map(Some),
             AnthropicClient::Batch(batching_llm_client) => {
                 batching_llm_client
-                    .generate(model, max_tokens, messages, seed)
+                    .generate(model, max_tokens, messages, seed, cache_only)
                     .await
             }
             AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),

crates/edit_prediction_cli/src/main.rs 🔗

@@ -257,6 +257,9 @@ struct PredictArgs {
     provider: Option<PredictionProvider>,
     #[clap(long, default_value_t = 1)]
     repetitions: usize,
+    /// Only use cached responses, don't queue new requests for batching
+    #[clap(long)]
+    cache_only: bool,
 }
 
 #[derive(Debug, Args, Clone)]

crates/edit_prediction_cli/src/openai_client.rs 🔗

@@ -194,13 +194,16 @@ impl BatchingOpenAiClient {
         max_tokens: u64,
         messages: Vec<RequestMessage>,
         seed: Option<usize>,
+        cache_only: bool,
     ) -> Result<Option<OpenAiResponse>> {
         let response = self.lookup(model, max_tokens, &messages, seed)?;
         if let Some(response) = response {
             return Ok(Some(response));
         }
 
-        self.mark_for_batch(model, max_tokens, &messages, seed)?;
+        if !cache_only {
+            self.mark_for_batch(model, max_tokens, &messages, seed)?;
+        }
 
         Ok(None)
     }
@@ -643,6 +646,7 @@ impl OpenAiClient {
         max_tokens: u64,
         messages: Vec<RequestMessage>,
         seed: Option<usize>,
+        cache_only: bool,
     ) -> Result<Option<OpenAiResponse>> {
         match self {
             OpenAiClient::Plain(plain_client) => plain_client
@@ -651,7 +655,7 @@ impl OpenAiClient {
                 .map(Some),
             OpenAiClient::Batch(batching_client) => {
                 batching_client
-                    .generate(model, max_tokens, messages, seed)
+                    .generate(model, max_tokens, messages, seed, cache_only)
                     .await
             }
             OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -69,7 +69,7 @@ pub async fn run_prediction(
         .await?;
 
         let batched = matches!(provider, PredictionProvider::Teacher(..));
-        return predict_teacher(example, backend, batched, repetition_count).await;
+        return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await;
     }
 
     run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -262,12 +262,15 @@ async fn predict_teacher(
     backend: TeacherBackend,
     batched: bool,
     repetition_count: usize,
+    cache_only: bool,
 ) -> anyhow::Result<()> {
     match backend {
         TeacherBackend::Sonnet45 => {
-            predict_anthropic(example, backend, batched, repetition_count).await
+            predict_anthropic(example, backend, batched, repetition_count, cache_only).await
+        }
+        TeacherBackend::Gpt52 => {
+            predict_openai(example, backend, batched, repetition_count, cache_only).await
         }
-        TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await,
     }
 }
 
@@ -276,6 +279,7 @@ async fn predict_anthropic(
     backend: TeacherBackend,
     batched: bool,
     repetition_count: usize,
+    cache_only: bool,
 ) -> anyhow::Result<()> {
     let llm_model_name = backend.model_name();
     let max_tokens = 16384;
@@ -301,7 +305,7 @@ async fn predict_anthropic(
 
         let seed = if repetition_count > 1 { Some(ix) } else { None };
         let Some(response) = llm_client
-            .generate(llm_model_name, max_tokens, messages, seed)
+            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
             .await?
         else {
             // Request stashed for batched processing
@@ -341,6 +345,7 @@ async fn predict_openai(
     backend: TeacherBackend,
     batched: bool,
     repetition_count: usize,
+    cache_only: bool,
 ) -> anyhow::Result<()> {
     let llm_model_name = backend.model_name();
     let max_tokens = 16384;
@@ -362,7 +367,7 @@ async fn predict_openai(
 
         let seed = if repetition_count > 1 { Some(ix) } else { None };
         let Some(response) = llm_client
-            .generate(llm_model_name, max_tokens, messages, seed)
+            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
             .await?
         else {
             // Request stashed for batched processing

crates/edit_prediction_cli/src/qa.rs 🔗

@@ -172,7 +172,9 @@ impl QaClient {
                         cache_control: None,
                     }],
                 }];
-                let response = client.generate(model, max_tokens, messages, None).await?;
+                let response = client
+                    .generate(model, max_tokens, messages, None, false)
+                    .await?;
                 Ok(response.map(|r| {
                     r.content
                         .iter()
@@ -188,7 +190,9 @@ impl QaClient {
                 let messages = vec![open_ai::RequestMessage::User {
                     content: open_ai::MessageContent::Plain(prompt.to_string()),
                 }];
-                let response = client.generate(model, max_tokens, messages, None).await?;
+                let response = client
+                    .generate(model, max_tokens, messages, None, false)
+                    .await?;
                 Ok(response.map(|r| {
                     r.choices
                         .into_iter()

crates/edit_prediction_cli/src/repair.rs 🔗

@@ -152,7 +152,9 @@ impl RepairClient {
                         cache_control: None,
                     }],
                 }];
-                let response = client.generate(model, max_tokens, messages, None).await?;
+                let response = client
+                    .generate(model, max_tokens, messages, None, false)
+                    .await?;
                 Ok(response.map(|r| {
                     r.content
                         .iter()
@@ -168,7 +170,9 @@ impl RepairClient {
                 let messages = vec![open_ai::RequestMessage::User {
                     content: open_ai::MessageContent::Plain(prompt.to_string()),
                 }];
-                let response = client.generate(model, max_tokens, messages, None).await?;
+                let response = client
+                    .generate(model, max_tokens, messages, None, false)
+                    .await?;
                 Ok(response.map(|r| {
                     r.choices
                         .into_iter()