1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17 Tool,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 "tool" => Ok(Self::Tool),
29 _ => anyhow::bail!("invalid role '{value}'"),
30 }
31 }
32}
33
34impl From<Role> for String {
35 fn from(val: Role) -> Self {
36 match val {
37 Role::User => "user".to_owned(),
38 Role::Assistant => "assistant".to_owned(),
39 Role::System => "system".to_owned(),
40 Role::Tool => "tool".to_owned(),
41 }
42 }
43}
44
45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum Model {
48 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
49 #[default]
50 CodestralLatest,
51 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
52 MistralLargeLatest,
53 #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
54 MistralMediumLatest,
55 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
56 MistralSmallLatest,
57 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
58 OpenMistralNemo,
59 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
60 OpenCodestralMamba,
61 #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
62 DevstralSmallLatest,
63
64 #[serde(rename = "custom")]
65 Custom {
66 name: String,
67 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
68 display_name: Option<String>,
69 max_tokens: usize,
70 max_output_tokens: Option<u32>,
71 max_completion_tokens: Option<u32>,
72 supports_tools: Option<bool>,
73 },
74}
75
76impl Model {
77 pub fn default_fast() -> Self {
78 Model::MistralSmallLatest
79 }
80
81 pub fn from_id(id: &str) -> Result<Self> {
82 match id {
83 "codestral-latest" => Ok(Self::CodestralLatest),
84 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
85 "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
86 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
87 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
88 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
89 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
90 }
91 }
92
93 pub fn id(&self) -> &str {
94 match self {
95 Self::CodestralLatest => "codestral-latest",
96 Self::MistralLargeLatest => "mistral-large-latest",
97 Self::MistralMediumLatest => "mistral-medium-latest",
98 Self::MistralSmallLatest => "mistral-small-latest",
99 Self::OpenMistralNemo => "open-mistral-nemo",
100 Self::OpenCodestralMamba => "open-codestral-mamba",
101 Self::DevstralSmallLatest => "devstral-small-latest",
102 Self::Custom { name, .. } => name,
103 }
104 }
105
106 pub fn display_name(&self) -> &str {
107 match self {
108 Self::CodestralLatest => "codestral-latest",
109 Self::MistralLargeLatest => "mistral-large-latest",
110 Self::MistralMediumLatest => "mistral-medium-latest",
111 Self::MistralSmallLatest => "mistral-small-latest",
112 Self::OpenMistralNemo => "open-mistral-nemo",
113 Self::OpenCodestralMamba => "open-codestral-mamba",
114 Self::DevstralSmallLatest => "devstral-small-latest",
115 Self::Custom {
116 name, display_name, ..
117 } => display_name.as_ref().unwrap_or(name),
118 }
119 }
120
121 pub fn max_token_count(&self) -> usize {
122 match self {
123 Self::CodestralLatest => 256000,
124 Self::MistralLargeLatest => 131000,
125 Self::MistralMediumLatest => 128000,
126 Self::MistralSmallLatest => 32000,
127 Self::OpenMistralNemo => 131000,
128 Self::OpenCodestralMamba => 256000,
129 Self::DevstralSmallLatest => 262144,
130 Self::Custom { max_tokens, .. } => *max_tokens,
131 }
132 }
133
134 pub fn max_output_tokens(&self) -> Option<u32> {
135 match self {
136 Self::Custom {
137 max_output_tokens, ..
138 } => *max_output_tokens,
139 _ => None,
140 }
141 }
142
143 pub fn supports_tools(&self) -> bool {
144 match self {
145 Self::CodestralLatest
146 | Self::MistralLargeLatest
147 | Self::MistralMediumLatest
148 | Self::MistralSmallLatest
149 | Self::OpenMistralNemo
150 | Self::OpenCodestralMamba
151 | Self::DevstralSmallLatest => true,
152 Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
153 }
154 }
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158pub struct Request {
159 pub model: String,
160 pub messages: Vec<RequestMessage>,
161 pub stream: bool,
162 #[serde(default, skip_serializing_if = "Option::is_none")]
163 pub max_tokens: Option<u32>,
164 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub temperature: Option<f32>,
166 #[serde(default, skip_serializing_if = "Option::is_none")]
167 pub response_format: Option<ResponseFormat>,
168 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub tool_choice: Option<ToolChoice>,
170 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub parallel_tool_calls: Option<bool>,
172 #[serde(default, skip_serializing_if = "Vec::is_empty")]
173 pub tools: Vec<ToolDefinition>,
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177#[serde(rename_all = "snake_case")]
178pub enum ResponseFormat {
179 Text,
180 #[serde(rename = "json_object")]
181 JsonObject,
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum ToolDefinition {
187 Function { function: FunctionDefinition },
188}
189
190#[derive(Debug, Serialize, Deserialize)]
191pub struct FunctionDefinition {
192 pub name: String,
193 pub description: Option<String>,
194 pub parameters: Option<Value>,
195}
196
197#[derive(Debug, Serialize, Deserialize)]
198pub struct CompletionRequest {
199 pub model: String,
200 pub prompt: String,
201 pub max_tokens: u32,
202 pub temperature: f32,
203 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub prediction: Option<Prediction>,
205 #[serde(default, skip_serializing_if = "Option::is_none")]
206 pub rewrite_speculation: Option<bool>,
207}
208
209#[derive(Clone, Deserialize, Serialize, Debug)]
210#[serde(tag = "type", rename_all = "snake_case")]
211pub enum Prediction {
212 Content { content: String },
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216#[serde(rename_all = "snake_case")]
217pub enum ToolChoice {
218 Auto,
219 Required,
220 None,
221 Any,
222 Function(ToolDefinition),
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226#[serde(tag = "role", rename_all = "lowercase")]
227pub enum RequestMessage {
228 Assistant {
229 content: Option<String>,
230 #[serde(default, skip_serializing_if = "Vec::is_empty")]
231 tool_calls: Vec<ToolCall>,
232 },
233 User {
234 content: String,
235 },
236 System {
237 content: String,
238 },
239 Tool {
240 content: String,
241 tool_call_id: String,
242 },
243}
244
245#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
246pub struct ToolCall {
247 pub id: String,
248 #[serde(flatten)]
249 pub content: ToolCallContent,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253#[serde(tag = "type", rename_all = "lowercase")]
254pub enum ToolCallContent {
255 Function { function: FunctionContent },
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259pub struct FunctionContent {
260 pub name: String,
261 pub arguments: String,
262}
263
264#[derive(Serialize, Deserialize, Debug)]
265pub struct CompletionChoice {
266 pub text: String,
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct Response {
271 pub id: String,
272 pub object: String,
273 pub created: u64,
274 pub model: String,
275 pub choices: Vec<Choice>,
276 pub usage: Usage,
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct Usage {
281 pub prompt_tokens: u32,
282 pub completion_tokens: u32,
283 pub total_tokens: u32,
284}
285
286#[derive(Serialize, Deserialize, Debug)]
287pub struct Choice {
288 pub index: u32,
289 pub message: RequestMessage,
290 pub finish_reason: Option<String>,
291}
292
293#[derive(Serialize, Deserialize, Debug)]
294pub struct StreamResponse {
295 pub id: String,
296 pub object: String,
297 pub created: u64,
298 pub model: String,
299 pub choices: Vec<StreamChoice>,
300}
301
302#[derive(Serialize, Deserialize, Debug)]
303pub struct StreamChoice {
304 pub index: u32,
305 pub delta: StreamDelta,
306 pub finish_reason: Option<String>,
307}
308
309#[derive(Serialize, Deserialize, Debug)]
310pub struct StreamDelta {
311 pub role: Option<Role>,
312 pub content: Option<String>,
313 #[serde(default, skip_serializing_if = "Option::is_none")]
314 pub tool_calls: Option<Vec<ToolCallChunk>>,
315 #[serde(default, skip_serializing_if = "Option::is_none")]
316 pub reasoning_content: Option<String>,
317}
318
319#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
320pub struct ToolCallChunk {
321 pub index: usize,
322 pub id: Option<String>,
323 pub function: Option<FunctionChunk>,
324}
325
326#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
327pub struct FunctionChunk {
328 pub name: Option<String>,
329 pub arguments: Option<String>,
330}
331
332pub async fn stream_completion(
333 client: &dyn HttpClient,
334 api_url: &str,
335 api_key: &str,
336 request: Request,
337) -> Result<BoxStream<'static, Result<StreamResponse>>> {
338 let uri = format!("{api_url}/chat/completions");
339 let request_builder = HttpRequest::builder()
340 .method(Method::POST)
341 .uri(uri)
342 .header("Content-Type", "application/json")
343 .header("Authorization", format!("Bearer {}", api_key));
344
345 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
346 let mut response = client.send(request).await?;
347
348 if response.status().is_success() {
349 let reader = BufReader::new(response.into_body());
350 Ok(reader
351 .lines()
352 .filter_map(|line| async move {
353 match line {
354 Ok(line) => {
355 let line = line.strip_prefix("data: ")?;
356 if line == "[DONE]" {
357 None
358 } else {
359 match serde_json::from_str(line) {
360 Ok(response) => Some(Ok(response)),
361 Err(error) => Some(Err(anyhow!(error))),
362 }
363 }
364 }
365 Err(error) => Some(Err(anyhow!(error))),
366 }
367 })
368 .boxed())
369 } else {
370 let mut body = String::new();
371 response.body_mut().read_to_string(&mut body).await?;
372 anyhow::bail!(
373 "Failed to connect to Mistral API: {} {}",
374 response.status(),
375 body,
376 );
377 }
378}