1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17 Tool,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 "tool" => Ok(Self::Tool),
29 _ => anyhow::bail!("invalid role '{value}'"),
30 }
31 }
32}
33
34impl From<Role> for String {
35 fn from(val: Role) -> Self {
36 match val {
37 Role::User => "user".to_owned(),
38 Role::Assistant => "assistant".to_owned(),
39 Role::System => "system".to_owned(),
40 Role::Tool => "tool".to_owned(),
41 }
42 }
43}
44
45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum Model {
48 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
49 #[default]
50 CodestralLatest,
51 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
52 MistralLargeLatest,
53 #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
54 MistralMediumLatest,
55 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
56 MistralSmallLatest,
57 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
58 OpenMistralNemo,
59 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
60 OpenCodestralMamba,
61 #[serde(rename = "devstral-medium-latest", alias = "devstral-medium-latest")]
62 DevstralMediumLatest,
63 #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
64 DevstralSmallLatest,
65 #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
66 Pixtral12BLatest,
67 #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
68 PixtralLargeLatest,
69
70 #[serde(rename = "custom")]
71 Custom {
72 name: String,
73 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
74 display_name: Option<String>,
75 max_tokens: u64,
76 max_output_tokens: Option<u64>,
77 max_completion_tokens: Option<u64>,
78 supports_tools: Option<bool>,
79 supports_images: Option<bool>,
80 },
81}
82
83impl Model {
84 pub fn default_fast() -> Self {
85 Model::MistralSmallLatest
86 }
87
88 pub fn from_id(id: &str) -> Result<Self> {
89 match id {
90 "codestral-latest" => Ok(Self::CodestralLatest),
91 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
92 "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
93 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
94 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
95 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
96 "devstral-medium-latest" => Ok(Self::DevstralMediumLatest),
97 "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
98 "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
99 "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
100 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
101 }
102 }
103
104 pub fn id(&self) -> &str {
105 match self {
106 Self::CodestralLatest => "codestral-latest",
107 Self::MistralLargeLatest => "mistral-large-latest",
108 Self::MistralMediumLatest => "mistral-medium-latest",
109 Self::MistralSmallLatest => "mistral-small-latest",
110 Self::OpenMistralNemo => "open-mistral-nemo",
111 Self::OpenCodestralMamba => "open-codestral-mamba",
112 Self::DevstralMediumLatest => "devstral-medium-latest",
113 Self::DevstralSmallLatest => "devstral-small-latest",
114 Self::Pixtral12BLatest => "pixtral-12b-latest",
115 Self::PixtralLargeLatest => "pixtral-large-latest",
116 Self::Custom { name, .. } => name,
117 }
118 }
119
120 pub fn display_name(&self) -> &str {
121 match self {
122 Self::CodestralLatest => "codestral-latest",
123 Self::MistralLargeLatest => "mistral-large-latest",
124 Self::MistralMediumLatest => "mistral-medium-latest",
125 Self::MistralSmallLatest => "mistral-small-latest",
126 Self::OpenMistralNemo => "open-mistral-nemo",
127 Self::OpenCodestralMamba => "open-codestral-mamba",
128 Self::DevstralMediumLatest => "devstral-medium-latest",
129 Self::DevstralSmallLatest => "devstral-small-latest",
130 Self::Pixtral12BLatest => "pixtral-12b-latest",
131 Self::PixtralLargeLatest => "pixtral-large-latest",
132 Self::Custom {
133 name, display_name, ..
134 } => display_name.as_ref().unwrap_or(name),
135 }
136 }
137
138 pub fn max_token_count(&self) -> u64 {
139 match self {
140 Self::CodestralLatest => 256000,
141 Self::MistralLargeLatest => 131000,
142 Self::MistralMediumLatest => 128000,
143 Self::MistralSmallLatest => 32000,
144 Self::OpenMistralNemo => 131000,
145 Self::OpenCodestralMamba => 256000,
146 Self::DevstralMediumLatest => 128000,
147 Self::DevstralSmallLatest => 262144,
148 Self::Pixtral12BLatest => 128000,
149 Self::PixtralLargeLatest => 128000,
150 Self::Custom { max_tokens, .. } => *max_tokens,
151 }
152 }
153
154 pub fn max_output_tokens(&self) -> Option<u64> {
155 match self {
156 Self::Custom {
157 max_output_tokens, ..
158 } => *max_output_tokens,
159 _ => None,
160 }
161 }
162
163 pub fn supports_tools(&self) -> bool {
164 match self {
165 Self::CodestralLatest
166 | Self::MistralLargeLatest
167 | Self::MistralMediumLatest
168 | Self::MistralSmallLatest
169 | Self::OpenMistralNemo
170 | Self::OpenCodestralMamba
171 | Self::DevstralMediumLatest
172 | Self::DevstralSmallLatest
173 | Self::Pixtral12BLatest
174 | Self::PixtralLargeLatest => true,
175 Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
176 }
177 }
178
179 pub fn supports_images(&self) -> bool {
180 match self {
181 Self::Pixtral12BLatest
182 | Self::PixtralLargeLatest
183 | Self::MistralMediumLatest
184 | Self::MistralSmallLatest => true,
185 Self::CodestralLatest
186 | Self::MistralLargeLatest
187 | Self::OpenMistralNemo
188 | Self::OpenCodestralMamba
189 | Self::DevstralMediumLatest
190 | Self::DevstralSmallLatest => false,
191 Self::Custom {
192 supports_images, ..
193 } => supports_images.unwrap_or(false),
194 }
195 }
196}
197
198#[derive(Debug, Serialize, Deserialize)]
199pub struct Request {
200 pub model: String,
201 pub messages: Vec<RequestMessage>,
202 pub stream: bool,
203 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub max_tokens: Option<u64>,
205 #[serde(default, skip_serializing_if = "Option::is_none")]
206 pub temperature: Option<f32>,
207 #[serde(default, skip_serializing_if = "Option::is_none")]
208 pub response_format: Option<ResponseFormat>,
209 #[serde(default, skip_serializing_if = "Option::is_none")]
210 pub tool_choice: Option<ToolChoice>,
211 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub parallel_tool_calls: Option<bool>,
213 #[serde(default, skip_serializing_if = "Vec::is_empty")]
214 pub tools: Vec<ToolDefinition>,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(rename_all = "snake_case")]
219pub enum ResponseFormat {
220 Text,
221 #[serde(rename = "json_object")]
222 JsonObject,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226#[serde(tag = "type", rename_all = "snake_case")]
227pub enum ToolDefinition {
228 Function { function: FunctionDefinition },
229}
230
231#[derive(Debug, Serialize, Deserialize)]
232pub struct FunctionDefinition {
233 pub name: String,
234 pub description: Option<String>,
235 pub parameters: Option<Value>,
236}
237
238#[derive(Debug, Serialize, Deserialize)]
239pub struct CompletionRequest {
240 pub model: String,
241 pub prompt: String,
242 pub max_tokens: u32,
243 pub temperature: f32,
244 #[serde(default, skip_serializing_if = "Option::is_none")]
245 pub prediction: Option<Prediction>,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub rewrite_speculation: Option<bool>,
248}
249
250#[derive(Clone, Deserialize, Serialize, Debug)]
251#[serde(tag = "type", rename_all = "snake_case")]
252pub enum Prediction {
253 Content { content: String },
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257#[serde(rename_all = "snake_case")]
258pub enum ToolChoice {
259 Auto,
260 Required,
261 None,
262 Any,
263 Function(ToolDefinition),
264}
265
266#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
267#[serde(tag = "role", rename_all = "lowercase")]
268pub enum RequestMessage {
269 Assistant {
270 content: Option<String>,
271 #[serde(default, skip_serializing_if = "Vec::is_empty")]
272 tool_calls: Vec<ToolCall>,
273 },
274 User {
275 #[serde(flatten)]
276 content: MessageContent,
277 },
278 System {
279 content: String,
280 },
281 Tool {
282 content: String,
283 tool_call_id: String,
284 },
285}
286
287#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
288#[serde(untagged)]
289pub enum MessageContent {
290 #[serde(rename = "content")]
291 Plain { content: String },
292 #[serde(rename = "content")]
293 Multipart { content: Vec<MessagePart> },
294}
295
296impl MessageContent {
297 pub fn empty() -> Self {
298 Self::Plain {
299 content: String::new(),
300 }
301 }
302
303 pub fn push_part(&mut self, part: MessagePart) {
304 match self {
305 Self::Plain { content } => match part {
306 MessagePart::Text { text } => {
307 content.push_str(&text);
308 }
309 part => {
310 let mut parts = if content.is_empty() {
311 Vec::new()
312 } else {
313 vec![MessagePart::Text {
314 text: content.clone(),
315 }]
316 };
317 parts.push(part);
318 *self = Self::Multipart { content: parts };
319 }
320 },
321 Self::Multipart { content } => {
322 content.push(part);
323 }
324 }
325 }
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329#[serde(tag = "type", rename_all = "snake_case")]
330pub enum MessagePart {
331 Text { text: String },
332 ImageUrl { image_url: String },
333}
334
335#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
336pub struct ToolCall {
337 pub id: String,
338 #[serde(flatten)]
339 pub content: ToolCallContent,
340}
341
342#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
343#[serde(tag = "type", rename_all = "lowercase")]
344pub enum ToolCallContent {
345 Function { function: FunctionContent },
346}
347
348#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
349pub struct FunctionContent {
350 pub name: String,
351 pub arguments: String,
352}
353
354#[derive(Serialize, Deserialize, Debug)]
355pub struct CompletionChoice {
356 pub text: String,
357}
358
359#[derive(Serialize, Deserialize, Debug)]
360pub struct Response {
361 pub id: String,
362 pub object: String,
363 pub created: u64,
364 pub model: String,
365 pub choices: Vec<Choice>,
366 pub usage: Usage,
367}
368
369#[derive(Serialize, Deserialize, Debug)]
370pub struct Usage {
371 pub prompt_tokens: u64,
372 pub completion_tokens: u64,
373 pub total_tokens: u64,
374}
375
376#[derive(Serialize, Deserialize, Debug)]
377pub struct Choice {
378 pub index: u32,
379 pub message: RequestMessage,
380 pub finish_reason: Option<String>,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct StreamResponse {
385 pub id: String,
386 pub object: String,
387 pub created: u64,
388 pub model: String,
389 pub choices: Vec<StreamChoice>,
390 pub usage: Option<Usage>,
391}
392
393#[derive(Serialize, Deserialize, Debug)]
394pub struct StreamChoice {
395 pub index: u32,
396 pub delta: StreamDelta,
397 pub finish_reason: Option<String>,
398}
399
400#[derive(Serialize, Deserialize, Debug)]
401pub struct StreamDelta {
402 pub role: Option<Role>,
403 pub content: Option<String>,
404 #[serde(default, skip_serializing_if = "Option::is_none")]
405 pub tool_calls: Option<Vec<ToolCallChunk>>,
406 #[serde(default, skip_serializing_if = "Option::is_none")]
407 pub reasoning_content: Option<String>,
408}
409
410#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
411pub struct ToolCallChunk {
412 pub index: usize,
413 pub id: Option<String>,
414 pub function: Option<FunctionChunk>,
415}
416
417#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
418pub struct FunctionChunk {
419 pub name: Option<String>,
420 pub arguments: Option<String>,
421}
422
423pub async fn stream_completion(
424 client: &dyn HttpClient,
425 api_url: &str,
426 api_key: &str,
427 request: Request,
428) -> Result<BoxStream<'static, Result<StreamResponse>>> {
429 let uri = format!("{api_url}/chat/completions");
430 let request_builder = HttpRequest::builder()
431 .method(Method::POST)
432 .uri(uri)
433 .header("Content-Type", "application/json")
434 .header("Authorization", format!("Bearer {}", api_key));
435
436 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
437 let mut response = client.send(request).await?;
438
439 if response.status().is_success() {
440 let reader = BufReader::new(response.into_body());
441 Ok(reader
442 .lines()
443 .filter_map(|line| async move {
444 match line {
445 Ok(line) => {
446 let line = line.strip_prefix("data: ")?;
447 if line == "[DONE]" {
448 None
449 } else {
450 match serde_json::from_str(line) {
451 Ok(response) => Some(Ok(response)),
452 Err(error) => Some(Err(anyhow!(error))),
453 }
454 }
455 }
456 Err(error) => Some(Err(anyhow!(error))),
457 }
458 })
459 .boxed())
460 } else {
461 let mut body = String::new();
462 response.body_mut().read_to_string(&mut body).await?;
463 anyhow::bail!(
464 "Failed to connect to Mistral API: {} {}",
465 response.status(),
466 body,
467 );
468 }
469}