mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
 11
 12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 13#[serde(rename_all = "lowercase")]
 14pub enum Role {
 15    User,
 16    Assistant,
 17    System,
 18    Tool,
 19}
 20
 21impl TryFrom<String> for Role {
 22    type Error = anyhow::Error;
 23
 24    fn try_from(value: String) -> Result<Self> {
 25        match value.as_str() {
 26            "user" => Ok(Self::User),
 27            "assistant" => Ok(Self::Assistant),
 28            "system" => Ok(Self::System),
 29            "tool" => Ok(Self::Tool),
 30            _ => anyhow::bail!("invalid role '{value}'"),
 31        }
 32    }
 33}
 34
 35impl From<Role> for String {
 36    fn from(val: Role) -> Self {
 37        match val {
 38            Role::User => "user".to_owned(),
 39            Role::Assistant => "assistant".to_owned(),
 40            Role::System => "system".to_owned(),
 41            Role::Tool => "tool".to_owned(),
 42        }
 43    }
 44}
 45
 46#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 47#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 48pub enum Model {
 49    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 50    #[default]
 51    CodestralLatest,
 52
 53    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 54    MistralLargeLatest,
 55    #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
 56    MistralMediumLatest,
 57    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 58    MistralSmallLatest,
 59
 60    #[serde(rename = "magistral-medium-latest", alias = "magistral-medium-latest")]
 61    MagistralMediumLatest,
 62    #[serde(rename = "magistral-small-latest", alias = "magistral-small-latest")]
 63    MagistralSmallLatest,
 64
 65    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 66    OpenMistralNemo,
 67    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 68    OpenCodestralMamba,
 69
 70    #[serde(rename = "devstral-medium-latest", alias = "devstral-medium-latest")]
 71    DevstralMediumLatest,
 72    #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
 73    DevstralSmallLatest,
 74
 75    #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
 76    Pixtral12BLatest,
 77    #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
 78    PixtralLargeLatest,
 79
 80    #[serde(rename = "custom")]
 81    Custom {
 82        name: String,
 83        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 84        display_name: Option<String>,
 85        max_tokens: u64,
 86        max_output_tokens: Option<u64>,
 87        max_completion_tokens: Option<u64>,
 88        supports_tools: Option<bool>,
 89        supports_images: Option<bool>,
 90        supports_thinking: Option<bool>,
 91    },
 92}
 93
 94impl Model {
 95    pub fn default_fast() -> Self {
 96        Model::MistralSmallLatest
 97    }
 98
 99    pub fn from_id(id: &str) -> Result<Self> {
100        match id {
101            "codestral-latest" => Ok(Self::CodestralLatest),
102            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
103            "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
104            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
105            "magistral-medium-latest" => Ok(Self::MagistralMediumLatest),
106            "magistral-small-latest" => Ok(Self::MagistralSmallLatest),
107            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
108            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
109            "devstral-medium-latest" => Ok(Self::DevstralMediumLatest),
110            "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
111            "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
112            "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
113            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
114        }
115    }
116
117    pub fn id(&self) -> &str {
118        match self {
119            Self::CodestralLatest => "codestral-latest",
120            Self::MistralLargeLatest => "mistral-large-latest",
121            Self::MistralMediumLatest => "mistral-medium-latest",
122            Self::MistralSmallLatest => "mistral-small-latest",
123            Self::MagistralMediumLatest => "magistral-medium-latest",
124            Self::MagistralSmallLatest => "magistral-small-latest",
125            Self::OpenMistralNemo => "open-mistral-nemo",
126            Self::OpenCodestralMamba => "open-codestral-mamba",
127            Self::DevstralMediumLatest => "devstral-medium-latest",
128            Self::DevstralSmallLatest => "devstral-small-latest",
129            Self::Pixtral12BLatest => "pixtral-12b-latest",
130            Self::PixtralLargeLatest => "pixtral-large-latest",
131            Self::Custom { name, .. } => name,
132        }
133    }
134
135    pub fn display_name(&self) -> &str {
136        match self {
137            Self::CodestralLatest => "codestral-latest",
138            Self::MistralLargeLatest => "mistral-large-latest",
139            Self::MistralMediumLatest => "mistral-medium-latest",
140            Self::MistralSmallLatest => "mistral-small-latest",
141            Self::MagistralMediumLatest => "magistral-medium-latest",
142            Self::MagistralSmallLatest => "magistral-small-latest",
143            Self::OpenMistralNemo => "open-mistral-nemo",
144            Self::OpenCodestralMamba => "open-codestral-mamba",
145            Self::DevstralMediumLatest => "devstral-medium-latest",
146            Self::DevstralSmallLatest => "devstral-small-latest",
147            Self::Pixtral12BLatest => "pixtral-12b-latest",
148            Self::PixtralLargeLatest => "pixtral-large-latest",
149            Self::Custom {
150                name, display_name, ..
151            } => display_name.as_ref().unwrap_or(name),
152        }
153    }
154
155    pub fn max_token_count(&self) -> u64 {
156        match self {
157            Self::CodestralLatest => 256000,
158            Self::MistralLargeLatest => 131000,
159            Self::MistralMediumLatest => 128000,
160            Self::MistralSmallLatest => 32000,
161            Self::MagistralMediumLatest => 40000,
162            Self::MagistralSmallLatest => 40000,
163            Self::OpenMistralNemo => 131000,
164            Self::OpenCodestralMamba => 256000,
165            Self::DevstralMediumLatest => 128000,
166            Self::DevstralSmallLatest => 262144,
167            Self::Pixtral12BLatest => 128000,
168            Self::PixtralLargeLatest => 128000,
169            Self::Custom { max_tokens, .. } => *max_tokens,
170        }
171    }
172
173    pub fn max_output_tokens(&self) -> Option<u64> {
174        match self {
175            Self::Custom {
176                max_output_tokens, ..
177            } => *max_output_tokens,
178            _ => None,
179        }
180    }
181
182    pub fn supports_tools(&self) -> bool {
183        match self {
184            Self::CodestralLatest
185            | Self::MistralLargeLatest
186            | Self::MistralMediumLatest
187            | Self::MistralSmallLatest
188            | Self::MagistralMediumLatest
189            | Self::MagistralSmallLatest
190            | Self::OpenMistralNemo
191            | Self::OpenCodestralMamba
192            | Self::DevstralMediumLatest
193            | Self::DevstralSmallLatest
194            | Self::Pixtral12BLatest
195            | Self::PixtralLargeLatest => true,
196            Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
197        }
198    }
199
200    pub fn supports_images(&self) -> bool {
201        match self {
202            Self::Pixtral12BLatest
203            | Self::PixtralLargeLatest
204            | Self::MistralMediumLatest
205            | Self::MistralSmallLatest => true,
206            Self::CodestralLatest
207            | Self::MistralLargeLatest
208            | Self::MagistralMediumLatest
209            | Self::MagistralSmallLatest
210            | Self::OpenMistralNemo
211            | Self::OpenCodestralMamba
212            | Self::DevstralMediumLatest
213            | Self::DevstralSmallLatest => false,
214            Self::Custom {
215                supports_images, ..
216            } => supports_images.unwrap_or(false),
217        }
218    }
219
220    pub fn supports_thinking(&self) -> bool {
221        match self {
222            Self::MagistralMediumLatest | Self::MagistralSmallLatest => true,
223            Self::Custom {
224                supports_thinking, ..
225            } => supports_thinking.unwrap_or(false),
226            _ => false,
227        }
228    }
229}
230
231#[derive(Debug, Serialize, Deserialize)]
232pub struct Request {
233    pub model: String,
234    pub messages: Vec<RequestMessage>,
235    pub stream: bool,
236    #[serde(default, skip_serializing_if = "Option::is_none")]
237    pub max_tokens: Option<u64>,
238    #[serde(default, skip_serializing_if = "Option::is_none")]
239    pub temperature: Option<f32>,
240    #[serde(default, skip_serializing_if = "Option::is_none")]
241    pub response_format: Option<ResponseFormat>,
242    #[serde(default, skip_serializing_if = "Option::is_none")]
243    pub tool_choice: Option<ToolChoice>,
244    #[serde(default, skip_serializing_if = "Option::is_none")]
245    pub parallel_tool_calls: Option<bool>,
246    #[serde(default, skip_serializing_if = "Vec::is_empty")]
247    pub tools: Vec<ToolDefinition>,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251#[serde(rename_all = "snake_case")]
252pub enum ResponseFormat {
253    Text,
254    #[serde(rename = "json_object")]
255    JsonObject,
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259#[serde(tag = "type", rename_all = "snake_case")]
260pub enum ToolDefinition {
261    Function { function: FunctionDefinition },
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265pub struct FunctionDefinition {
266    pub name: String,
267    pub description: Option<String>,
268    pub parameters: Option<Value>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272#[serde(rename_all = "lowercase")]
273pub enum ToolChoice {
274    Auto,
275    Required,
276    None,
277    Any,
278    #[serde(untagged)]
279    Function(ToolDefinition),
280}
281
282#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
283#[serde(tag = "role", rename_all = "lowercase")]
284pub enum RequestMessage {
285    Assistant {
286        #[serde(flatten)]
287        #[serde(default, skip_serializing_if = "Option::is_none")]
288        content: Option<MessageContent>,
289        #[serde(default, skip_serializing_if = "Vec::is_empty")]
290        tool_calls: Vec<ToolCall>,
291    },
292    User {
293        #[serde(flatten)]
294        content: MessageContent,
295    },
296    System {
297        #[serde(flatten)]
298        content: MessageContent,
299    },
300    Tool {
301        content: String,
302        tool_call_id: String,
303    },
304}
305
306#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
307#[serde(untagged)]
308pub enum MessageContent {
309    #[serde(rename = "content")]
310    Plain { content: String },
311    #[serde(rename = "content")]
312    Multipart { content: Vec<MessagePart> },
313}
314
315impl MessageContent {
316    pub fn empty() -> Self {
317        Self::Plain {
318            content: String::new(),
319        }
320    }
321
322    pub fn push_part(&mut self, part: MessagePart) {
323        match self {
324            Self::Plain { content } => match part {
325                MessagePart::Text { text } => {
326                    content.push_str(&text);
327                }
328                part => {
329                    let mut parts = if content.is_empty() {
330                        Vec::new()
331                    } else {
332                        vec![MessagePart::Text {
333                            text: content.clone(),
334                        }]
335                    };
336                    parts.push(part);
337                    *self = Self::Multipart { content: parts };
338                }
339            },
340            Self::Multipart { content } => {
341                content.push(part);
342            }
343        }
344    }
345}
346
347#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
348#[serde(tag = "type", rename_all = "snake_case")]
349pub enum MessagePart {
350    Text { text: String },
351    ImageUrl { image_url: String },
352    Thinking { thinking: Vec<ThinkingPart> },
353}
354
355// Backwards-compatibility alias for provider code that refers to ContentPart
356pub type ContentPart = MessagePart;
357
358#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
359#[serde(tag = "type", rename_all = "snake_case")]
360pub enum ThinkingPart {
361    Text { text: String },
362}
363
364#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
365pub struct ToolCall {
366    pub id: String,
367    #[serde(flatten)]
368    pub content: ToolCallContent,
369}
370
371#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
372#[serde(tag = "type", rename_all = "lowercase")]
373pub enum ToolCallContent {
374    Function { function: FunctionContent },
375}
376
377#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
378pub struct FunctionContent {
379    pub name: String,
380    pub arguments: String,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct Usage {
385    pub prompt_tokens: u64,
386    pub completion_tokens: u64,
387    pub total_tokens: u64,
388}
389
390#[derive(Serialize, Deserialize, Debug)]
391pub struct StreamResponse {
392    pub id: String,
393    pub object: String,
394    pub created: u64,
395    pub model: String,
396    pub choices: Vec<StreamChoice>,
397    pub usage: Option<Usage>,
398}
399
400#[derive(Serialize, Deserialize, Debug)]
401pub struct StreamChoice {
402    pub index: u32,
403    pub delta: StreamDelta,
404    pub finish_reason: Option<String>,
405}
406
407#[derive(Serialize, Deserialize, Debug, Clone)]
408pub struct StreamDelta {
409    pub role: Option<Role>,
410    #[serde(default, skip_serializing_if = "Option::is_none")]
411    pub content: Option<MessageContentDelta>,
412    #[serde(default, skip_serializing_if = "Option::is_none")]
413    pub tool_calls: Option<Vec<ToolCallChunk>>,
414}
415
416#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
417#[serde(untagged)]
418pub enum MessageContentDelta {
419    Text(String),
420    Parts(Vec<MessagePart>),
421}
422
423#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
424pub struct ToolCallChunk {
425    pub index: usize,
426    pub id: Option<String>,
427    pub function: Option<FunctionChunk>,
428}
429
430#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
431pub struct FunctionChunk {
432    pub name: Option<String>,
433    pub arguments: Option<String>,
434}
435
436pub async fn stream_completion(
437    client: &dyn HttpClient,
438    api_url: &str,
439    api_key: &str,
440    request: Request,
441) -> Result<BoxStream<'static, Result<StreamResponse>>> {
442    let uri = format!("{api_url}/chat/completions");
443    let request_builder = HttpRequest::builder()
444        .method(Method::POST)
445        .uri(uri)
446        .header("Content-Type", "application/json")
447        .header("Authorization", format!("Bearer {}", api_key.trim()));
448
449    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
450    let mut response = client.send(request).await?;
451
452    if response.status().is_success() {
453        let reader = BufReader::new(response.into_body());
454        Ok(reader
455            .lines()
456            .filter_map(|line| async move {
457                match line {
458                    Ok(line) => {
459                        let line = line.strip_prefix("data: ")?;
460                        if line == "[DONE]" {
461                            None
462                        } else {
463                            match serde_json::from_str(line) {
464                                Ok(response) => Some(Ok(response)),
465                                Err(error) => Some(Err(anyhow!(error))),
466                            }
467                        }
468                    }
469                    Err(error) => Some(Err(anyhow!(error))),
470                }
471            })
472            .boxed())
473    } else {
474        let mut body = String::new();
475        response.body_mut().read_to_string(&mut body).await?;
476        anyhow::bail!(
477            "Failed to connect to Mistral API: {} {}",
478            response.status(),
479            body,
480        );
481    }
482}