1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17 Tool,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 "tool" => Ok(Self::Tool),
29 _ => Err(anyhow!("invalid role '{value}'")),
30 }
31 }
32}
33
34impl From<Role> for String {
35 fn from(val: Role) -> Self {
36 match val {
37 Role::User => "user".to_owned(),
38 Role::Assistant => "assistant".to_owned(),
39 Role::System => "system".to_owned(),
40 Role::Tool => "tool".to_owned(),
41 }
42 }
43}
44
45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum Model {
48 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
49 #[default]
50 CodestralLatest,
51 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
52 MistralLargeLatest,
53 #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
54 MistralMediumLatest,
55 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
56 MistralSmallLatest,
57 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
58 OpenMistralNemo,
59 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
60 OpenCodestralMamba,
61
62 #[serde(rename = "custom")]
63 Custom {
64 name: String,
65 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
66 display_name: Option<String>,
67 max_tokens: usize,
68 max_output_tokens: Option<u32>,
69 max_completion_tokens: Option<u32>,
70 supports_tools: Option<bool>,
71 },
72}
73
74impl Model {
75 pub fn default_fast() -> Self {
76 Model::MistralSmallLatest
77 }
78
79 pub fn from_id(id: &str) -> Result<Self> {
80 match id {
81 "codestral-latest" => Ok(Self::CodestralLatest),
82 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
83 "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
84 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
85 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
86 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
87 _ => Err(anyhow!("invalid model id")),
88 }
89 }
90
91 pub fn id(&self) -> &str {
92 match self {
93 Self::CodestralLatest => "codestral-latest",
94 Self::MistralLargeLatest => "mistral-large-latest",
95 Self::MistralMediumLatest => "mistral-medium-latest",
96 Self::MistralSmallLatest => "mistral-small-latest",
97 Self::OpenMistralNemo => "open-mistral-nemo",
98 Self::OpenCodestralMamba => "open-codestral-mamba",
99 Self::Custom { name, .. } => name,
100 }
101 }
102
103 pub fn display_name(&self) -> &str {
104 match self {
105 Self::CodestralLatest => "codestral-latest",
106 Self::MistralLargeLatest => "mistral-large-latest",
107 Self::MistralMediumLatest => "mistral-medium-latest",
108 Self::MistralSmallLatest => "mistral-small-latest",
109 Self::OpenMistralNemo => "open-mistral-nemo",
110 Self::OpenCodestralMamba => "open-codestral-mamba",
111 Self::Custom {
112 name, display_name, ..
113 } => display_name.as_ref().unwrap_or(name),
114 }
115 }
116
117 pub fn max_token_count(&self) -> usize {
118 match self {
119 Self::CodestralLatest => 256000,
120 Self::MistralLargeLatest => 131000,
121 Self::MistralMediumLatest => 128000,
122 Self::MistralSmallLatest => 32000,
123 Self::OpenMistralNemo => 131000,
124 Self::OpenCodestralMamba => 256000,
125 Self::Custom { max_tokens, .. } => *max_tokens,
126 }
127 }
128
129 pub fn max_output_tokens(&self) -> Option<u32> {
130 match self {
131 Self::Custom {
132 max_output_tokens, ..
133 } => *max_output_tokens,
134 _ => None,
135 }
136 }
137
138 pub fn supports_tools(&self) -> bool {
139 match self {
140 Self::CodestralLatest
141 | Self::MistralLargeLatest
142 | Self::MistralMediumLatest
143 | Self::MistralSmallLatest
144 | Self::OpenMistralNemo
145 | Self::OpenCodestralMamba => true,
146 Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
147 }
148 }
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152pub struct Request {
153 pub model: String,
154 pub messages: Vec<RequestMessage>,
155 pub stream: bool,
156 #[serde(default, skip_serializing_if = "Option::is_none")]
157 pub max_tokens: Option<u32>,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub temperature: Option<f32>,
160 #[serde(default, skip_serializing_if = "Option::is_none")]
161 pub response_format: Option<ResponseFormat>,
162 #[serde(default, skip_serializing_if = "Option::is_none")]
163 pub tool_choice: Option<ToolChoice>,
164 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub parallel_tool_calls: Option<bool>,
166 #[serde(default, skip_serializing_if = "Vec::is_empty")]
167 pub tools: Vec<ToolDefinition>,
168}
169
170#[derive(Debug, Serialize, Deserialize)]
171#[serde(rename_all = "snake_case")]
172pub enum ResponseFormat {
173 Text,
174 #[serde(rename = "json_object")]
175 JsonObject,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179#[serde(tag = "type", rename_all = "snake_case")]
180pub enum ToolDefinition {
181 Function { function: FunctionDefinition },
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185pub struct FunctionDefinition {
186 pub name: String,
187 pub description: Option<String>,
188 pub parameters: Option<Value>,
189}
190
191#[derive(Debug, Serialize, Deserialize)]
192pub struct CompletionRequest {
193 pub model: String,
194 pub prompt: String,
195 pub max_tokens: u32,
196 pub temperature: f32,
197 #[serde(default, skip_serializing_if = "Option::is_none")]
198 pub prediction: Option<Prediction>,
199 #[serde(default, skip_serializing_if = "Option::is_none")]
200 pub rewrite_speculation: Option<bool>,
201}
202
203#[derive(Clone, Deserialize, Serialize, Debug)]
204#[serde(tag = "type", rename_all = "snake_case")]
205pub enum Prediction {
206 Content { content: String },
207}
208
209#[derive(Debug, Serialize, Deserialize)]
210#[serde(rename_all = "snake_case")]
211pub enum ToolChoice {
212 Auto,
213 Required,
214 None,
215 Any,
216 Function(ToolDefinition),
217}
218
219#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
220#[serde(tag = "role", rename_all = "lowercase")]
221pub enum RequestMessage {
222 Assistant {
223 content: Option<String>,
224 #[serde(default, skip_serializing_if = "Vec::is_empty")]
225 tool_calls: Vec<ToolCall>,
226 },
227 User {
228 content: String,
229 },
230 System {
231 content: String,
232 },
233 Tool {
234 content: String,
235 tool_call_id: String,
236 },
237}
238
239#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
240pub struct ToolCall {
241 pub id: String,
242 #[serde(flatten)]
243 pub content: ToolCallContent,
244}
245
246#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
247#[serde(tag = "type", rename_all = "lowercase")]
248pub enum ToolCallContent {
249 Function { function: FunctionContent },
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct FunctionContent {
254 pub name: String,
255 pub arguments: String,
256}
257
258#[derive(Serialize, Deserialize, Debug)]
259pub struct CompletionChoice {
260 pub text: String,
261}
262
263#[derive(Serialize, Deserialize, Debug)]
264pub struct Response {
265 pub id: String,
266 pub object: String,
267 pub created: u64,
268 pub model: String,
269 pub choices: Vec<Choice>,
270 pub usage: Usage,
271}
272
273#[derive(Serialize, Deserialize, Debug)]
274pub struct Usage {
275 pub prompt_tokens: u32,
276 pub completion_tokens: u32,
277 pub total_tokens: u32,
278}
279
280#[derive(Serialize, Deserialize, Debug)]
281pub struct Choice {
282 pub index: u32,
283 pub message: RequestMessage,
284 pub finish_reason: Option<String>,
285}
286
287#[derive(Serialize, Deserialize, Debug)]
288pub struct StreamResponse {
289 pub id: String,
290 pub object: String,
291 pub created: u64,
292 pub model: String,
293 pub choices: Vec<StreamChoice>,
294}
295
296#[derive(Serialize, Deserialize, Debug)]
297pub struct StreamChoice {
298 pub index: u32,
299 pub delta: StreamDelta,
300 pub finish_reason: Option<String>,
301}
302
303#[derive(Serialize, Deserialize, Debug)]
304pub struct StreamDelta {
305 pub role: Option<Role>,
306 pub content: Option<String>,
307 #[serde(default, skip_serializing_if = "Option::is_none")]
308 pub tool_calls: Option<Vec<ToolCallChunk>>,
309 #[serde(default, skip_serializing_if = "Option::is_none")]
310 pub reasoning_content: Option<String>,
311}
312
313#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
314pub struct ToolCallChunk {
315 pub index: usize,
316 pub id: Option<String>,
317 pub function: Option<FunctionChunk>,
318}
319
320#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
321pub struct FunctionChunk {
322 pub name: Option<String>,
323 pub arguments: Option<String>,
324}
325
326pub async fn stream_completion(
327 client: &dyn HttpClient,
328 api_url: &str,
329 api_key: &str,
330 request: Request,
331) -> Result<BoxStream<'static, Result<StreamResponse>>> {
332 let uri = format!("{api_url}/chat/completions");
333 let request_builder = HttpRequest::builder()
334 .method(Method::POST)
335 .uri(uri)
336 .header("Content-Type", "application/json")
337 .header("Authorization", format!("Bearer {}", api_key));
338
339 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
340 let mut response = client.send(request).await?;
341
342 if response.status().is_success() {
343 let reader = BufReader::new(response.into_body());
344 Ok(reader
345 .lines()
346 .filter_map(|line| async move {
347 match line {
348 Ok(line) => {
349 let line = line.strip_prefix("data: ")?;
350 if line == "[DONE]" {
351 None
352 } else {
353 match serde_json::from_str(line) {
354 Ok(response) => Some(Ok(response)),
355 Err(error) => Some(Err(anyhow!(error))),
356 }
357 }
358 }
359 Err(error) => Some(Err(anyhow!(error))),
360 }
361 })
362 .boxed())
363 } else {
364 let mut body = String::new();
365 response.body_mut().read_to_string(&mut body).await?;
366 Err(anyhow!(
367 "Failed to connect to Mistral API: {} {}",
368 response.status(),
369 body,
370 ))
371 }
372}