mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => anyhow::bail!("invalid role '{value}'"),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    #[default]
 50    CodestralLatest,
 51    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 52    MistralLargeLatest,
 53    #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
 54    MistralMediumLatest,
 55    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 56    MistralSmallLatest,
 57    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 58    OpenMistralNemo,
 59    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 60    OpenCodestralMamba,
 61    #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
 62    DevstralSmallLatest,
 63
 64    #[serde(rename = "custom")]
 65    Custom {
 66        name: String,
 67        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 68        display_name: Option<String>,
 69        max_tokens: usize,
 70        max_output_tokens: Option<u32>,
 71        max_completion_tokens: Option<u32>,
 72        supports_tools: Option<bool>,
 73    },
 74}
 75
 76impl Model {
 77    pub fn default_fast() -> Self {
 78        Model::MistralSmallLatest
 79    }
 80
 81    pub fn from_id(id: &str) -> Result<Self> {
 82        match id {
 83            "codestral-latest" => Ok(Self::CodestralLatest),
 84            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
 85            "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
 86            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
 87            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
 88            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
 89            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
 90        }
 91    }
 92
 93    pub fn id(&self) -> &str {
 94        match self {
 95            Self::CodestralLatest => "codestral-latest",
 96            Self::MistralLargeLatest => "mistral-large-latest",
 97            Self::MistralMediumLatest => "mistral-medium-latest",
 98            Self::MistralSmallLatest => "mistral-small-latest",
 99            Self::OpenMistralNemo => "open-mistral-nemo",
100            Self::OpenCodestralMamba => "open-codestral-mamba",
101            Self::DevstralSmallLatest => "devstral-small-latest",
102            Self::Custom { name, .. } => name,
103        }
104    }
105
106    pub fn display_name(&self) -> &str {
107        match self {
108            Self::CodestralLatest => "codestral-latest",
109            Self::MistralLargeLatest => "mistral-large-latest",
110            Self::MistralMediumLatest => "mistral-medium-latest",
111            Self::MistralSmallLatest => "mistral-small-latest",
112            Self::OpenMistralNemo => "open-mistral-nemo",
113            Self::OpenCodestralMamba => "open-codestral-mamba",
114            Self::DevstralSmallLatest => "devstral-small-latest",
115            Self::Custom {
116                name, display_name, ..
117            } => display_name.as_ref().unwrap_or(name),
118        }
119    }
120
121    pub fn max_token_count(&self) -> usize {
122        match self {
123            Self::CodestralLatest => 256000,
124            Self::MistralLargeLatest => 131000,
125            Self::MistralMediumLatest => 128000,
126            Self::MistralSmallLatest => 32000,
127            Self::OpenMistralNemo => 131000,
128            Self::OpenCodestralMamba => 256000,
129            Self::DevstralSmallLatest => 262144,
130            Self::Custom { max_tokens, .. } => *max_tokens,
131        }
132    }
133
134    pub fn max_output_tokens(&self) -> Option<u32> {
135        match self {
136            Self::Custom {
137                max_output_tokens, ..
138            } => *max_output_tokens,
139            _ => None,
140        }
141    }
142
143    pub fn supports_tools(&self) -> bool {
144        match self {
145            Self::CodestralLatest
146            | Self::MistralLargeLatest
147            | Self::MistralMediumLatest
148            | Self::MistralSmallLatest
149            | Self::OpenMistralNemo
150            | Self::OpenCodestralMamba
151            | Self::DevstralSmallLatest => true,
152            Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
153        }
154    }
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158pub struct Request {
159    pub model: String,
160    pub messages: Vec<RequestMessage>,
161    pub stream: bool,
162    #[serde(default, skip_serializing_if = "Option::is_none")]
163    pub max_tokens: Option<u32>,
164    #[serde(default, skip_serializing_if = "Option::is_none")]
165    pub temperature: Option<f32>,
166    #[serde(default, skip_serializing_if = "Option::is_none")]
167    pub response_format: Option<ResponseFormat>,
168    #[serde(default, skip_serializing_if = "Option::is_none")]
169    pub tool_choice: Option<ToolChoice>,
170    #[serde(default, skip_serializing_if = "Option::is_none")]
171    pub parallel_tool_calls: Option<bool>,
172    #[serde(default, skip_serializing_if = "Vec::is_empty")]
173    pub tools: Vec<ToolDefinition>,
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177#[serde(rename_all = "snake_case")]
178pub enum ResponseFormat {
179    Text,
180    #[serde(rename = "json_object")]
181    JsonObject,
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum ToolDefinition {
187    Function { function: FunctionDefinition },
188}
189
190#[derive(Debug, Serialize, Deserialize)]
191pub struct FunctionDefinition {
192    pub name: String,
193    pub description: Option<String>,
194    pub parameters: Option<Value>,
195}
196
197#[derive(Debug, Serialize, Deserialize)]
198pub struct CompletionRequest {
199    pub model: String,
200    pub prompt: String,
201    pub max_tokens: u32,
202    pub temperature: f32,
203    #[serde(default, skip_serializing_if = "Option::is_none")]
204    pub prediction: Option<Prediction>,
205    #[serde(default, skip_serializing_if = "Option::is_none")]
206    pub rewrite_speculation: Option<bool>,
207}
208
209#[derive(Clone, Deserialize, Serialize, Debug)]
210#[serde(tag = "type", rename_all = "snake_case")]
211pub enum Prediction {
212    Content { content: String },
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216#[serde(rename_all = "snake_case")]
217pub enum ToolChoice {
218    Auto,
219    Required,
220    None,
221    Any,
222    Function(ToolDefinition),
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226#[serde(tag = "role", rename_all = "lowercase")]
227pub enum RequestMessage {
228    Assistant {
229        content: Option<String>,
230        #[serde(default, skip_serializing_if = "Vec::is_empty")]
231        tool_calls: Vec<ToolCall>,
232    },
233    User {
234        content: String,
235    },
236    System {
237        content: String,
238    },
239    Tool {
240        content: String,
241        tool_call_id: String,
242    },
243}
244
245#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
246pub struct ToolCall {
247    pub id: String,
248    #[serde(flatten)]
249    pub content: ToolCallContent,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253#[serde(tag = "type", rename_all = "lowercase")]
254pub enum ToolCallContent {
255    Function { function: FunctionContent },
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259pub struct FunctionContent {
260    pub name: String,
261    pub arguments: String,
262}
263
264#[derive(Serialize, Deserialize, Debug)]
265pub struct CompletionChoice {
266    pub text: String,
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct Response {
271    pub id: String,
272    pub object: String,
273    pub created: u64,
274    pub model: String,
275    pub choices: Vec<Choice>,
276    pub usage: Usage,
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct Usage {
281    pub prompt_tokens: u32,
282    pub completion_tokens: u32,
283    pub total_tokens: u32,
284}
285
286#[derive(Serialize, Deserialize, Debug)]
287pub struct Choice {
288    pub index: u32,
289    pub message: RequestMessage,
290    pub finish_reason: Option<String>,
291}
292
293#[derive(Serialize, Deserialize, Debug)]
294pub struct StreamResponse {
295    pub id: String,
296    pub object: String,
297    pub created: u64,
298    pub model: String,
299    pub choices: Vec<StreamChoice>,
300}
301
302#[derive(Serialize, Deserialize, Debug)]
303pub struct StreamChoice {
304    pub index: u32,
305    pub delta: StreamDelta,
306    pub finish_reason: Option<String>,
307}
308
309#[derive(Serialize, Deserialize, Debug)]
310pub struct StreamDelta {
311    pub role: Option<Role>,
312    pub content: Option<String>,
313    #[serde(default, skip_serializing_if = "Option::is_none")]
314    pub tool_calls: Option<Vec<ToolCallChunk>>,
315    #[serde(default, skip_serializing_if = "Option::is_none")]
316    pub reasoning_content: Option<String>,
317}
318
319#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
320pub struct ToolCallChunk {
321    pub index: usize,
322    pub id: Option<String>,
323    pub function: Option<FunctionChunk>,
324}
325
326#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
327pub struct FunctionChunk {
328    pub name: Option<String>,
329    pub arguments: Option<String>,
330}
331
332pub async fn stream_completion(
333    client: &dyn HttpClient,
334    api_url: &str,
335    api_key: &str,
336    request: Request,
337) -> Result<BoxStream<'static, Result<StreamResponse>>> {
338    let uri = format!("{api_url}/chat/completions");
339    let request_builder = HttpRequest::builder()
340        .method(Method::POST)
341        .uri(uri)
342        .header("Content-Type", "application/json")
343        .header("Authorization", format!("Bearer {}", api_key));
344
345    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
346    let mut response = client.send(request).await?;
347
348    if response.status().is_success() {
349        let reader = BufReader::new(response.into_body());
350        Ok(reader
351            .lines()
352            .filter_map(|line| async move {
353                match line {
354                    Ok(line) => {
355                        let line = line.strip_prefix("data: ")?;
356                        if line == "[DONE]" {
357                            None
358                        } else {
359                            match serde_json::from_str(line) {
360                                Ok(response) => Some(Ok(response)),
361                                Err(error) => Some(Err(anyhow!(error))),
362                            }
363                        }
364                    }
365                    Err(error) => Some(Err(anyhow!(error))),
366                }
367            })
368            .boxed())
369    } else {
370        let mut body = String::new();
371        response.body_mut().read_to_string(&mut body).await?;
372        anyhow::bail!(
373            "Failed to connect to Mistral API: {} {}",
374            response.status(),
375            body,
376        );
377    }
378}