1mod supported_countries;
2
3use anyhow::{anyhow, Context as _, Result};
4use futures::{
5 io::BufReader,
6 stream::{self, BoxStream},
7 AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15 pin::Pin,
16};
17use strum::EnumIter;
18
19pub use supported_countries::*;
20
21pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
22
23fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
24 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
25}
26
27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum Role {
30 User,
31 Assistant,
32 System,
33 Tool,
34}
35
36impl TryFrom<String> for Role {
37 type Error = anyhow::Error;
38
39 fn try_from(value: String) -> Result<Self> {
40 match value.as_str() {
41 "user" => Ok(Self::User),
42 "assistant" => Ok(Self::Assistant),
43 "system" => Ok(Self::System),
44 "tool" => Ok(Self::Tool),
45 _ => Err(anyhow!("invalid role '{value}'")),
46 }
47 }
48}
49
50impl From<Role> for String {
51 fn from(val: Role) -> Self {
52 match val {
53 Role::User => "user".to_owned(),
54 Role::Assistant => "assistant".to_owned(),
55 Role::System => "system".to_owned(),
56 Role::Tool => "tool".to_owned(),
57 }
58 }
59}
60
61#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
62#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
63pub enum Model {
64 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
65 ThreePointFiveTurbo,
66 #[serde(rename = "gpt-4", alias = "gpt-4")]
67 Four,
68 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
69 FourTurbo,
70 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
71 #[default]
72 FourOmni,
73 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
74 FourOmniMini,
75 #[serde(rename = "o1", alias = "o1")]
76 O1,
77 #[serde(rename = "o1-preview", alias = "o1-preview")]
78 O1Preview,
79 #[serde(rename = "o1-mini", alias = "o1-mini")]
80 O1Mini,
81 #[serde(rename = "o3-mini", alias = "o3-mini")]
82 O3Mini,
83
84 #[serde(rename = "custom")]
85 Custom {
86 name: String,
87 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
88 display_name: Option<String>,
89 max_tokens: usize,
90 max_output_tokens: Option<u32>,
91 max_completion_tokens: Option<u32>,
92 },
93}
94
95impl Model {
96 pub fn from_id(id: &str) -> Result<Self> {
97 match id {
98 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
99 "gpt-4" => Ok(Self::Four),
100 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
101 "gpt-4o" => Ok(Self::FourOmni),
102 "gpt-4o-mini" => Ok(Self::FourOmniMini),
103 "o1" => Ok(Self::O1),
104 "o1-preview" => Ok(Self::O1Preview),
105 "o1-mini" => Ok(Self::O1Mini),
106 _ => Err(anyhow!("invalid model id")),
107 }
108 }
109
110 pub fn id(&self) -> &str {
111 match self {
112 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
113 Self::Four => "gpt-4",
114 Self::FourTurbo => "gpt-4-turbo",
115 Self::FourOmni => "gpt-4o",
116 Self::FourOmniMini => "gpt-4o-mini",
117 Self::O1 => "o1",
118 Self::O1Preview => "o1-preview",
119 Self::O1Mini => "o1-mini",
120 Self::O3Mini => "o3-mini",
121 Self::Custom { name, .. } => name,
122 }
123 }
124
125 pub fn display_name(&self) -> &str {
126 match self {
127 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
128 Self::Four => "gpt-4",
129 Self::FourTurbo => "gpt-4-turbo",
130 Self::FourOmni => "gpt-4o",
131 Self::FourOmniMini => "gpt-4o-mini",
132 Self::O1 => "o1",
133 Self::O1Preview => "o1-preview",
134 Self::O1Mini => "o1-mini",
135 Self::O3Mini => "o3-mini",
136 Self::Custom {
137 name, display_name, ..
138 } => display_name.as_ref().unwrap_or(name),
139 }
140 }
141
142 pub fn max_token_count(&self) -> usize {
143 match self {
144 Self::ThreePointFiveTurbo => 16385,
145 Self::Four => 8192,
146 Self::FourTurbo => 128000,
147 Self::FourOmni => 128000,
148 Self::FourOmniMini => 128000,
149 Self::O1 => 200000,
150 Self::O1Preview => 128000,
151 Self::O1Mini => 128000,
152 Self::O3Mini => 200000,
153 Self::Custom { max_tokens, .. } => *max_tokens,
154 }
155 }
156
157 pub fn max_output_tokens(&self) -> Option<u32> {
158 match self {
159 Self::Custom {
160 max_output_tokens, ..
161 } => *max_output_tokens,
162 _ => None,
163 }
164 }
165}
166
167#[derive(Debug, Serialize, Deserialize)]
168pub struct Request {
169 pub model: String,
170 pub messages: Vec<RequestMessage>,
171 pub stream: bool,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
173 pub max_tokens: Option<u32>,
174 #[serde(default, skip_serializing_if = "Vec::is_empty")]
175 pub stop: Vec<String>,
176 pub temperature: f32,
177 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub tool_choice: Option<ToolChoice>,
179 #[serde(default, skip_serializing_if = "Vec::is_empty")]
180 pub tools: Vec<ToolDefinition>,
181}
182
183#[derive(Debug, Serialize, Deserialize)]
184pub struct CompletionRequest {
185 pub model: String,
186 pub prompt: String,
187 pub max_tokens: u32,
188 pub temperature: f32,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub prediction: Option<Prediction>,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub rewrite_speculation: Option<bool>,
193}
194
195#[derive(Clone, Deserialize, Serialize, Debug)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum Prediction {
198 Content { content: String },
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202#[serde(untagged)]
203pub enum ToolChoice {
204 Auto,
205 Required,
206 None,
207 Other(ToolDefinition),
208}
209
210#[derive(Clone, Deserialize, Serialize, Debug)]
211#[serde(tag = "type", rename_all = "snake_case")]
212pub enum ToolDefinition {
213 #[allow(dead_code)]
214 Function { function: FunctionDefinition },
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize)]
218pub struct FunctionDefinition {
219 pub name: String,
220 pub description: Option<String>,
221 pub parameters: Option<Value>,
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225#[serde(tag = "role", rename_all = "lowercase")]
226pub enum RequestMessage {
227 Assistant {
228 content: Option<String>,
229 #[serde(default, skip_serializing_if = "Vec::is_empty")]
230 tool_calls: Vec<ToolCall>,
231 },
232 User {
233 content: String,
234 },
235 System {
236 content: String,
237 },
238 Tool {
239 content: String,
240 tool_call_id: String,
241 },
242}
243
244#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
245pub struct ToolCall {
246 pub id: String,
247 #[serde(flatten)]
248 pub content: ToolCallContent,
249}
250
251#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
252#[serde(tag = "type", rename_all = "lowercase")]
253pub enum ToolCallContent {
254 Function { function: FunctionContent },
255}
256
257#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
258pub struct FunctionContent {
259 pub name: String,
260 pub arguments: String,
261}
262
263#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
264pub struct ResponseMessageDelta {
265 pub role: Option<Role>,
266 pub content: Option<String>,
267 #[serde(default, skip_serializing_if = "is_none_or_empty")]
268 pub tool_calls: Option<Vec<ToolCallChunk>>,
269}
270
271#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
272pub struct ToolCallChunk {
273 pub index: usize,
274 pub id: Option<String>,
275
276 // There is also an optional `type` field that would determine if a
277 // function is there. Sometimes this streams in with the `function` before
278 // it streams in the `type`
279 pub function: Option<FunctionChunk>,
280}
281
282#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
283pub struct FunctionChunk {
284 pub name: Option<String>,
285 pub arguments: Option<String>,
286}
287
288#[derive(Serialize, Deserialize, Debug)]
289pub struct Usage {
290 pub prompt_tokens: u32,
291 pub completion_tokens: u32,
292 pub total_tokens: u32,
293}
294
295#[derive(Serialize, Deserialize, Debug)]
296pub struct ChoiceDelta {
297 pub index: u32,
298 pub delta: ResponseMessageDelta,
299 pub finish_reason: Option<String>,
300}
301
302#[derive(Serialize, Deserialize, Debug)]
303#[serde(untagged)]
304pub enum ResponseStreamResult {
305 Ok(ResponseStreamEvent),
306 Err { error: String },
307}
308
309#[derive(Serialize, Deserialize, Debug)]
310pub struct ResponseStreamEvent {
311 pub created: u32,
312 pub model: String,
313 pub choices: Vec<ChoiceDelta>,
314 pub usage: Option<Usage>,
315}
316
317#[derive(Serialize, Deserialize, Debug)]
318pub struct CompletionResponse {
319 pub id: String,
320 pub object: String,
321 pub created: u64,
322 pub model: String,
323 pub choices: Vec<CompletionChoice>,
324 pub usage: Usage,
325}
326
327#[derive(Serialize, Deserialize, Debug)]
328pub struct CompletionChoice {
329 pub text: String,
330}
331
332#[derive(Serialize, Deserialize, Debug)]
333pub struct Response {
334 pub id: String,
335 pub object: String,
336 pub created: u64,
337 pub model: String,
338 pub choices: Vec<Choice>,
339 pub usage: Usage,
340}
341
342#[derive(Serialize, Deserialize, Debug)]
343pub struct Choice {
344 pub index: u32,
345 pub message: RequestMessage,
346 pub finish_reason: Option<String>,
347}
348
349pub async fn complete(
350 client: &dyn HttpClient,
351 api_url: &str,
352 api_key: &str,
353 request: Request,
354) -> Result<Response> {
355 let uri = format!("{api_url}/chat/completions");
356 let request_builder = HttpRequest::builder()
357 .method(Method::POST)
358 .uri(uri)
359 .header("Content-Type", "application/json")
360 .header("Authorization", format!("Bearer {}", api_key));
361
362 let mut request_body = request;
363 request_body.stream = false;
364
365 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
366 let mut response = client.send(request).await?;
367
368 if response.status().is_success() {
369 let mut body = String::new();
370 response.body_mut().read_to_string(&mut body).await?;
371 let response: Response = serde_json::from_str(&body)?;
372 Ok(response)
373 } else {
374 let mut body = String::new();
375 response.body_mut().read_to_string(&mut body).await?;
376
377 #[derive(Deserialize)]
378 struct OpenAiResponse {
379 error: OpenAiError,
380 }
381
382 #[derive(Deserialize)]
383 struct OpenAiError {
384 message: String,
385 }
386
387 match serde_json::from_str::<OpenAiResponse>(&body) {
388 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
389 "Failed to connect to OpenAI API: {}",
390 response.error.message,
391 )),
392
393 _ => Err(anyhow!(
394 "Failed to connect to OpenAI API: {} {}",
395 response.status(),
396 body,
397 )),
398 }
399 }
400}
401
402pub async fn complete_text(
403 client: &dyn HttpClient,
404 api_url: &str,
405 api_key: &str,
406 request: CompletionRequest,
407) -> Result<CompletionResponse> {
408 let uri = format!("{api_url}/completions");
409 let request_builder = HttpRequest::builder()
410 .method(Method::POST)
411 .uri(uri)
412 .header("Content-Type", "application/json")
413 .header("Authorization", format!("Bearer {}", api_key));
414
415 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
416 let mut response = client.send(request).await?;
417
418 if response.status().is_success() {
419 let mut body = String::new();
420 response.body_mut().read_to_string(&mut body).await?;
421 let response = serde_json::from_str(&body)?;
422 Ok(response)
423 } else {
424 let mut body = String::new();
425 response.body_mut().read_to_string(&mut body).await?;
426
427 #[derive(Deserialize)]
428 struct OpenAiResponse {
429 error: OpenAiError,
430 }
431
432 #[derive(Deserialize)]
433 struct OpenAiError {
434 message: String,
435 }
436
437 match serde_json::from_str::<OpenAiResponse>(&body) {
438 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
439 "Failed to connect to OpenAI API: {}",
440 response.error.message,
441 )),
442
443 _ => Err(anyhow!(
444 "Failed to connect to OpenAI API: {} {}",
445 response.status(),
446 body,
447 )),
448 }
449 }
450}
451
452fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
453 ResponseStreamEvent {
454 created: response.created as u32,
455 model: response.model,
456 choices: response
457 .choices
458 .into_iter()
459 .map(|choice| ChoiceDelta {
460 index: choice.index,
461 delta: ResponseMessageDelta {
462 role: Some(match choice.message {
463 RequestMessage::Assistant { .. } => Role::Assistant,
464 RequestMessage::User { .. } => Role::User,
465 RequestMessage::System { .. } => Role::System,
466 RequestMessage::Tool { .. } => Role::Tool,
467 }),
468 content: match choice.message {
469 RequestMessage::Assistant { content, .. } => content,
470 RequestMessage::User { content } => Some(content),
471 RequestMessage::System { content } => Some(content),
472 RequestMessage::Tool { content, .. } => Some(content),
473 },
474 tool_calls: None,
475 },
476 finish_reason: choice.finish_reason,
477 })
478 .collect(),
479 usage: Some(response.usage),
480 }
481}
482
483pub async fn stream_completion(
484 client: &dyn HttpClient,
485 api_url: &str,
486 api_key: &str,
487 request: Request,
488) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
489 if request.model.starts_with("o1") {
490 let response = complete(client, api_url, api_key, request).await;
491 let response_stream_event = response.map(adapt_response_to_stream);
492 return Ok(stream::once(future::ready(response_stream_event)).boxed());
493 }
494
495 let uri = format!("{api_url}/chat/completions");
496 let request_builder = HttpRequest::builder()
497 .method(Method::POST)
498 .uri(uri)
499 .header("Content-Type", "application/json")
500 .header("Authorization", format!("Bearer {}", api_key));
501
502 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
503 let mut response = client.send(request).await?;
504 if response.status().is_success() {
505 let reader = BufReader::new(response.into_body());
506 Ok(reader
507 .lines()
508 .filter_map(|line| async move {
509 match line {
510 Ok(line) => {
511 let line = line.strip_prefix("data: ")?;
512 if line == "[DONE]" {
513 None
514 } else {
515 match serde_json::from_str(line) {
516 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
517 Ok(ResponseStreamResult::Err { error }) => {
518 Some(Err(anyhow!(error)))
519 }
520 Err(error) => Some(Err(anyhow!(error))),
521 }
522 }
523 }
524 Err(error) => Some(Err(anyhow!(error))),
525 }
526 })
527 .boxed())
528 } else {
529 let mut body = String::new();
530 response.body_mut().read_to_string(&mut body).await?;
531
532 #[derive(Deserialize)]
533 struct OpenAiResponse {
534 error: OpenAiError,
535 }
536
537 #[derive(Deserialize)]
538 struct OpenAiError {
539 message: String,
540 }
541
542 match serde_json::from_str::<OpenAiResponse>(&body) {
543 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
544 "Failed to connect to OpenAI API: {}",
545 response.error.message,
546 )),
547
548 _ => Err(anyhow!(
549 "Failed to connect to OpenAI API: {} {}",
550 response.status(),
551 body,
552 )),
553 }
554 }
555}
556
557#[derive(Copy, Clone, Serialize, Deserialize)]
558pub enum OpenAiEmbeddingModel {
559 #[serde(rename = "text-embedding-3-small")]
560 TextEmbedding3Small,
561 #[serde(rename = "text-embedding-3-large")]
562 TextEmbedding3Large,
563}
564
565#[derive(Serialize)]
566struct OpenAiEmbeddingRequest<'a> {
567 model: OpenAiEmbeddingModel,
568 input: Vec<&'a str>,
569}
570
571#[derive(Deserialize)]
572pub struct OpenAiEmbeddingResponse {
573 pub data: Vec<OpenAiEmbedding>,
574}
575
576#[derive(Deserialize)]
577pub struct OpenAiEmbedding {
578 pub embedding: Vec<f32>,
579}
580
581pub fn embed<'a>(
582 client: &dyn HttpClient,
583 api_url: &str,
584 api_key: &str,
585 model: OpenAiEmbeddingModel,
586 texts: impl IntoIterator<Item = &'a str>,
587) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
588 let uri = format!("{api_url}/embeddings");
589
590 let request = OpenAiEmbeddingRequest {
591 model,
592 input: texts.into_iter().collect(),
593 };
594 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
595 let request = HttpRequest::builder()
596 .method(Method::POST)
597 .uri(uri)
598 .header("Content-Type", "application/json")
599 .header("Authorization", format!("Bearer {}", api_key))
600 .body(body)
601 .map(|request| client.send(request));
602
603 async move {
604 let mut response = request?.await?;
605 let mut body = String::new();
606 response.body_mut().read_to_string(&mut body).await?;
607
608 if response.status().is_success() {
609 let response: OpenAiEmbeddingResponse =
610 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
611 Ok(response)
612 } else {
613 Err(anyhow!(
614 "error during embedding, status: {:?}, body: {:?}",
615 response.status(),
616 body
617 ))
618 }
619 }
620}
621
622pub async fn extract_tool_args_from_events(
623 tool_name: String,
624 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
625) -> Result<impl Send + Stream<Item = Result<String>>> {
626 let mut tool_use_index = None;
627 let mut first_chunk = None;
628 while let Some(event) = events.next().await {
629 let call = event?.choices.into_iter().find_map(|choice| {
630 choice.delta.tool_calls?.into_iter().find_map(|call| {
631 if call.function.as_ref()?.name.as_deref()? == tool_name {
632 Some(call)
633 } else {
634 None
635 }
636 })
637 });
638 if let Some(call) = call {
639 tool_use_index = Some(call.index);
640 first_chunk = call.function.and_then(|func| func.arguments);
641 break;
642 }
643 }
644
645 let Some(tool_use_index) = tool_use_index else {
646 return Err(anyhow!("tool not used"));
647 };
648
649 Ok(events.filter_map(move |event| {
650 let result = match event {
651 Err(error) => Some(Err(error)),
652 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
653 choice.delta.tool_calls?.into_iter().find_map(|call| {
654 if call.index == tool_use_index {
655 let func = call.function?;
656 let mut arguments = func.arguments?;
657 if let Some(mut first_chunk) = first_chunk.take() {
658 first_chunk.push_str(&arguments);
659 arguments = first_chunk
660 }
661 Some(Ok(arguments))
662 } else {
663 None
664 }
665 })
666 }),
667 };
668
669 async move { result }
670 }))
671}
672
673pub fn extract_text_from_events(
674 response: impl Stream<Item = Result<ResponseStreamEvent>>,
675) -> impl Stream<Item = Result<String>> {
676 response.filter_map(|response| async move {
677 match response {
678 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
679 Err(error) => Some(Err(error)),
680 }
681 })
682}