diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index a2d7c2925b74906acc6bbe62356ff80cf6a2967c..3cc583ddde1cb03a4fd312b36f4358c0fbf3b4c1 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -225,6 +225,7 @@ impl MistralLanguageModel { fn stream_completion( &self, request: mistral::Request, + affinity: Option, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -243,8 +244,13 @@ impl MistralLanguageModel { provider: PROVIDER_NAME, }); }; - let request = - mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = mistral::stream_completion( + http_client.as_ref(), + &api_url, + &api_key, + request, + affinity, + ); let response = request.await?; Ok(response) }); @@ -331,8 +337,9 @@ impl LanguageModel for MistralLanguageModel { LanguageModelCompletionError, >, > { - let request = into_mistral(request, self.model.clone(), self.max_output_tokens()); - let stream = self.stream_completion(request, cx); + let (request, affinity) = + into_mistral(request, self.model.clone(), self.max_output_tokens()); + let stream = self.stream_completion(request, affinity, cx); async move { let stream = stream.await?; @@ -347,7 +354,7 @@ pub fn into_mistral( request: LanguageModelRequest, model: mistral::Model, max_output_tokens: Option, -) -> mistral::Request { +) -> (mistral::Request, Option) { let stream = true; let mut messages = Vec::new(); @@ -496,41 +503,44 @@ pub fn into_mistral( } } - mistral::Request { - model: model.id().to_string(), - messages, - stream, - max_tokens: max_output_tokens, - temperature: request.temperature, - response_format: None, - tool_choice: match request.tool_choice { - Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => { - Some(mistral::ToolChoice::Auto) - } - Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => { - Some(mistral::ToolChoice::Any) - } - Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None), - _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto), - _ => None, - }, - parallel_tool_calls: if !request.tools.is_empty() { - Some(false) - } else { - None + ( + mistral::Request { + model: model.id().to_string(), + messages, + stream, + max_tokens: max_output_tokens, + temperature: request.temperature, + response_format: None, + tool_choice: match request.tool_choice { + Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => { + Some(mistral::ToolChoice::Auto) + } + Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => { + Some(mistral::ToolChoice::Any) + } + Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None), + _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto), + _ => None, + }, + parallel_tool_calls: if !request.tools.is_empty() { + Some(false) + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| mistral::ToolDefinition::Function { + function: mistral::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), }, - tools: request - .tools - .into_iter() - .map(|tool| mistral::ToolDefinition::Function { - function: mistral::FunctionDefinition { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - }, - }) - .collect(), - } + request.thread_id, + ) } pub struct MistralEventMapper { @@ -867,7 +877,7 @@ mod tests { temperature: Some(0.5), tools: vec![], tool_choice: None, - thread_id: None, + thread_id: Some("abcdef".into()), prompt_id: None, intent: None, stop: vec![], @@ -875,12 +885,14 @@ mod tests { thinking_effort: None, }; - let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None); + let (mistral_request, affinity) = + into_mistral(request, mistral::Model::MistralSmallLatest, None); assert_eq!(mistral_request.model, "mistral-small-latest"); assert_eq!(mistral_request.temperature, Some(0.5)); assert_eq!(mistral_request.messages.len(), 2); assert!(mistral_request.stream); + assert_eq!(affinity, Some("abcdef".into())); } #[test] @@ -909,7 +921,7 @@ mod tests { thinking_effort: None, }; - let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None); + let (mistral_request, _) = into_mistral(request, mistral::Model::Pixtral12BLatest, None); assert_eq!(mistral_request.messages.len(), 1); assert!(matches!( diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index 04e641c23b5387966fe8228e4bc13aa27758e5b2..cc9f94304d989c69c3f5a4bd3763704314564a19 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -1,6 +1,6 @@ use anyhow::{Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::convert::TryFrom; @@ -437,13 +437,17 @@ pub async fn stream_completion( api_url: &str, api_key: &str, request: Request, + affinity: Option, ) -> Result>> { let uri = format!("{api_url}/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key.trim())); + .header("Authorization", format!("Bearer {}", api_key.trim())) + .when_some(affinity, |this, affinity| { + this.header("x-affinity", affinity) + }); let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?;