mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => Err(anyhow!("invalid role '{value}'")),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    #[default]
 50    CodestralLatest,
 51    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 52    MistralLargeLatest,
 53    #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
 54    MistralMediumLatest,
 55    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 56    MistralSmallLatest,
 57    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 58    OpenMistralNemo,
 59    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 60    OpenCodestralMamba,
 61
 62    #[serde(rename = "custom")]
 63    Custom {
 64        name: String,
 65        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 66        display_name: Option<String>,
 67        max_tokens: usize,
 68        max_output_tokens: Option<u32>,
 69        max_completion_tokens: Option<u32>,
 70    },
 71}
 72
 73impl Model {
 74    pub fn default_fast() -> Self {
 75        Model::MistralSmallLatest
 76    }
 77
 78    pub fn from_id(id: &str) -> Result<Self> {
 79        match id {
 80            "codestral-latest" => Ok(Self::CodestralLatest),
 81            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
 82            "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
 83            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
 84            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
 85            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
 86            _ => Err(anyhow!("invalid model id")),
 87        }
 88    }
 89
 90    pub fn id(&self) -> &str {
 91        match self {
 92            Self::CodestralLatest => "codestral-latest",
 93            Self::MistralLargeLatest => "mistral-large-latest",
 94            Self::MistralMediumLatest => "mistral-medium-latest",
 95            Self::MistralSmallLatest => "mistral-small-latest",
 96            Self::OpenMistralNemo => "open-mistral-nemo",
 97            Self::OpenCodestralMamba => "open-codestral-mamba",
 98            Self::Custom { name, .. } => name,
 99        }
100    }
101
102    pub fn display_name(&self) -> &str {
103        match self {
104            Self::CodestralLatest => "codestral-latest",
105            Self::MistralLargeLatest => "mistral-large-latest",
106            Self::MistralMediumLatest => "mistral-medium-latest",
107            Self::MistralSmallLatest => "mistral-small-latest",
108            Self::OpenMistralNemo => "open-mistral-nemo",
109            Self::OpenCodestralMamba => "open-codestral-mamba",
110            Self::Custom {
111                name, display_name, ..
112            } => display_name.as_ref().unwrap_or(name),
113        }
114    }
115
116    pub fn max_token_count(&self) -> usize {
117        match self {
118            Self::CodestralLatest => 256000,
119            Self::MistralLargeLatest => 131000,
120            Self::MistralMediumLatest => 128000,
121            Self::MistralSmallLatest => 32000,
122            Self::OpenMistralNemo => 131000,
123            Self::OpenCodestralMamba => 256000,
124            Self::Custom { max_tokens, .. } => *max_tokens,
125        }
126    }
127
128    pub fn max_output_tokens(&self) -> Option<u32> {
129        match self {
130            Self::Custom {
131                max_output_tokens, ..
132            } => *max_output_tokens,
133            _ => None,
134        }
135    }
136}
137
138#[derive(Debug, Serialize, Deserialize)]
139pub struct Request {
140    pub model: String,
141    pub messages: Vec<RequestMessage>,
142    pub stream: bool,
143    #[serde(default, skip_serializing_if = "Option::is_none")]
144    pub max_tokens: Option<u32>,
145    #[serde(default, skip_serializing_if = "Option::is_none")]
146    pub temperature: Option<f32>,
147    #[serde(default, skip_serializing_if = "Option::is_none")]
148    pub response_format: Option<ResponseFormat>,
149    #[serde(default, skip_serializing_if = "Vec::is_empty")]
150    pub tools: Vec<ToolDefinition>,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155pub enum ResponseFormat {
156    Text,
157    #[serde(rename = "json_object")]
158    JsonObject,
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162#[serde(tag = "type", rename_all = "snake_case")]
163pub enum ToolDefinition {
164    Function { function: FunctionDefinition },
165}
166
167#[derive(Debug, Serialize, Deserialize)]
168pub struct FunctionDefinition {
169    pub name: String,
170    pub description: Option<String>,
171    pub parameters: Option<Value>,
172}
173
174#[derive(Debug, Serialize, Deserialize)]
175pub struct CompletionRequest {
176    pub model: String,
177    pub prompt: String,
178    pub max_tokens: u32,
179    pub temperature: f32,
180    #[serde(default, skip_serializing_if = "Option::is_none")]
181    pub prediction: Option<Prediction>,
182    #[serde(default, skip_serializing_if = "Option::is_none")]
183    pub rewrite_speculation: Option<bool>,
184}
185
186#[derive(Clone, Deserialize, Serialize, Debug)]
187#[serde(tag = "type", rename_all = "snake_case")]
188pub enum Prediction {
189    Content { content: String },
190}
191
192#[derive(Debug, Serialize, Deserialize)]
193#[serde(untagged)]
194pub enum ToolChoice {
195    Auto,
196    Required,
197    None,
198    Other(ToolDefinition),
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202#[serde(tag = "role", rename_all = "lowercase")]
203pub enum RequestMessage {
204    Assistant {
205        content: Option<String>,
206        #[serde(default, skip_serializing_if = "Vec::is_empty")]
207        tool_calls: Vec<ToolCall>,
208    },
209    User {
210        content: String,
211    },
212    System {
213        content: String,
214    },
215    Tool {
216        content: String,
217        tool_call_id: String,
218    },
219}
220
221#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
222pub struct ToolCall {
223    pub id: String,
224    #[serde(flatten)]
225    pub content: ToolCallContent,
226}
227
228#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
229#[serde(tag = "type", rename_all = "lowercase")]
230pub enum ToolCallContent {
231    Function { function: FunctionContent },
232}
233
234#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
235pub struct FunctionContent {
236    pub name: String,
237    pub arguments: String,
238}
239
240#[derive(Serialize, Deserialize, Debug)]
241pub struct CompletionChoice {
242    pub text: String,
243}
244
245#[derive(Serialize, Deserialize, Debug)]
246pub struct Response {
247    pub id: String,
248    pub object: String,
249    pub created: u64,
250    pub model: String,
251    pub choices: Vec<Choice>,
252    pub usage: Usage,
253}
254
255#[derive(Serialize, Deserialize, Debug)]
256pub struct Usage {
257    pub prompt_tokens: u32,
258    pub completion_tokens: u32,
259    pub total_tokens: u32,
260}
261
262#[derive(Serialize, Deserialize, Debug)]
263pub struct Choice {
264    pub index: u32,
265    pub message: RequestMessage,
266    pub finish_reason: Option<String>,
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct StreamResponse {
271    pub id: String,
272    pub object: String,
273    pub created: u64,
274    pub model: String,
275    pub choices: Vec<StreamChoice>,
276}
277
278#[derive(Serialize, Deserialize, Debug)]
279pub struct StreamChoice {
280    pub index: u32,
281    pub delta: StreamDelta,
282    pub finish_reason: Option<String>,
283}
284
285#[derive(Serialize, Deserialize, Debug)]
286pub struct StreamDelta {
287    pub role: Option<Role>,
288    pub content: Option<String>,
289    #[serde(default, skip_serializing_if = "Option::is_none")]
290    pub tool_calls: Option<Vec<ToolCallChunk>>,
291    #[serde(default, skip_serializing_if = "Option::is_none")]
292    pub reasoning_content: Option<String>,
293}
294
295#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
296pub struct ToolCallChunk {
297    pub index: usize,
298    pub id: Option<String>,
299    pub function: Option<FunctionChunk>,
300}
301
302#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
303pub struct FunctionChunk {
304    pub name: Option<String>,
305    pub arguments: Option<String>,
306}
307
308pub async fn stream_completion(
309    client: &dyn HttpClient,
310    api_url: &str,
311    api_key: &str,
312    request: Request,
313) -> Result<BoxStream<'static, Result<StreamResponse>>> {
314    let uri = format!("{api_url}/chat/completions");
315    let request_builder = HttpRequest::builder()
316        .method(Method::POST)
317        .uri(uri)
318        .header("Content-Type", "application/json")
319        .header("Authorization", format!("Bearer {}", api_key));
320
321    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
322    let mut response = client.send(request).await?;
323
324    if response.status().is_success() {
325        let reader = BufReader::new(response.into_body());
326        Ok(reader
327            .lines()
328            .filter_map(|line| async move {
329                match line {
330                    Ok(line) => {
331                        let line = line.strip_prefix("data: ")?;
332                        if line == "[DONE]" {
333                            None
334                        } else {
335                            match serde_json::from_str(line) {
336                                Ok(response) => Some(Ok(response)),
337                                Err(error) => Some(Err(anyhow!(error))),
338                            }
339                        }
340                    }
341                    Err(error) => Some(Err(anyhow!(error))),
342                }
343            })
344            .boxed())
345    } else {
346        let mut body = String::new();
347        response.body_mut().read_to_string(&mut body).await?;
348        Err(anyhow!(
349            "Failed to connect to Mistral API: {} {}",
350            response.status(),
351            body,
352        ))
353    }
354}