1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{
5 io::BufReader,
6 stream::{self, BoxStream},
7 AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15 pin::Pin,
16};
17use strum::EnumIter;
18
19pub use supported_countries::*;
20
21pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
22
23fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
24 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
25}
26
27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum Role {
30 User,
31 Assistant,
32 System,
33 Tool,
34}
35
36impl TryFrom<String> for Role {
37 type Error = anyhow::Error;
38
39 fn try_from(value: String) -> Result<Self> {
40 match value.as_str() {
41 "user" => Ok(Self::User),
42 "assistant" => Ok(Self::Assistant),
43 "system" => Ok(Self::System),
44 "tool" => Ok(Self::Tool),
45 _ => Err(anyhow!("invalid role '{value}'")),
46 }
47 }
48}
49
50impl From<Role> for String {
51 fn from(val: Role) -> Self {
52 match val {
53 Role::User => "user".to_owned(),
54 Role::Assistant => "assistant".to_owned(),
55 Role::System => "system".to_owned(),
56 Role::Tool => "tool".to_owned(),
57 }
58 }
59}
60
61#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
62#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
63pub enum Model {
64 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
65 ThreePointFiveTurbo,
66 #[serde(rename = "gpt-4", alias = "gpt-4")]
67 Four,
68 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
69 FourTurbo,
70 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
71 #[default]
72 FourOmni,
73 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
74 FourOmniMini,
75 #[serde(rename = "o1", alias = "o1")]
76 O1,
77 #[serde(rename = "o1-preview", alias = "o1-preview")]
78 O1Preview,
79 #[serde(rename = "o1-mini", alias = "o1-mini")]
80 O1Mini,
81
82 #[serde(rename = "custom")]
83 Custom {
84 name: String,
85 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
86 display_name: Option<String>,
87 max_tokens: usize,
88 max_output_tokens: Option<u32>,
89 max_completion_tokens: Option<u32>,
90 },
91}
92
93impl Model {
94 pub fn from_id(id: &str) -> Result<Self> {
95 match id {
96 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
97 "gpt-4" => Ok(Self::Four),
98 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
99 "gpt-4o" => Ok(Self::FourOmni),
100 "gpt-4o-mini" => Ok(Self::FourOmniMini),
101 "o1" => Ok(Self::O1),
102 "o1-preview" => Ok(Self::O1Preview),
103 "o1-mini" => Ok(Self::O1Mini),
104 _ => Err(anyhow!("invalid model id")),
105 }
106 }
107
108 pub fn id(&self) -> &str {
109 match self {
110 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
111 Self::Four => "gpt-4",
112 Self::FourTurbo => "gpt-4-turbo",
113 Self::FourOmni => "gpt-4o",
114 Self::FourOmniMini => "gpt-4o-mini",
115 Self::O1 => "o1",
116 Self::O1Preview => "o1-preview",
117 Self::O1Mini => "o1-mini",
118 Self::Custom { name, .. } => name,
119 }
120 }
121
122 pub fn display_name(&self) -> &str {
123 match self {
124 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
125 Self::Four => "gpt-4",
126 Self::FourTurbo => "gpt-4-turbo",
127 Self::FourOmni => "gpt-4o",
128 Self::FourOmniMini => "gpt-4o-mini",
129 Self::O1 => "o1",
130 Self::O1Preview => "o1-preview",
131 Self::O1Mini => "o1-mini",
132 Self::Custom {
133 name, display_name, ..
134 } => display_name.as_ref().unwrap_or(name),
135 }
136 }
137
138 pub fn max_token_count(&self) -> usize {
139 match self {
140 Self::ThreePointFiveTurbo => 16385,
141 Self::Four => 8192,
142 Self::FourTurbo => 128000,
143 Self::FourOmni => 128000,
144 Self::FourOmniMini => 128000,
145 Self::O1 => 200000,
146 Self::O1Preview => 128000,
147 Self::O1Mini => 128000,
148 Self::Custom { max_tokens, .. } => *max_tokens,
149 }
150 }
151
152 pub fn max_output_tokens(&self) -> Option<u32> {
153 match self {
154 Self::Custom {
155 max_output_tokens, ..
156 } => *max_output_tokens,
157 _ => None,
158 }
159 }
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163pub struct Request {
164 pub model: String,
165 pub messages: Vec<RequestMessage>,
166 pub stream: bool,
167 #[serde(default, skip_serializing_if = "Option::is_none")]
168 pub max_tokens: Option<u32>,
169 #[serde(default, skip_serializing_if = "Vec::is_empty")]
170 pub stop: Vec<String>,
171 pub temperature: f32,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
173 pub tool_choice: Option<ToolChoice>,
174 #[serde(default, skip_serializing_if = "Vec::is_empty")]
175 pub tools: Vec<ToolDefinition>,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179pub struct CompletionRequest {
180 pub model: String,
181 pub prompt: String,
182 pub max_tokens: u32,
183 pub temperature: f32,
184 #[serde(default, skip_serializing_if = "Option::is_none")]
185 pub prediction: Option<Prediction>,
186 #[serde(default, skip_serializing_if = "Option::is_none")]
187 pub rewrite_speculation: Option<bool>,
188}
189
190#[derive(Clone, Deserialize, Serialize, Debug)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum Prediction {
193 Content { content: String },
194}
195
196#[derive(Debug, Serialize, Deserialize)]
197#[serde(untagged)]
198pub enum ToolChoice {
199 Auto,
200 Required,
201 None,
202 Other(ToolDefinition),
203}
204
205#[derive(Clone, Deserialize, Serialize, Debug)]
206#[serde(tag = "type", rename_all = "snake_case")]
207pub enum ToolDefinition {
208 #[allow(dead_code)]
209 Function { function: FunctionDefinition },
210}
211
212#[derive(Clone, Debug, Serialize, Deserialize)]
213pub struct FunctionDefinition {
214 pub name: String,
215 pub description: Option<String>,
216 pub parameters: Option<Value>,
217}
218
219#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
220#[serde(tag = "role", rename_all = "lowercase")]
221pub enum RequestMessage {
222 Assistant {
223 content: Option<String>,
224 #[serde(default, skip_serializing_if = "Vec::is_empty")]
225 tool_calls: Vec<ToolCall>,
226 },
227 User {
228 content: String,
229 },
230 System {
231 content: String,
232 },
233 Tool {
234 content: String,
235 tool_call_id: String,
236 },
237}
238
239#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
240pub struct ToolCall {
241 pub id: String,
242 #[serde(flatten)]
243 pub content: ToolCallContent,
244}
245
246#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
247#[serde(tag = "type", rename_all = "lowercase")]
248pub enum ToolCallContent {
249 Function { function: FunctionContent },
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct FunctionContent {
254 pub name: String,
255 pub arguments: String,
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259pub struct ResponseMessageDelta {
260 pub role: Option<Role>,
261 pub content: Option<String>,
262 #[serde(default, skip_serializing_if = "is_none_or_empty")]
263 pub tool_calls: Option<Vec<ToolCallChunk>>,
264}
265
266#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
267pub struct ToolCallChunk {
268 pub index: usize,
269 pub id: Option<String>,
270
271 // There is also an optional `type` field that would determine if a
272 // function is there. Sometimes this streams in with the `function` before
273 // it streams in the `type`
274 pub function: Option<FunctionChunk>,
275}
276
277#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
278pub struct FunctionChunk {
279 pub name: Option<String>,
280 pub arguments: Option<String>,
281}
282
283#[derive(Serialize, Deserialize, Debug)]
284pub struct Usage {
285 pub prompt_tokens: u32,
286 pub completion_tokens: u32,
287 pub total_tokens: u32,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291pub struct ChoiceDelta {
292 pub index: u32,
293 pub delta: ResponseMessageDelta,
294 pub finish_reason: Option<String>,
295}
296
297#[derive(Serialize, Deserialize, Debug)]
298#[serde(untagged)]
299pub enum ResponseStreamResult {
300 Ok(ResponseStreamEvent),
301 Err { error: String },
302}
303
304#[derive(Serialize, Deserialize, Debug)]
305pub struct ResponseStreamEvent {
306 pub created: u32,
307 pub model: String,
308 pub choices: Vec<ChoiceDelta>,
309 pub usage: Option<Usage>,
310}
311
312#[derive(Serialize, Deserialize, Debug)]
313pub struct CompletionResponse {
314 pub id: String,
315 pub object: String,
316 pub created: u64,
317 pub model: String,
318 pub choices: Vec<CompletionChoice>,
319 pub usage: Usage,
320}
321
322#[derive(Serialize, Deserialize, Debug)]
323pub struct CompletionChoice {
324 pub text: String,
325}
326
327#[derive(Serialize, Deserialize, Debug)]
328pub struct Response {
329 pub id: String,
330 pub object: String,
331 pub created: u64,
332 pub model: String,
333 pub choices: Vec<Choice>,
334 pub usage: Usage,
335}
336
337#[derive(Serialize, Deserialize, Debug)]
338pub struct Choice {
339 pub index: u32,
340 pub message: RequestMessage,
341 pub finish_reason: Option<String>,
342}
343
344pub async fn complete(
345 client: &dyn HttpClient,
346 api_url: &str,
347 api_key: &str,
348 request: Request,
349) -> Result<Response> {
350 let uri = format!("{api_url}/chat/completions");
351 let request_builder = HttpRequest::builder()
352 .method(Method::POST)
353 .uri(uri)
354 .header("Content-Type", "application/json")
355 .header("Authorization", format!("Bearer {}", api_key));
356
357 let mut request_body = request;
358 request_body.stream = false;
359
360 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
361 let mut response = client.send(request).await?;
362
363 if response.status().is_success() {
364 let mut body = String::new();
365 response.body_mut().read_to_string(&mut body).await?;
366 let response: Response = serde_json::from_str(&body)?;
367 Ok(response)
368 } else {
369 let mut body = String::new();
370 response.body_mut().read_to_string(&mut body).await?;
371
372 #[derive(Deserialize)]
373 struct OpenAiResponse {
374 error: OpenAiError,
375 }
376
377 #[derive(Deserialize)]
378 struct OpenAiError {
379 message: String,
380 }
381
382 match serde_json::from_str::<OpenAiResponse>(&body) {
383 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
384 "Failed to connect to OpenAI API: {}",
385 response.error.message,
386 )),
387
388 _ => Err(anyhow!(
389 "Failed to connect to OpenAI API: {} {}",
390 response.status(),
391 body,
392 )),
393 }
394 }
395}
396
397pub async fn complete_text(
398 client: &dyn HttpClient,
399 api_url: &str,
400 api_key: &str,
401 request: CompletionRequest,
402) -> Result<CompletionResponse> {
403 let uri = format!("{api_url}/completions");
404 let request_builder = HttpRequest::builder()
405 .method(Method::POST)
406 .uri(uri)
407 .header("Content-Type", "application/json")
408 .header("Authorization", format!("Bearer {}", api_key));
409
410 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
411 let mut response = client.send(request).await?;
412
413 if response.status().is_success() {
414 let mut body = String::new();
415 response.body_mut().read_to_string(&mut body).await?;
416 let response = serde_json::from_str(&body)?;
417 Ok(response)
418 } else {
419 let mut body = String::new();
420 response.body_mut().read_to_string(&mut body).await?;
421
422 #[derive(Deserialize)]
423 struct OpenAiResponse {
424 error: OpenAiError,
425 }
426
427 #[derive(Deserialize)]
428 struct OpenAiError {
429 message: String,
430 }
431
432 match serde_json::from_str::<OpenAiResponse>(&body) {
433 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
434 "Failed to connect to OpenAI API: {}",
435 response.error.message,
436 )),
437
438 _ => Err(anyhow!(
439 "Failed to connect to OpenAI API: {} {}",
440 response.status(),
441 body,
442 )),
443 }
444 }
445}
446
447fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
448 ResponseStreamEvent {
449 created: response.created as u32,
450 model: response.model,
451 choices: response
452 .choices
453 .into_iter()
454 .map(|choice| ChoiceDelta {
455 index: choice.index,
456 delta: ResponseMessageDelta {
457 role: Some(match choice.message {
458 RequestMessage::Assistant { .. } => Role::Assistant,
459 RequestMessage::User { .. } => Role::User,
460 RequestMessage::System { .. } => Role::System,
461 RequestMessage::Tool { .. } => Role::Tool,
462 }),
463 content: match choice.message {
464 RequestMessage::Assistant { content, .. } => content,
465 RequestMessage::User { content } => Some(content),
466 RequestMessage::System { content } => Some(content),
467 RequestMessage::Tool { content, .. } => Some(content),
468 },
469 tool_calls: None,
470 },
471 finish_reason: choice.finish_reason,
472 })
473 .collect(),
474 usage: Some(response.usage),
475 }
476}
477
478pub async fn stream_completion(
479 client: &dyn HttpClient,
480 api_url: &str,
481 api_key: &str,
482 request: Request,
483) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
484 if request.model.starts_with("o1") {
485 let response = complete(client, api_url, api_key, request).await;
486 let response_stream_event = response.map(adapt_response_to_stream);
487 return Ok(stream::once(future::ready(response_stream_event)).boxed());
488 }
489
490 let uri = format!("{api_url}/chat/completions");
491 let request_builder = HttpRequest::builder()
492 .method(Method::POST)
493 .uri(uri)
494 .header("Content-Type", "application/json")
495 .header("Authorization", format!("Bearer {}", api_key));
496
497 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
498 let mut response = client.send(request).await?;
499 if response.status().is_success() {
500 let reader = BufReader::new(response.into_body());
501 Ok(reader
502 .lines()
503 .filter_map(|line| async move {
504 match line {
505 Ok(line) => {
506 let line = line.strip_prefix("data: ")?;
507 if line == "[DONE]" {
508 None
509 } else {
510 match serde_json::from_str(line) {
511 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
512 Ok(ResponseStreamResult::Err { error }) => {
513 Some(Err(anyhow!(error)))
514 }
515 Err(error) => Some(Err(anyhow!(error))),
516 }
517 }
518 }
519 Err(error) => Some(Err(anyhow!(error))),
520 }
521 })
522 .boxed())
523 } else {
524 let mut body = String::new();
525 response.body_mut().read_to_string(&mut body).await?;
526
527 #[derive(Deserialize)]
528 struct OpenAiResponse {
529 error: OpenAiError,
530 }
531
532 #[derive(Deserialize)]
533 struct OpenAiError {
534 message: String,
535 }
536
537 match serde_json::from_str::<OpenAiResponse>(&body) {
538 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
539 "Failed to connect to OpenAI API: {}",
540 response.error.message,
541 )),
542
543 _ => Err(anyhow!(
544 "Failed to connect to OpenAI API: {} {}",
545 response.status(),
546 body,
547 )),
548 }
549 }
550}
551
552#[derive(Copy, Clone, Serialize, Deserialize)]
553pub enum OpenAiEmbeddingModel {
554 #[serde(rename = "text-embedding-3-small")]
555 TextEmbedding3Small,
556 #[serde(rename = "text-embedding-3-large")]
557 TextEmbedding3Large,
558}
559
560#[derive(Serialize)]
561struct OpenAiEmbeddingRequest<'a> {
562 model: OpenAiEmbeddingModel,
563 input: Vec<&'a str>,
564}
565
566#[derive(Deserialize)]
567pub struct OpenAiEmbeddingResponse {
568 pub data: Vec<OpenAiEmbedding>,
569}
570
571#[derive(Deserialize)]
572pub struct OpenAiEmbedding {
573 pub embedding: Vec<f32>,
574}
575
576pub fn embed<'a>(
577 client: &dyn HttpClient,
578 api_url: &str,
579 api_key: &str,
580 model: OpenAiEmbeddingModel,
581 texts: impl IntoIterator<Item = &'a str>,
582) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
583 let uri = format!("{api_url}/embeddings");
584
585 let request = OpenAiEmbeddingRequest {
586 model,
587 input: texts.into_iter().collect(),
588 };
589 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
590 let request = HttpRequest::builder()
591 .method(Method::POST)
592 .uri(uri)
593 .header("Content-Type", "application/json")
594 .header("Authorization", format!("Bearer {}", api_key))
595 .body(body)
596 .map(|request| client.send(request));
597
598 async move {
599 let mut response = request?.await?;
600 let mut body = String::new();
601 response.body_mut().read_to_string(&mut body).await?;
602
603 if response.status().is_success() {
604 let response: OpenAiEmbeddingResponse =
605 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
606 Ok(response)
607 } else {
608 Err(anyhow!(
609 "error during embedding, status: {:?}, body: {:?}",
610 response.status(),
611 body
612 ))
613 }
614 }
615}
616
617pub async fn extract_tool_args_from_events(
618 tool_name: String,
619 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
620) -> Result<impl Send + Stream<Item = Result<String>>> {
621 let mut tool_use_index = None;
622 let mut first_chunk = None;
623 while let Some(event) = events.next().await {
624 let call = event?.choices.into_iter().find_map(|choice| {
625 choice.delta.tool_calls?.into_iter().find_map(|call| {
626 if call.function.as_ref()?.name.as_deref()? == tool_name {
627 Some(call)
628 } else {
629 None
630 }
631 })
632 });
633 if let Some(call) = call {
634 tool_use_index = Some(call.index);
635 first_chunk = call.function.and_then(|func| func.arguments);
636 break;
637 }
638 }
639
640 let Some(tool_use_index) = tool_use_index else {
641 return Err(anyhow!("tool not used"));
642 };
643
644 Ok(events.filter_map(move |event| {
645 let result = match event {
646 Err(error) => Some(Err(error)),
647 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
648 choice.delta.tool_calls?.into_iter().find_map(|call| {
649 if call.index == tool_use_index {
650 let func = call.function?;
651 let mut arguments = func.arguments?;
652 if let Some(mut first_chunk) = first_chunk.take() {
653 first_chunk.push_str(&arguments);
654 arguments = first_chunk
655 }
656 Some(Ok(arguments))
657 } else {
658 None
659 }
660 })
661 }),
662 };
663
664 async move { result }
665 }))
666}
667
668pub fn extract_text_from_events(
669 response: impl Stream<Item = Result<ResponseStreamEvent>>,
670) -> impl Stream<Item = Result<String>> {
671 response.filter_map(|response| async move {
672 match response {
673 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
674 Err(error) => Some(Err(error)),
675 }
676 })
677}