@@ -16,6 +16,7 @@ use language_model::{
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
Role, StopReason, fake_provider::FakeLanguageModel,
};
+use pretty_assertions::assert_eq;
use project::Project;
use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient;
@@ -129,6 +130,134 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
);
}
+#[gpui::test]
+async fn test_prompt_caching(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ // Send initial user message and verify it's cached
+ thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Message 1"], cx)
+ });
+ cx.run_until_parked();
+
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(
+ completion.messages[1..],
+ vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Message 1".into()],
+ cache: true
+ }]
+ );
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
+ "Response to Message 1".into(),
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Send another user message and verify only the latest is cached
+ thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Message 2"], cx)
+ });
+ cx.run_until_parked();
+
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(
+ completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Message 1".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec!["Response to Message 1".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Message 2".into()],
+ cache: true
+ }
+ ]
+ );
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
+ "Response to Message 2".into(),
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Simulate a tool call and verify that the latest tool result is cached
+ thread.update(cx, |thread, _| thread.add_tool(EchoTool));
+ thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
+ });
+ cx.run_until_parked();
+
+ let tool_use = LanguageModelToolUse {
+ id: "tool_1".into(),
+ name: EchoTool.name().into(),
+ raw_input: json!({"text": "test"}).to_string(),
+ input: json!({"text": "test"}),
+ is_input_complete: true,
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let completion = fake_model.pending_completions().pop().unwrap();
+ let tool_result = LanguageModelToolResult {
+ tool_use_id: "tool_1".into(),
+ tool_name: EchoTool.name().into(),
+ is_error: false,
+ content: "test".into(),
+ output: Some("test".into()),
+ };
+ assert_eq!(
+ completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Message 1".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec!["Response to Message 1".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Message 2".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec!["Response to Message 2".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Use the echo tool".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::ToolUse(tool_use)],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::ToolResult(tool_result)],
+ cache: true
+ }
+ ]
+ );
+}
+
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
@@ -440,7 +569,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result.clone())],
- cache: false
+ cache: true
},
]
);
@@ -481,7 +610,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
- cache: false
+ cache: true
}
]
);
@@ -574,7 +703,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
LanguageModelRequestMessage {
role: Role::User,
content: vec!["ghi".into()],
- cache: false
+ cache: true
}
]
);