mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => anyhow::bail!("invalid role '{value}'"),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    #[default]
 50    CodestralLatest,
 51
 52    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 53    MistralLargeLatest,
 54    #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
 55    MistralMediumLatest,
 56    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 57    MistralSmallLatest,
 58
 59    #[serde(rename = "magistral-medium-latest", alias = "magistral-medium-latest")]
 60    MagistralMediumLatest,
 61    #[serde(rename = "magistral-small-latest", alias = "magistral-small-latest")]
 62    MagistralSmallLatest,
 63
 64    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 65    OpenMistralNemo,
 66    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 67    OpenCodestralMamba,
 68
 69    #[serde(rename = "devstral-medium-latest", alias = "devstral-medium-latest")]
 70    DevstralMediumLatest,
 71    #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
 72    DevstralSmallLatest,
 73
 74    #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
 75    Pixtral12BLatest,
 76    #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
 77    PixtralLargeLatest,
 78
 79    #[serde(rename = "custom")]
 80    Custom {
 81        name: String,
 82        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 83        display_name: Option<String>,
 84        max_tokens: u64,
 85        max_output_tokens: Option<u64>,
 86        max_completion_tokens: Option<u64>,
 87        supports_tools: Option<bool>,
 88        supports_images: Option<bool>,
 89        supports_thinking: Option<bool>,
 90    },
 91}
 92
 93impl Model {
 94    pub fn default_fast() -> Self {
 95        Model::MistralSmallLatest
 96    }
 97
 98    pub fn from_id(id: &str) -> Result<Self> {
 99        match id {
100            "codestral-latest" => Ok(Self::CodestralLatest),
101            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
102            "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
103            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
104            "magistral-medium-latest" => Ok(Self::MagistralMediumLatest),
105            "magistral-small-latest" => Ok(Self::MagistralSmallLatest),
106            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
107            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
108            "devstral-medium-latest" => Ok(Self::DevstralMediumLatest),
109            "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
110            "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
111            "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
112            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
113        }
114    }
115
116    pub fn id(&self) -> &str {
117        match self {
118            Self::CodestralLatest => "codestral-latest",
119            Self::MistralLargeLatest => "mistral-large-latest",
120            Self::MistralMediumLatest => "mistral-medium-latest",
121            Self::MistralSmallLatest => "mistral-small-latest",
122            Self::MagistralMediumLatest => "magistral-medium-latest",
123            Self::MagistralSmallLatest => "magistral-small-latest",
124            Self::OpenMistralNemo => "open-mistral-nemo",
125            Self::OpenCodestralMamba => "open-codestral-mamba",
126            Self::DevstralMediumLatest => "devstral-medium-latest",
127            Self::DevstralSmallLatest => "devstral-small-latest",
128            Self::Pixtral12BLatest => "pixtral-12b-latest",
129            Self::PixtralLargeLatest => "pixtral-large-latest",
130            Self::Custom { name, .. } => name,
131        }
132    }
133
134    pub fn display_name(&self) -> &str {
135        match self {
136            Self::CodestralLatest => "codestral-latest",
137            Self::MistralLargeLatest => "mistral-large-latest",
138            Self::MistralMediumLatest => "mistral-medium-latest",
139            Self::MistralSmallLatest => "mistral-small-latest",
140            Self::MagistralMediumLatest => "magistral-medium-latest",
141            Self::MagistralSmallLatest => "magistral-small-latest",
142            Self::OpenMistralNemo => "open-mistral-nemo",
143            Self::OpenCodestralMamba => "open-codestral-mamba",
144            Self::DevstralMediumLatest => "devstral-medium-latest",
145            Self::DevstralSmallLatest => "devstral-small-latest",
146            Self::Pixtral12BLatest => "pixtral-12b-latest",
147            Self::PixtralLargeLatest => "pixtral-large-latest",
148            Self::Custom {
149                name, display_name, ..
150            } => display_name.as_ref().unwrap_or(name),
151        }
152    }
153
154    pub fn max_token_count(&self) -> u64 {
155        match self {
156            Self::CodestralLatest => 256000,
157            Self::MistralLargeLatest => 131000,
158            Self::MistralMediumLatest => 128000,
159            Self::MistralSmallLatest => 32000,
160            Self::MagistralMediumLatest => 40000,
161            Self::MagistralSmallLatest => 40000,
162            Self::OpenMistralNemo => 131000,
163            Self::OpenCodestralMamba => 256000,
164            Self::DevstralMediumLatest => 128000,
165            Self::DevstralSmallLatest => 262144,
166            Self::Pixtral12BLatest => 128000,
167            Self::PixtralLargeLatest => 128000,
168            Self::Custom { max_tokens, .. } => *max_tokens,
169        }
170    }
171
172    pub fn max_output_tokens(&self) -> Option<u64> {
173        match self {
174            Self::Custom {
175                max_output_tokens, ..
176            } => *max_output_tokens,
177            _ => None,
178        }
179    }
180
181    pub fn supports_tools(&self) -> bool {
182        match self {
183            Self::CodestralLatest
184            | Self::MistralLargeLatest
185            | Self::MistralMediumLatest
186            | Self::MistralSmallLatest
187            | Self::MagistralMediumLatest
188            | Self::MagistralSmallLatest
189            | Self::OpenMistralNemo
190            | Self::OpenCodestralMamba
191            | Self::DevstralMediumLatest
192            | Self::DevstralSmallLatest
193            | Self::Pixtral12BLatest
194            | Self::PixtralLargeLatest => true,
195            Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
196        }
197    }
198
199    pub fn supports_images(&self) -> bool {
200        match self {
201            Self::Pixtral12BLatest
202            | Self::PixtralLargeLatest
203            | Self::MistralMediumLatest
204            | Self::MistralSmallLatest => true,
205            Self::CodestralLatest
206            | Self::MistralLargeLatest
207            | Self::MagistralMediumLatest
208            | Self::MagistralSmallLatest
209            | Self::OpenMistralNemo
210            | Self::OpenCodestralMamba
211            | Self::DevstralMediumLatest
212            | Self::DevstralSmallLatest => false,
213            Self::Custom {
214                supports_images, ..
215            } => supports_images.unwrap_or(false),
216        }
217    }
218
219    pub fn supports_thinking(&self) -> bool {
220        match self {
221            Self::MagistralMediumLatest | Self::MagistralSmallLatest => true,
222            Self::Custom {
223                supports_thinking, ..
224            } => supports_thinking.unwrap_or(false),
225            _ => false,
226        }
227    }
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231pub struct Request {
232    pub model: String,
233    pub messages: Vec<RequestMessage>,
234    pub stream: bool,
235    #[serde(default, skip_serializing_if = "Option::is_none")]
236    pub max_tokens: Option<u64>,
237    #[serde(default, skip_serializing_if = "Option::is_none")]
238    pub temperature: Option<f32>,
239    #[serde(default, skip_serializing_if = "Option::is_none")]
240    pub response_format: Option<ResponseFormat>,
241    #[serde(default, skip_serializing_if = "Option::is_none")]
242    pub tool_choice: Option<ToolChoice>,
243    #[serde(default, skip_serializing_if = "Option::is_none")]
244    pub parallel_tool_calls: Option<bool>,
245    #[serde(default, skip_serializing_if = "Vec::is_empty")]
246    pub tools: Vec<ToolDefinition>,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250#[serde(rename_all = "snake_case")]
251pub enum ResponseFormat {
252    Text,
253    #[serde(rename = "json_object")]
254    JsonObject,
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258#[serde(tag = "type", rename_all = "snake_case")]
259pub enum ToolDefinition {
260    Function { function: FunctionDefinition },
261}
262
263#[derive(Debug, Serialize, Deserialize)]
264pub struct FunctionDefinition {
265    pub name: String,
266    pub description: Option<String>,
267    pub parameters: Option<Value>,
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub struct CompletionRequest {
272    pub model: String,
273    pub prompt: String,
274    pub max_tokens: u32,
275    pub temperature: f32,
276    #[serde(default, skip_serializing_if = "Option::is_none")]
277    pub prediction: Option<Prediction>,
278    #[serde(default, skip_serializing_if = "Option::is_none")]
279    pub rewrite_speculation: Option<bool>,
280}
281
282#[derive(Clone, Deserialize, Serialize, Debug)]
283#[serde(tag = "type", rename_all = "snake_case")]
284pub enum Prediction {
285    Content { content: String },
286}
287
288#[derive(Debug, Serialize, Deserialize)]
289#[serde(rename_all = "lowercase")]
290pub enum ToolChoice {
291    Auto,
292    Required,
293    None,
294    Any,
295    #[serde(untagged)]
296    Function(ToolDefinition),
297}
298
299#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
300#[serde(tag = "role", rename_all = "lowercase")]
301pub enum RequestMessage {
302    Assistant {
303        #[serde(flatten)]
304        #[serde(default, skip_serializing_if = "Option::is_none")]
305        content: Option<MessageContent>,
306        #[serde(default, skip_serializing_if = "Vec::is_empty")]
307        tool_calls: Vec<ToolCall>,
308    },
309    User {
310        #[serde(flatten)]
311        content: MessageContent,
312    },
313    System {
314        #[serde(flatten)]
315        content: MessageContent,
316    },
317    Tool {
318        content: String,
319        tool_call_id: String,
320    },
321}
322
323#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
324#[serde(untagged)]
325pub enum MessageContent {
326    #[serde(rename = "content")]
327    Plain { content: String },
328    #[serde(rename = "content")]
329    Multipart { content: Vec<MessagePart> },
330}
331
332impl MessageContent {
333    pub fn empty() -> Self {
334        Self::Plain {
335            content: String::new(),
336        }
337    }
338
339    pub fn push_part(&mut self, part: MessagePart) {
340        match self {
341            Self::Plain { content } => match part {
342                MessagePart::Text { text } => {
343                    content.push_str(&text);
344                }
345                part => {
346                    let mut parts = if content.is_empty() {
347                        Vec::new()
348                    } else {
349                        vec![MessagePart::Text {
350                            text: content.clone(),
351                        }]
352                    };
353                    parts.push(part);
354                    *self = Self::Multipart { content: parts };
355                }
356            },
357            Self::Multipart { content } => {
358                content.push(part);
359            }
360        }
361    }
362}
363
364#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
365#[serde(tag = "type", rename_all = "snake_case")]
366pub enum MessagePart {
367    Text { text: String },
368    ImageUrl { image_url: String },
369    Thinking { thinking: Vec<ThinkingPart> },
370}
371
372// Backwards-compatibility alias for provider code that refers to ContentPart
373pub type ContentPart = MessagePart;
374
375#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
376#[serde(tag = "type", rename_all = "snake_case")]
377pub enum ThinkingPart {
378    Text { text: String },
379}
380
381#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
382pub struct ToolCall {
383    pub id: String,
384    #[serde(flatten)]
385    pub content: ToolCallContent,
386}
387
388#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
389#[serde(tag = "type", rename_all = "lowercase")]
390pub enum ToolCallContent {
391    Function { function: FunctionContent },
392}
393
394#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
395pub struct FunctionContent {
396    pub name: String,
397    pub arguments: String,
398}
399
400#[derive(Serialize, Deserialize, Debug)]
401pub struct CompletionChoice {
402    pub text: String,
403}
404
405#[derive(Serialize, Deserialize, Debug)]
406pub struct Response {
407    pub id: String,
408    pub object: String,
409    pub created: u64,
410    pub model: String,
411    pub choices: Vec<Choice>,
412    pub usage: Usage,
413}
414
415#[derive(Serialize, Deserialize, Debug)]
416pub struct Usage {
417    pub prompt_tokens: u64,
418    pub completion_tokens: u64,
419    pub total_tokens: u64,
420}
421
422#[derive(Serialize, Deserialize, Debug)]
423pub struct Choice {
424    pub index: u32,
425    pub message: RequestMessage,
426    pub finish_reason: Option<String>,
427}
428
429#[derive(Serialize, Deserialize, Debug)]
430pub struct StreamResponse {
431    pub id: String,
432    pub object: String,
433    pub created: u64,
434    pub model: String,
435    pub choices: Vec<StreamChoice>,
436    pub usage: Option<Usage>,
437}
438
439#[derive(Serialize, Deserialize, Debug)]
440pub struct StreamChoice {
441    pub index: u32,
442    pub delta: StreamDelta,
443    pub finish_reason: Option<String>,
444}
445
446#[derive(Serialize, Deserialize, Debug, Clone)]
447pub struct StreamDelta {
448    pub role: Option<Role>,
449    #[serde(default, skip_serializing_if = "Option::is_none")]
450    pub content: Option<MessageContentDelta>,
451    #[serde(default, skip_serializing_if = "Option::is_none")]
452    pub tool_calls: Option<Vec<ToolCallChunk>>,
453}
454
455#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
456#[serde(untagged)]
457pub enum MessageContentDelta {
458    Text(String),
459    Parts(Vec<MessagePart>),
460}
461
462#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
463pub struct ToolCallChunk {
464    pub index: usize,
465    pub id: Option<String>,
466    pub function: Option<FunctionChunk>,
467}
468
469#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
470pub struct FunctionChunk {
471    pub name: Option<String>,
472    pub arguments: Option<String>,
473}
474
475pub async fn stream_completion(
476    client: &dyn HttpClient,
477    api_url: &str,
478    api_key: &str,
479    request: Request,
480) -> Result<BoxStream<'static, Result<StreamResponse>>> {
481    let uri = format!("{api_url}/chat/completions");
482    let request_builder = HttpRequest::builder()
483        .method(Method::POST)
484        .uri(uri)
485        .header("Content-Type", "application/json")
486        .header("Authorization", format!("Bearer {}", api_key.trim()));
487
488    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
489    let mut response = client.send(request).await?;
490
491    if response.status().is_success() {
492        let reader = BufReader::new(response.into_body());
493        Ok(reader
494            .lines()
495            .filter_map(|line| async move {
496                match line {
497                    Ok(line) => {
498                        let line = line.strip_prefix("data: ")?;
499                        if line == "[DONE]" {
500                            None
501                        } else {
502                            match serde_json::from_str(line) {
503                                Ok(response) => Some(Ok(response)),
504                                Err(error) => Some(Err(anyhow!(error))),
505                            }
506                        }
507                    }
508                    Err(error) => Some(Err(anyhow!(error))),
509                }
510            })
511            .boxed())
512    } else {
513        let mut body = String::new();
514        response.body_mut().read_to_string(&mut body).await?;
515        anyhow::bail!(
516            "Failed to connect to Mistral API: {} {}",
517            response.status(),
518            body,
519        );
520    }
521}