1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17 Tool,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 "tool" => Ok(Self::Tool),
29 _ => Err(anyhow!("invalid role '{value}'")),
30 }
31 }
32}
33
34impl From<Role> for String {
35 fn from(val: Role) -> Self {
36 match val {
37 Role::User => "user".to_owned(),
38 Role::Assistant => "assistant".to_owned(),
39 Role::System => "system".to_owned(),
40 Role::Tool => "tool".to_owned(),
41 }
42 }
43}
44
45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum Model {
48 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
49 #[default]
50 CodestralLatest,
51 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
52 MistralLargeLatest,
53 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
54 MistralSmallLatest,
55 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
56 OpenMistralNemo,
57 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
58 OpenCodestralMamba,
59
60 #[serde(rename = "custom")]
61 Custom {
62 name: String,
63 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
64 display_name: Option<String>,
65 max_tokens: usize,
66 max_output_tokens: Option<u32>,
67 max_completion_tokens: Option<u32>,
68 },
69}
70
71impl Model {
72 pub fn default_fast() -> Self {
73 Model::MistralSmallLatest
74 }
75
76 pub fn from_id(id: &str) -> Result<Self> {
77 match id {
78 "codestral-latest" => Ok(Self::CodestralLatest),
79 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
80 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
81 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
82 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
83 _ => Err(anyhow!("invalid model id")),
84 }
85 }
86
87 pub fn id(&self) -> &str {
88 match self {
89 Self::CodestralLatest => "codestral-latest",
90 Self::MistralLargeLatest => "mistral-large-latest",
91 Self::MistralSmallLatest => "mistral-small-latest",
92 Self::OpenMistralNemo => "open-mistral-nemo",
93 Self::OpenCodestralMamba => "open-codestral-mamba",
94 Self::Custom { name, .. } => name,
95 }
96 }
97
98 pub fn display_name(&self) -> &str {
99 match self {
100 Self::CodestralLatest => "codestral-latest",
101 Self::MistralLargeLatest => "mistral-large-latest",
102 Self::MistralSmallLatest => "mistral-small-latest",
103 Self::OpenMistralNemo => "open-mistral-nemo",
104 Self::OpenCodestralMamba => "open-codestral-mamba",
105 Self::Custom {
106 name, display_name, ..
107 } => display_name.as_ref().unwrap_or(name),
108 }
109 }
110
111 pub fn max_token_count(&self) -> usize {
112 match self {
113 Self::CodestralLatest => 256000,
114 Self::MistralLargeLatest => 131000,
115 Self::MistralSmallLatest => 32000,
116 Self::OpenMistralNemo => 131000,
117 Self::OpenCodestralMamba => 256000,
118 Self::Custom { max_tokens, .. } => *max_tokens,
119 }
120 }
121
122 pub fn max_output_tokens(&self) -> Option<u32> {
123 match self {
124 Self::Custom {
125 max_output_tokens, ..
126 } => *max_output_tokens,
127 _ => None,
128 }
129 }
130}
131
132#[derive(Debug, Serialize, Deserialize)]
133pub struct Request {
134 pub model: String,
135 pub messages: Vec<RequestMessage>,
136 pub stream: bool,
137 #[serde(default, skip_serializing_if = "Option::is_none")]
138 pub max_tokens: Option<u32>,
139 #[serde(default, skip_serializing_if = "Option::is_none")]
140 pub temperature: Option<f32>,
141 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub response_format: Option<ResponseFormat>,
143 #[serde(default, skip_serializing_if = "Vec::is_empty")]
144 pub tools: Vec<ToolDefinition>,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148#[serde(rename_all = "snake_case")]
149pub enum ResponseFormat {
150 Text,
151 #[serde(rename = "json_object")]
152 JsonObject,
153}
154
155#[derive(Debug, Serialize, Deserialize)]
156#[serde(tag = "type", rename_all = "snake_case")]
157pub enum ToolDefinition {
158 Function { function: FunctionDefinition },
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162pub struct FunctionDefinition {
163 pub name: String,
164 pub description: Option<String>,
165 pub parameters: Option<Value>,
166}
167
168#[derive(Debug, Serialize, Deserialize)]
169pub struct CompletionRequest {
170 pub model: String,
171 pub prompt: String,
172 pub max_tokens: u32,
173 pub temperature: f32,
174 #[serde(default, skip_serializing_if = "Option::is_none")]
175 pub prediction: Option<Prediction>,
176 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub rewrite_speculation: Option<bool>,
178}
179
180#[derive(Clone, Deserialize, Serialize, Debug)]
181#[serde(tag = "type", rename_all = "snake_case")]
182pub enum Prediction {
183 Content { content: String },
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187#[serde(untagged)]
188pub enum ToolChoice {
189 Auto,
190 Required,
191 None,
192 Other(ToolDefinition),
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196#[serde(tag = "role", rename_all = "lowercase")]
197pub enum RequestMessage {
198 Assistant {
199 content: Option<String>,
200 #[serde(default, skip_serializing_if = "Vec::is_empty")]
201 tool_calls: Vec<ToolCall>,
202 },
203 User {
204 content: String,
205 },
206 System {
207 content: String,
208 },
209 Tool {
210 content: String,
211 tool_call_id: String,
212 },
213}
214
215#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
216pub struct ToolCall {
217 pub id: String,
218 #[serde(flatten)]
219 pub content: ToolCallContent,
220}
221
222#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum ToolCallContent {
225 Function { function: FunctionContent },
226}
227
228#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
229pub struct FunctionContent {
230 pub name: String,
231 pub arguments: String,
232}
233
234#[derive(Serialize, Deserialize, Debug)]
235pub struct CompletionChoice {
236 pub text: String,
237}
238
239#[derive(Serialize, Deserialize, Debug)]
240pub struct Response {
241 pub id: String,
242 pub object: String,
243 pub created: u64,
244 pub model: String,
245 pub choices: Vec<Choice>,
246 pub usage: Usage,
247}
248
249#[derive(Serialize, Deserialize, Debug)]
250pub struct Usage {
251 pub prompt_tokens: u32,
252 pub completion_tokens: u32,
253 pub total_tokens: u32,
254}
255
256#[derive(Serialize, Deserialize, Debug)]
257pub struct Choice {
258 pub index: u32,
259 pub message: RequestMessage,
260 pub finish_reason: Option<String>,
261}
262
263#[derive(Serialize, Deserialize, Debug)]
264pub struct StreamResponse {
265 pub id: String,
266 pub object: String,
267 pub created: u64,
268 pub model: String,
269 pub choices: Vec<StreamChoice>,
270}
271
272#[derive(Serialize, Deserialize, Debug)]
273pub struct StreamChoice {
274 pub index: u32,
275 pub delta: StreamDelta,
276 pub finish_reason: Option<String>,
277}
278
279#[derive(Serialize, Deserialize, Debug)]
280pub struct StreamDelta {
281 pub role: Option<Role>,
282 pub content: Option<String>,
283 #[serde(default, skip_serializing_if = "Option::is_none")]
284 pub tool_calls: Option<Vec<ToolCallChunk>>,
285 #[serde(default, skip_serializing_if = "Option::is_none")]
286 pub reasoning_content: Option<String>,
287}
288
289#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
290pub struct ToolCallChunk {
291 pub index: usize,
292 pub id: Option<String>,
293 pub function: Option<FunctionChunk>,
294}
295
296#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
297pub struct FunctionChunk {
298 pub name: Option<String>,
299 pub arguments: Option<String>,
300}
301
302pub async fn stream_completion(
303 client: &dyn HttpClient,
304 api_url: &str,
305 api_key: &str,
306 request: Request,
307) -> Result<BoxStream<'static, Result<StreamResponse>>> {
308 let uri = format!("{api_url}/chat/completions");
309 let request_builder = HttpRequest::builder()
310 .method(Method::POST)
311 .uri(uri)
312 .header("Content-Type", "application/json")
313 .header("Authorization", format!("Bearer {}", api_key));
314
315 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
316 let mut response = client.send(request).await?;
317
318 if response.status().is_success() {
319 let reader = BufReader::new(response.into_body());
320 Ok(reader
321 .lines()
322 .filter_map(|line| async move {
323 match line {
324 Ok(line) => {
325 let line = line.strip_prefix("data: ")?;
326 if line == "[DONE]" {
327 None
328 } else {
329 match serde_json::from_str(line) {
330 Ok(response) => Some(Ok(response)),
331 Err(error) => Some(Err(anyhow!(error))),
332 }
333 }
334 }
335 Err(error) => Some(Err(anyhow!(error))),
336 }
337 })
338 .boxed())
339 } else {
340 let mut body = String::new();
341 response.body_mut().read_to_string(&mut body).await?;
342 Err(anyhow!(
343 "Failed to connect to Mistral API: {} {}",
344 response.status(),
345 body,
346 ))
347 }
348}