1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{
5 io::BufReader,
6 stream::{self, BoxStream},
7 AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15 pin::Pin,
16};
17use strum::EnumIter;
18
19pub use supported_countries::*;
20
21pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
22
23fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
24 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
25}
26
27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum Role {
30 User,
31 Assistant,
32 System,
33 Tool,
34}
35
36impl TryFrom<String> for Role {
37 type Error = anyhow::Error;
38
39 fn try_from(value: String) -> Result<Self> {
40 match value.as_str() {
41 "user" => Ok(Self::User),
42 "assistant" => Ok(Self::Assistant),
43 "system" => Ok(Self::System),
44 "tool" => Ok(Self::Tool),
45 _ => Err(anyhow!("invalid role '{value}'")),
46 }
47 }
48}
49
50impl From<Role> for String {
51 fn from(val: Role) -> Self {
52 match val {
53 Role::User => "user".to_owned(),
54 Role::Assistant => "assistant".to_owned(),
55 Role::System => "system".to_owned(),
56 Role::Tool => "tool".to_owned(),
57 }
58 }
59}
60
61#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
62#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
63pub enum Model {
64 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
65 ThreePointFiveTurbo,
66 #[serde(rename = "gpt-4", alias = "gpt-4")]
67 Four,
68 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
69 FourTurbo,
70 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
71 #[default]
72 FourOmni,
73 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
74 FourOmniMini,
75 #[serde(rename = "o1-preview", alias = "o1-preview")]
76 O1Preview,
77 #[serde(rename = "o1-mini", alias = "o1-mini")]
78 O1Mini,
79
80 #[serde(rename = "custom")]
81 Custom {
82 name: String,
83 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
84 display_name: Option<String>,
85 max_tokens: usize,
86 max_output_tokens: Option<u32>,
87 max_completion_tokens: Option<u32>,
88 },
89}
90
91impl Model {
92 pub fn from_id(id: &str) -> Result<Self> {
93 match id {
94 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
95 "gpt-4" => Ok(Self::Four),
96 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
97 "gpt-4o" => Ok(Self::FourOmni),
98 "gpt-4o-mini" => Ok(Self::FourOmniMini),
99 "o1-preview" => Ok(Self::O1Preview),
100 "o1-mini" => Ok(Self::O1Mini),
101 _ => Err(anyhow!("invalid model id")),
102 }
103 }
104
105 pub fn id(&self) -> &str {
106 match self {
107 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
108 Self::Four => "gpt-4",
109 Self::FourTurbo => "gpt-4-turbo",
110 Self::FourOmni => "gpt-4o",
111 Self::FourOmniMini => "gpt-4o-mini",
112 Self::O1Preview => "o1-preview",
113 Self::O1Mini => "o1-mini",
114 Self::Custom { name, .. } => name,
115 }
116 }
117
118 pub fn display_name(&self) -> &str {
119 match self {
120 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
121 Self::Four => "gpt-4",
122 Self::FourTurbo => "gpt-4-turbo",
123 Self::FourOmni => "gpt-4o",
124 Self::FourOmniMini => "gpt-4o-mini",
125 Self::O1Preview => "o1-preview",
126 Self::O1Mini => "o1-mini",
127 Self::Custom {
128 name, display_name, ..
129 } => display_name.as_ref().unwrap_or(name),
130 }
131 }
132
133 pub fn max_token_count(&self) -> usize {
134 match self {
135 Self::ThreePointFiveTurbo => 16385,
136 Self::Four => 8192,
137 Self::FourTurbo => 128000,
138 Self::FourOmni => 128000,
139 Self::FourOmniMini => 128000,
140 Self::O1Preview => 128000,
141 Self::O1Mini => 128000,
142 Self::Custom { max_tokens, .. } => *max_tokens,
143 }
144 }
145
146 pub fn max_output_tokens(&self) -> Option<u32> {
147 match self {
148 Self::Custom {
149 max_output_tokens, ..
150 } => *max_output_tokens,
151 _ => None,
152 }
153 }
154}
155
156#[derive(Debug, Serialize, Deserialize)]
157pub struct Request {
158 pub model: String,
159 pub messages: Vec<RequestMessage>,
160 pub stream: bool,
161 #[serde(default, skip_serializing_if = "Option::is_none")]
162 pub max_tokens: Option<u32>,
163 #[serde(default, skip_serializing_if = "Vec::is_empty")]
164 pub stop: Vec<String>,
165 pub temperature: f32,
166 #[serde(default, skip_serializing_if = "Option::is_none")]
167 pub tool_choice: Option<ToolChoice>,
168 #[serde(default, skip_serializing_if = "Vec::is_empty")]
169 pub tools: Vec<ToolDefinition>,
170}
171
172#[derive(Debug, Serialize, Deserialize)]
173pub struct CompletionRequest {
174 pub model: String,
175 pub prompt: String,
176 pub max_tokens: u32,
177 pub temperature: f32,
178 #[serde(default, skip_serializing_if = "Option::is_none")]
179 pub prediction: Option<Prediction>,
180 #[serde(default, skip_serializing_if = "Option::is_none")]
181 pub rewrite_speculation: Option<bool>,
182}
183
184#[derive(Clone, Deserialize, Serialize, Debug)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum Prediction {
187 Content { content: String },
188}
189
190#[derive(Debug, Serialize, Deserialize)]
191#[serde(untagged)]
192pub enum ToolChoice {
193 Auto,
194 Required,
195 None,
196 Other(ToolDefinition),
197}
198
199#[derive(Clone, Deserialize, Serialize, Debug)]
200#[serde(tag = "type", rename_all = "snake_case")]
201pub enum ToolDefinition {
202 #[allow(dead_code)]
203 Function { function: FunctionDefinition },
204}
205
206#[derive(Clone, Debug, Serialize, Deserialize)]
207pub struct FunctionDefinition {
208 pub name: String,
209 pub description: Option<String>,
210 pub parameters: Option<Value>,
211}
212
213#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
214#[serde(tag = "role", rename_all = "lowercase")]
215pub enum RequestMessage {
216 Assistant {
217 content: Option<String>,
218 #[serde(default, skip_serializing_if = "Vec::is_empty")]
219 tool_calls: Vec<ToolCall>,
220 },
221 User {
222 content: String,
223 },
224 System {
225 content: String,
226 },
227 Tool {
228 content: String,
229 tool_call_id: String,
230 },
231}
232
233#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
234pub struct ToolCall {
235 pub id: String,
236 #[serde(flatten)]
237 pub content: ToolCallContent,
238}
239
240#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
241#[serde(tag = "type", rename_all = "lowercase")]
242pub enum ToolCallContent {
243 Function { function: FunctionContent },
244}
245
246#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
247pub struct FunctionContent {
248 pub name: String,
249 pub arguments: String,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct ResponseMessageDelta {
254 pub role: Option<Role>,
255 pub content: Option<String>,
256 #[serde(default, skip_serializing_if = "is_none_or_empty")]
257 pub tool_calls: Option<Vec<ToolCallChunk>>,
258}
259
260#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
261pub struct ToolCallChunk {
262 pub index: usize,
263 pub id: Option<String>,
264
265 // There is also an optional `type` field that would determine if a
266 // function is there. Sometimes this streams in with the `function` before
267 // it streams in the `type`
268 pub function: Option<FunctionChunk>,
269}
270
271#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
272pub struct FunctionChunk {
273 pub name: Option<String>,
274 pub arguments: Option<String>,
275}
276
277#[derive(Serialize, Deserialize, Debug)]
278pub struct Usage {
279 pub prompt_tokens: u32,
280 pub completion_tokens: u32,
281 pub total_tokens: u32,
282}
283
284#[derive(Serialize, Deserialize, Debug)]
285pub struct ChoiceDelta {
286 pub index: u32,
287 pub delta: ResponseMessageDelta,
288 pub finish_reason: Option<String>,
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292#[serde(untagged)]
293pub enum ResponseStreamResult {
294 Ok(ResponseStreamEvent),
295 Err { error: String },
296}
297
298#[derive(Serialize, Deserialize, Debug)]
299pub struct ResponseStreamEvent {
300 pub created: u32,
301 pub model: String,
302 pub choices: Vec<ChoiceDelta>,
303 pub usage: Option<Usage>,
304}
305
306#[derive(Serialize, Deserialize, Debug)]
307pub struct CompletionResponse {
308 pub id: String,
309 pub object: String,
310 pub created: u64,
311 pub model: String,
312 pub choices: Vec<CompletionChoice>,
313 pub usage: Usage,
314}
315
316#[derive(Serialize, Deserialize, Debug)]
317pub struct CompletionChoice {
318 pub text: String,
319}
320
321#[derive(Serialize, Deserialize, Debug)]
322pub struct Response {
323 pub id: String,
324 pub object: String,
325 pub created: u64,
326 pub model: String,
327 pub choices: Vec<Choice>,
328 pub usage: Usage,
329}
330
331#[derive(Serialize, Deserialize, Debug)]
332pub struct Choice {
333 pub index: u32,
334 pub message: RequestMessage,
335 pub finish_reason: Option<String>,
336}
337
338pub async fn complete(
339 client: &dyn HttpClient,
340 api_url: &str,
341 api_key: &str,
342 request: Request,
343) -> Result<Response> {
344 let uri = format!("{api_url}/chat/completions");
345 let request_builder = HttpRequest::builder()
346 .method(Method::POST)
347 .uri(uri)
348 .header("Content-Type", "application/json")
349 .header("Authorization", format!("Bearer {}", api_key));
350
351 let mut request_body = request;
352 request_body.stream = false;
353
354 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
355 let mut response = client.send(request).await?;
356
357 if response.status().is_success() {
358 let mut body = String::new();
359 response.body_mut().read_to_string(&mut body).await?;
360 let response: Response = serde_json::from_str(&body)?;
361 Ok(response)
362 } else {
363 let mut body = String::new();
364 response.body_mut().read_to_string(&mut body).await?;
365
366 #[derive(Deserialize)]
367 struct OpenAiResponse {
368 error: OpenAiError,
369 }
370
371 #[derive(Deserialize)]
372 struct OpenAiError {
373 message: String,
374 }
375
376 match serde_json::from_str::<OpenAiResponse>(&body) {
377 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
378 "Failed to connect to OpenAI API: {}",
379 response.error.message,
380 )),
381
382 _ => Err(anyhow!(
383 "Failed to connect to OpenAI API: {} {}",
384 response.status(),
385 body,
386 )),
387 }
388 }
389}
390
391pub async fn complete_text(
392 client: &dyn HttpClient,
393 api_url: &str,
394 api_key: &str,
395 request: CompletionRequest,
396) -> Result<CompletionResponse> {
397 let uri = format!("{api_url}/completions");
398 let request_builder = HttpRequest::builder()
399 .method(Method::POST)
400 .uri(uri)
401 .header("Content-Type", "application/json")
402 .header("Authorization", format!("Bearer {}", api_key));
403
404 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
405 let mut response = client.send(request).await?;
406
407 if response.status().is_success() {
408 let mut body = String::new();
409 response.body_mut().read_to_string(&mut body).await?;
410 let response = serde_json::from_str(&body)?;
411 Ok(response)
412 } else {
413 let mut body = String::new();
414 response.body_mut().read_to_string(&mut body).await?;
415
416 #[derive(Deserialize)]
417 struct OpenAiResponse {
418 error: OpenAiError,
419 }
420
421 #[derive(Deserialize)]
422 struct OpenAiError {
423 message: String,
424 }
425
426 match serde_json::from_str::<OpenAiResponse>(&body) {
427 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
428 "Failed to connect to OpenAI API: {}",
429 response.error.message,
430 )),
431
432 _ => Err(anyhow!(
433 "Failed to connect to OpenAI API: {} {}",
434 response.status(),
435 body,
436 )),
437 }
438 }
439}
440
441fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
442 ResponseStreamEvent {
443 created: response.created as u32,
444 model: response.model,
445 choices: response
446 .choices
447 .into_iter()
448 .map(|choice| ChoiceDelta {
449 index: choice.index,
450 delta: ResponseMessageDelta {
451 role: Some(match choice.message {
452 RequestMessage::Assistant { .. } => Role::Assistant,
453 RequestMessage::User { .. } => Role::User,
454 RequestMessage::System { .. } => Role::System,
455 RequestMessage::Tool { .. } => Role::Tool,
456 }),
457 content: match choice.message {
458 RequestMessage::Assistant { content, .. } => content,
459 RequestMessage::User { content } => Some(content),
460 RequestMessage::System { content } => Some(content),
461 RequestMessage::Tool { content, .. } => Some(content),
462 },
463 tool_calls: None,
464 },
465 finish_reason: choice.finish_reason,
466 })
467 .collect(),
468 usage: Some(response.usage),
469 }
470}
471
472pub async fn stream_completion(
473 client: &dyn HttpClient,
474 api_url: &str,
475 api_key: &str,
476 request: Request,
477) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
478 if request.model == "o1-preview" || request.model == "o1-mini" {
479 let response = complete(client, api_url, api_key, request).await;
480 let response_stream_event = response.map(adapt_response_to_stream);
481 return Ok(stream::once(future::ready(response_stream_event)).boxed());
482 }
483
484 let uri = format!("{api_url}/chat/completions");
485 let request_builder = HttpRequest::builder()
486 .method(Method::POST)
487 .uri(uri)
488 .header("Content-Type", "application/json")
489 .header("Authorization", format!("Bearer {}", api_key));
490
491 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
492 let mut response = client.send(request).await?;
493 if response.status().is_success() {
494 let reader = BufReader::new(response.into_body());
495 Ok(reader
496 .lines()
497 .filter_map(|line| async move {
498 match line {
499 Ok(line) => {
500 let line = line.strip_prefix("data: ")?;
501 if line == "[DONE]" {
502 None
503 } else {
504 match serde_json::from_str(line) {
505 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
506 Ok(ResponseStreamResult::Err { error }) => {
507 Some(Err(anyhow!(error)))
508 }
509 Err(error) => Some(Err(anyhow!(error))),
510 }
511 }
512 }
513 Err(error) => Some(Err(anyhow!(error))),
514 }
515 })
516 .boxed())
517 } else {
518 let mut body = String::new();
519 response.body_mut().read_to_string(&mut body).await?;
520
521 #[derive(Deserialize)]
522 struct OpenAiResponse {
523 error: OpenAiError,
524 }
525
526 #[derive(Deserialize)]
527 struct OpenAiError {
528 message: String,
529 }
530
531 match serde_json::from_str::<OpenAiResponse>(&body) {
532 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
533 "Failed to connect to OpenAI API: {}",
534 response.error.message,
535 )),
536
537 _ => Err(anyhow!(
538 "Failed to connect to OpenAI API: {} {}",
539 response.status(),
540 body,
541 )),
542 }
543 }
544}
545
546#[derive(Copy, Clone, Serialize, Deserialize)]
547pub enum OpenAiEmbeddingModel {
548 #[serde(rename = "text-embedding-3-small")]
549 TextEmbedding3Small,
550 #[serde(rename = "text-embedding-3-large")]
551 TextEmbedding3Large,
552}
553
554#[derive(Serialize)]
555struct OpenAiEmbeddingRequest<'a> {
556 model: OpenAiEmbeddingModel,
557 input: Vec<&'a str>,
558}
559
560#[derive(Deserialize)]
561pub struct OpenAiEmbeddingResponse {
562 pub data: Vec<OpenAiEmbedding>,
563}
564
565#[derive(Deserialize)]
566pub struct OpenAiEmbedding {
567 pub embedding: Vec<f32>,
568}
569
570pub fn embed<'a>(
571 client: &dyn HttpClient,
572 api_url: &str,
573 api_key: &str,
574 model: OpenAiEmbeddingModel,
575 texts: impl IntoIterator<Item = &'a str>,
576) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
577 let uri = format!("{api_url}/embeddings");
578
579 let request = OpenAiEmbeddingRequest {
580 model,
581 input: texts.into_iter().collect(),
582 };
583 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
584 let request = HttpRequest::builder()
585 .method(Method::POST)
586 .uri(uri)
587 .header("Content-Type", "application/json")
588 .header("Authorization", format!("Bearer {}", api_key))
589 .body(body)
590 .map(|request| client.send(request));
591
592 async move {
593 let mut response = request?.await?;
594 let mut body = String::new();
595 response.body_mut().read_to_string(&mut body).await?;
596
597 if response.status().is_success() {
598 let response: OpenAiEmbeddingResponse =
599 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
600 Ok(response)
601 } else {
602 Err(anyhow!(
603 "error during embedding, status: {:?}, body: {:?}",
604 response.status(),
605 body
606 ))
607 }
608 }
609}
610
611pub async fn extract_tool_args_from_events(
612 tool_name: String,
613 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
614) -> Result<impl Send + Stream<Item = Result<String>>> {
615 let mut tool_use_index = None;
616 let mut first_chunk = None;
617 while let Some(event) = events.next().await {
618 let call = event?.choices.into_iter().find_map(|choice| {
619 choice.delta.tool_calls?.into_iter().find_map(|call| {
620 if call.function.as_ref()?.name.as_deref()? == tool_name {
621 Some(call)
622 } else {
623 None
624 }
625 })
626 });
627 if let Some(call) = call {
628 tool_use_index = Some(call.index);
629 first_chunk = call.function.and_then(|func| func.arguments);
630 break;
631 }
632 }
633
634 let Some(tool_use_index) = tool_use_index else {
635 return Err(anyhow!("tool not used"));
636 };
637
638 Ok(events.filter_map(move |event| {
639 let result = match event {
640 Err(error) => Some(Err(error)),
641 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
642 choice.delta.tool_calls?.into_iter().find_map(|call| {
643 if call.index == tool_use_index {
644 let func = call.function?;
645 let mut arguments = func.arguments?;
646 if let Some(mut first_chunk) = first_chunk.take() {
647 first_chunk.push_str(&arguments);
648 arguments = first_chunk
649 }
650 Some(Ok(arguments))
651 } else {
652 None
653 }
654 })
655 }),
656 };
657
658 async move { result }
659 }))
660}
661
662pub fn extract_text_from_events(
663 response: impl Stream<Item = Result<ResponseStreamEvent>>,
664) -> impl Stream<Item = Result<String>> {
665 response.filter_map(|response| async move {
666 match response {
667 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
668 Err(error) => Some(Err(error)),
669 }
670 })
671}