1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7use strum::EnumIter;
8
9pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
10
11#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
12#[serde(rename_all = "lowercase")]
13pub enum Role {
14 User,
15 Assistant,
16 System,
17 Tool,
18}
19
20impl TryFrom<String> for Role {
21 type Error = anyhow::Error;
22
23 fn try_from(value: String) -> Result<Self> {
24 match value.as_str() {
25 "user" => Ok(Self::User),
26 "assistant" => Ok(Self::Assistant),
27 "system" => Ok(Self::System),
28 "tool" => Ok(Self::Tool),
29 _ => Err(anyhow!("invalid role '{value}'")),
30 }
31 }
32}
33
34impl From<Role> for String {
35 fn from(val: Role) -> Self {
36 match val {
37 Role::User => "user".to_owned(),
38 Role::Assistant => "assistant".to_owned(),
39 Role::System => "system".to_owned(),
40 Role::Tool => "tool".to_owned(),
41 }
42 }
43}
44
45#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
46#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
47pub enum Model {
48 #[serde(rename = "codestral-latest", alias = "codestral-latest")]
49 CodestralLatest,
50 #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
51 MistralLargeLatest,
52 #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
53 MistralSmallLatest,
54 #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
55 OpenMistralNemo,
56 #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
57 #[default]
58 OpenCodestralMamba,
59
60 #[serde(rename = "custom")]
61 Custom {
62 name: String,
63 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
64 display_name: Option<String>,
65 max_tokens: usize,
66 max_output_tokens: Option<u32>,
67 max_completion_tokens: Option<u32>,
68 },
69}
70
71impl Model {
72 pub fn from_id(id: &str) -> Result<Self> {
73 match id {
74 "codestral-latest" => Ok(Self::CodestralLatest),
75 "mistral-large-latest" => Ok(Self::MistralLargeLatest),
76 "mistral-small-latest" => Ok(Self::MistralSmallLatest),
77 "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
78 "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
79 _ => Err(anyhow!("invalid model id")),
80 }
81 }
82
83 pub fn id(&self) -> &str {
84 match self {
85 Self::CodestralLatest => "codestral-latest",
86 Self::MistralLargeLatest => "mistral-large-latest",
87 Self::MistralSmallLatest => "mistral-small-latest",
88 Self::OpenMistralNemo => "open-mistral-nemo",
89 Self::OpenCodestralMamba => "open-codestral-mamba",
90 Self::Custom { name, .. } => name,
91 }
92 }
93
94 pub fn display_name(&self) -> &str {
95 match self {
96 Self::CodestralLatest => "codestral-latest",
97 Self::MistralLargeLatest => "mistral-large-latest",
98 Self::MistralSmallLatest => "mistral-small-latest",
99 Self::OpenMistralNemo => "open-mistral-nemo",
100 Self::OpenCodestralMamba => "open-codestral-mamba",
101 Self::Custom {
102 name, display_name, ..
103 } => display_name.as_ref().unwrap_or(name),
104 }
105 }
106
107 pub fn max_token_count(&self) -> usize {
108 match self {
109 Self::CodestralLatest => 256000,
110 Self::MistralLargeLatest => 131000,
111 Self::MistralSmallLatest => 32000,
112 Self::OpenMistralNemo => 131000,
113 Self::OpenCodestralMamba => 256000,
114 Self::Custom { max_tokens, .. } => *max_tokens,
115 }
116 }
117
118 pub fn max_output_tokens(&self) -> Option<u32> {
119 match self {
120 Self::Custom {
121 max_output_tokens, ..
122 } => *max_output_tokens,
123 _ => None,
124 }
125 }
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129pub struct Request {
130 pub model: String,
131 pub messages: Vec<RequestMessage>,
132 pub stream: bool,
133 #[serde(default, skip_serializing_if = "Option::is_none")]
134 pub max_tokens: Option<u32>,
135 #[serde(default, skip_serializing_if = "Option::is_none")]
136 pub temperature: Option<f32>,
137 #[serde(default, skip_serializing_if = "Option::is_none")]
138 pub response_format: Option<ResponseFormat>,
139 #[serde(default, skip_serializing_if = "Vec::is_empty")]
140 pub tools: Vec<ToolDefinition>,
141}
142
143#[derive(Debug, Serialize, Deserialize)]
144#[serde(rename_all = "snake_case")]
145pub enum ResponseFormat {
146 Text,
147 #[serde(rename = "json_object")]
148 JsonObject,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(tag = "type", rename_all = "snake_case")]
153pub enum ToolDefinition {
154 Function { function: FunctionDefinition },
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158pub struct FunctionDefinition {
159 pub name: String,
160 pub description: Option<String>,
161 pub parameters: Option<Value>,
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165pub struct CompletionRequest {
166 pub model: String,
167 pub prompt: String,
168 pub max_tokens: u32,
169 pub temperature: f32,
170 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub prediction: Option<Prediction>,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
173 pub rewrite_speculation: Option<bool>,
174}
175
176#[derive(Clone, Deserialize, Serialize, Debug)]
177#[serde(tag = "type", rename_all = "snake_case")]
178pub enum Prediction {
179 Content { content: String },
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183#[serde(untagged)]
184pub enum ToolChoice {
185 Auto,
186 Required,
187 None,
188 Other(ToolDefinition),
189}
190
191#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
192#[serde(tag = "role", rename_all = "lowercase")]
193pub enum RequestMessage {
194 Assistant {
195 content: Option<String>,
196 #[serde(default, skip_serializing_if = "Vec::is_empty")]
197 tool_calls: Vec<ToolCall>,
198 },
199 User {
200 content: String,
201 },
202 System {
203 content: String,
204 },
205 Tool {
206 content: String,
207 tool_call_id: String,
208 },
209}
210
211#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
212pub struct ToolCall {
213 pub id: String,
214 #[serde(flatten)]
215 pub content: ToolCallContent,
216}
217
218#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
219#[serde(tag = "type", rename_all = "lowercase")]
220pub enum ToolCallContent {
221 Function { function: FunctionContent },
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225pub struct FunctionContent {
226 pub name: String,
227 pub arguments: String,
228}
229
230#[derive(Serialize, Deserialize, Debug)]
231pub struct CompletionChoice {
232 pub text: String,
233}
234
235#[derive(Serialize, Deserialize, Debug)]
236pub struct Response {
237 pub id: String,
238 pub object: String,
239 pub created: u64,
240 pub model: String,
241 pub choices: Vec<Choice>,
242 pub usage: Usage,
243}
244
245#[derive(Serialize, Deserialize, Debug)]
246pub struct Usage {
247 pub prompt_tokens: u32,
248 pub completion_tokens: u32,
249 pub total_tokens: u32,
250}
251
252#[derive(Serialize, Deserialize, Debug)]
253pub struct Choice {
254 pub index: u32,
255 pub message: RequestMessage,
256 pub finish_reason: Option<String>,
257}
258
259#[derive(Serialize, Deserialize, Debug)]
260pub struct StreamResponse {
261 pub id: String,
262 pub object: String,
263 pub created: u64,
264 pub model: String,
265 pub choices: Vec<StreamChoice>,
266}
267
268#[derive(Serialize, Deserialize, Debug)]
269pub struct StreamChoice {
270 pub index: u32,
271 pub delta: StreamDelta,
272 pub finish_reason: Option<String>,
273}
274
275#[derive(Serialize, Deserialize, Debug)]
276pub struct StreamDelta {
277 pub role: Option<Role>,
278 pub content: Option<String>,
279 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub tool_calls: Option<Vec<ToolCallChunk>>,
281 #[serde(default, skip_serializing_if = "Option::is_none")]
282 pub reasoning_content: Option<String>,
283}
284
285#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
286pub struct ToolCallChunk {
287 pub index: usize,
288 pub id: Option<String>,
289 pub function: Option<FunctionChunk>,
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293pub struct FunctionChunk {
294 pub name: Option<String>,
295 pub arguments: Option<String>,
296}
297
298pub async fn stream_completion(
299 client: &dyn HttpClient,
300 api_url: &str,
301 api_key: &str,
302 request: Request,
303) -> Result<BoxStream<'static, Result<StreamResponse>>> {
304 let uri = format!("{api_url}/chat/completions");
305 let request_builder = HttpRequest::builder()
306 .method(Method::POST)
307 .uri(uri)
308 .header("Content-Type", "application/json")
309 .header("Authorization", format!("Bearer {}", api_key));
310
311 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
312 let mut response = client.send(request).await?;
313
314 if response.status().is_success() {
315 let reader = BufReader::new(response.into_body());
316 Ok(reader
317 .lines()
318 .filter_map(|line| async move {
319 match line {
320 Ok(line) => {
321 let line = line.strip_prefix("data: ")?;
322 if line == "[DONE]" {
323 None
324 } else {
325 match serde_json::from_str(line) {
326 Ok(response) => Some(Ok(response)),
327 Err(error) => Some(Err(anyhow!(error))),
328 }
329 }
330 }
331 Err(error) => Some(Err(anyhow!(error))),
332 }
333 })
334 .boxed())
335 } else {
336 let mut body = String::new();
337 response.body_mut().read_to_string(&mut body).await?;
338 Err(anyhow!(
339 "Failed to connect to Mistral API: {} {}",
340 response.status(),
341 body,
342 ))
343 }
344}