1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{
5 io::BufReader,
6 stream::{self, BoxStream},
7 AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
8};
9use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15 pin::Pin,
16 time::Duration,
17};
18use strum::EnumIter;
19
20pub use supported_countries::*;
21
22pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
23
24fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
25 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
26}
27
28#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
29#[serde(rename_all = "lowercase")]
30pub enum Role {
31 User,
32 Assistant,
33 System,
34 Tool,
35}
36
37impl TryFrom<String> for Role {
38 type Error = anyhow::Error;
39
40 fn try_from(value: String) -> Result<Self> {
41 match value.as_str() {
42 "user" => Ok(Self::User),
43 "assistant" => Ok(Self::Assistant),
44 "system" => Ok(Self::System),
45 "tool" => Ok(Self::Tool),
46 _ => Err(anyhow!("invalid role '{value}'")),
47 }
48 }
49}
50
51impl From<Role> for String {
52 fn from(val: Role) -> Self {
53 match val {
54 Role::User => "user".to_owned(),
55 Role::Assistant => "assistant".to_owned(),
56 Role::System => "system".to_owned(),
57 Role::Tool => "tool".to_owned(),
58 }
59 }
60}
61
62#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
63#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
64pub enum Model {
65 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
66 ThreePointFiveTurbo,
67 #[serde(rename = "gpt-4", alias = "gpt-4")]
68 Four,
69 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
70 FourTurbo,
71 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
72 #[default]
73 FourOmni,
74 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
75 FourOmniMini,
76 #[serde(rename = "o1-preview", alias = "o1-preview")]
77 O1Preview,
78 #[serde(rename = "o1-mini", alias = "o1-mini")]
79 O1Mini,
80
81 #[serde(rename = "custom")]
82 Custom {
83 name: String,
84 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
85 display_name: Option<String>,
86 max_tokens: usize,
87 max_output_tokens: Option<u32>,
88 max_completion_tokens: Option<u32>,
89 },
90}
91
92impl Model {
93 pub fn from_id(id: &str) -> Result<Self> {
94 match id {
95 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
96 "gpt-4" => Ok(Self::Four),
97 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
98 "gpt-4o" => Ok(Self::FourOmni),
99 "gpt-4o-mini" => Ok(Self::FourOmniMini),
100 "o1-preview" => Ok(Self::O1Preview),
101 "o1-mini" => Ok(Self::O1Mini),
102 _ => Err(anyhow!("invalid model id")),
103 }
104 }
105
106 pub fn id(&self) -> &str {
107 match self {
108 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
109 Self::Four => "gpt-4",
110 Self::FourTurbo => "gpt-4-turbo",
111 Self::FourOmni => "gpt-4o",
112 Self::FourOmniMini => "gpt-4o-mini",
113 Self::O1Preview => "o1-preview",
114 Self::O1Mini => "o1-mini",
115 Self::Custom { name, .. } => name,
116 }
117 }
118
119 pub fn display_name(&self) -> &str {
120 match self {
121 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
122 Self::Four => "gpt-4",
123 Self::FourTurbo => "gpt-4-turbo",
124 Self::FourOmni => "gpt-4o",
125 Self::FourOmniMini => "gpt-4o-mini",
126 Self::O1Preview => "o1-preview",
127 Self::O1Mini => "o1-mini",
128 Self::Custom {
129 name, display_name, ..
130 } => display_name.as_ref().unwrap_or(name),
131 }
132 }
133
134 pub fn max_token_count(&self) -> usize {
135 match self {
136 Self::ThreePointFiveTurbo => 16385,
137 Self::Four => 8192,
138 Self::FourTurbo => 128000,
139 Self::FourOmni => 128000,
140 Self::FourOmniMini => 128000,
141 Self::O1Preview => 128000,
142 Self::O1Mini => 128000,
143 Self::Custom { max_tokens, .. } => *max_tokens,
144 }
145 }
146
147 pub fn max_output_tokens(&self) -> Option<u32> {
148 match self {
149 Self::Custom {
150 max_output_tokens, ..
151 } => *max_output_tokens,
152 _ => None,
153 }
154 }
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158pub struct Request {
159 pub model: String,
160 pub messages: Vec<RequestMessage>,
161 pub stream: bool,
162 #[serde(default, skip_serializing_if = "Option::is_none")]
163 pub max_tokens: Option<u32>,
164 #[serde(default, skip_serializing_if = "Vec::is_empty")]
165 pub stop: Vec<String>,
166 pub temperature: f32,
167 #[serde(default, skip_serializing_if = "Option::is_none")]
168 pub tool_choice: Option<ToolChoice>,
169 #[serde(default, skip_serializing_if = "Vec::is_empty")]
170 pub tools: Vec<ToolDefinition>,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174#[serde(untagged)]
175pub enum ToolChoice {
176 Auto,
177 Required,
178 None,
179 Other(ToolDefinition),
180}
181
182#[derive(Clone, Deserialize, Serialize, Debug)]
183#[serde(tag = "type", rename_all = "snake_case")]
184pub enum ToolDefinition {
185 #[allow(dead_code)]
186 Function { function: FunctionDefinition },
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190pub struct FunctionDefinition {
191 pub name: String,
192 pub description: Option<String>,
193 pub parameters: Option<Value>,
194}
195
196#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
197#[serde(tag = "role", rename_all = "lowercase")]
198pub enum RequestMessage {
199 Assistant {
200 content: Option<String>,
201 #[serde(default, skip_serializing_if = "Vec::is_empty")]
202 tool_calls: Vec<ToolCall>,
203 },
204 User {
205 content: String,
206 },
207 System {
208 content: String,
209 },
210 Tool {
211 content: String,
212 tool_call_id: String,
213 },
214}
215
216#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
217pub struct ToolCall {
218 pub id: String,
219 #[serde(flatten)]
220 pub content: ToolCallContent,
221}
222
223#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
224#[serde(tag = "type", rename_all = "lowercase")]
225pub enum ToolCallContent {
226 Function { function: FunctionContent },
227}
228
229#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
230pub struct FunctionContent {
231 pub name: String,
232 pub arguments: String,
233}
234
235#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
236pub struct ResponseMessageDelta {
237 pub role: Option<Role>,
238 pub content: Option<String>,
239 #[serde(default, skip_serializing_if = "is_none_or_empty")]
240 pub tool_calls: Option<Vec<ToolCallChunk>>,
241}
242
243#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
244pub struct ToolCallChunk {
245 pub index: usize,
246 pub id: Option<String>,
247
248 // There is also an optional `type` field that would determine if a
249 // function is there. Sometimes this streams in with the `function` before
250 // it streams in the `type`
251 pub function: Option<FunctionChunk>,
252}
253
254#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
255pub struct FunctionChunk {
256 pub name: Option<String>,
257 pub arguments: Option<String>,
258}
259
260#[derive(Serialize, Deserialize, Debug)]
261pub struct Usage {
262 pub prompt_tokens: u32,
263 pub completion_tokens: u32,
264 pub total_tokens: u32,
265}
266
267#[derive(Serialize, Deserialize, Debug)]
268pub struct ChoiceDelta {
269 pub index: u32,
270 pub delta: ResponseMessageDelta,
271 pub finish_reason: Option<String>,
272}
273
274#[derive(Serialize, Deserialize, Debug)]
275#[serde(untagged)]
276pub enum ResponseStreamResult {
277 Ok(ResponseStreamEvent),
278 Err { error: String },
279}
280
281#[derive(Serialize, Deserialize, Debug)]
282pub struct ResponseStreamEvent {
283 pub created: u32,
284 pub model: String,
285 pub choices: Vec<ChoiceDelta>,
286 pub usage: Option<Usage>,
287}
288
289#[derive(Serialize, Deserialize, Debug)]
290pub struct Response {
291 pub id: String,
292 pub object: String,
293 pub created: u64,
294 pub model: String,
295 pub choices: Vec<Choice>,
296 pub usage: Usage,
297}
298
299#[derive(Serialize, Deserialize, Debug)]
300pub struct Choice {
301 pub index: u32,
302 pub message: RequestMessage,
303 pub finish_reason: Option<String>,
304}
305
306pub async fn complete(
307 client: &dyn HttpClient,
308 api_url: &str,
309 api_key: &str,
310 request: Request,
311 low_speed_timeout: Option<Duration>,
312) -> Result<Response> {
313 let uri = format!("{api_url}/chat/completions");
314 let mut request_builder = HttpRequest::builder()
315 .method(Method::POST)
316 .uri(uri)
317 .header("Content-Type", "application/json")
318 .header("Authorization", format!("Bearer {}", api_key));
319 if let Some(low_speed_timeout) = low_speed_timeout {
320 request_builder = request_builder.read_timeout(low_speed_timeout);
321 };
322
323 let mut request_body = request;
324 request_body.stream = false;
325
326 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
327 let mut response = client.send(request).await?;
328
329 if response.status().is_success() {
330 let mut body = String::new();
331 response.body_mut().read_to_string(&mut body).await?;
332 let response: Response = serde_json::from_str(&body)?;
333 Ok(response)
334 } else {
335 let mut body = String::new();
336 response.body_mut().read_to_string(&mut body).await?;
337
338 #[derive(Deserialize)]
339 struct OpenAiResponse {
340 error: OpenAiError,
341 }
342
343 #[derive(Deserialize)]
344 struct OpenAiError {
345 message: String,
346 }
347
348 match serde_json::from_str::<OpenAiResponse>(&body) {
349 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
350 "Failed to connect to OpenAI API: {}",
351 response.error.message,
352 )),
353
354 _ => Err(anyhow!(
355 "Failed to connect to OpenAI API: {} {}",
356 response.status(),
357 body,
358 )),
359 }
360 }
361}
362
363fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
364 ResponseStreamEvent {
365 created: response.created as u32,
366 model: response.model,
367 choices: response
368 .choices
369 .into_iter()
370 .map(|choice| ChoiceDelta {
371 index: choice.index,
372 delta: ResponseMessageDelta {
373 role: Some(match choice.message {
374 RequestMessage::Assistant { .. } => Role::Assistant,
375 RequestMessage::User { .. } => Role::User,
376 RequestMessage::System { .. } => Role::System,
377 RequestMessage::Tool { .. } => Role::Tool,
378 }),
379 content: match choice.message {
380 RequestMessage::Assistant { content, .. } => content,
381 RequestMessage::User { content } => Some(content),
382 RequestMessage::System { content } => Some(content),
383 RequestMessage::Tool { content, .. } => Some(content),
384 },
385 tool_calls: None,
386 },
387 finish_reason: choice.finish_reason,
388 })
389 .collect(),
390 usage: Some(response.usage),
391 }
392}
393
394pub async fn stream_completion(
395 client: &dyn HttpClient,
396 api_url: &str,
397 api_key: &str,
398 request: Request,
399 low_speed_timeout: Option<Duration>,
400) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
401 if request.model == "o1-preview" || request.model == "o1-mini" {
402 let response = complete(client, api_url, api_key, request, low_speed_timeout).await;
403 let response_stream_event = response.map(adapt_response_to_stream);
404 return Ok(stream::once(future::ready(response_stream_event)).boxed());
405 }
406
407 let uri = format!("{api_url}/chat/completions");
408 let mut request_builder = HttpRequest::builder()
409 .method(Method::POST)
410 .uri(uri)
411 .header("Content-Type", "application/json")
412 .header("Authorization", format!("Bearer {}", api_key));
413
414 if let Some(low_speed_timeout) = low_speed_timeout {
415 request_builder = request_builder.read_timeout(low_speed_timeout);
416 };
417
418 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
419 let mut response = client.send(request).await?;
420 if response.status().is_success() {
421 let reader = BufReader::new(response.into_body());
422 Ok(reader
423 .lines()
424 .filter_map(|line| async move {
425 match line {
426 Ok(line) => {
427 let line = line.strip_prefix("data: ")?;
428 if line == "[DONE]" {
429 None
430 } else {
431 match serde_json::from_str(line) {
432 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
433 Ok(ResponseStreamResult::Err { error }) => {
434 Some(Err(anyhow!(error)))
435 }
436 Err(error) => Some(Err(anyhow!(error))),
437 }
438 }
439 }
440 Err(error) => Some(Err(anyhow!(error))),
441 }
442 })
443 .boxed())
444 } else {
445 let mut body = String::new();
446 response.body_mut().read_to_string(&mut body).await?;
447
448 #[derive(Deserialize)]
449 struct OpenAiResponse {
450 error: OpenAiError,
451 }
452
453 #[derive(Deserialize)]
454 struct OpenAiError {
455 message: String,
456 }
457
458 match serde_json::from_str::<OpenAiResponse>(&body) {
459 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
460 "Failed to connect to OpenAI API: {}",
461 response.error.message,
462 )),
463
464 _ => Err(anyhow!(
465 "Failed to connect to OpenAI API: {} {}",
466 response.status(),
467 body,
468 )),
469 }
470 }
471}
472
473#[derive(Copy, Clone, Serialize, Deserialize)]
474pub enum OpenAiEmbeddingModel {
475 #[serde(rename = "text-embedding-3-small")]
476 TextEmbedding3Small,
477 #[serde(rename = "text-embedding-3-large")]
478 TextEmbedding3Large,
479}
480
481#[derive(Serialize)]
482struct OpenAiEmbeddingRequest<'a> {
483 model: OpenAiEmbeddingModel,
484 input: Vec<&'a str>,
485}
486
487#[derive(Deserialize)]
488pub struct OpenAiEmbeddingResponse {
489 pub data: Vec<OpenAiEmbedding>,
490}
491
492#[derive(Deserialize)]
493pub struct OpenAiEmbedding {
494 pub embedding: Vec<f32>,
495}
496
497pub fn embed<'a>(
498 client: &dyn HttpClient,
499 api_url: &str,
500 api_key: &str,
501 model: OpenAiEmbeddingModel,
502 texts: impl IntoIterator<Item = &'a str>,
503) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
504 let uri = format!("{api_url}/embeddings");
505
506 let request = OpenAiEmbeddingRequest {
507 model,
508 input: texts.into_iter().collect(),
509 };
510 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
511 let request = HttpRequest::builder()
512 .method(Method::POST)
513 .uri(uri)
514 .header("Content-Type", "application/json")
515 .header("Authorization", format!("Bearer {}", api_key))
516 .body(body)
517 .map(|request| client.send(request));
518
519 async move {
520 let mut response = request?.await?;
521 let mut body = String::new();
522 response.body_mut().read_to_string(&mut body).await?;
523
524 if response.status().is_success() {
525 let response: OpenAiEmbeddingResponse =
526 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
527 Ok(response)
528 } else {
529 Err(anyhow!(
530 "error during embedding, status: {:?}, body: {:?}",
531 response.status(),
532 body
533 ))
534 }
535 }
536}
537
538pub async fn extract_tool_args_from_events(
539 tool_name: String,
540 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
541) -> Result<impl Send + Stream<Item = Result<String>>> {
542 let mut tool_use_index = None;
543 let mut first_chunk = None;
544 while let Some(event) = events.next().await {
545 let call = event?.choices.into_iter().find_map(|choice| {
546 choice.delta.tool_calls?.into_iter().find_map(|call| {
547 if call.function.as_ref()?.name.as_deref()? == tool_name {
548 Some(call)
549 } else {
550 None
551 }
552 })
553 });
554 if let Some(call) = call {
555 tool_use_index = Some(call.index);
556 first_chunk = call.function.and_then(|func| func.arguments);
557 break;
558 }
559 }
560
561 let Some(tool_use_index) = tool_use_index else {
562 return Err(anyhow!("tool not used"));
563 };
564
565 Ok(events.filter_map(move |event| {
566 let result = match event {
567 Err(error) => Some(Err(error)),
568 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
569 choice.delta.tool_calls?.into_iter().find_map(|call| {
570 if call.index == tool_use_index {
571 let func = call.function?;
572 let mut arguments = func.arguments?;
573 if let Some(mut first_chunk) = first_chunk.take() {
574 first_chunk.push_str(&arguments);
575 arguments = first_chunk
576 }
577 Some(Ok(arguments))
578 } else {
579 None
580 }
581 })
582 }),
583 };
584
585 async move { result }
586 }))
587}
588
589pub fn extract_text_from_events(
590 response: impl Stream<Item = Result<ResponseStreamEvent>>,
591) -> impl Stream<Item = Result<String>> {
592 response.filter_map(|response| async move {
593 match response {
594 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
595 Err(error) => Some(Err(error)),
596 }
597 })
598}