mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => anyhow::bail!("invalid role '{value}'"),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    #[default]
 50    CodestralLatest,
 51    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 52    MistralLargeLatest,
 53    #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
 54    MistralMediumLatest,
 55    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 56    MistralSmallLatest,
 57    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 58    OpenMistralNemo,
 59    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 60    OpenCodestralMamba,
 61    #[serde(rename = "devstral-medium-latest", alias = "devstral-medium-latest")]
 62    DevstralMediumLatest,
 63    #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
 64    DevstralSmallLatest,
 65    #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
 66    Pixtral12BLatest,
 67    #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
 68    PixtralLargeLatest,
 69
 70    #[serde(rename = "custom")]
 71    Custom {
 72        name: String,
 73        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 74        display_name: Option<String>,
 75        max_tokens: u64,
 76        max_output_tokens: Option<u64>,
 77        max_completion_tokens: Option<u64>,
 78        supports_tools: Option<bool>,
 79        supports_images: Option<bool>,
 80    },
 81}
 82
 83impl Model {
 84    pub fn default_fast() -> Self {
 85        Model::MistralSmallLatest
 86    }
 87
 88    pub fn from_id(id: &str) -> Result<Self> {
 89        match id {
 90            "codestral-latest" => Ok(Self::CodestralLatest),
 91            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
 92            "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
 93            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
 94            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
 95            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
 96            "devstral-medium-latest" => Ok(Self::DevstralMediumLatest),
 97            "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
 98            "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
 99            "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
100            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
101        }
102    }
103
104    pub fn id(&self) -> &str {
105        match self {
106            Self::CodestralLatest => "codestral-latest",
107            Self::MistralLargeLatest => "mistral-large-latest",
108            Self::MistralMediumLatest => "mistral-medium-latest",
109            Self::MistralSmallLatest => "mistral-small-latest",
110            Self::OpenMistralNemo => "open-mistral-nemo",
111            Self::OpenCodestralMamba => "open-codestral-mamba",
112            Self::DevstralMediumLatest => "devstral-medium-latest",
113            Self::DevstralSmallLatest => "devstral-small-latest",
114            Self::Pixtral12BLatest => "pixtral-12b-latest",
115            Self::PixtralLargeLatest => "pixtral-large-latest",
116            Self::Custom { name, .. } => name,
117        }
118    }
119
120    pub fn display_name(&self) -> &str {
121        match self {
122            Self::CodestralLatest => "codestral-latest",
123            Self::MistralLargeLatest => "mistral-large-latest",
124            Self::MistralMediumLatest => "mistral-medium-latest",
125            Self::MistralSmallLatest => "mistral-small-latest",
126            Self::OpenMistralNemo => "open-mistral-nemo",
127            Self::OpenCodestralMamba => "open-codestral-mamba",
128            Self::DevstralMediumLatest => "devstral-medium-latest",
129            Self::DevstralSmallLatest => "devstral-small-latest",
130            Self::Pixtral12BLatest => "pixtral-12b-latest",
131            Self::PixtralLargeLatest => "pixtral-large-latest",
132            Self::Custom {
133                name, display_name, ..
134            } => display_name.as_ref().unwrap_or(name),
135        }
136    }
137
138    pub fn max_token_count(&self) -> u64 {
139        match self {
140            Self::CodestralLatest => 256000,
141            Self::MistralLargeLatest => 131000,
142            Self::MistralMediumLatest => 128000,
143            Self::MistralSmallLatest => 32000,
144            Self::OpenMistralNemo => 131000,
145            Self::OpenCodestralMamba => 256000,
146            Self::DevstralMediumLatest => 128000,
147            Self::DevstralSmallLatest => 262144,
148            Self::Pixtral12BLatest => 128000,
149            Self::PixtralLargeLatest => 128000,
150            Self::Custom { max_tokens, .. } => *max_tokens,
151        }
152    }
153
154    pub fn max_output_tokens(&self) -> Option<u64> {
155        match self {
156            Self::Custom {
157                max_output_tokens, ..
158            } => *max_output_tokens,
159            _ => None,
160        }
161    }
162
163    pub fn supports_tools(&self) -> bool {
164        match self {
165            Self::CodestralLatest
166            | Self::MistralLargeLatest
167            | Self::MistralMediumLatest
168            | Self::MistralSmallLatest
169            | Self::OpenMistralNemo
170            | Self::OpenCodestralMamba
171            | Self::DevstralMediumLatest
172            | Self::DevstralSmallLatest
173            | Self::Pixtral12BLatest
174            | Self::PixtralLargeLatest => true,
175            Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
176        }
177    }
178
179    pub fn supports_images(&self) -> bool {
180        match self {
181            Self::Pixtral12BLatest
182            | Self::PixtralLargeLatest
183            | Self::MistralMediumLatest
184            | Self::MistralSmallLatest => true,
185            Self::CodestralLatest
186            | Self::MistralLargeLatest
187            | Self::OpenMistralNemo
188            | Self::OpenCodestralMamba
189            | Self::DevstralMediumLatest
190            | Self::DevstralSmallLatest => false,
191            Self::Custom {
192                supports_images, ..
193            } => supports_images.unwrap_or(false),
194        }
195    }
196}
197
198#[derive(Debug, Serialize, Deserialize)]
199pub struct Request {
200    pub model: String,
201    pub messages: Vec<RequestMessage>,
202    pub stream: bool,
203    #[serde(default, skip_serializing_if = "Option::is_none")]
204    pub max_tokens: Option<u64>,
205    #[serde(default, skip_serializing_if = "Option::is_none")]
206    pub temperature: Option<f32>,
207    #[serde(default, skip_serializing_if = "Option::is_none")]
208    pub response_format: Option<ResponseFormat>,
209    #[serde(default, skip_serializing_if = "Option::is_none")]
210    pub tool_choice: Option<ToolChoice>,
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub parallel_tool_calls: Option<bool>,
213    #[serde(default, skip_serializing_if = "Vec::is_empty")]
214    pub tools: Vec<ToolDefinition>,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(rename_all = "snake_case")]
219pub enum ResponseFormat {
220    Text,
221    #[serde(rename = "json_object")]
222    JsonObject,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226#[serde(tag = "type", rename_all = "snake_case")]
227pub enum ToolDefinition {
228    Function { function: FunctionDefinition },
229}
230
231#[derive(Debug, Serialize, Deserialize)]
232pub struct FunctionDefinition {
233    pub name: String,
234    pub description: Option<String>,
235    pub parameters: Option<Value>,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239pub struct CompletionRequest {
240    pub model: String,
241    pub prompt: String,
242    pub max_tokens: u32,
243    pub temperature: f32,
244    #[serde(default, skip_serializing_if = "Option::is_none")]
245    pub prediction: Option<Prediction>,
246    #[serde(default, skip_serializing_if = "Option::is_none")]
247    pub rewrite_speculation: Option<bool>,
248}
249
250#[derive(Clone, Deserialize, Serialize, Debug)]
251#[serde(tag = "type", rename_all = "snake_case")]
252pub enum Prediction {
253    Content { content: String },
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257#[serde(rename_all = "snake_case")]
258pub enum ToolChoice {
259    Auto,
260    Required,
261    None,
262    Any,
263    Function(ToolDefinition),
264}
265
266#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
267#[serde(tag = "role", rename_all = "lowercase")]
268pub enum RequestMessage {
269    Assistant {
270        content: Option<String>,
271        #[serde(default, skip_serializing_if = "Vec::is_empty")]
272        tool_calls: Vec<ToolCall>,
273    },
274    User {
275        #[serde(flatten)]
276        content: MessageContent,
277    },
278    System {
279        content: String,
280    },
281    Tool {
282        content: String,
283        tool_call_id: String,
284    },
285}
286
287#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
288#[serde(untagged)]
289pub enum MessageContent {
290    #[serde(rename = "content")]
291    Plain { content: String },
292    #[serde(rename = "content")]
293    Multipart { content: Vec<MessagePart> },
294}
295
296impl MessageContent {
297    pub fn empty() -> Self {
298        Self::Plain {
299            content: String::new(),
300        }
301    }
302
303    pub fn push_part(&mut self, part: MessagePart) {
304        match self {
305            Self::Plain { content } => match part {
306                MessagePart::Text { text } => {
307                    content.push_str(&text);
308                }
309                part => {
310                    let mut parts = if content.is_empty() {
311                        Vec::new()
312                    } else {
313                        vec![MessagePart::Text {
314                            text: content.clone(),
315                        }]
316                    };
317                    parts.push(part);
318                    *self = Self::Multipart { content: parts };
319                }
320            },
321            Self::Multipart { content } => {
322                content.push(part);
323            }
324        }
325    }
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329#[serde(tag = "type", rename_all = "snake_case")]
330pub enum MessagePart {
331    Text { text: String },
332    ImageUrl { image_url: String },
333}
334
335#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
336pub struct ToolCall {
337    pub id: String,
338    #[serde(flatten)]
339    pub content: ToolCallContent,
340}
341
342#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
343#[serde(tag = "type", rename_all = "lowercase")]
344pub enum ToolCallContent {
345    Function { function: FunctionContent },
346}
347
348#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
349pub struct FunctionContent {
350    pub name: String,
351    pub arguments: String,
352}
353
354#[derive(Serialize, Deserialize, Debug)]
355pub struct CompletionChoice {
356    pub text: String,
357}
358
359#[derive(Serialize, Deserialize, Debug)]
360pub struct Response {
361    pub id: String,
362    pub object: String,
363    pub created: u64,
364    pub model: String,
365    pub choices: Vec<Choice>,
366    pub usage: Usage,
367}
368
369#[derive(Serialize, Deserialize, Debug)]
370pub struct Usage {
371    pub prompt_tokens: u64,
372    pub completion_tokens: u64,
373    pub total_tokens: u64,
374}
375
376#[derive(Serialize, Deserialize, Debug)]
377pub struct Choice {
378    pub index: u32,
379    pub message: RequestMessage,
380    pub finish_reason: Option<String>,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct StreamResponse {
385    pub id: String,
386    pub object: String,
387    pub created: u64,
388    pub model: String,
389    pub choices: Vec<StreamChoice>,
390    pub usage: Option<Usage>,
391}
392
393#[derive(Serialize, Deserialize, Debug)]
394pub struct StreamChoice {
395    pub index: u32,
396    pub delta: StreamDelta,
397    pub finish_reason: Option<String>,
398}
399
400#[derive(Serialize, Deserialize, Debug)]
401pub struct StreamDelta {
402    pub role: Option<Role>,
403    pub content: Option<String>,
404    #[serde(default, skip_serializing_if = "Option::is_none")]
405    pub tool_calls: Option<Vec<ToolCallChunk>>,
406    #[serde(default, skip_serializing_if = "Option::is_none")]
407    pub reasoning_content: Option<String>,
408}
409
410#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
411pub struct ToolCallChunk {
412    pub index: usize,
413    pub id: Option<String>,
414    pub function: Option<FunctionChunk>,
415}
416
417#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
418pub struct FunctionChunk {
419    pub name: Option<String>,
420    pub arguments: Option<String>,
421}
422
423pub async fn stream_completion(
424    client: &dyn HttpClient,
425    api_url: &str,
426    api_key: &str,
427    request: Request,
428) -> Result<BoxStream<'static, Result<StreamResponse>>> {
429    let uri = format!("{api_url}/chat/completions");
430    let request_builder = HttpRequest::builder()
431        .method(Method::POST)
432        .uri(uri)
433        .header("Content-Type", "application/json")
434        .header("Authorization", format!("Bearer {}", api_key));
435
436    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
437    let mut response = client.send(request).await?;
438
439    if response.status().is_success() {
440        let reader = BufReader::new(response.into_body());
441        Ok(reader
442            .lines()
443            .filter_map(|line| async move {
444                match line {
445                    Ok(line) => {
446                        let line = line.strip_prefix("data: ")?;
447                        if line == "[DONE]" {
448                            None
449                        } else {
450                            match serde_json::from_str(line) {
451                                Ok(response) => Some(Ok(response)),
452                                Err(error) => Some(Err(anyhow!(error))),
453                            }
454                        }
455                    }
456                    Err(error) => Some(Err(anyhow!(error))),
457                }
458            })
459            .boxed())
460    } else {
461        let mut body = String::new();
462        response.body_mut().read_to_string(&mut body).await?;
463        anyhow::bail!(
464            "Failed to connect to Mistral API: {} {}",
465            response.status(),
466            body,
467        );
468    }
469}