mistral: Add x-affinity header (#48584)

Vianney le Clément de Saint-Marcq created

Mistral Vibe, the official coding agent from Mistral, sets the
x-affinity header to the session ID to enable prompt caching. This patch
implements the same mechanism, resulting in a faster agent loop.

Release Notes:

- Added prompt caching for Mistral AI.

Change summary

crates/language_models/src/provider/mistral.rs | 96 +++++++++++--------
crates/mistral/src/mistral.rs                  |  8 +
2 files changed, 60 insertions(+), 44 deletions(-)

Detailed changes

crates/language_models/src/provider/mistral.rs 🔗

@@ -225,6 +225,7 @@ impl MistralLanguageModel {
     fn stream_completion(
         &self,
         request: mistral::Request,
+        affinity: Option<String>,
         cx: &AsyncApp,
     ) -> BoxFuture<
         'static,
@@ -243,8 +244,13 @@ impl MistralLanguageModel {
                     provider: PROVIDER_NAME,
                 });
             };
-            let request =
-                mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+            let request = mistral::stream_completion(
+                http_client.as_ref(),
+                &api_url,
+                &api_key,
+                request,
+                affinity,
+            );
             let response = request.await?;
             Ok(response)
         });
@@ -331,8 +337,9 @@ impl LanguageModel for MistralLanguageModel {
             LanguageModelCompletionError,
         >,
     > {
-        let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
-        let stream = self.stream_completion(request, cx);
+        let (request, affinity) =
+            into_mistral(request, self.model.clone(), self.max_output_tokens());
+        let stream = self.stream_completion(request, affinity, cx);
 
         async move {
             let stream = stream.await?;
@@ -347,7 +354,7 @@ pub fn into_mistral(
     request: LanguageModelRequest,
     model: mistral::Model,
     max_output_tokens: Option<u64>,
-) -> mistral::Request {
+) -> (mistral::Request, Option<String>) {
     let stream = true;
 
     let mut messages = Vec::new();
@@ -496,41 +503,44 @@ pub fn into_mistral(
         }
     }
 
-    mistral::Request {
-        model: model.id().to_string(),
-        messages,
-        stream,
-        max_tokens: max_output_tokens,
-        temperature: request.temperature,
-        response_format: None,
-        tool_choice: match request.tool_choice {
-            Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
-                Some(mistral::ToolChoice::Auto)
-            }
-            Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
-                Some(mistral::ToolChoice::Any)
-            }
-            Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
-            _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
-            _ => None,
-        },
-        parallel_tool_calls: if !request.tools.is_empty() {
-            Some(false)
-        } else {
-            None
+    (
+        mistral::Request {
+            model: model.id().to_string(),
+            messages,
+            stream,
+            max_tokens: max_output_tokens,
+            temperature: request.temperature,
+            response_format: None,
+            tool_choice: match request.tool_choice {
+                Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
+                    Some(mistral::ToolChoice::Auto)
+                }
+                Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
+                    Some(mistral::ToolChoice::Any)
+                }
+                Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
+                _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
+                _ => None,
+            },
+            parallel_tool_calls: if !request.tools.is_empty() {
+                Some(false)
+            } else {
+                None
+            },
+            tools: request
+                .tools
+                .into_iter()
+                .map(|tool| mistral::ToolDefinition::Function {
+                    function: mistral::FunctionDefinition {
+                        name: tool.name,
+                        description: Some(tool.description),
+                        parameters: Some(tool.input_schema),
+                    },
+                })
+                .collect(),
         },
-        tools: request
-            .tools
-            .into_iter()
-            .map(|tool| mistral::ToolDefinition::Function {
-                function: mistral::FunctionDefinition {
-                    name: tool.name,
-                    description: Some(tool.description),
-                    parameters: Some(tool.input_schema),
-                },
-            })
-            .collect(),
-    }
+        request.thread_id,
+    )
 }
 
 pub struct MistralEventMapper {
@@ -867,7 +877,7 @@ mod tests {
             temperature: Some(0.5),
             tools: vec![],
             tool_choice: None,
-            thread_id: None,
+            thread_id: Some("abcdef".into()),
             prompt_id: None,
             intent: None,
             stop: vec![],
@@ -875,12 +885,14 @@ mod tests {
             thinking_effort: None,
         };
 
-        let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
+        let (mistral_request, affinity) =
+            into_mistral(request, mistral::Model::MistralSmallLatest, None);
 
         assert_eq!(mistral_request.model, "mistral-small-latest");
         assert_eq!(mistral_request.temperature, Some(0.5));
         assert_eq!(mistral_request.messages.len(), 2);
         assert!(mistral_request.stream);
+        assert_eq!(affinity, Some("abcdef".into()));
     }
 
     #[test]
@@ -909,7 +921,7 @@ mod tests {
             thinking_effort: None,
         };
 
-        let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
+        let (mistral_request, _) = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
 
         assert_eq!(mistral_request.messages.len(), 1);
         assert!(matches!(

crates/mistral/src/mistral.rs 🔗

@@ -1,6 +1,6 @@
 use anyhow::{Result, anyhow};
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
 use std::convert::TryFrom;
@@ -437,13 +437,17 @@ pub async fn stream_completion(
     api_url: &str,
     api_key: &str,
     request: Request,
+    affinity: Option<String>,
 ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
     let uri = format!("{api_url}/chat/completions");
     let request_builder = HttpRequest::builder()
         .method(Method::POST)
         .uri(uri)
         .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {}", api_key.trim()));
+        .header("Authorization", format!("Bearer {}", api_key.trim()))
+        .when_some(affinity, |this, affinity| {
+            this.header("x-affinity", affinity)
+        });
 
     let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
     let mut response = client.send(request).await?;