1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17 Tool,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 "tool" => Ok(Self::Tool),
29 _ => anyhow::bail!("invalid role '{value}'"),
30 }
31 }
32}
33
34impl From<Role> for String {
35 fn from(val: Role) -> Self {
36 match val {
37 Role::User => "user".to_owned(),
38 Role::Assistant => "assistant".to_owned(),
39 Role::System => "system".to_owned(),
40 Role::Tool => "tool".to_owned(),
41 }
42 }
43}
44
45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum Model {
48 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
49 #[default]
50 CodestralLatest,
51 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
52 MistralLargeLatest,
53 #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
54 MistralMediumLatest,
55 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
56 MistralSmallLatest,
57 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
58 OpenMistralNemo,
59 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
60 OpenCodestralMamba,
61 #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
62 DevstralSmallLatest,
63 #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
64 Pixtral12BLatest,
65 #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
66 PixtralLargeLatest,
67
68 #[serde(rename = "custom")]
69 Custom {
70 name: String,
71 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
72 display_name: Option<String>,
73 max_tokens: u64,
74 max_output_tokens: Option<u64>,
75 max_completion_tokens: Option<u64>,
76 supports_tools: Option<bool>,
77 supports_images: Option<bool>,
78 },
79}
80
81impl Model {
82 pub fn default_fast() -> Self {
83 Model::MistralSmallLatest
84 }
85
86 pub fn from_id(id: &str) -> Result<Self> {
87 match id {
88 "codestral-latest" => Ok(Self::CodestralLatest),
89 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
90 "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
91 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
92 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
93 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
94 "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
95 "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
96 "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
97 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
98 }
99 }
100
101 pub fn id(&self) -> &str {
102 match self {
103 Self::CodestralLatest => "codestral-latest",
104 Self::MistralLargeLatest => "mistral-large-latest",
105 Self::MistralMediumLatest => "mistral-medium-latest",
106 Self::MistralSmallLatest => "mistral-small-latest",
107 Self::OpenMistralNemo => "open-mistral-nemo",
108 Self::OpenCodestralMamba => "open-codestral-mamba",
109 Self::DevstralSmallLatest => "devstral-small-latest",
110 Self::Pixtral12BLatest => "pixtral-12b-latest",
111 Self::PixtralLargeLatest => "pixtral-large-latest",
112 Self::Custom { name, .. } => name,
113 }
114 }
115
116 pub fn display_name(&self) -> &str {
117 match self {
118 Self::CodestralLatest => "codestral-latest",
119 Self::MistralLargeLatest => "mistral-large-latest",
120 Self::MistralMediumLatest => "mistral-medium-latest",
121 Self::MistralSmallLatest => "mistral-small-latest",
122 Self::OpenMistralNemo => "open-mistral-nemo",
123 Self::OpenCodestralMamba => "open-codestral-mamba",
124 Self::DevstralSmallLatest => "devstral-small-latest",
125 Self::Pixtral12BLatest => "pixtral-12b-latest",
126 Self::PixtralLargeLatest => "pixtral-large-latest",
127 Self::Custom {
128 name, display_name, ..
129 } => display_name.as_ref().unwrap_or(name),
130 }
131 }
132
133 pub fn max_token_count(&self) -> u64 {
134 match self {
135 Self::CodestralLatest => 256000,
136 Self::MistralLargeLatest => 131000,
137 Self::MistralMediumLatest => 128000,
138 Self::MistralSmallLatest => 32000,
139 Self::OpenMistralNemo => 131000,
140 Self::OpenCodestralMamba => 256000,
141 Self::DevstralSmallLatest => 262144,
142 Self::Pixtral12BLatest => 128000,
143 Self::PixtralLargeLatest => 128000,
144 Self::Custom { max_tokens, .. } => *max_tokens,
145 }
146 }
147
148 pub fn max_output_tokens(&self) -> Option<u64> {
149 match self {
150 Self::Custom {
151 max_output_tokens, ..
152 } => *max_output_tokens,
153 _ => None,
154 }
155 }
156
157 pub fn supports_tools(&self) -> bool {
158 match self {
159 Self::CodestralLatest
160 | Self::MistralLargeLatest
161 | Self::MistralMediumLatest
162 | Self::MistralSmallLatest
163 | Self::OpenMistralNemo
164 | Self::OpenCodestralMamba
165 | Self::DevstralSmallLatest
166 | Self::Pixtral12BLatest
167 | Self::PixtralLargeLatest => true,
168 Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
169 }
170 }
171
172 pub fn supports_images(&self) -> bool {
173 match self {
174 Self::Pixtral12BLatest
175 | Self::PixtralLargeLatest
176 | Self::MistralMediumLatest
177 | Self::MistralSmallLatest => true,
178 Self::CodestralLatest
179 | Self::MistralLargeLatest
180 | Self::OpenMistralNemo
181 | Self::OpenCodestralMamba
182 | Self::DevstralSmallLatest => false,
183 Self::Custom {
184 supports_images, ..
185 } => supports_images.unwrap_or(false),
186 }
187 }
188}
189
190#[derive(Debug, Serialize, Deserialize)]
191pub struct Request {
192 pub model: String,
193 pub messages: Vec<RequestMessage>,
194 pub stream: bool,
195 #[serde(default, skip_serializing_if = "Option::is_none")]
196 pub max_tokens: Option<u64>,
197 #[serde(default, skip_serializing_if = "Option::is_none")]
198 pub temperature: Option<f32>,
199 #[serde(default, skip_serializing_if = "Option::is_none")]
200 pub response_format: Option<ResponseFormat>,
201 #[serde(default, skip_serializing_if = "Option::is_none")]
202 pub tool_choice: Option<ToolChoice>,
203 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub parallel_tool_calls: Option<bool>,
205 #[serde(default, skip_serializing_if = "Vec::is_empty")]
206 pub tools: Vec<ToolDefinition>,
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "snake_case")]
211pub enum ResponseFormat {
212 Text,
213 #[serde(rename = "json_object")]
214 JsonObject,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(tag = "type", rename_all = "snake_case")]
219pub enum ToolDefinition {
220 Function { function: FunctionDefinition },
221}
222
223#[derive(Debug, Serialize, Deserialize)]
224pub struct FunctionDefinition {
225 pub name: String,
226 pub description: Option<String>,
227 pub parameters: Option<Value>,
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231pub struct CompletionRequest {
232 pub model: String,
233 pub prompt: String,
234 pub max_tokens: u32,
235 pub temperature: f32,
236 #[serde(default, skip_serializing_if = "Option::is_none")]
237 pub prediction: Option<Prediction>,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub rewrite_speculation: Option<bool>,
240}
241
242#[derive(Clone, Deserialize, Serialize, Debug)]
243#[serde(tag = "type", rename_all = "snake_case")]
244pub enum Prediction {
245 Content { content: String },
246}
247
248#[derive(Debug, Serialize, Deserialize)]
249#[serde(rename_all = "snake_case")]
250pub enum ToolChoice {
251 Auto,
252 Required,
253 None,
254 Any,
255 Function(ToolDefinition),
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259#[serde(tag = "role", rename_all = "lowercase")]
260pub enum RequestMessage {
261 Assistant {
262 content: Option<String>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 tool_calls: Vec<ToolCall>,
265 },
266 User {
267 #[serde(flatten)]
268 content: MessageContent,
269 },
270 System {
271 content: String,
272 },
273 Tool {
274 content: String,
275 tool_call_id: String,
276 },
277}
278
279#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
280#[serde(untagged)]
281pub enum MessageContent {
282 #[serde(rename = "content")]
283 Plain { content: String },
284 #[serde(rename = "content")]
285 Multipart { content: Vec<MessagePart> },
286}
287
288impl MessageContent {
289 pub fn empty() -> Self {
290 Self::Plain {
291 content: String::new(),
292 }
293 }
294
295 pub fn push_part(&mut self, part: MessagePart) {
296 match self {
297 Self::Plain { content } => match part {
298 MessagePart::Text { text } => {
299 content.push_str(&text);
300 }
301 part => {
302 let mut parts = if content.is_empty() {
303 Vec::new()
304 } else {
305 vec![MessagePart::Text {
306 text: content.clone(),
307 }]
308 };
309 parts.push(part);
310 *self = Self::Multipart { content: parts };
311 }
312 },
313 Self::Multipart { content } => {
314 content.push(part);
315 }
316 }
317 }
318}
319
320#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
321#[serde(tag = "type", rename_all = "snake_case")]
322pub enum MessagePart {
323 Text { text: String },
324 ImageUrl { image_url: String },
325}
326
327#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
328pub struct ToolCall {
329 pub id: String,
330 #[serde(flatten)]
331 pub content: ToolCallContent,
332}
333
334#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
335#[serde(tag = "type", rename_all = "lowercase")]
336pub enum ToolCallContent {
337 Function { function: FunctionContent },
338}
339
340#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
341pub struct FunctionContent {
342 pub name: String,
343 pub arguments: String,
344}
345
346#[derive(Serialize, Deserialize, Debug)]
347pub struct CompletionChoice {
348 pub text: String,
349}
350
351#[derive(Serialize, Deserialize, Debug)]
352pub struct Response {
353 pub id: String,
354 pub object: String,
355 pub created: u64,
356 pub model: String,
357 pub choices: Vec<Choice>,
358 pub usage: Usage,
359}
360
361#[derive(Serialize, Deserialize, Debug)]
362pub struct Usage {
363 pub prompt_tokens: u64,
364 pub completion_tokens: u64,
365 pub total_tokens: u64,
366}
367
368#[derive(Serialize, Deserialize, Debug)]
369pub struct Choice {
370 pub index: u32,
371 pub message: RequestMessage,
372 pub finish_reason: Option<String>,
373}
374
375#[derive(Serialize, Deserialize, Debug)]
376pub struct StreamResponse {
377 pub id: String,
378 pub object: String,
379 pub created: u64,
380 pub model: String,
381 pub choices: Vec<StreamChoice>,
382 pub usage: Option<Usage>,
383}
384
385#[derive(Serialize, Deserialize, Debug)]
386pub struct StreamChoice {
387 pub index: u32,
388 pub delta: StreamDelta,
389 pub finish_reason: Option<String>,
390}
391
392#[derive(Serialize, Deserialize, Debug)]
393pub struct StreamDelta {
394 pub role: Option<Role>,
395 pub content: Option<String>,
396 #[serde(default, skip_serializing_if = "Option::is_none")]
397 pub tool_calls: Option<Vec<ToolCallChunk>>,
398 #[serde(default, skip_serializing_if = "Option::is_none")]
399 pub reasoning_content: Option<String>,
400}
401
402#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
403pub struct ToolCallChunk {
404 pub index: usize,
405 pub id: Option<String>,
406 pub function: Option<FunctionChunk>,
407}
408
409#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
410pub struct FunctionChunk {
411 pub name: Option<String>,
412 pub arguments: Option<String>,
413}
414
415pub async fn stream_completion(
416 client: &dyn HttpClient,
417 api_url: &str,
418 api_key: &str,
419 request: Request,
420) -> Result<BoxStream<'static, Result<StreamResponse>>> {
421 let uri = format!("{api_url}/chat/completions");
422 let request_builder = HttpRequest::builder()
423 .method(Method::POST)
424 .uri(uri)
425 .header("Content-Type", "application/json")
426 .header("Authorization", format!("Bearer {}", api_key));
427
428 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
429 let mut response = client.send(request).await?;
430
431 if response.status().is_success() {
432 let reader = BufReader::new(response.into_body());
433 Ok(reader
434 .lines()
435 .filter_map(|line| async move {
436 match line {
437 Ok(line) => {
438 let line = line.strip_prefix("data: ")?;
439 if line == "[DONE]" {
440 None
441 } else {
442 match serde_json::from_str(line) {
443 Ok(response) => Some(Ok(response)),
444 Err(error) => Some(Err(anyhow!(error))),
445 }
446 }
447 }
448 Err(error) => Some(Err(anyhow!(error))),
449 }
450 })
451 .boxed())
452 } else {
453 let mut body = String::new();
454 response.body_mut().read_to_string(&mut body).await?;
455 anyhow::bail!(
456 "Failed to connect to Mistral API: {} {}",
457 response.status(),
458 body,
459 );
460 }
461}