@@ -581,7 +581,7 @@ async fn stream_completion(
api_key: String,
request: Request,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
- let is_vision_request = request.messages.last().map_or(false, |message| match message {
+ let is_vision_request = request.messages.iter().any(|message| match message {
ChatMessage::User { content }
| ChatMessage::Assistant { content, .. }
| ChatMessage::Tool { content, .. } => {
@@ -736,4 +736,116 @@ mod tests {
assert_eq!(schema.data[0].id, "gpt-4");
assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
}
+
+ #[test]
+ fn test_vision_request_detection() {
+ fn message_contains_image(message: &ChatMessage) -> bool {
+ match message {
+ ChatMessage::User { content }
+ | ChatMessage::Assistant { content, .. }
+ | ChatMessage::Tool { content, .. } => {
+ matches!(content, ChatMessageContent::Multipart(parts) if
+ parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
+ }
+ _ => false,
+ }
+ }
+
+ // Helper function to detect if a request is a vision request
+ fn is_vision_request(request: &Request) -> bool {
+ request.messages.iter().any(message_contains_image)
+ }
+
+ let request_with_image_in_last = Request {
+ intent: true,
+ n: 1,
+ stream: true,
+ temperature: 0.1,
+ model: "claude-3.7-sonnet".to_string(),
+ messages: vec![
+ ChatMessage::User {
+ content: ChatMessageContent::Plain("Hello".to_string()),
+ },
+ ChatMessage::Assistant {
+ content: ChatMessageContent::Plain("How can I help?".to_string()),
+ tool_calls: vec![],
+ },
+ ChatMessage::User {
+ content: ChatMessageContent::Multipart(vec![
+ ChatMessagePart::Text {
+ text: "What's in this image?".to_string(),
+ },
+ ChatMessagePart::Image {
+ image_url: ImageUrl {
+ url: "".to_string(),
+ },
+ },
+ ]),
+ },
+ ],
+ tools: vec![],
+ tool_choice: None,
+ };
+
+ let request_with_image_in_earlier = Request {
+ intent: true,
+ n: 1,
+ stream: true,
+ temperature: 0.1,
+ model: "claude-3.7-sonnet".to_string(),
+ messages: vec![
+ ChatMessage::User {
+ content: ChatMessageContent::Plain("Hello".to_string()),
+ },
+ ChatMessage::User {
+ content: ChatMessageContent::Multipart(vec![
+ ChatMessagePart::Text {
+ text: "What's in this image?".to_string(),
+ },
+ ChatMessagePart::Image {
+ image_url: ImageUrl {
+ url: "".to_string(),
+ },
+ },
+ ]),
+ },
+ ChatMessage::Assistant {
+ content: ChatMessageContent::Plain("I see a cat in the image.".to_string()),
+ tool_calls: vec![],
+ },
+ ChatMessage::User {
+ content: ChatMessageContent::Plain("What color is it?".to_string()),
+ },
+ ],
+ tools: vec![],
+ tool_choice: None,
+ };
+
+ let request_with_no_images = Request {
+ intent: true,
+ n: 1,
+ stream: true,
+ temperature: 0.1,
+ model: "claude-3.7-sonnet".to_string(),
+ messages: vec![
+ ChatMessage::User {
+ content: ChatMessageContent::Plain("Hello".to_string()),
+ },
+ ChatMessage::Assistant {
+ content: ChatMessageContent::Plain("How can I help?".to_string()),
+ tool_calls: vec![],
+ },
+ ChatMessage::User {
+ content: ChatMessageContent::Plain("Tell me about Rust.".to_string()),
+ },
+ ],
+ tools: vec![],
+ tool_choice: None,
+ };
+
+ assert!(is_vision_request(&request_with_image_in_last));
+ assert!(is_vision_request(&request_with_image_in_earlier));
+
+ assert!(!is_vision_request(&request_with_no_images));
+ }
}