mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => anyhow::bail!("invalid role '{value}'"),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    #[default]
 50    CodestralLatest,
 51    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 52    MistralLargeLatest,
 53    #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
 54    MistralMediumLatest,
 55    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 56    MistralSmallLatest,
 57    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 58    OpenMistralNemo,
 59    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 60    OpenCodestralMamba,
 61    #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
 62    DevstralSmallLatest,
 63    #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
 64    Pixtral12BLatest,
 65    #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
 66    PixtralLargeLatest,
 67
 68    #[serde(rename = "custom")]
 69    Custom {
 70        name: String,
 71        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 72        display_name: Option<String>,
 73        max_tokens: u64,
 74        max_output_tokens: Option<u64>,
 75        max_completion_tokens: Option<u64>,
 76        supports_tools: Option<bool>,
 77        supports_images: Option<bool>,
 78    },
 79}
 80
 81impl Model {
 82    pub fn default_fast() -> Self {
 83        Model::MistralSmallLatest
 84    }
 85
 86    pub fn from_id(id: &str) -> Result<Self> {
 87        match id {
 88            "codestral-latest" => Ok(Self::CodestralLatest),
 89            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
 90            "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
 91            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
 92            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
 93            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
 94            "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
 95            "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
 96            "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
 97            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
 98        }
 99    }
100
101    pub fn id(&self) -> &str {
102        match self {
103            Self::CodestralLatest => "codestral-latest",
104            Self::MistralLargeLatest => "mistral-large-latest",
105            Self::MistralMediumLatest => "mistral-medium-latest",
106            Self::MistralSmallLatest => "mistral-small-latest",
107            Self::OpenMistralNemo => "open-mistral-nemo",
108            Self::OpenCodestralMamba => "open-codestral-mamba",
109            Self::DevstralSmallLatest => "devstral-small-latest",
110            Self::Pixtral12BLatest => "pixtral-12b-latest",
111            Self::PixtralLargeLatest => "pixtral-large-latest",
112            Self::Custom { name, .. } => name,
113        }
114    }
115
116    pub fn display_name(&self) -> &str {
117        match self {
118            Self::CodestralLatest => "codestral-latest",
119            Self::MistralLargeLatest => "mistral-large-latest",
120            Self::MistralMediumLatest => "mistral-medium-latest",
121            Self::MistralSmallLatest => "mistral-small-latest",
122            Self::OpenMistralNemo => "open-mistral-nemo",
123            Self::OpenCodestralMamba => "open-codestral-mamba",
124            Self::DevstralSmallLatest => "devstral-small-latest",
125            Self::Pixtral12BLatest => "pixtral-12b-latest",
126            Self::PixtralLargeLatest => "pixtral-large-latest",
127            Self::Custom {
128                name, display_name, ..
129            } => display_name.as_ref().unwrap_or(name),
130        }
131    }
132
133    pub fn max_token_count(&self) -> u64 {
134        match self {
135            Self::CodestralLatest => 256000,
136            Self::MistralLargeLatest => 131000,
137            Self::MistralMediumLatest => 128000,
138            Self::MistralSmallLatest => 32000,
139            Self::OpenMistralNemo => 131000,
140            Self::OpenCodestralMamba => 256000,
141            Self::DevstralSmallLatest => 262144,
142            Self::Pixtral12BLatest => 128000,
143            Self::PixtralLargeLatest => 128000,
144            Self::Custom { max_tokens, .. } => *max_tokens,
145        }
146    }
147
148    pub fn max_output_tokens(&self) -> Option<u64> {
149        match self {
150            Self::Custom {
151                max_output_tokens, ..
152            } => *max_output_tokens,
153            _ => None,
154        }
155    }
156
157    pub fn supports_tools(&self) -> bool {
158        match self {
159            Self::CodestralLatest
160            | Self::MistralLargeLatest
161            | Self::MistralMediumLatest
162            | Self::MistralSmallLatest
163            | Self::OpenMistralNemo
164            | Self::OpenCodestralMamba
165            | Self::DevstralSmallLatest
166            | Self::Pixtral12BLatest
167            | Self::PixtralLargeLatest => true,
168            Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
169        }
170    }
171
172    pub fn supports_images(&self) -> bool {
173        match self {
174            Self::Pixtral12BLatest
175            | Self::PixtralLargeLatest
176            | Self::MistralMediumLatest
177            | Self::MistralSmallLatest => true,
178            Self::CodestralLatest
179            | Self::MistralLargeLatest
180            | Self::OpenMistralNemo
181            | Self::OpenCodestralMamba
182            | Self::DevstralSmallLatest => false,
183            Self::Custom {
184                supports_images, ..
185            } => supports_images.unwrap_or(false),
186        }
187    }
188}
189
190#[derive(Debug, Serialize, Deserialize)]
191pub struct Request {
192    pub model: String,
193    pub messages: Vec<RequestMessage>,
194    pub stream: bool,
195    #[serde(default, skip_serializing_if = "Option::is_none")]
196    pub max_tokens: Option<u64>,
197    #[serde(default, skip_serializing_if = "Option::is_none")]
198    pub temperature: Option<f32>,
199    #[serde(default, skip_serializing_if = "Option::is_none")]
200    pub response_format: Option<ResponseFormat>,
201    #[serde(default, skip_serializing_if = "Option::is_none")]
202    pub tool_choice: Option<ToolChoice>,
203    #[serde(default, skip_serializing_if = "Option::is_none")]
204    pub parallel_tool_calls: Option<bool>,
205    #[serde(default, skip_serializing_if = "Vec::is_empty")]
206    pub tools: Vec<ToolDefinition>,
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "snake_case")]
211pub enum ResponseFormat {
212    Text,
213    #[serde(rename = "json_object")]
214    JsonObject,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(tag = "type", rename_all = "snake_case")]
219pub enum ToolDefinition {
220    Function { function: FunctionDefinition },
221}
222
223#[derive(Debug, Serialize, Deserialize)]
224pub struct FunctionDefinition {
225    pub name: String,
226    pub description: Option<String>,
227    pub parameters: Option<Value>,
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231pub struct CompletionRequest {
232    pub model: String,
233    pub prompt: String,
234    pub max_tokens: u32,
235    pub temperature: f32,
236    #[serde(default, skip_serializing_if = "Option::is_none")]
237    pub prediction: Option<Prediction>,
238    #[serde(default, skip_serializing_if = "Option::is_none")]
239    pub rewrite_speculation: Option<bool>,
240}
241
242#[derive(Clone, Deserialize, Serialize, Debug)]
243#[serde(tag = "type", rename_all = "snake_case")]
244pub enum Prediction {
245    Content { content: String },
246}
247
248#[derive(Debug, Serialize, Deserialize)]
249#[serde(rename_all = "snake_case")]
250pub enum ToolChoice {
251    Auto,
252    Required,
253    None,
254    Any,
255    Function(ToolDefinition),
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259#[serde(tag = "role", rename_all = "lowercase")]
260pub enum RequestMessage {
261    Assistant {
262        content: Option<String>,
263        #[serde(default, skip_serializing_if = "Vec::is_empty")]
264        tool_calls: Vec<ToolCall>,
265    },
266    User {
267        #[serde(flatten)]
268        content: MessageContent,
269    },
270    System {
271        content: String,
272    },
273    Tool {
274        content: String,
275        tool_call_id: String,
276    },
277}
278
279#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
280#[serde(untagged)]
281pub enum MessageContent {
282    #[serde(rename = "content")]
283    Plain { content: String },
284    #[serde(rename = "content")]
285    Multipart { content: Vec<MessagePart> },
286}
287
288impl MessageContent {
289    pub fn empty() -> Self {
290        Self::Plain {
291            content: String::new(),
292        }
293    }
294
295    pub fn push_part(&mut self, part: MessagePart) {
296        match self {
297            Self::Plain { content } => match part {
298                MessagePart::Text { text } => {
299                    content.push_str(&text);
300                }
301                part => {
302                    let mut parts = if content.is_empty() {
303                        Vec::new()
304                    } else {
305                        vec![MessagePart::Text {
306                            text: content.clone(),
307                        }]
308                    };
309                    parts.push(part);
310                    *self = Self::Multipart { content: parts };
311                }
312            },
313            Self::Multipart { content } => {
314                content.push(part);
315            }
316        }
317    }
318}
319
320#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
321#[serde(tag = "type", rename_all = "snake_case")]
322pub enum MessagePart {
323    Text { text: String },
324    ImageUrl { image_url: String },
325}
326
327#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
328pub struct ToolCall {
329    pub id: String,
330    #[serde(flatten)]
331    pub content: ToolCallContent,
332}
333
334#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
335#[serde(tag = "type", rename_all = "lowercase")]
336pub enum ToolCallContent {
337    Function { function: FunctionContent },
338}
339
340#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
341pub struct FunctionContent {
342    pub name: String,
343    pub arguments: String,
344}
345
346#[derive(Serialize, Deserialize, Debug)]
347pub struct CompletionChoice {
348    pub text: String,
349}
350
351#[derive(Serialize, Deserialize, Debug)]
352pub struct Response {
353    pub id: String,
354    pub object: String,
355    pub created: u64,
356    pub model: String,
357    pub choices: Vec<Choice>,
358    pub usage: Usage,
359}
360
361#[derive(Serialize, Deserialize, Debug)]
362pub struct Usage {
363    pub prompt_tokens: u64,
364    pub completion_tokens: u64,
365    pub total_tokens: u64,
366}
367
368#[derive(Serialize, Deserialize, Debug)]
369pub struct Choice {
370    pub index: u32,
371    pub message: RequestMessage,
372    pub finish_reason: Option<String>,
373}
374
375#[derive(Serialize, Deserialize, Debug)]
376pub struct StreamResponse {
377    pub id: String,
378    pub object: String,
379    pub created: u64,
380    pub model: String,
381    pub choices: Vec<StreamChoice>,
382    pub usage: Option<Usage>,
383}
384
385#[derive(Serialize, Deserialize, Debug)]
386pub struct StreamChoice {
387    pub index: u32,
388    pub delta: StreamDelta,
389    pub finish_reason: Option<String>,
390}
391
392#[derive(Serialize, Deserialize, Debug)]
393pub struct StreamDelta {
394    pub role: Option<Role>,
395    pub content: Option<String>,
396    #[serde(default, skip_serializing_if = "Option::is_none")]
397    pub tool_calls: Option<Vec<ToolCallChunk>>,
398    #[serde(default, skip_serializing_if = "Option::is_none")]
399    pub reasoning_content: Option<String>,
400}
401
402#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
403pub struct ToolCallChunk {
404    pub index: usize,
405    pub id: Option<String>,
406    pub function: Option<FunctionChunk>,
407}
408
409#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
410pub struct FunctionChunk {
411    pub name: Option<String>,
412    pub arguments: Option<String>,
413}
414
415pub async fn stream_completion(
416    client: &dyn HttpClient,
417    api_url: &str,
418    api_key: &str,
419    request: Request,
420) -> Result<BoxStream<'static, Result<StreamResponse>>> {
421    let uri = format!("{api_url}/chat/completions");
422    let request_builder = HttpRequest::builder()
423        .method(Method::POST)
424        .uri(uri)
425        .header("Content-Type", "application/json")
426        .header("Authorization", format!("Bearer {}", api_key));
427
428    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
429    let mut response = client.send(request).await?;
430
431    if response.status().is_success() {
432        let reader = BufReader::new(response.into_body());
433        Ok(reader
434            .lines()
435            .filter_map(|line| async move {
436                match line {
437                    Ok(line) => {
438                        let line = line.strip_prefix("data: ")?;
439                        if line == "[DONE]" {
440                            None
441                        } else {
442                            match serde_json::from_str(line) {
443                                Ok(response) => Some(Ok(response)),
444                                Err(error) => Some(Err(anyhow!(error))),
445                            }
446                        }
447                    }
448                    Err(error) => Some(Err(anyhow!(error))),
449                }
450            })
451            .boxed())
452    } else {
453        let mut body = String::new();
454        response.body_mut().read_to_string(&mut body).await?;
455        anyhow::bail!(
456            "Failed to connect to Mistral API: {} {}",
457            response.status(),
458            body,
459        );
460    }
461}