mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => anyhow::bail!("invalid role '{value}'"),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    #[default]
 50    CodestralLatest,
 51    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 52    MistralLargeLatest,
 53    #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
 54    MistralMediumLatest,
 55    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 56    MistralSmallLatest,
 57    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 58    OpenMistralNemo,
 59    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 60    OpenCodestralMamba,
 61
 62    #[serde(rename = "custom")]
 63    Custom {
 64        name: String,
 65        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 66        display_name: Option<String>,
 67        max_tokens: usize,
 68        max_output_tokens: Option<u32>,
 69        max_completion_tokens: Option<u32>,
 70        supports_tools: Option<bool>,
 71    },
 72}
 73
 74impl Model {
 75    pub fn default_fast() -> Self {
 76        Model::MistralSmallLatest
 77    }
 78
 79    pub fn from_id(id: &str) -> Result<Self> {
 80        match id {
 81            "codestral-latest" => Ok(Self::CodestralLatest),
 82            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
 83            "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
 84            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
 85            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
 86            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
 87            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
 88        }
 89    }
 90
 91    pub fn id(&self) -> &str {
 92        match self {
 93            Self::CodestralLatest => "codestral-latest",
 94            Self::MistralLargeLatest => "mistral-large-latest",
 95            Self::MistralMediumLatest => "mistral-medium-latest",
 96            Self::MistralSmallLatest => "mistral-small-latest",
 97            Self::OpenMistralNemo => "open-mistral-nemo",
 98            Self::OpenCodestralMamba => "open-codestral-mamba",
 99            Self::Custom { name, .. } => name,
100        }
101    }
102
103    pub fn display_name(&self) -> &str {
104        match self {
105            Self::CodestralLatest => "codestral-latest",
106            Self::MistralLargeLatest => "mistral-large-latest",
107            Self::MistralMediumLatest => "mistral-medium-latest",
108            Self::MistralSmallLatest => "mistral-small-latest",
109            Self::OpenMistralNemo => "open-mistral-nemo",
110            Self::OpenCodestralMamba => "open-codestral-mamba",
111            Self::Custom {
112                name, display_name, ..
113            } => display_name.as_ref().unwrap_or(name),
114        }
115    }
116
117    pub fn max_token_count(&self) -> usize {
118        match self {
119            Self::CodestralLatest => 256000,
120            Self::MistralLargeLatest => 131000,
121            Self::MistralMediumLatest => 128000,
122            Self::MistralSmallLatest => 32000,
123            Self::OpenMistralNemo => 131000,
124            Self::OpenCodestralMamba => 256000,
125            Self::Custom { max_tokens, .. } => *max_tokens,
126        }
127    }
128
129    pub fn max_output_tokens(&self) -> Option<u32> {
130        match self {
131            Self::Custom {
132                max_output_tokens, ..
133            } => *max_output_tokens,
134            _ => None,
135        }
136    }
137
138    pub fn supports_tools(&self) -> bool {
139        match self {
140            Self::CodestralLatest
141            | Self::MistralLargeLatest
142            | Self::MistralMediumLatest
143            | Self::MistralSmallLatest
144            | Self::OpenMistralNemo
145            | Self::OpenCodestralMamba => true,
146            Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
147        }
148    }
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152pub struct Request {
153    pub model: String,
154    pub messages: Vec<RequestMessage>,
155    pub stream: bool,
156    #[serde(default, skip_serializing_if = "Option::is_none")]
157    pub max_tokens: Option<u32>,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub temperature: Option<f32>,
160    #[serde(default, skip_serializing_if = "Option::is_none")]
161    pub response_format: Option<ResponseFormat>,
162    #[serde(default, skip_serializing_if = "Option::is_none")]
163    pub tool_choice: Option<ToolChoice>,
164    #[serde(default, skip_serializing_if = "Option::is_none")]
165    pub parallel_tool_calls: Option<bool>,
166    #[serde(default, skip_serializing_if = "Vec::is_empty")]
167    pub tools: Vec<ToolDefinition>,
168}
169
170#[derive(Debug, Serialize, Deserialize)]
171#[serde(rename_all = "snake_case")]
172pub enum ResponseFormat {
173    Text,
174    #[serde(rename = "json_object")]
175    JsonObject,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179#[serde(tag = "type", rename_all = "snake_case")]
180pub enum ToolDefinition {
181    Function { function: FunctionDefinition },
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185pub struct FunctionDefinition {
186    pub name: String,
187    pub description: Option<String>,
188    pub parameters: Option<Value>,
189}
190
191#[derive(Debug, Serialize, Deserialize)]
192pub struct CompletionRequest {
193    pub model: String,
194    pub prompt: String,
195    pub max_tokens: u32,
196    pub temperature: f32,
197    #[serde(default, skip_serializing_if = "Option::is_none")]
198    pub prediction: Option<Prediction>,
199    #[serde(default, skip_serializing_if = "Option::is_none")]
200    pub rewrite_speculation: Option<bool>,
201}
202
203#[derive(Clone, Deserialize, Serialize, Debug)]
204#[serde(tag = "type", rename_all = "snake_case")]
205pub enum Prediction {
206    Content { content: String },
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "snake_case")]
211pub enum ToolChoice {
212    Auto,
213    Required,
214    None,
215    Any,
216    Function(ToolDefinition),
217}
218
219#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
220#[serde(tag = "role", rename_all = "lowercase")]
221pub enum RequestMessage {
222    Assistant {
223        content: Option<String>,
224        #[serde(default, skip_serializing_if = "Vec::is_empty")]
225        tool_calls: Vec<ToolCall>,
226    },
227    User {
228        content: String,
229    },
230    System {
231        content: String,
232    },
233    Tool {
234        content: String,
235        tool_call_id: String,
236    },
237}
238
239#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
240pub struct ToolCall {
241    pub id: String,
242    #[serde(flatten)]
243    pub content: ToolCallContent,
244}
245
246#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
247#[serde(tag = "type", rename_all = "lowercase")]
248pub enum ToolCallContent {
249    Function { function: FunctionContent },
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct FunctionContent {
254    pub name: String,
255    pub arguments: String,
256}
257
258#[derive(Serialize, Deserialize, Debug)]
259pub struct CompletionChoice {
260    pub text: String,
261}
262
263#[derive(Serialize, Deserialize, Debug)]
264pub struct Response {
265    pub id: String,
266    pub object: String,
267    pub created: u64,
268    pub model: String,
269    pub choices: Vec<Choice>,
270    pub usage: Usage,
271}
272
273#[derive(Serialize, Deserialize, Debug)]
274pub struct Usage {
275    pub prompt_tokens: u32,
276    pub completion_tokens: u32,
277    pub total_tokens: u32,
278}
279
280#[derive(Serialize, Deserialize, Debug)]
281pub struct Choice {
282    pub index: u32,
283    pub message: RequestMessage,
284    pub finish_reason: Option<String>,
285}
286
287#[derive(Serialize, Deserialize, Debug)]
288pub struct StreamResponse {
289    pub id: String,
290    pub object: String,
291    pub created: u64,
292    pub model: String,
293    pub choices: Vec<StreamChoice>,
294}
295
296#[derive(Serialize, Deserialize, Debug)]
297pub struct StreamChoice {
298    pub index: u32,
299    pub delta: StreamDelta,
300    pub finish_reason: Option<String>,
301}
302
303#[derive(Serialize, Deserialize, Debug)]
304pub struct StreamDelta {
305    pub role: Option<Role>,
306    pub content: Option<String>,
307    #[serde(default, skip_serializing_if = "Option::is_none")]
308    pub tool_calls: Option<Vec<ToolCallChunk>>,
309    #[serde(default, skip_serializing_if = "Option::is_none")]
310    pub reasoning_content: Option<String>,
311}
312
313#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
314pub struct ToolCallChunk {
315    pub index: usize,
316    pub id: Option<String>,
317    pub function: Option<FunctionChunk>,
318}
319
320#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
321pub struct FunctionChunk {
322    pub name: Option<String>,
323    pub arguments: Option<String>,
324}
325
326pub async fn stream_completion(
327    client: &dyn HttpClient,
328    api_url: &str,
329    api_key: &str,
330    request: Request,
331) -> Result<BoxStream<'static, Result<StreamResponse>>> {
332    let uri = format!("{api_url}/chat/completions");
333    let request_builder = HttpRequest::builder()
334        .method(Method::POST)
335        .uri(uri)
336        .header("Content-Type", "application/json")
337        .header("Authorization", format!("Bearer {}", api_key));
338
339    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
340    let mut response = client.send(request).await?;
341
342    if response.status().is_success() {
343        let reader = BufReader::new(response.into_body());
344        Ok(reader
345            .lines()
346            .filter_map(|line| async move {
347                match line {
348                    Ok(line) => {
349                        let line = line.strip_prefix("data: ")?;
350                        if line == "[DONE]" {
351                            None
352                        } else {
353                            match serde_json::from_str(line) {
354                                Ok(response) => Some(Ok(response)),
355                                Err(error) => Some(Err(anyhow!(error))),
356                            }
357                        }
358                    }
359                    Err(error) => Some(Err(anyhow!(error))),
360                }
361            })
362            .boxed())
363    } else {
364        let mut body = String::new();
365        response.body_mut().read_to_string(&mut body).await?;
366        anyhow::bail!(
367            "Failed to connect to Mistral API: {} {}",
368            response.status(),
369            body,
370        );
371    }
372}