diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 8975115d907875569f63e4247cf7edcdbcb91f8a..a80cacfc4a02521af74b32c34cc3360e9665a7d9 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -11,8 +11,8 @@ use language_model::{ LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; use ollama::{ - ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, - OllamaToolCall, get_models, show_model, stream_chat_completion, + ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionCall, + OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -282,59 +282,85 @@ impl OllamaLanguageModel { fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { let supports_vision = self.model.supports_vision.unwrap_or(false); - ChatRequest { - model: self.model.name.clone(), - messages: request - .messages - .into_iter() - .map(|msg| { - let images = if supports_vision { - msg.content - .iter() - .filter_map(|content| match content { - MessageContent::Image(image) => Some(image.source.to_string()), - _ => None, - }) - .collect::>() - } else { - vec![] - }; - - match msg.role { - Role::User => ChatMessage::User { + let mut messages = Vec::with_capacity(request.messages.len()); + + for mut msg in request.messages.into_iter() { + let images = if supports_vision { + msg.content + .iter() + .filter_map(|content| match content { + MessageContent::Image(image) => Some(image.source.to_string()), + _ => None, + }) + .collect::>() + } else { + vec![] + }; + + match msg.role { + Role::User => { + for tool_result in msg + .content + .extract_if(.., |x| matches!(x, MessageContent::ToolResult(..))) + { + match tool_result { + MessageContent::ToolResult(tool_result) => { + messages.push(ChatMessage::Tool { + tool_name: tool_result.tool_name.to_string(), + content: tool_result.content.to_str().unwrap_or("").to_string(), + }) + } + _ => unreachable!("Only tool result should be extracted"), + } + } + if !msg.content.is_empty() { + messages.push(ChatMessage::User { content: msg.string_contents(), images: if images.is_empty() { None } else { Some(images) }, - }, - Role::Assistant => { - let content = msg.string_contents(); - let thinking = - msg.content.into_iter().find_map(|content| match content { - MessageContent::Thinking { text, .. } if !text.is_empty() => { - Some(text) - } - _ => None, - }); - ChatMessage::Assistant { - content, - tool_calls: None, - images: if images.is_empty() { - None - } else { - Some(images) - }, - thinking, + }) + } + } + Role::Assistant => { + let content = msg.string_contents(); + let mut thinking = None; + let mut tool_calls = Vec::new(); + for content in msg.content.into_iter() { + match content { + MessageContent::Thinking { text, .. } if !text.is_empty() => { + thinking = Some(text) } + MessageContent::ToolUse(tool_use) => { + tool_calls.push(OllamaToolCall::Function(OllamaFunctionCall { + name: tool_use.name.to_string(), + arguments: tool_use.input, + })); + } + _ => (), } - Role::System => ChatMessage::System { - content: msg.string_contents(), - }, } - }) - .collect(), + messages.push(ChatMessage::Assistant { + content, + tool_calls: Some(tool_calls), + images: if images.is_empty() { + None + } else { + Some(images) + }, + thinking, + }) + } + Role::System => messages.push(ChatMessage::System { + content: msg.string_contents(), + }), + } + } + ChatRequest { + model: self.model.name.clone(), + messages, keep_alive: self.model.keep_alive.clone().unwrap_or_default(), stream: true, options: Some(ChatOptions { @@ -483,6 +509,9 @@ fn map_to_language_model_completion_events( ChatMessage::System { content } => { events.push(Ok(LanguageModelCompletionEvent::Text(content))); } + ChatMessage::Tool { content, .. } => { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } ChatMessage::Assistant { content, tool_calls, diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 64cd1cc0cbc06607ee9b3b72ee81cbeb9489c344..3c935d2152556393829f648abe31a717b239ce76 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -117,6 +117,10 @@ pub enum ChatMessage { System { content: String, }, + Tool { + tool_name: String, + content: String, + }, } #[derive(Serialize, Deserialize, Debug)]