1mod supported_countries;
2
3use anyhow::{Context as _, Result, anyhow};
4use futures::{
5 AsyncBufReadExt, AsyncReadExt, StreamExt,
6 io::BufReader,
7 stream::{self, BoxStream},
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15};
16use strum::EnumIter;
17
18pub use supported_countries::*;
19
20pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
21
22fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
23 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
24}
25
26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 User,
30 Assistant,
31 System,
32 Tool,
33}
34
35impl TryFrom<String> for Role {
36 type Error = anyhow::Error;
37
38 fn try_from(value: String) -> Result<Self> {
39 match value.as_str() {
40 "user" => Ok(Self::User),
41 "assistant" => Ok(Self::Assistant),
42 "system" => Ok(Self::System),
43 "tool" => Ok(Self::Tool),
44 _ => Err(anyhow!("invalid role '{value}'")),
45 }
46 }
47}
48
49impl From<Role> for String {
50 fn from(val: Role) -> Self {
51 match val {
52 Role::User => "user".to_owned(),
53 Role::Assistant => "assistant".to_owned(),
54 Role::System => "system".to_owned(),
55 Role::Tool => "tool".to_owned(),
56 }
57 }
58}
59
60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
62pub enum Model {
63 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
64 ThreePointFiveTurbo,
65 #[serde(rename = "gpt-4", alias = "gpt-4")]
66 Four,
67 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
68 FourTurbo,
69 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
70 #[default]
71 FourOmni,
72 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
73 FourOmniMini,
74 #[serde(rename = "o1", alias = "o1")]
75 O1,
76 #[serde(rename = "o1-preview", alias = "o1-preview")]
77 O1Preview,
78 #[serde(rename = "o1-mini", alias = "o1-mini")]
79 O1Mini,
80 #[serde(rename = "o3-mini", alias = "o3-mini")]
81 O3Mini,
82
83 #[serde(rename = "custom")]
84 Custom {
85 name: String,
86 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
87 display_name: Option<String>,
88 max_tokens: usize,
89 max_output_tokens: Option<u32>,
90 max_completion_tokens: Option<u32>,
91 },
92}
93
94impl Model {
95 pub fn from_id(id: &str) -> Result<Self> {
96 match id {
97 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
98 "gpt-4" => Ok(Self::Four),
99 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
100 "gpt-4o" => Ok(Self::FourOmni),
101 "gpt-4o-mini" => Ok(Self::FourOmniMini),
102 "o1" => Ok(Self::O1),
103 "o1-preview" => Ok(Self::O1Preview),
104 "o1-mini" => Ok(Self::O1Mini),
105 "o3-mini" => Ok(Self::O3Mini),
106 _ => Err(anyhow!("invalid model id")),
107 }
108 }
109
110 pub fn id(&self) -> &str {
111 match self {
112 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
113 Self::Four => "gpt-4",
114 Self::FourTurbo => "gpt-4-turbo",
115 Self::FourOmni => "gpt-4o",
116 Self::FourOmniMini => "gpt-4o-mini",
117 Self::O1 => "o1",
118 Self::O1Preview => "o1-preview",
119 Self::O1Mini => "o1-mini",
120 Self::O3Mini => "o3-mini",
121 Self::Custom { name, .. } => name,
122 }
123 }
124
125 pub fn display_name(&self) -> &str {
126 match self {
127 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
128 Self::Four => "gpt-4",
129 Self::FourTurbo => "gpt-4-turbo",
130 Self::FourOmni => "gpt-4o",
131 Self::FourOmniMini => "gpt-4o-mini",
132 Self::O1 => "o1",
133 Self::O1Preview => "o1-preview",
134 Self::O1Mini => "o1-mini",
135 Self::O3Mini => "o3-mini",
136 Self::Custom {
137 name, display_name, ..
138 } => display_name.as_ref().unwrap_or(name),
139 }
140 }
141
142 pub fn max_token_count(&self) -> usize {
143 match self {
144 Self::ThreePointFiveTurbo => 16_385,
145 Self::Four => 8_192,
146 Self::FourTurbo => 128_000,
147 Self::FourOmni => 128_000,
148 Self::FourOmniMini => 128_000,
149 Self::O1 => 200_000,
150 Self::O1Preview => 128_000,
151 Self::O1Mini => 128_000,
152 Self::O3Mini => 200_000,
153 Self::Custom { max_tokens, .. } => *max_tokens,
154 }
155 }
156
157 pub fn max_output_tokens(&self) -> Option<u32> {
158 match self {
159 Self::Custom {
160 max_output_tokens, ..
161 } => *max_output_tokens,
162 _ => None,
163 }
164 }
165}
166
167#[derive(Debug, Serialize, Deserialize)]
168pub struct Request {
169 pub model: String,
170 pub messages: Vec<RequestMessage>,
171 pub stream: bool,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
173 pub max_tokens: Option<u32>,
174 #[serde(default, skip_serializing_if = "Vec::is_empty")]
175 pub stop: Vec<String>,
176 pub temperature: f32,
177 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub tool_choice: Option<ToolChoice>,
179 #[serde(default, skip_serializing_if = "Vec::is_empty")]
180 pub tools: Vec<ToolDefinition>,
181}
182
183#[derive(Debug, Serialize, Deserialize)]
184pub struct CompletionRequest {
185 pub model: String,
186 pub prompt: String,
187 pub max_tokens: u32,
188 pub temperature: f32,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub prediction: Option<Prediction>,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub rewrite_speculation: Option<bool>,
193}
194
195#[derive(Clone, Deserialize, Serialize, Debug)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum Prediction {
198 Content { content: String },
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202#[serde(untagged)]
203pub enum ToolChoice {
204 Auto,
205 Required,
206 None,
207 Other(ToolDefinition),
208}
209
210#[derive(Clone, Deserialize, Serialize, Debug)]
211#[serde(tag = "type", rename_all = "snake_case")]
212pub enum ToolDefinition {
213 #[allow(dead_code)]
214 Function { function: FunctionDefinition },
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize)]
218pub struct FunctionDefinition {
219 pub name: String,
220 pub description: Option<String>,
221 pub parameters: Option<Value>,
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225#[serde(tag = "role", rename_all = "lowercase")]
226pub enum RequestMessage {
227 Assistant {
228 content: Option<String>,
229 #[serde(default, skip_serializing_if = "Vec::is_empty")]
230 tool_calls: Vec<ToolCall>,
231 },
232 User {
233 content: String,
234 },
235 System {
236 content: String,
237 },
238 Tool {
239 content: String,
240 tool_call_id: String,
241 },
242}
243
244#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
245pub struct ToolCall {
246 pub id: String,
247 #[serde(flatten)]
248 pub content: ToolCallContent,
249}
250
251#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
252#[serde(tag = "type", rename_all = "lowercase")]
253pub enum ToolCallContent {
254 Function { function: FunctionContent },
255}
256
257#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
258pub struct FunctionContent {
259 pub name: String,
260 pub arguments: String,
261}
262
263#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
264pub struct ResponseMessageDelta {
265 pub role: Option<Role>,
266 pub content: Option<String>,
267 #[serde(default, skip_serializing_if = "is_none_or_empty")]
268 pub tool_calls: Option<Vec<ToolCallChunk>>,
269}
270
271#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
272pub struct ToolCallChunk {
273 pub index: usize,
274 pub id: Option<String>,
275
276 // There is also an optional `type` field that would determine if a
277 // function is there. Sometimes this streams in with the `function` before
278 // it streams in the `type`
279 pub function: Option<FunctionChunk>,
280}
281
282#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
283pub struct FunctionChunk {
284 pub name: Option<String>,
285 pub arguments: Option<String>,
286}
287
288#[derive(Serialize, Deserialize, Debug)]
289pub struct Usage {
290 pub prompt_tokens: u32,
291 pub completion_tokens: u32,
292 pub total_tokens: u32,
293}
294
295#[derive(Serialize, Deserialize, Debug)]
296pub struct ChoiceDelta {
297 pub index: u32,
298 pub delta: ResponseMessageDelta,
299 pub finish_reason: Option<String>,
300}
301
302#[derive(Serialize, Deserialize, Debug)]
303#[serde(untagged)]
304pub enum ResponseStreamResult {
305 Ok(ResponseStreamEvent),
306 Err { error: String },
307}
308
309#[derive(Serialize, Deserialize, Debug)]
310pub struct ResponseStreamEvent {
311 pub created: u32,
312 pub model: String,
313 pub choices: Vec<ChoiceDelta>,
314 pub usage: Option<Usage>,
315}
316
317#[derive(Serialize, Deserialize, Debug)]
318pub struct CompletionResponse {
319 pub id: String,
320 pub object: String,
321 pub created: u64,
322 pub model: String,
323 pub choices: Vec<CompletionChoice>,
324 pub usage: Usage,
325}
326
327#[derive(Serialize, Deserialize, Debug)]
328pub struct CompletionChoice {
329 pub text: String,
330}
331
332#[derive(Serialize, Deserialize, Debug)]
333pub struct Response {
334 pub id: String,
335 pub object: String,
336 pub created: u64,
337 pub model: String,
338 pub choices: Vec<Choice>,
339 pub usage: Usage,
340}
341
342#[derive(Serialize, Deserialize, Debug)]
343pub struct Choice {
344 pub index: u32,
345 pub message: RequestMessage,
346 pub finish_reason: Option<String>,
347}
348
349pub async fn complete(
350 client: &dyn HttpClient,
351 api_url: &str,
352 api_key: &str,
353 request: Request,
354) -> Result<Response> {
355 let uri = format!("{api_url}/chat/completions");
356 let request_builder = HttpRequest::builder()
357 .method(Method::POST)
358 .uri(uri)
359 .header("Content-Type", "application/json")
360 .header("Authorization", format!("Bearer {}", api_key));
361
362 let mut request_body = request;
363 request_body.stream = false;
364
365 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
366 let mut response = client.send(request).await?;
367
368 if response.status().is_success() {
369 let mut body = String::new();
370 response.body_mut().read_to_string(&mut body).await?;
371 let response: Response = serde_json::from_str(&body)?;
372 Ok(response)
373 } else {
374 let mut body = String::new();
375 response.body_mut().read_to_string(&mut body).await?;
376
377 #[derive(Deserialize)]
378 struct OpenAiResponse {
379 error: OpenAiError,
380 }
381
382 #[derive(Deserialize)]
383 struct OpenAiError {
384 message: String,
385 }
386
387 match serde_json::from_str::<OpenAiResponse>(&body) {
388 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
389 "Failed to connect to OpenAI API: {}",
390 response.error.message,
391 )),
392
393 _ => Err(anyhow!(
394 "Failed to connect to OpenAI API: {} {}",
395 response.status(),
396 body,
397 )),
398 }
399 }
400}
401
402pub async fn complete_text(
403 client: &dyn HttpClient,
404 api_url: &str,
405 api_key: &str,
406 request: CompletionRequest,
407) -> Result<CompletionResponse> {
408 let uri = format!("{api_url}/completions");
409 let request_builder = HttpRequest::builder()
410 .method(Method::POST)
411 .uri(uri)
412 .header("Content-Type", "application/json")
413 .header("Authorization", format!("Bearer {}", api_key));
414
415 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
416 let mut response = client.send(request).await?;
417
418 if response.status().is_success() {
419 let mut body = String::new();
420 response.body_mut().read_to_string(&mut body).await?;
421 let response = serde_json::from_str(&body)?;
422 Ok(response)
423 } else {
424 let mut body = String::new();
425 response.body_mut().read_to_string(&mut body).await?;
426
427 #[derive(Deserialize)]
428 struct OpenAiResponse {
429 error: OpenAiError,
430 }
431
432 #[derive(Deserialize)]
433 struct OpenAiError {
434 message: String,
435 }
436
437 match serde_json::from_str::<OpenAiResponse>(&body) {
438 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
439 "Failed to connect to OpenAI API: {}",
440 response.error.message,
441 )),
442
443 _ => Err(anyhow!(
444 "Failed to connect to OpenAI API: {} {}",
445 response.status(),
446 body,
447 )),
448 }
449 }
450}
451
452fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
453 ResponseStreamEvent {
454 created: response.created as u32,
455 model: response.model,
456 choices: response
457 .choices
458 .into_iter()
459 .map(|choice| ChoiceDelta {
460 index: choice.index,
461 delta: ResponseMessageDelta {
462 role: Some(match choice.message {
463 RequestMessage::Assistant { .. } => Role::Assistant,
464 RequestMessage::User { .. } => Role::User,
465 RequestMessage::System { .. } => Role::System,
466 RequestMessage::Tool { .. } => Role::Tool,
467 }),
468 content: match choice.message {
469 RequestMessage::Assistant { content, .. } => content,
470 RequestMessage::User { content } => Some(content),
471 RequestMessage::System { content } => Some(content),
472 RequestMessage::Tool { content, .. } => Some(content),
473 },
474 tool_calls: None,
475 },
476 finish_reason: choice.finish_reason,
477 })
478 .collect(),
479 usage: Some(response.usage),
480 }
481}
482
483pub async fn stream_completion(
484 client: &dyn HttpClient,
485 api_url: &str,
486 api_key: &str,
487 request: Request,
488) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
489 if request.model.starts_with("o1") {
490 let response = complete(client, api_url, api_key, request).await;
491 let response_stream_event = response.map(adapt_response_to_stream);
492 return Ok(stream::once(future::ready(response_stream_event)).boxed());
493 }
494
495 let uri = format!("{api_url}/chat/completions");
496 let request_builder = HttpRequest::builder()
497 .method(Method::POST)
498 .uri(uri)
499 .header("Content-Type", "application/json")
500 .header("Authorization", format!("Bearer {}", api_key));
501
502 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
503 let mut response = client.send(request).await?;
504 if response.status().is_success() {
505 let reader = BufReader::new(response.into_body());
506 Ok(reader
507 .lines()
508 .filter_map(|line| async move {
509 match line {
510 Ok(line) => {
511 let line = line.strip_prefix("data: ")?;
512 if line == "[DONE]" {
513 None
514 } else {
515 match serde_json::from_str(line) {
516 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
517 Ok(ResponseStreamResult::Err { error }) => {
518 Some(Err(anyhow!(error)))
519 }
520 Err(error) => Some(Err(anyhow!(error))),
521 }
522 }
523 }
524 Err(error) => Some(Err(anyhow!(error))),
525 }
526 })
527 .boxed())
528 } else {
529 let mut body = String::new();
530 response.body_mut().read_to_string(&mut body).await?;
531
532 #[derive(Deserialize)]
533 struct OpenAiResponse {
534 error: OpenAiError,
535 }
536
537 #[derive(Deserialize)]
538 struct OpenAiError {
539 message: String,
540 }
541
542 match serde_json::from_str::<OpenAiResponse>(&body) {
543 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
544 "Failed to connect to OpenAI API: {}",
545 response.error.message,
546 )),
547
548 _ => Err(anyhow!(
549 "Failed to connect to OpenAI API: {} {}",
550 response.status(),
551 body,
552 )),
553 }
554 }
555}
556
557#[derive(Copy, Clone, Serialize, Deserialize)]
558pub enum OpenAiEmbeddingModel {
559 #[serde(rename = "text-embedding-3-small")]
560 TextEmbedding3Small,
561 #[serde(rename = "text-embedding-3-large")]
562 TextEmbedding3Large,
563}
564
565#[derive(Serialize)]
566struct OpenAiEmbeddingRequest<'a> {
567 model: OpenAiEmbeddingModel,
568 input: Vec<&'a str>,
569}
570
571#[derive(Deserialize)]
572pub struct OpenAiEmbeddingResponse {
573 pub data: Vec<OpenAiEmbedding>,
574}
575
576#[derive(Deserialize)]
577pub struct OpenAiEmbedding {
578 pub embedding: Vec<f32>,
579}
580
581pub fn embed<'a>(
582 client: &dyn HttpClient,
583 api_url: &str,
584 api_key: &str,
585 model: OpenAiEmbeddingModel,
586 texts: impl IntoIterator<Item = &'a str>,
587) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
588 let uri = format!("{api_url}/embeddings");
589
590 let request = OpenAiEmbeddingRequest {
591 model,
592 input: texts.into_iter().collect(),
593 };
594 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
595 let request = HttpRequest::builder()
596 .method(Method::POST)
597 .uri(uri)
598 .header("Content-Type", "application/json")
599 .header("Authorization", format!("Bearer {}", api_key))
600 .body(body)
601 .map(|request| client.send(request));
602
603 async move {
604 let mut response = request?.await?;
605 let mut body = String::new();
606 response.body_mut().read_to_string(&mut body).await?;
607
608 if response.status().is_success() {
609 let response: OpenAiEmbeddingResponse =
610 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
611 Ok(response)
612 } else {
613 Err(anyhow!(
614 "error during embedding, status: {:?}, body: {:?}",
615 response.status(),
616 body
617 ))
618 }
619 }
620}