Detailed changes
@@ -267,13 +267,16 @@ impl BatchingLlmClient {
max_tokens: u64,
messages: Vec<Message>,
seed: Option<usize>,
+ cache_only: bool,
) -> Result<Option<AnthropicResponse>> {
let response = self.lookup(model, max_tokens, &messages, seed)?;
if let Some(response) = response {
return Ok(Some(response));
}
- self.mark_for_batch(model, max_tokens, &messages, seed)?;
+ if !cache_only {
+ self.mark_for_batch(model, max_tokens, &messages, seed)?;
+ }
Ok(None)
}
@@ -672,6 +675,7 @@ impl AnthropicClient {
max_tokens: u64,
messages: Vec<Message>,
seed: Option<usize>,
+ cache_only: bool,
) -> Result<Option<AnthropicResponse>> {
match self {
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
@@ -680,7 +684,7 @@ impl AnthropicClient {
.map(Some),
AnthropicClient::Batch(batching_llm_client) => {
batching_llm_client
- .generate(model, max_tokens, messages, seed)
+ .generate(model, max_tokens, messages, seed, cache_only)
.await
}
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
@@ -257,6 +257,9 @@ struct PredictArgs {
provider: Option<PredictionProvider>,
#[clap(long, default_value_t = 1)]
repetitions: usize,
+ /// Only use cached responses, don't queue new requests for batching
+ #[clap(long)]
+ cache_only: bool,
}
#[derive(Debug, Args, Clone)]
@@ -194,13 +194,16 @@ impl BatchingOpenAiClient {
max_tokens: u64,
messages: Vec<RequestMessage>,
seed: Option<usize>,
+ cache_only: bool,
) -> Result<Option<OpenAiResponse>> {
let response = self.lookup(model, max_tokens, &messages, seed)?;
if let Some(response) = response {
return Ok(Some(response));
}
- self.mark_for_batch(model, max_tokens, &messages, seed)?;
+ if !cache_only {
+ self.mark_for_batch(model, max_tokens, &messages, seed)?;
+ }
Ok(None)
}
@@ -643,6 +646,7 @@ impl OpenAiClient {
max_tokens: u64,
messages: Vec<RequestMessage>,
seed: Option<usize>,
+ cache_only: bool,
) -> Result<Option<OpenAiResponse>> {
match self {
OpenAiClient::Plain(plain_client) => plain_client
@@ -651,7 +655,7 @@ impl OpenAiClient {
.map(Some),
OpenAiClient::Batch(batching_client) => {
batching_client
- .generate(model, max_tokens, messages, seed)
+ .generate(model, max_tokens, messages, seed, cache_only)
.await
}
OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
@@ -69,7 +69,7 @@ pub async fn run_prediction(
.await?;
let batched = matches!(provider, PredictionProvider::Teacher(..));
- return predict_teacher(example, backend, batched, repetition_count).await;
+ return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await;
}
run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -262,12 +262,15 @@ async fn predict_teacher(
backend: TeacherBackend,
batched: bool,
repetition_count: usize,
+ cache_only: bool,
) -> anyhow::Result<()> {
match backend {
TeacherBackend::Sonnet45 => {
- predict_anthropic(example, backend, batched, repetition_count).await
+ predict_anthropic(example, backend, batched, repetition_count, cache_only).await
+ }
+ TeacherBackend::Gpt52 => {
+ predict_openai(example, backend, batched, repetition_count, cache_only).await
}
- TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await,
}
}
@@ -276,6 +279,7 @@ async fn predict_anthropic(
backend: TeacherBackend,
batched: bool,
repetition_count: usize,
+ cache_only: bool,
) -> anyhow::Result<()> {
let llm_model_name = backend.model_name();
let max_tokens = 16384;
@@ -301,7 +305,7 @@ async fn predict_anthropic(
let seed = if repetition_count > 1 { Some(ix) } else { None };
let Some(response) = llm_client
- .generate(llm_model_name, max_tokens, messages, seed)
+ .generate(llm_model_name, max_tokens, messages, seed, cache_only)
.await?
else {
// Request stashed for batched processing
@@ -341,6 +345,7 @@ async fn predict_openai(
backend: TeacherBackend,
batched: bool,
repetition_count: usize,
+ cache_only: bool,
) -> anyhow::Result<()> {
let llm_model_name = backend.model_name();
let max_tokens = 16384;
@@ -362,7 +367,7 @@ async fn predict_openai(
let seed = if repetition_count > 1 { Some(ix) } else { None };
let Some(response) = llm_client
- .generate(llm_model_name, max_tokens, messages, seed)
+ .generate(llm_model_name, max_tokens, messages, seed, cache_only)
.await?
else {
// Request stashed for batched processing
@@ -172,7 +172,9 @@ impl QaClient {
cache_control: None,
}],
}];
- let response = client.generate(model, max_tokens, messages, None).await?;
+ let response = client
+ .generate(model, max_tokens, messages, None, false)
+ .await?;
Ok(response.map(|r| {
r.content
.iter()
@@ -188,7 +190,9 @@ impl QaClient {
let messages = vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt.to_string()),
}];
- let response = client.generate(model, max_tokens, messages, None).await?;
+ let response = client
+ .generate(model, max_tokens, messages, None, false)
+ .await?;
Ok(response.map(|r| {
r.choices
.into_iter()
@@ -152,7 +152,9 @@ impl RepairClient {
cache_control: None,
}],
}];
- let response = client.generate(model, max_tokens, messages, None).await?;
+ let response = client
+ .generate(model, max_tokens, messages, None, false)
+ .await?;
Ok(response.map(|r| {
r.content
.iter()
@@ -168,7 +170,9 @@ impl RepairClient {
let messages = vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt.to_string()),
}];
- let response = client.generate(model, max_tokens, messages, None).await?;
+ let response = client
+ .generate(model, max_tokens, messages, None, false)
+ .await?;
Ok(response.map(|r| {
r.choices
.into_iter()