1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17 Tool,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 "tool" => Ok(Self::Tool),
29 _ => anyhow::bail!("invalid role '{value}'"),
30 }
31 }
32}
33
34impl From<Role> for String {
35 fn from(val: Role) -> Self {
36 match val {
37 Role::User => "user".to_owned(),
38 Role::Assistant => "assistant".to_owned(),
39 Role::System => "system".to_owned(),
40 Role::Tool => "tool".to_owned(),
41 }
42 }
43}
44
45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum Model {
48 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
49 #[default]
50 CodestralLatest,
51 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
52 MistralLargeLatest,
53 #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
54 MistralMediumLatest,
55 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
56 MistralSmallLatest,
57 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
58 OpenMistralNemo,
59 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
60 OpenCodestralMamba,
61 #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
62 DevstralSmallLatest,
63 #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
64 Pixtral12BLatest,
65 #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
66 PixtralLargeLatest,
67
68 #[serde(rename = "custom")]
69 Custom {
70 name: String,
71 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
72 display_name: Option<String>,
73 max_tokens: usize,
74 max_output_tokens: Option<u32>,
75 max_completion_tokens: Option<u32>,
76 supports_tools: Option<bool>,
77 supports_images: Option<bool>,
78 },
79}
80
81impl Model {
82 pub fn default_fast() -> Self {
83 Model::MistralSmallLatest
84 }
85
86 pub fn from_id(id: &str) -> Result<Self> {
87 match id {
88 "codestral-latest" => Ok(Self::CodestralLatest),
89 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
90 "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
91 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
92 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
93 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
94 "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
95 "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
96 "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
97 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
98 }
99 }
100
101 pub fn id(&self) -> &str {
102 match self {
103 Self::CodestralLatest => "codestral-latest",
104 Self::MistralLargeLatest => "mistral-large-latest",
105 Self::MistralMediumLatest => "mistral-medium-latest",
106 Self::MistralSmallLatest => "mistral-small-latest",
107 Self::OpenMistralNemo => "open-mistral-nemo",
108 Self::OpenCodestralMamba => "open-codestral-mamba",
109 Self::DevstralSmallLatest => "devstral-small-latest",
110 Self::Pixtral12BLatest => "pixtral-12b-latest",
111 Self::PixtralLargeLatest => "pixtral-large-latest",
112 Self::Custom { name, .. } => name,
113 }
114 }
115
116 pub fn display_name(&self) -> &str {
117 match self {
118 Self::CodestralLatest => "codestral-latest",
119 Self::MistralLargeLatest => "mistral-large-latest",
120 Self::MistralMediumLatest => "mistral-medium-latest",
121 Self::MistralSmallLatest => "mistral-small-latest",
122 Self::OpenMistralNemo => "open-mistral-nemo",
123 Self::OpenCodestralMamba => "open-codestral-mamba",
124 Self::DevstralSmallLatest => "devstral-small-latest",
125 Self::Pixtral12BLatest => "pixtral-12b-latest",
126 Self::PixtralLargeLatest => "pixtral-large-latest",
127 Self::Custom {
128 name, display_name, ..
129 } => display_name.as_ref().unwrap_or(name),
130 }
131 }
132
133 pub fn max_token_count(&self) -> usize {
134 match self {
135 Self::CodestralLatest => 256000,
136 Self::MistralLargeLatest => 131000,
137 Self::MistralMediumLatest => 128000,
138 Self::MistralSmallLatest => 32000,
139 Self::OpenMistralNemo => 131000,
140 Self::OpenCodestralMamba => 256000,
141 Self::DevstralSmallLatest => 262144,
142 Self::Pixtral12BLatest => 128000,
143 Self::PixtralLargeLatest => 128000,
144 Self::Custom { max_tokens, .. } => *max_tokens,
145 }
146 }
147
148 pub fn max_output_tokens(&self) -> Option<u32> {
149 match self {
150 Self::Custom {
151 max_output_tokens, ..
152 } => *max_output_tokens,
153 _ => None,
154 }
155 }
156
157 pub fn supports_tools(&self) -> bool {
158 match self {
159 Self::CodestralLatest
160 | Self::MistralLargeLatest
161 | Self::MistralMediumLatest
162 | Self::MistralSmallLatest
163 | Self::OpenMistralNemo
164 | Self::OpenCodestralMamba
165 | Self::DevstralSmallLatest
166 | Self::Pixtral12BLatest
167 | Self::PixtralLargeLatest => true,
168 Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
169 }
170 }
171
172 pub fn supports_images(&self) -> bool {
173 match self {
174 Self::Pixtral12BLatest
175 | Self::PixtralLargeLatest
176 | Self::MistralMediumLatest
177 | Self::MistralSmallLatest => true,
178 Self::CodestralLatest
179 | Self::MistralLargeLatest
180 | Self::OpenMistralNemo
181 | Self::OpenCodestralMamba
182 | Self::DevstralSmallLatest => false,
183 Self::Custom {
184 supports_images, ..
185 } => supports_images.unwrap_or(false),
186 }
187 }
188}
189
190#[derive(Debug, Serialize, Deserialize)]
191pub struct Request {
192 pub model: String,
193 pub messages: Vec<RequestMessage>,
194 pub stream: bool,
195 #[serde(default, skip_serializing_if = "Option::is_none")]
196 pub max_tokens: Option<u32>,
197 #[serde(default, skip_serializing_if = "Option::is_none")]
198 pub temperature: Option<f32>,
199 #[serde(default, skip_serializing_if = "Option::is_none")]
200 pub response_format: Option<ResponseFormat>,
201 #[serde(default, skip_serializing_if = "Option::is_none")]
202 pub tool_choice: Option<ToolChoice>,
203 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub parallel_tool_calls: Option<bool>,
205 #[serde(default, skip_serializing_if = "Vec::is_empty")]
206 pub tools: Vec<ToolDefinition>,
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "snake_case")]
211pub enum ResponseFormat {
212 Text,
213 #[serde(rename = "json_object")]
214 JsonObject,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(tag = "type", rename_all = "snake_case")]
219pub enum ToolDefinition {
220 Function { function: FunctionDefinition },
221}
222
223#[derive(Debug, Serialize, Deserialize)]
224pub struct FunctionDefinition {
225 pub name: String,
226 pub description: Option<String>,
227 pub parameters: Option<Value>,
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231pub struct CompletionRequest {
232 pub model: String,
233 pub prompt: String,
234 pub max_tokens: u32,
235 pub temperature: f32,
236 #[serde(default, skip_serializing_if = "Option::is_none")]
237 pub prediction: Option<Prediction>,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub rewrite_speculation: Option<bool>,
240}
241
242#[derive(Clone, Deserialize, Serialize, Debug)]
243#[serde(tag = "type", rename_all = "snake_case")]
244pub enum Prediction {
245 Content { content: String },
246}
247
248#[derive(Debug, Serialize, Deserialize)]
249#[serde(rename_all = "snake_case")]
250pub enum ToolChoice {
251 Auto,
252 Required,
253 None,
254 Any,
255 Function(ToolDefinition),
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259#[serde(tag = "role", rename_all = "lowercase")]
260pub enum RequestMessage {
261 Assistant {
262 content: Option<String>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 tool_calls: Vec<ToolCall>,
265 },
266 User {
267 #[serde(flatten)]
268 content: MessageContent,
269 },
270 System {
271 content: String,
272 },
273 Tool {
274 content: String,
275 tool_call_id: String,
276 },
277}
278
279#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
280#[serde(untagged)]
281pub enum MessageContent {
282 #[serde(rename = "content")]
283 Plain { content: String },
284 #[serde(rename = "content")]
285 Multipart { content: Vec<MessagePart> },
286}
287
288impl MessageContent {
289 pub fn empty() -> Self {
290 Self::Plain {
291 content: String::new(),
292 }
293 }
294
295 pub fn push_part(&mut self, part: MessagePart) {
296 match self {
297 Self::Plain { content } => match part {
298 MessagePart::Text { text } => {
299 content.push_str(&text);
300 }
301 part => {
302 let mut parts = if content.is_empty() {
303 Vec::new()
304 } else {
305 vec![MessagePart::Text {
306 text: content.clone(),
307 }]
308 };
309 parts.push(part);
310 *self = Self::Multipart { content: parts };
311 }
312 },
313 Self::Multipart { content } => {
314 content.push(part);
315 }
316 }
317 }
318}
319
320#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
321#[serde(tag = "type", rename_all = "snake_case")]
322pub enum MessagePart {
323 Text { text: String },
324 ImageUrl { image_url: String },
325}
326
327#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
328pub struct ToolCall {
329 pub id: String,
330 #[serde(flatten)]
331 pub content: ToolCallContent,
332}
333
334#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
335#[serde(tag = "type", rename_all = "lowercase")]
336pub enum ToolCallContent {
337 Function { function: FunctionContent },
338}
339
340#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
341pub struct FunctionContent {
342 pub name: String,
343 pub arguments: String,
344}
345
346#[derive(Serialize, Deserialize, Debug)]
347pub struct CompletionChoice {
348 pub text: String,
349}
350
351#[derive(Serialize, Deserialize, Debug)]
352pub struct Response {
353 pub id: String,
354 pub object: String,
355 pub created: u64,
356 pub model: String,
357 pub choices: Vec<Choice>,
358 pub usage: Usage,
359}
360
361#[derive(Serialize, Deserialize, Debug)]
362pub struct Usage {
363 pub prompt_tokens: u32,
364 pub completion_tokens: u32,
365 pub total_tokens: u32,
366}
367
368#[derive(Serialize, Deserialize, Debug)]
369pub struct Choice {
370 pub index: u32,
371 pub message: RequestMessage,
372 pub finish_reason: Option<String>,
373}
374
375#[derive(Serialize, Deserialize, Debug)]
376pub struct StreamResponse {
377 pub id: String,
378 pub object: String,
379 pub created: u64,
380 pub model: String,
381 pub choices: Vec<StreamChoice>,
382}
383
384#[derive(Serialize, Deserialize, Debug)]
385pub struct StreamChoice {
386 pub index: u32,
387 pub delta: StreamDelta,
388 pub finish_reason: Option<String>,
389}
390
391#[derive(Serialize, Deserialize, Debug)]
392pub struct StreamDelta {
393 pub role: Option<Role>,
394 pub content: Option<String>,
395 #[serde(default, skip_serializing_if = "Option::is_none")]
396 pub tool_calls: Option<Vec<ToolCallChunk>>,
397 #[serde(default, skip_serializing_if = "Option::is_none")]
398 pub reasoning_content: Option<String>,
399}
400
401#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
402pub struct ToolCallChunk {
403 pub index: usize,
404 pub id: Option<String>,
405 pub function: Option<FunctionChunk>,
406}
407
408#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
409pub struct FunctionChunk {
410 pub name: Option<String>,
411 pub arguments: Option<String>,
412}
413
414pub async fn stream_completion(
415 client: &dyn HttpClient,
416 api_url: &str,
417 api_key: &str,
418 request: Request,
419) -> Result<BoxStream<'static, Result<StreamResponse>>> {
420 let uri = format!("{api_url}/chat/completions");
421 let request_builder = HttpRequest::builder()
422 .method(Method::POST)
423 .uri(uri)
424 .header("Content-Type", "application/json")
425 .header("Authorization", format!("Bearer {}", api_key));
426
427 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
428 let mut response = client.send(request).await?;
429
430 if response.status().is_success() {
431 let reader = BufReader::new(response.into_body());
432 Ok(reader
433 .lines()
434 .filter_map(|line| async move {
435 match line {
436 Ok(line) => {
437 let line = line.strip_prefix("data: ")?;
438 if line == "[DONE]" {
439 None
440 } else {
441 match serde_json::from_str(line) {
442 Ok(response) => Some(Ok(response)),
443 Err(error) => Some(Err(anyhow!(error))),
444 }
445 }
446 }
447 Err(error) => Some(Err(anyhow!(error))),
448 }
449 })
450 .boxed())
451 } else {
452 let mut body = String::new();
453 response.body_mut().read_to_string(&mut body).await?;
454 anyhow::bail!(
455 "Failed to connect to Mistral API: {} {}",
456 response.status(),
457 body,
458 );
459 }
460}