agent2: Implement prompt caching (#36236)

Bennet Bo Fenner created

Release Notes:

- N/A

Change summary

crates/agent2/src/tests/mod.rs | 135 +++++++++++++++++++++++++++++++++++
crates/agent2/src/thread.rs    |   8 ++
2 files changed, 140 insertions(+), 3 deletions(-)

Detailed changes

crates/agent2/src/tests/mod.rs 🔗

@@ -16,6 +16,7 @@ use language_model::{
     LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
     Role, StopReason, fake_provider::FakeLanguageModel,
 };
+use pretty_assertions::assert_eq;
 use project::Project;
 use prompt_store::ProjectContext;
 use reqwest_client::ReqwestClient;
@@ -129,6 +130,134 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
     );
 }
 
+#[gpui::test]
+async fn test_prompt_caching(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    // Send initial user message and verify it's cached
+    thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Message 1"], cx)
+    });
+    cx.run_until_parked();
+
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(
+        completion.messages[1..],
+        vec![LanguageModelRequestMessage {
+            role: Role::User,
+            content: vec!["Message 1".into()],
+            cache: true
+        }]
+    );
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
+        "Response to Message 1".into(),
+    ));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    // Send another user message and verify only the latest is cached
+    thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Message 2"], cx)
+    });
+    cx.run_until_parked();
+
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(
+        completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Message 1".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec!["Response to Message 1".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Message 2".into()],
+                cache: true
+            }
+        ]
+    );
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
+        "Response to Message 2".into(),
+    ));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    // Simulate a tool call and verify that the latest tool result is cached
+    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
+    thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
+    });
+    cx.run_until_parked();
+
+    let tool_use = LanguageModelToolUse {
+        id: "tool_1".into(),
+        name: EchoTool.name().into(),
+        raw_input: json!({"text": "test"}).to_string(),
+        input: json!({"text": "test"}),
+        is_input_complete: true,
+    };
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    let completion = fake_model.pending_completions().pop().unwrap();
+    let tool_result = LanguageModelToolResult {
+        tool_use_id: "tool_1".into(),
+        tool_name: EchoTool.name().into(),
+        is_error: false,
+        content: "test".into(),
+        output: Some("test".into()),
+    };
+    assert_eq!(
+        completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Message 1".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec!["Response to Message 1".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Message 2".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec!["Response to Message 2".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Use the echo tool".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![MessageContent::ToolUse(tool_use)],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![MessageContent::ToolResult(tool_result)],
+                cache: true
+            }
+        ]
+    );
+}
+
 #[gpui::test]
 #[ignore = "can't run on CI yet"]
 async fn test_basic_tool_calls(cx: &mut TestAppContext) {
@@ -440,7 +569,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
             LanguageModelRequestMessage {
                 role: Role::User,
                 content: vec![MessageContent::ToolResult(tool_result.clone())],
-                cache: false
+                cache: true
             },
         ]
     );
@@ -481,7 +610,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
             LanguageModelRequestMessage {
                 role: Role::User,
                 content: vec!["Continue where you left off".into()],
-                cache: false
+                cache: true
             }
         ]
     );
@@ -574,7 +703,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
             LanguageModelRequestMessage {
                 role: Role::User,
                 content: vec!["ghi".into()],
-                cache: false
+                cache: true
             }
         ]
     );

crates/agent2/src/thread.rs 🔗

@@ -1041,6 +1041,14 @@ impl Thread {
             messages.extend(message.to_request());
         }
 
+        if let Some(last_user_message) = messages
+            .iter_mut()
+            .rev()
+            .find(|message| message.role == Role::User)
+        {
+            last_user_message.cache = true;
+        }
+
         messages
     }