mistral.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7use strum::EnumIter;
  8
  9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
 10
 11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 12#[serde(rename_all = "lowercase")]
 13pub enum Role {
 14    User,
 15    Assistant,
 16    System,
 17    Tool,
 18}
 19
 20impl TryFrom<String> for Role {
 21    type Error = anyhow::Error;
 22
 23    fn try_from(value: String) -> Result<Self> {
 24        match value.as_str() {
 25            "user" => Ok(Self::User),
 26            "assistant" => Ok(Self::Assistant),
 27            "system" => Ok(Self::System),
 28            "tool" => Ok(Self::Tool),
 29            _ => Err(anyhow!("invalid role '{value}'")),
 30        }
 31    }
 32}
 33
 34impl From<Role> for String {
 35    fn from(val: Role) -> Self {
 36        match val {
 37            Role::User => "user".to_owned(),
 38            Role::Assistant => "assistant".to_owned(),
 39            Role::System => "system".to_owned(),
 40            Role::Tool => "tool".to_owned(),
 41        }
 42    }
 43}
 44
 45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 47pub enum Model {
 48    #[serde(rename = "codestral-latest", alias = "codestral-latest")]
 49    CodestralLatest,
 50    #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
 51    MistralLargeLatest,
 52    #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
 53    MistralSmallLatest,
 54    #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
 55    OpenMistralNemo,
 56    #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
 57    #[default]
 58    OpenCodestralMamba,
 59
 60    #[serde(rename = "custom")]
 61    Custom {
 62        name: String,
 63        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 64        display_name: Option<String>,
 65        max_tokens: usize,
 66        max_output_tokens: Option<u32>,
 67        max_completion_tokens: Option<u32>,
 68    },
 69}
 70
 71impl Model {
 72    pub fn from_id(id: &str) -> Result<Self> {
 73        match id {
 74            "codestral-latest" => Ok(Self::CodestralLatest),
 75            "mistral-large-latest" => Ok(Self::MistralLargeLatest),
 76            "mistral-small-latest" => Ok(Self::MistralSmallLatest),
 77            "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
 78            "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
 79            _ => Err(anyhow!("invalid model id")),
 80        }
 81    }
 82
 83    pub fn id(&self) -> &str {
 84        match self {
 85            Self::CodestralLatest => "codestral-latest",
 86            Self::MistralLargeLatest => "mistral-large-latest",
 87            Self::MistralSmallLatest => "mistral-small-latest",
 88            Self::OpenMistralNemo => "open-mistral-nemo",
 89            Self::OpenCodestralMamba => "open-codestral-mamba",
 90            Self::Custom { name, .. } => name,
 91        }
 92    }
 93
 94    pub fn display_name(&self) -> &str {
 95        match self {
 96            Self::CodestralLatest => "codestral-latest",
 97            Self::MistralLargeLatest => "mistral-large-latest",
 98            Self::MistralSmallLatest => "mistral-small-latest",
 99            Self::OpenMistralNemo => "open-mistral-nemo",
100            Self::OpenCodestralMamba => "open-codestral-mamba",
101            Self::Custom {
102                name, display_name, ..
103            } => display_name.as_ref().unwrap_or(name),
104        }
105    }
106
107    pub fn max_token_count(&self) -> usize {
108        match self {
109            Self::CodestralLatest => 256000,
110            Self::MistralLargeLatest => 131000,
111            Self::MistralSmallLatest => 32000,
112            Self::OpenMistralNemo => 131000,
113            Self::OpenCodestralMamba => 256000,
114            Self::Custom { max_tokens, .. } => *max_tokens,
115        }
116    }
117
118    pub fn max_output_tokens(&self) -> Option<u32> {
119        match self {
120            Self::Custom {
121                max_output_tokens, ..
122            } => *max_output_tokens,
123            _ => None,
124        }
125    }
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129pub struct Request {
130    pub model: String,
131    pub messages: Vec<RequestMessage>,
132    pub stream: bool,
133    #[serde(default, skip_serializing_if = "Option::is_none")]
134    pub max_tokens: Option<u32>,
135    #[serde(default, skip_serializing_if = "Option::is_none")]
136    pub temperature: Option<f32>,
137    #[serde(default, skip_serializing_if = "Option::is_none")]
138    pub response_format: Option<ResponseFormat>,
139    #[serde(default, skip_serializing_if = "Vec::is_empty")]
140    pub tools: Vec<ToolDefinition>,
141}
142
143#[derive(Debug, Serialize, Deserialize)]
144#[serde(rename_all = "snake_case")]
145pub enum ResponseFormat {
146    Text,
147    #[serde(rename = "json_object")]
148    JsonObject,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(tag = "type", rename_all = "snake_case")]
153pub enum ToolDefinition {
154    Function { function: FunctionDefinition },
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158pub struct FunctionDefinition {
159    pub name: String,
160    pub description: Option<String>,
161    pub parameters: Option<Value>,
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165pub struct CompletionRequest {
166    pub model: String,
167    pub prompt: String,
168    pub max_tokens: u32,
169    pub temperature: f32,
170    #[serde(default, skip_serializing_if = "Option::is_none")]
171    pub prediction: Option<Prediction>,
172    #[serde(default, skip_serializing_if = "Option::is_none")]
173    pub rewrite_speculation: Option<bool>,
174}
175
176#[derive(Clone, Deserialize, Serialize, Debug)]
177#[serde(tag = "type", rename_all = "snake_case")]
178pub enum Prediction {
179    Content { content: String },
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183#[serde(untagged)]
184pub enum ToolChoice {
185    Auto,
186    Required,
187    None,
188    Other(ToolDefinition),
189}
190
191#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
192#[serde(tag = "role", rename_all = "lowercase")]
193pub enum RequestMessage {
194    Assistant {
195        content: Option<String>,
196        #[serde(default, skip_serializing_if = "Vec::is_empty")]
197        tool_calls: Vec<ToolCall>,
198    },
199    User {
200        content: String,
201    },
202    System {
203        content: String,
204    },
205    Tool {
206        content: String,
207        tool_call_id: String,
208    },
209}
210
211#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
212pub struct ToolCall {
213    pub id: String,
214    #[serde(flatten)]
215    pub content: ToolCallContent,
216}
217
218#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
219#[serde(tag = "type", rename_all = "lowercase")]
220pub enum ToolCallContent {
221    Function { function: FunctionContent },
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225pub struct FunctionContent {
226    pub name: String,
227    pub arguments: String,
228}
229
230#[derive(Serialize, Deserialize, Debug)]
231pub struct CompletionChoice {
232    pub text: String,
233}
234
235#[derive(Serialize, Deserialize, Debug)]
236pub struct Response {
237    pub id: String,
238    pub object: String,
239    pub created: u64,
240    pub model: String,
241    pub choices: Vec<Choice>,
242    pub usage: Usage,
243}
244
245#[derive(Serialize, Deserialize, Debug)]
246pub struct Usage {
247    pub prompt_tokens: u32,
248    pub completion_tokens: u32,
249    pub total_tokens: u32,
250}
251
252#[derive(Serialize, Deserialize, Debug)]
253pub struct Choice {
254    pub index: u32,
255    pub message: RequestMessage,
256    pub finish_reason: Option<String>,
257}
258
259#[derive(Serialize, Deserialize, Debug)]
260pub struct StreamResponse {
261    pub id: String,
262    pub object: String,
263    pub created: u64,
264    pub model: String,
265    pub choices: Vec<StreamChoice>,
266}
267
268#[derive(Serialize, Deserialize, Debug)]
269pub struct StreamChoice {
270    pub index: u32,
271    pub delta: StreamDelta,
272    pub finish_reason: Option<String>,
273}
274
275#[derive(Serialize, Deserialize, Debug)]
276pub struct StreamDelta {
277    pub role: Option<Role>,
278    pub content: Option<String>,
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub tool_calls: Option<Vec<ToolCallChunk>>,
281    #[serde(default, skip_serializing_if = "Option::is_none")]
282    pub reasoning_content: Option<String>,
283}
284
285#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
286pub struct ToolCallChunk {
287    pub index: usize,
288    pub id: Option<String>,
289    pub function: Option<FunctionChunk>,
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293pub struct FunctionChunk {
294    pub name: Option<String>,
295    pub arguments: Option<String>,
296}
297
298pub async fn stream_completion(
299    client: &dyn HttpClient,
300    api_url: &str,
301    api_key: &str,
302    request: Request,
303) -> Result<BoxStream<'static, Result<StreamResponse>>> {
304    let uri = format!("{api_url}/chat/completions");
305    let request_builder = HttpRequest::builder()
306        .method(Method::POST)
307        .uri(uri)
308        .header("Content-Type", "application/json")
309        .header("Authorization", format!("Bearer {}", api_key));
310
311    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
312    let mut response = client.send(request).await?;
313
314    if response.status().is_success() {
315        let reader = BufReader::new(response.into_body());
316        Ok(reader
317            .lines()
318            .filter_map(|line| async move {
319                match line {
320                    Ok(line) => {
321                        let line = line.strip_prefix("data: ")?;
322                        if line == "[DONE]" {
323                            None
324                        } else {
325                            match serde_json::from_str(line) {
326                                Ok(response) => Some(Ok(response)),
327                                Err(error) => Some(Err(anyhow!(error))),
328                            }
329                        }
330                    }
331                    Err(error) => Some(Err(anyhow!(error))),
332                }
333            })
334            .boxed())
335    } else {
336        let mut body = String::new();
337        response.body_mut().read_to_string(&mut body).await?;
338        Err(anyhow!(
339            "Failed to connect to Mistral API: {} {}",
340            response.status(),
341            body,
342        ))
343    }
344}