1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17 Tool,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 "tool" => Ok(Self::Tool),
29 _ => Err(anyhow!("invalid role '{value}'")),
30 }
31 }
32}
33
34impl From<Role> for String {
35 fn from(val: Role) -> Self {
36 match val {
37 Role::User => "user".to_owned(),
38 Role::Assistant => "assistant".to_owned(),
39 Role::System => "system".to_owned(),
40 Role::Tool => "tool".to_owned(),
41 }
42 }
43}
44
45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum Model {
48 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
49 #[default]
50 CodestralLatest,
51 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
52 MistralLargeLatest,
53 #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-latest")]
54 MistralMediumLatest,
55 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
56 MistralSmallLatest,
57 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
58 OpenMistralNemo,
59 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
60 OpenCodestralMamba,
61
62 #[serde(rename = "custom")]
63 Custom {
64 name: String,
65 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
66 display_name: Option<String>,
67 max_tokens: usize,
68 max_output_tokens: Option<u32>,
69 max_completion_tokens: Option<u32>,
70 },
71}
72
73impl Model {
74 pub fn default_fast() -> Self {
75 Model::MistralSmallLatest
76 }
77
78 pub fn from_id(id: &str) -> Result<Self> {
79 match id {
80 "codestral-latest" => Ok(Self::CodestralLatest),
81 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
82 "mistral-medium-latest" => Ok(Self::MistralMediumLatest),
83 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
84 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
85 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
86 _ => Err(anyhow!("invalid model id")),
87 }
88 }
89
90 pub fn id(&self) -> &str {
91 match self {
92 Self::CodestralLatest => "codestral-latest",
93 Self::MistralLargeLatest => "mistral-large-latest",
94 Self::MistralMediumLatest => "mistral-medium-latest",
95 Self::MistralSmallLatest => "mistral-small-latest",
96 Self::OpenMistralNemo => "open-mistral-nemo",
97 Self::OpenCodestralMamba => "open-codestral-mamba",
98 Self::Custom { name, .. } => name,
99 }
100 }
101
102 pub fn display_name(&self) -> &str {
103 match self {
104 Self::CodestralLatest => "codestral-latest",
105 Self::MistralLargeLatest => "mistral-large-latest",
106 Self::MistralMediumLatest => "mistral-medium-latest",
107 Self::MistralSmallLatest => "mistral-small-latest",
108 Self::OpenMistralNemo => "open-mistral-nemo",
109 Self::OpenCodestralMamba => "open-codestral-mamba",
110 Self::Custom {
111 name, display_name, ..
112 } => display_name.as_ref().unwrap_or(name),
113 }
114 }
115
116 pub fn max_token_count(&self) -> usize {
117 match self {
118 Self::CodestralLatest => 256000,
119 Self::MistralLargeLatest => 131000,
120 Self::MistralMediumLatest => 128000,
121 Self::MistralSmallLatest => 32000,
122 Self::OpenMistralNemo => 131000,
123 Self::OpenCodestralMamba => 256000,
124 Self::Custom { max_tokens, .. } => *max_tokens,
125 }
126 }
127
128 pub fn max_output_tokens(&self) -> Option<u32> {
129 match self {
130 Self::Custom {
131 max_output_tokens, ..
132 } => *max_output_tokens,
133 _ => None,
134 }
135 }
136}
137
138#[derive(Debug, Serialize, Deserialize)]
139pub struct Request {
140 pub model: String,
141 pub messages: Vec<RequestMessage>,
142 pub stream: bool,
143 #[serde(default, skip_serializing_if = "Option::is_none")]
144 pub max_tokens: Option<u32>,
145 #[serde(default, skip_serializing_if = "Option::is_none")]
146 pub temperature: Option<f32>,
147 #[serde(default, skip_serializing_if = "Option::is_none")]
148 pub response_format: Option<ResponseFormat>,
149 #[serde(default, skip_serializing_if = "Vec::is_empty")]
150 pub tools: Vec<ToolDefinition>,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155pub enum ResponseFormat {
156 Text,
157 #[serde(rename = "json_object")]
158 JsonObject,
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162#[serde(tag = "type", rename_all = "snake_case")]
163pub enum ToolDefinition {
164 Function { function: FunctionDefinition },
165}
166
167#[derive(Debug, Serialize, Deserialize)]
168pub struct FunctionDefinition {
169 pub name: String,
170 pub description: Option<String>,
171 pub parameters: Option<Value>,
172}
173
174#[derive(Debug, Serialize, Deserialize)]
175pub struct CompletionRequest {
176 pub model: String,
177 pub prompt: String,
178 pub max_tokens: u32,
179 pub temperature: f32,
180 #[serde(default, skip_serializing_if = "Option::is_none")]
181 pub prediction: Option<Prediction>,
182 #[serde(default, skip_serializing_if = "Option::is_none")]
183 pub rewrite_speculation: Option<bool>,
184}
185
186#[derive(Clone, Deserialize, Serialize, Debug)]
187#[serde(tag = "type", rename_all = "snake_case")]
188pub enum Prediction {
189 Content { content: String },
190}
191
192#[derive(Debug, Serialize, Deserialize)]
193#[serde(untagged)]
194pub enum ToolChoice {
195 Auto,
196 Required,
197 None,
198 Other(ToolDefinition),
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202#[serde(tag = "role", rename_all = "lowercase")]
203pub enum RequestMessage {
204 Assistant {
205 content: Option<String>,
206 #[serde(default, skip_serializing_if = "Vec::is_empty")]
207 tool_calls: Vec<ToolCall>,
208 },
209 User {
210 content: String,
211 },
212 System {
213 content: String,
214 },
215 Tool {
216 content: String,
217 tool_call_id: String,
218 },
219}
220
221#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
222pub struct ToolCall {
223 pub id: String,
224 #[serde(flatten)]
225 pub content: ToolCallContent,
226}
227
228#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
229#[serde(tag = "type", rename_all = "lowercase")]
230pub enum ToolCallContent {
231 Function { function: FunctionContent },
232}
233
234#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
235pub struct FunctionContent {
236 pub name: String,
237 pub arguments: String,
238}
239
240#[derive(Serialize, Deserialize, Debug)]
241pub struct CompletionChoice {
242 pub text: String,
243}
244
245#[derive(Serialize, Deserialize, Debug)]
246pub struct Response {
247 pub id: String,
248 pub object: String,
249 pub created: u64,
250 pub model: String,
251 pub choices: Vec<Choice>,
252 pub usage: Usage,
253}
254
255#[derive(Serialize, Deserialize, Debug)]
256pub struct Usage {
257 pub prompt_tokens: u32,
258 pub completion_tokens: u32,
259 pub total_tokens: u32,
260}
261
262#[derive(Serialize, Deserialize, Debug)]
263pub struct Choice {
264 pub index: u32,
265 pub message: RequestMessage,
266 pub finish_reason: Option<String>,
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct StreamResponse {
271 pub id: String,
272 pub object: String,
273 pub created: u64,
274 pub model: String,
275 pub choices: Vec<StreamChoice>,
276}
277
278#[derive(Serialize, Deserialize, Debug)]
279pub struct StreamChoice {
280 pub index: u32,
281 pub delta: StreamDelta,
282 pub finish_reason: Option<String>,
283}
284
285#[derive(Serialize, Deserialize, Debug)]
286pub struct StreamDelta {
287 pub role: Option<Role>,
288 pub content: Option<String>,
289 #[serde(default, skip_serializing_if = "Option::is_none")]
290 pub tool_calls: Option<Vec<ToolCallChunk>>,
291 #[serde(default, skip_serializing_if = "Option::is_none")]
292 pub reasoning_content: Option<String>,
293}
294
295#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
296pub struct ToolCallChunk {
297 pub index: usize,
298 pub id: Option<String>,
299 pub function: Option<FunctionChunk>,
300}
301
302#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
303pub struct FunctionChunk {
304 pub name: Option<String>,
305 pub arguments: Option<String>,
306}
307
308pub async fn stream_completion(
309 client: &dyn HttpClient,
310 api_url: &str,
311 api_key: &str,
312 request: Request,
313) -> Result<BoxStream<'static, Result<StreamResponse>>> {
314 let uri = format!("{api_url}/chat/completions");
315 let request_builder = HttpRequest::builder()
316 .method(Method::POST)
317 .uri(uri)
318 .header("Content-Type", "application/json")
319 .header("Authorization", format!("Bearer {}", api_key));
320
321 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
322 let mut response = client.send(request).await?;
323
324 if response.status().is_success() {
325 let reader = BufReader::new(response.into_body());
326 Ok(reader
327 .lines()
328 .filter_map(|line| async move {
329 match line {
330 Ok(line) => {
331 let line = line.strip_prefix("data: ")?;
332 if line == "[DONE]" {
333 None
334 } else {
335 match serde_json::from_str(line) {
336 Ok(response) => Some(Ok(response)),
337 Err(error) => Some(Err(anyhow!(error))),
338 }
339 }
340 }
341 Err(error) => Some(Err(anyhow!(error))),
342 }
343 })
344 .boxed())
345 } else {
346 let mut body = String::new();
347 response.body_mut().read_to_string(&mut body).await?;
348 Err(anyhow!(
349 "Failed to connect to Mistral API: {} {}",
350 response.status(),
351 body,
352 ))
353 }
354}