1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
11
12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 User,
16 Assistant,
17 System,
18 Tool,
19}
20
21impl TryFrom<String> for Role {
22 type Error = anyhow::Error;
23
24 fn try_from(value: String) -> Result<Self> {
25 match value.as_str() {
26 "user" => Ok(Self::User),
27 "assistant" => Ok(Self::Assistant),
28 "system" => Ok(Self::System),
29 "tool" => Ok(Self::Tool),
30 _ => anyhow::bail!("invalid role '{value}'"),
31 }
32 }
33}
34
35impl From<Role> for String {
36 fn from(val: Role) -> Self {
37 match val {
38 Role::User => "user".to_owned(),
39 Role::Assistant => "assistant".to_owned(),
40 Role::System => "system".to_owned(),
41 Role::Tool => "tool".to_owned(),
42 }
43 }
44}
45
46#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
47#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
48pub enum Model {
49 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
50 #[default]
51 CodestralLatest,
52
53 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
54 MistralLargeLatest,
55 #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
56 MistralMediumLatest,
57 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
58 MistralSmallLatest,
59
60 #[serde(rename = "magistral-medium-latest", alias = "magistral-medium-latest")]
61 MagistralMediumLatest,
62 #[serde(rename = "magistral-small-latest", alias = "magistral-small-latest")]
63 MagistralSmallLatest,
64
65 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
66 OpenMistralNemo,
67 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
68 OpenCodestralMamba,
69
70 #[serde(rename = "devstral-medium-latest", alias = "devstral-medium-latest")]
71 DevstralMediumLatest,
72 #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
73 DevstralSmallLatest,
74
75 #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
76 Pixtral12BLatest,
77 #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
78 PixtralLargeLatest,
79
80 #[serde(rename = "custom")]
81 Custom {
82 name: String,
83 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
84 display_name: Option<String>,
85 max_tokens: u64,
86 max_output_tokens: Option<u64>,
87 max_completion_tokens: Option<u64>,
88 supports_tools: Option<bool>,
89 supports_images: Option<bool>,
90 supports_thinking: Option<bool>,
91 },
92}
93
94impl Model {
95 pub fn default_fast() -> Self {
96 Model::MistralSmallLatest
97 }
98
99 pub fn from_id(id: &str) -> Result<Self> {
100 match id {
101 "codestral-latest" => Ok(Self::CodestralLatest),
102 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
103 "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
104 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
105 "magistral-medium-latest" => Ok(Self::MagistralMediumLatest),
106 "magistral-small-latest" => Ok(Self::MagistralSmallLatest),
107 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
108 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
109 "devstral-medium-latest" => Ok(Self::DevstralMediumLatest),
110 "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
111 "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
112 "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
113 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
114 }
115 }
116
117 pub fn id(&self) -> &str {
118 match self {
119 Self::CodestralLatest => "codestral-latest",
120 Self::MistralLargeLatest => "mistral-large-latest",
121 Self::MistralMediumLatest => "mistral-medium-latest",
122 Self::MistralSmallLatest => "mistral-small-latest",
123 Self::MagistralMediumLatest => "magistral-medium-latest",
124 Self::MagistralSmallLatest => "magistral-small-latest",
125 Self::OpenMistralNemo => "open-mistral-nemo",
126 Self::OpenCodestralMamba => "open-codestral-mamba",
127 Self::DevstralMediumLatest => "devstral-medium-latest",
128 Self::DevstralSmallLatest => "devstral-small-latest",
129 Self::Pixtral12BLatest => "pixtral-12b-latest",
130 Self::PixtralLargeLatest => "pixtral-large-latest",
131 Self::Custom { name, .. } => name,
132 }
133 }
134
135 pub fn display_name(&self) -> &str {
136 match self {
137 Self::CodestralLatest => "codestral-latest",
138 Self::MistralLargeLatest => "mistral-large-latest",
139 Self::MistralMediumLatest => "mistral-medium-latest",
140 Self::MistralSmallLatest => "mistral-small-latest",
141 Self::MagistralMediumLatest => "magistral-medium-latest",
142 Self::MagistralSmallLatest => "magistral-small-latest",
143 Self::OpenMistralNemo => "open-mistral-nemo",
144 Self::OpenCodestralMamba => "open-codestral-mamba",
145 Self::DevstralMediumLatest => "devstral-medium-latest",
146 Self::DevstralSmallLatest => "devstral-small-latest",
147 Self::Pixtral12BLatest => "pixtral-12b-latest",
148 Self::PixtralLargeLatest => "pixtral-large-latest",
149 Self::Custom {
150 name, display_name, ..
151 } => display_name.as_ref().unwrap_or(name),
152 }
153 }
154
155 pub fn max_token_count(&self) -> u64 {
156 match self {
157 Self::CodestralLatest => 256000,
158 Self::MistralLargeLatest => 131000,
159 Self::MistralMediumLatest => 128000,
160 Self::MistralSmallLatest => 32000,
161 Self::MagistralMediumLatest => 40000,
162 Self::MagistralSmallLatest => 40000,
163 Self::OpenMistralNemo => 131000,
164 Self::OpenCodestralMamba => 256000,
165 Self::DevstralMediumLatest => 128000,
166 Self::DevstralSmallLatest => 262144,
167 Self::Pixtral12BLatest => 128000,
168 Self::PixtralLargeLatest => 128000,
169 Self::Custom { max_tokens, .. } => *max_tokens,
170 }
171 }
172
173 pub fn max_output_tokens(&self) -> Option<u64> {
174 match self {
175 Self::Custom {
176 max_output_tokens, ..
177 } => *max_output_tokens,
178 _ => None,
179 }
180 }
181
182 pub fn supports_tools(&self) -> bool {
183 match self {
184 Self::CodestralLatest
185 | Self::MistralLargeLatest
186 | Self::MistralMediumLatest
187 | Self::MistralSmallLatest
188 | Self::MagistralMediumLatest
189 | Self::MagistralSmallLatest
190 | Self::OpenMistralNemo
191 | Self::OpenCodestralMamba
192 | Self::DevstralMediumLatest
193 | Self::DevstralSmallLatest
194 | Self::Pixtral12BLatest
195 | Self::PixtralLargeLatest => true,
196 Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
197 }
198 }
199
200 pub fn supports_images(&self) -> bool {
201 match self {
202 Self::Pixtral12BLatest
203 | Self::PixtralLargeLatest
204 | Self::MistralMediumLatest
205 | Self::MistralSmallLatest => true,
206 Self::CodestralLatest
207 | Self::MistralLargeLatest
208 | Self::MagistralMediumLatest
209 | Self::MagistralSmallLatest
210 | Self::OpenMistralNemo
211 | Self::OpenCodestralMamba
212 | Self::DevstralMediumLatest
213 | Self::DevstralSmallLatest => false,
214 Self::Custom {
215 supports_images, ..
216 } => supports_images.unwrap_or(false),
217 }
218 }
219
220 pub fn supports_thinking(&self) -> bool {
221 match self {
222 Self::MagistralMediumLatest | Self::MagistralSmallLatest => true,
223 Self::Custom {
224 supports_thinking, ..
225 } => supports_thinking.unwrap_or(false),
226 _ => false,
227 }
228 }
229}
230
231#[derive(Debug, Serialize, Deserialize)]
232pub struct Request {
233 pub model: String,
234 pub messages: Vec<RequestMessage>,
235 pub stream: bool,
236 #[serde(default, skip_serializing_if = "Option::is_none")]
237 pub max_tokens: Option<u64>,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub temperature: Option<f32>,
240 #[serde(default, skip_serializing_if = "Option::is_none")]
241 pub response_format: Option<ResponseFormat>,
242 #[serde(default, skip_serializing_if = "Option::is_none")]
243 pub tool_choice: Option<ToolChoice>,
244 #[serde(default, skip_serializing_if = "Option::is_none")]
245 pub parallel_tool_calls: Option<bool>,
246 #[serde(default, skip_serializing_if = "Vec::is_empty")]
247 pub tools: Vec<ToolDefinition>,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251#[serde(rename_all = "snake_case")]
252pub enum ResponseFormat {
253 Text,
254 #[serde(rename = "json_object")]
255 JsonObject,
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259#[serde(tag = "type", rename_all = "snake_case")]
260pub enum ToolDefinition {
261 Function { function: FunctionDefinition },
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265pub struct FunctionDefinition {
266 pub name: String,
267 pub description: Option<String>,
268 pub parameters: Option<Value>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272#[serde(rename_all = "lowercase")]
273pub enum ToolChoice {
274 Auto,
275 Required,
276 None,
277 Any,
278 #[serde(untagged)]
279 Function(ToolDefinition),
280}
281
282#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
283#[serde(tag = "role", rename_all = "lowercase")]
284pub enum RequestMessage {
285 Assistant {
286 #[serde(flatten)]
287 #[serde(default, skip_serializing_if = "Option::is_none")]
288 content: Option<MessageContent>,
289 #[serde(default, skip_serializing_if = "Vec::is_empty")]
290 tool_calls: Vec<ToolCall>,
291 },
292 User {
293 #[serde(flatten)]
294 content: MessageContent,
295 },
296 System {
297 #[serde(flatten)]
298 content: MessageContent,
299 },
300 Tool {
301 content: String,
302 tool_call_id: String,
303 },
304}
305
306#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
307#[serde(untagged)]
308pub enum MessageContent {
309 #[serde(rename = "content")]
310 Plain { content: String },
311 #[serde(rename = "content")]
312 Multipart { content: Vec<MessagePart> },
313}
314
315impl MessageContent {
316 pub fn empty() -> Self {
317 Self::Plain {
318 content: String::new(),
319 }
320 }
321
322 pub fn push_part(&mut self, part: MessagePart) {
323 match self {
324 Self::Plain { content } => match part {
325 MessagePart::Text { text } => {
326 content.push_str(&text);
327 }
328 part => {
329 let mut parts = if content.is_empty() {
330 Vec::new()
331 } else {
332 vec![MessagePart::Text {
333 text: content.clone(),
334 }]
335 };
336 parts.push(part);
337 *self = Self::Multipart { content: parts };
338 }
339 },
340 Self::Multipart { content } => {
341 content.push(part);
342 }
343 }
344 }
345}
346
347#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
348#[serde(tag = "type", rename_all = "snake_case")]
349pub enum MessagePart {
350 Text { text: String },
351 ImageUrl { image_url: String },
352 Thinking { thinking: Vec<ThinkingPart> },
353}
354
355// Backwards-compatibility alias for provider code that refers to ContentPart
356pub type ContentPart = MessagePart;
357
358#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
359#[serde(tag = "type", rename_all = "snake_case")]
360pub enum ThinkingPart {
361 Text { text: String },
362}
363
364#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
365pub struct ToolCall {
366 pub id: String,
367 #[serde(flatten)]
368 pub content: ToolCallContent,
369}
370
371#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
372#[serde(tag = "type", rename_all = "lowercase")]
373pub enum ToolCallContent {
374 Function { function: FunctionContent },
375}
376
377#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
378pub struct FunctionContent {
379 pub name: String,
380 pub arguments: String,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct Usage {
385 pub prompt_tokens: u64,
386 pub completion_tokens: u64,
387 pub total_tokens: u64,
388}
389
390#[derive(Serialize, Deserialize, Debug)]
391pub struct StreamResponse {
392 pub id: String,
393 pub object: String,
394 pub created: u64,
395 pub model: String,
396 pub choices: Vec<StreamChoice>,
397 pub usage: Option<Usage>,
398}
399
400#[derive(Serialize, Deserialize, Debug)]
401pub struct StreamChoice {
402 pub index: u32,
403 pub delta: StreamDelta,
404 pub finish_reason: Option<String>,
405}
406
407#[derive(Serialize, Deserialize, Debug, Clone)]
408pub struct StreamDelta {
409 pub role: Option<Role>,
410 #[serde(default, skip_serializing_if = "Option::is_none")]
411 pub content: Option<MessageContentDelta>,
412 #[serde(default, skip_serializing_if = "Option::is_none")]
413 pub tool_calls: Option<Vec<ToolCallChunk>>,
414}
415
416#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
417#[serde(untagged)]
418pub enum MessageContentDelta {
419 Text(String),
420 Parts(Vec<MessagePart>),
421}
422
423#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
424pub struct ToolCallChunk {
425 pub index: usize,
426 pub id: Option<String>,
427 pub function: Option<FunctionChunk>,
428}
429
430#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
431pub struct FunctionChunk {
432 pub name: Option<String>,
433 pub arguments: Option<String>,
434}
435
436pub async fn stream_completion(
437 client: &dyn HttpClient,
438 api_url: &str,
439 api_key: &str,
440 request: Request,
441) -> Result<BoxStream<'static, Result<StreamResponse>>> {
442 let uri = format!("{api_url}/chat/completions");
443 let request_builder = HttpRequest::builder()
444 .method(Method::POST)
445 .uri(uri)
446 .header("Content-Type", "application/json")
447 .header("Authorization", format!("Bearer {}", api_key.trim()));
448
449 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
450 let mut response = client.send(request).await?;
451
452 if response.status().is_success() {
453 let reader = BufReader::new(response.into_body());
454 Ok(reader
455 .lines()
456 .filter_map(|line| async move {
457 match line {
458 Ok(line) => {
459 let line = line.strip_prefix("data: ")?;
460 if line == "[DONE]" {
461 None
462 } else {
463 match serde_json::from_str(line) {
464 Ok(response) => Some(Ok(response)),
465 Err(error) => Some(Err(anyhow!(error))),
466 }
467 }
468 }
469 Err(error) => Some(Err(anyhow!(error))),
470 }
471 })
472 .boxed())
473 } else {
474 let mut body = String::new();
475 response.body_mut().read_to_string(&mut body).await?;
476 anyhow::bail!(
477 "Failed to connect to Mistral API: {} {}",
478 response.status(),
479 body,
480 );
481 }
482}