mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => anyhow::bail!("invalid role '{value}'"),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    #[default]
 50    CodestralLatest,
 51
 52    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 53    MistralLargeLatest,
 54    #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
 55    MistralMediumLatest,
 56    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 57    MistralSmallLatest,
 58
 59    #[serde(rename = "magistral-medium-latest", alias = "magistral-medium-latest")]
 60    MagistralMediumLatest,
 61    #[serde(rename = "magistral-small-latest", alias = "magistral-small-latest")]
 62    MagistralSmallLatest,
 63
 64    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 65    OpenMistralNemo,
 66    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 67    OpenCodestralMamba,
 68
 69    #[serde(rename = "devstral-medium-latest", alias = "devstral-medium-latest")]
 70    DevstralMediumLatest,
 71    #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
 72    DevstralSmallLatest,
 73
 74    #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
 75    Pixtral12BLatest,
 76    #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
 77    PixtralLargeLatest,
 78
 79    #[serde(rename = "custom")]
 80    Custom {
 81        name: String,
 82        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 83        display_name: Option<String>,
 84        max_tokens: u64,
 85        max_output_tokens: Option<u64>,
 86        max_completion_tokens: Option<u64>,
 87        supports_tools: Option<bool>,
 88        supports_images: Option<bool>,
 89    },
 90}
 91
 92impl Model {
 93    pub fn default_fast() -> Self {
 94        Model::MistralSmallLatest
 95    }
 96
 97    pub fn from_id(id: &str) -> Result<Self> {
 98        match id {
 99            "codestral-latest" => Ok(Self::CodestralLatest),
100            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
101            "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
102            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
103            "magistral-medium-latest" => Ok(Self::MagistralMediumLatest),
104            "magistral-small-latest" => Ok(Self::MagistralSmallLatest),
105            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
106            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
107            "devstral-medium-latest" => Ok(Self::DevstralMediumLatest),
108            "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
109            "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
110            "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
111            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
112        }
113    }
114
115    pub fn id(&self) -> &str {
116        match self {
117            Self::CodestralLatest => "codestral-latest",
118            Self::MistralLargeLatest => "mistral-large-latest",
119            Self::MistralMediumLatest => "mistral-medium-latest",
120            Self::MistralSmallLatest => "mistral-small-latest",
121            Self::MagistralMediumLatest => "magistral-medium-latest",
122            Self::MagistralSmallLatest => "magistral-small-latest",
123            Self::OpenMistralNemo => "open-mistral-nemo",
124            Self::OpenCodestralMamba => "open-codestral-mamba",
125            Self::DevstralMediumLatest => "devstral-medium-latest",
126            Self::DevstralSmallLatest => "devstral-small-latest",
127            Self::Pixtral12BLatest => "pixtral-12b-latest",
128            Self::PixtralLargeLatest => "pixtral-large-latest",
129            Self::Custom { name, .. } => name,
130        }
131    }
132
133    pub fn display_name(&self) -> &str {
134        match self {
135            Self::CodestralLatest => "codestral-latest",
136            Self::MistralLargeLatest => "mistral-large-latest",
137            Self::MistralMediumLatest => "mistral-medium-latest",
138            Self::MistralSmallLatest => "mistral-small-latest",
139            Self::MagistralMediumLatest => "magistral-medium-latest",
140            Self::MagistralSmallLatest => "magistral-small-latest",
141            Self::OpenMistralNemo => "open-mistral-nemo",
142            Self::OpenCodestralMamba => "open-codestral-mamba",
143            Self::DevstralMediumLatest => "devstral-medium-latest",
144            Self::DevstralSmallLatest => "devstral-small-latest",
145            Self::Pixtral12BLatest => "pixtral-12b-latest",
146            Self::PixtralLargeLatest => "pixtral-large-latest",
147            Self::Custom {
148                name, display_name, ..
149            } => display_name.as_ref().unwrap_or(name),
150        }
151    }
152
153    pub fn max_token_count(&self) -> u64 {
154        match self {
155            Self::CodestralLatest => 256000,
156            Self::MistralLargeLatest => 131000,
157            Self::MistralMediumLatest => 128000,
158            Self::MistralSmallLatest => 32000,
159            Self::MagistralMediumLatest => 40000,
160            Self::MagistralSmallLatest => 40000,
161            Self::OpenMistralNemo => 131000,
162            Self::OpenCodestralMamba => 256000,
163            Self::DevstralMediumLatest => 128000,
164            Self::DevstralSmallLatest => 262144,
165            Self::Pixtral12BLatest => 128000,
166            Self::PixtralLargeLatest => 128000,
167            Self::Custom { max_tokens, .. } => *max_tokens,
168        }
169    }
170
171    pub fn max_output_tokens(&self) -> Option<u64> {
172        match self {
173            Self::Custom {
174                max_output_tokens, ..
175            } => *max_output_tokens,
176            _ => None,
177        }
178    }
179
180    pub fn supports_tools(&self) -> bool {
181        match self {
182            Self::CodestralLatest
183            | Self::MistralLargeLatest
184            | Self::MistralMediumLatest
185            | Self::MistralSmallLatest
186            | Self::MagistralMediumLatest
187            | Self::MagistralSmallLatest
188            | Self::OpenMistralNemo
189            | Self::OpenCodestralMamba
190            | Self::DevstralMediumLatest
191            | Self::DevstralSmallLatest
192            | Self::Pixtral12BLatest
193            | Self::PixtralLargeLatest => true,
194            Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
195        }
196    }
197
198    pub fn supports_images(&self) -> bool {
199        match self {
200            Self::Pixtral12BLatest
201            | Self::PixtralLargeLatest
202            | Self::MistralMediumLatest
203            | Self::MistralSmallLatest => true,
204            Self::CodestralLatest
205            | Self::MistralLargeLatest
206            | Self::MagistralMediumLatest
207            | Self::MagistralSmallLatest
208            | Self::OpenMistralNemo
209            | Self::OpenCodestralMamba
210            | Self::DevstralMediumLatest
211            | Self::DevstralSmallLatest => false,
212            Self::Custom {
213                supports_images, ..
214            } => supports_images.unwrap_or(false),
215        }
216    }
217}
218
219#[derive(Debug, Serialize, Deserialize)]
220pub struct Request {
221    pub model: String,
222    pub messages: Vec<RequestMessage>,
223    pub stream: bool,
224    #[serde(default, skip_serializing_if = "Option::is_none")]
225    pub max_tokens: Option<u64>,
226    #[serde(default, skip_serializing_if = "Option::is_none")]
227    pub temperature: Option<f32>,
228    #[serde(default, skip_serializing_if = "Option::is_none")]
229    pub response_format: Option<ResponseFormat>,
230    #[serde(default, skip_serializing_if = "Option::is_none")]
231    pub tool_choice: Option<ToolChoice>,
232    #[serde(default, skip_serializing_if = "Option::is_none")]
233    pub parallel_tool_calls: Option<bool>,
234    #[serde(default, skip_serializing_if = "Vec::is_empty")]
235    pub tools: Vec<ToolDefinition>,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239#[serde(rename_all = "snake_case")]
240pub enum ResponseFormat {
241    Text,
242    #[serde(rename = "json_object")]
243    JsonObject,
244}
245
246#[derive(Debug, Serialize, Deserialize)]
247#[serde(tag = "type", rename_all = "snake_case")]
248pub enum ToolDefinition {
249    Function { function: FunctionDefinition },
250}
251
252#[derive(Debug, Serialize, Deserialize)]
253pub struct FunctionDefinition {
254    pub name: String,
255    pub description: Option<String>,
256    pub parameters: Option<Value>,
257}
258
259#[derive(Debug, Serialize, Deserialize)]
260pub struct CompletionRequest {
261    pub model: String,
262    pub prompt: String,
263    pub max_tokens: u32,
264    pub temperature: f32,
265    #[serde(default, skip_serializing_if = "Option::is_none")]
266    pub prediction: Option<Prediction>,
267    #[serde(default, skip_serializing_if = "Option::is_none")]
268    pub rewrite_speculation: Option<bool>,
269}
270
271#[derive(Clone, Deserialize, Serialize, Debug)]
272#[serde(tag = "type", rename_all = "snake_case")]
273pub enum Prediction {
274    Content { content: String },
275}
276
277#[derive(Debug, Serialize, Deserialize)]
278#[serde(rename_all = "snake_case")]
279pub enum ToolChoice {
280    Auto,
281    Required,
282    None,
283    Any,
284    Function(ToolDefinition),
285}
286
287#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
288#[serde(tag = "role", rename_all = "lowercase")]
289pub enum RequestMessage {
290    Assistant {
291        content: Option<String>,
292        #[serde(default, skip_serializing_if = "Vec::is_empty")]
293        tool_calls: Vec<ToolCall>,
294    },
295    User {
296        #[serde(flatten)]
297        content: MessageContent,
298    },
299    System {
300        content: String,
301    },
302    Tool {
303        content: String,
304        tool_call_id: String,
305    },
306}
307
308#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
309#[serde(untagged)]
310pub enum MessageContent {
311    #[serde(rename = "content")]
312    Plain { content: String },
313    #[serde(rename = "content")]
314    Multipart { content: Vec<MessagePart> },
315}
316
317impl MessageContent {
318    pub fn empty() -> Self {
319        Self::Plain {
320            content: String::new(),
321        }
322    }
323
324    pub fn push_part(&mut self, part: MessagePart) {
325        match self {
326            Self::Plain { content } => match part {
327                MessagePart::Text { text } => {
328                    content.push_str(&text);
329                }
330                part => {
331                    let mut parts = if content.is_empty() {
332                        Vec::new()
333                    } else {
334                        vec![MessagePart::Text {
335                            text: content.clone(),
336                        }]
337                    };
338                    parts.push(part);
339                    *self = Self::Multipart { content: parts };
340                }
341            },
342            Self::Multipart { content } => {
343                content.push(part);
344            }
345        }
346    }
347}
348
349#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
350#[serde(tag = "type", rename_all = "snake_case")]
351pub enum MessagePart {
352    Text { text: String },
353    ImageUrl { image_url: String },
354}
355
356#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
357pub struct ToolCall {
358    pub id: String,
359    #[serde(flatten)]
360    pub content: ToolCallContent,
361}
362
363#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
364#[serde(tag = "type", rename_all = "lowercase")]
365pub enum ToolCallContent {
366    Function { function: FunctionContent },
367}
368
369#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
370pub struct FunctionContent {
371    pub name: String,
372    pub arguments: String,
373}
374
375#[derive(Serialize, Deserialize, Debug)]
376pub struct CompletionChoice {
377    pub text: String,
378}
379
380#[derive(Serialize, Deserialize, Debug)]
381pub struct Response {
382    pub id: String,
383    pub object: String,
384    pub created: u64,
385    pub model: String,
386    pub choices: Vec<Choice>,
387    pub usage: Usage,
388}
389
390#[derive(Serialize, Deserialize, Debug)]
391pub struct Usage {
392    pub prompt_tokens: u64,
393    pub completion_tokens: u64,
394    pub total_tokens: u64,
395}
396
397#[derive(Serialize, Deserialize, Debug)]
398pub struct Choice {
399    pub index: u32,
400    pub message: RequestMessage,
401    pub finish_reason: Option<String>,
402}
403
404#[derive(Serialize, Deserialize, Debug)]
405pub struct StreamResponse {
406    pub id: String,
407    pub object: String,
408    pub created: u64,
409    pub model: String,
410    pub choices: Vec<StreamChoice>,
411    pub usage: Option<Usage>,
412}
413
414#[derive(Serialize, Deserialize, Debug)]
415pub struct StreamChoice {
416    pub index: u32,
417    pub delta: StreamDelta,
418    pub finish_reason: Option<String>,
419}
420
421#[derive(Serialize, Deserialize, Debug)]
422pub struct StreamDelta {
423    pub role: Option<Role>,
424    pub content: Option<String>,
425    #[serde(default, skip_serializing_if = "Option::is_none")]
426    pub tool_calls: Option<Vec<ToolCallChunk>>,
427    #[serde(default, skip_serializing_if = "Option::is_none")]
428    pub reasoning_content: Option<String>,
429}
430
431#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
432pub struct ToolCallChunk {
433    pub index: usize,
434    pub id: Option<String>,
435    pub function: Option<FunctionChunk>,
436}
437
438#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
439pub struct FunctionChunk {
440    pub name: Option<String>,
441    pub arguments: Option<String>,
442}
443
444pub async fn stream_completion(
445    client: &dyn HttpClient,
446    api_url: &str,
447    api_key: &str,
448    request: Request,
449) -> Result<BoxStream<'static, Result<StreamResponse>>> {
450    let uri = format!("{api_url}/chat/completions");
451    let request_builder = HttpRequest::builder()
452        .method(Method::POST)
453        .uri(uri)
454        .header("Content-Type", "application/json")
455        .header("Authorization", format!("Bearer {}", api_key));
456
457    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
458    let mut response = client.send(request).await?;
459
460    if response.status().is_success() {
461        let reader = BufReader::new(response.into_body());
462        Ok(reader
463            .lines()
464            .filter_map(|line| async move {
465                match line {
466                    Ok(line) => {
467                        let line = line.strip_prefix("data: ")?;
468                        if line == "[DONE]" {
469                            None
470                        } else {
471                            match serde_json::from_str(line) {
472                                Ok(response) => Some(Ok(response)),
473                                Err(error) => Some(Err(anyhow!(error))),
474                            }
475                        }
476                    }
477                    Err(error) => Some(Err(anyhow!(error))),
478                }
479            })
480            .boxed())
481    } else {
482        let mut body = String::new();
483        response.body_mut().read_to_string(&mut body).await?;
484        anyhow::bail!(
485            "Failed to connect to Mistral API: {} {}",
486            response.status(),
487            body,
488        );
489    }
490}