mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => Err(anyhow!("invalid role '{value}'")),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    #[default]
 50    CodestralLatest,
 51    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 52    MistralLargeLatest,
 53    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 54    MistralSmallLatest,
 55    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 56    OpenMistralNemo,
 57    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 58    OpenCodestralMamba,
 59
 60    #[serde(rename = "custom")]
 61    Custom {
 62        name: String,
 63        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 64        display_name: Option<String>,
 65        max_tokens: usize,
 66        max_output_tokens: Option<u32>,
 67        max_completion_tokens: Option<u32>,
 68    },
 69}
 70
 71impl Model {
 72    pub fn default_fast() -> Self {
 73        Model::MistralSmallLatest
 74    }
 75
 76    pub fn from_id(id: &str) -> Result<Self> {
 77        match id {
 78            "codestral-latest" => Ok(Self::CodestralLatest),
 79            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
 80            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
 81            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
 82            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
 83            _ => Err(anyhow!("invalid model id")),
 84        }
 85    }
 86
 87    pub fn id(&self) -> &str {
 88        match self {
 89            Self::CodestralLatest => "codestral-latest",
 90            Self::MistralLargeLatest => "mistral-large-latest",
 91            Self::MistralSmallLatest => "mistral-small-latest",
 92            Self::OpenMistralNemo => "open-mistral-nemo",
 93            Self::OpenCodestralMamba => "open-codestral-mamba",
 94            Self::Custom { name, .. } => name,
 95        }
 96    }
 97
 98    pub fn display_name(&self) -> &str {
 99        match self {
100            Self::CodestralLatest => "codestral-latest",
101            Self::MistralLargeLatest => "mistral-large-latest",
102            Self::MistralSmallLatest => "mistral-small-latest",
103            Self::OpenMistralNemo => "open-mistral-nemo",
104            Self::OpenCodestralMamba => "open-codestral-mamba",
105            Self::Custom {
106                name, display_name, ..
107            } => display_name.as_ref().unwrap_or(name),
108        }
109    }
110
111    pub fn max_token_count(&self) -> usize {
112        match self {
113            Self::CodestralLatest => 256000,
114            Self::MistralLargeLatest => 131000,
115            Self::MistralSmallLatest => 32000,
116            Self::OpenMistralNemo => 131000,
117            Self::OpenCodestralMamba => 256000,
118            Self::Custom { max_tokens, .. } => *max_tokens,
119        }
120    }
121
122    pub fn max_output_tokens(&self) -> Option<u32> {
123        match self {
124            Self::Custom {
125                max_output_tokens, ..
126            } => *max_output_tokens,
127            _ => None,
128        }
129    }
130}
131
132#[derive(Debug, Serialize, Deserialize)]
133pub struct Request {
134    pub model: String,
135    pub messages: Vec<RequestMessage>,
136    pub stream: bool,
137    #[serde(default, skip_serializing_if = "Option::is_none")]
138    pub max_tokens: Option<u32>,
139    #[serde(default, skip_serializing_if = "Option::is_none")]
140    pub temperature: Option<f32>,
141    #[serde(default, skip_serializing_if = "Option::is_none")]
142    pub response_format: Option<ResponseFormat>,
143    #[serde(default, skip_serializing_if = "Vec::is_empty")]
144    pub tools: Vec<ToolDefinition>,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148#[serde(rename_all = "snake_case")]
149pub enum ResponseFormat {
150    Text,
151    #[serde(rename = "json_object")]
152    JsonObject,
153}
154
155#[derive(Debug, Serialize, Deserialize)]
156#[serde(tag = "type", rename_all = "snake_case")]
157pub enum ToolDefinition {
158    Function { function: FunctionDefinition },
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162pub struct FunctionDefinition {
163    pub name: String,
164    pub description: Option<String>,
165    pub parameters: Option<Value>,
166}
167
168#[derive(Debug, Serialize, Deserialize)]
169pub struct CompletionRequest {
170    pub model: String,
171    pub prompt: String,
172    pub max_tokens: u32,
173    pub temperature: f32,
174    #[serde(default, skip_serializing_if = "Option::is_none")]
175    pub prediction: Option<Prediction>,
176    #[serde(default, skip_serializing_if = "Option::is_none")]
177    pub rewrite_speculation: Option<bool>,
178}
179
180#[derive(Clone, Deserialize, Serialize, Debug)]
181#[serde(tag = "type", rename_all = "snake_case")]
182pub enum Prediction {
183    Content { content: String },
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187#[serde(untagged)]
188pub enum ToolChoice {
189    Auto,
190    Required,
191    None,
192    Other(ToolDefinition),
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196#[serde(tag = "role", rename_all = "lowercase")]
197pub enum RequestMessage {
198    Assistant {
199        content: Option<String>,
200        #[serde(default, skip_serializing_if = "Vec::is_empty")]
201        tool_calls: Vec<ToolCall>,
202    },
203    User {
204        content: String,
205    },
206    System {
207        content: String,
208    },
209    Tool {
210        content: String,
211        tool_call_id: String,
212    },
213}
214
215#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
216pub struct ToolCall {
217    pub id: String,
218    #[serde(flatten)]
219    pub content: ToolCallContent,
220}
221
222#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum ToolCallContent {
225    Function { function: FunctionContent },
226}
227
228#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
229pub struct FunctionContent {
230    pub name: String,
231    pub arguments: String,
232}
233
234#[derive(Serialize, Deserialize, Debug)]
235pub struct CompletionChoice {
236    pub text: String,
237}
238
239#[derive(Serialize, Deserialize, Debug)]
240pub struct Response {
241    pub id: String,
242    pub object: String,
243    pub created: u64,
244    pub model: String,
245    pub choices: Vec<Choice>,
246    pub usage: Usage,
247}
248
249#[derive(Serialize, Deserialize, Debug)]
250pub struct Usage {
251    pub prompt_tokens: u32,
252    pub completion_tokens: u32,
253    pub total_tokens: u32,
254}
255
256#[derive(Serialize, Deserialize, Debug)]
257pub struct Choice {
258    pub index: u32,
259    pub message: RequestMessage,
260    pub finish_reason: Option<String>,
261}
262
263#[derive(Serialize, Deserialize, Debug)]
264pub struct StreamResponse {
265    pub id: String,
266    pub object: String,
267    pub created: u64,
268    pub model: String,
269    pub choices: Vec<StreamChoice>,
270}
271
272#[derive(Serialize, Deserialize, Debug)]
273pub struct StreamChoice {
274    pub index: u32,
275    pub delta: StreamDelta,
276    pub finish_reason: Option<String>,
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct StreamDelta {
281    pub role: Option<Role>,
282    pub content: Option<String>,
283    #[serde(default, skip_serializing_if = "Option::is_none")]
284    pub tool_calls: Option<Vec<ToolCallChunk>>,
285    #[serde(default, skip_serializing_if = "Option::is_none")]
286    pub reasoning_content: Option<String>,
287}
288
289#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
290pub struct ToolCallChunk {
291    pub index: usize,
292    pub id: Option<String>,
293    pub function: Option<FunctionChunk>,
294}
295
296#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
297pub struct FunctionChunk {
298    pub name: Option<String>,
299    pub arguments: Option<String>,
300}
301
302pub async fn stream_completion(
303    client: &dyn HttpClient,
304    api_url: &str,
305    api_key: &str,
306    request: Request,
307) -> Result<BoxStream<'static, Result<StreamResponse>>> {
308    let uri = format!("{api_url}/chat/completions");
309    let request_builder = HttpRequest::builder()
310        .method(Method::POST)
311        .uri(uri)
312        .header("Content-Type", "application/json")
313        .header("Authorization", format!("Bearer {}", api_key));
314
315    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
316    let mut response = client.send(request).await?;
317
318    if response.status().is_success() {
319        let reader = BufReader::new(response.into_body());
320        Ok(reader
321            .lines()
322            .filter_map(|line| async move {
323                match line {
324                    Ok(line) => {
325                        let line = line.strip_prefix("data: ")?;
326                        if line == "[DONE]" {
327                            None
328                        } else {
329                            match serde_json::from_str(line) {
330                                Ok(response) => Some(Ok(response)),
331                                Err(error) => Some(Err(anyhow!(error))),
332                            }
333                        }
334                    }
335                    Err(error) => Some(Err(anyhow!(error))),
336                }
337            })
338            .boxed())
339    } else {
340        let mut body = String::new();
341        response.body_mut().read_to_string(&mut body).await?;
342        Err(anyhow!(
343            "Failed to connect to Mistral API: {} {}",
344            response.status(),
345            body,
346        ))
347    }
348}