1mod supported_countries;
2
3use anyhow::{Context as _, Result, anyhow};
4use futures::{
5 AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
6 io::BufReader,
7 stream::{self, BoxStream},
8};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::{
13 convert::TryFrom,
14 future::{self, Future},
15 pin::Pin,
16};
17use strum::EnumIter;
18
19pub use supported_countries::*;
20
21pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
22
23fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
24 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
25}
26
27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum Role {
30 User,
31 Assistant,
32 System,
33 Tool,
34}
35
36impl TryFrom<String> for Role {
37 type Error = anyhow::Error;
38
39 fn try_from(value: String) -> Result<Self> {
40 match value.as_str() {
41 "user" => Ok(Self::User),
42 "assistant" => Ok(Self::Assistant),
43 "system" => Ok(Self::System),
44 "tool" => Ok(Self::Tool),
45 _ => Err(anyhow!("invalid role '{value}'")),
46 }
47 }
48}
49
50impl From<Role> for String {
51 fn from(val: Role) -> Self {
52 match val {
53 Role::User => "user".to_owned(),
54 Role::Assistant => "assistant".to_owned(),
55 Role::System => "system".to_owned(),
56 Role::Tool => "tool".to_owned(),
57 }
58 }
59}
60
61#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
62#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
63pub enum Model {
64 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
65 ThreePointFiveTurbo,
66 #[serde(rename = "gpt-4", alias = "gpt-4")]
67 Four,
68 #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
69 FourTurbo,
70 #[serde(rename = "gpt-4o", alias = "gpt-4o")]
71 #[default]
72 FourOmni,
73 #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
74 FourOmniMini,
75 #[serde(rename = "o1", alias = "o1")]
76 O1,
77 #[serde(rename = "o1-preview", alias = "o1-preview")]
78 O1Preview,
79 #[serde(rename = "o1-mini", alias = "o1-mini")]
80 O1Mini,
81 #[serde(rename = "o3-mini", alias = "o3-mini")]
82 O3Mini,
83
84 #[serde(rename = "custom")]
85 Custom {
86 name: String,
87 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
88 display_name: Option<String>,
89 max_tokens: usize,
90 max_output_tokens: Option<u32>,
91 max_completion_tokens: Option<u32>,
92 },
93}
94
95impl Model {
96 pub fn from_id(id: &str) -> Result<Self> {
97 match id {
98 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
99 "gpt-4" => Ok(Self::Four),
100 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
101 "gpt-4o" => Ok(Self::FourOmni),
102 "gpt-4o-mini" => Ok(Self::FourOmniMini),
103 "o1" => Ok(Self::O1),
104 "o1-preview" => Ok(Self::O1Preview),
105 "o1-mini" => Ok(Self::O1Mini),
106 "o3-mini" => Ok(Self::O3Mini),
107 _ => Err(anyhow!("invalid model id")),
108 }
109 }
110
111 pub fn id(&self) -> &str {
112 match self {
113 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
114 Self::Four => "gpt-4",
115 Self::FourTurbo => "gpt-4-turbo",
116 Self::FourOmni => "gpt-4o",
117 Self::FourOmniMini => "gpt-4o-mini",
118 Self::O1 => "o1",
119 Self::O1Preview => "o1-preview",
120 Self::O1Mini => "o1-mini",
121 Self::O3Mini => "o3-mini",
122 Self::Custom { name, .. } => name,
123 }
124 }
125
126 pub fn display_name(&self) -> &str {
127 match self {
128 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
129 Self::Four => "gpt-4",
130 Self::FourTurbo => "gpt-4-turbo",
131 Self::FourOmni => "gpt-4o",
132 Self::FourOmniMini => "gpt-4o-mini",
133 Self::O1 => "o1",
134 Self::O1Preview => "o1-preview",
135 Self::O1Mini => "o1-mini",
136 Self::O3Mini => "o3-mini",
137 Self::Custom {
138 name, display_name, ..
139 } => display_name.as_ref().unwrap_or(name),
140 }
141 }
142
143 pub fn max_token_count(&self) -> usize {
144 match self {
145 Self::ThreePointFiveTurbo => 16_385,
146 Self::Four => 8_192,
147 Self::FourTurbo => 128_000,
148 Self::FourOmni => 128_000,
149 Self::FourOmniMini => 128_000,
150 Self::O1 => 200_000,
151 Self::O1Preview => 128_000,
152 Self::O1Mini => 128_000,
153 Self::O3Mini => 200_000,
154 Self::Custom { max_tokens, .. } => *max_tokens,
155 }
156 }
157
158 pub fn max_output_tokens(&self) -> Option<u32> {
159 match self {
160 Self::Custom {
161 max_output_tokens, ..
162 } => *max_output_tokens,
163 _ => None,
164 }
165 }
166}
167
168#[derive(Debug, Serialize, Deserialize)]
169pub struct Request {
170 pub model: String,
171 pub messages: Vec<RequestMessage>,
172 pub stream: bool,
173 #[serde(default, skip_serializing_if = "Option::is_none")]
174 pub max_tokens: Option<u32>,
175 #[serde(default, skip_serializing_if = "Vec::is_empty")]
176 pub stop: Vec<String>,
177 pub temperature: f32,
178 #[serde(default, skip_serializing_if = "Option::is_none")]
179 pub tool_choice: Option<ToolChoice>,
180 #[serde(default, skip_serializing_if = "Vec::is_empty")]
181 pub tools: Vec<ToolDefinition>,
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185pub struct CompletionRequest {
186 pub model: String,
187 pub prompt: String,
188 pub max_tokens: u32,
189 pub temperature: f32,
190 #[serde(default, skip_serializing_if = "Option::is_none")]
191 pub prediction: Option<Prediction>,
192 #[serde(default, skip_serializing_if = "Option::is_none")]
193 pub rewrite_speculation: Option<bool>,
194}
195
196#[derive(Clone, Deserialize, Serialize, Debug)]
197#[serde(tag = "type", rename_all = "snake_case")]
198pub enum Prediction {
199 Content { content: String },
200}
201
202#[derive(Debug, Serialize, Deserialize)]
203#[serde(untagged)]
204pub enum ToolChoice {
205 Auto,
206 Required,
207 None,
208 Other(ToolDefinition),
209}
210
211#[derive(Clone, Deserialize, Serialize, Debug)]
212#[serde(tag = "type", rename_all = "snake_case")]
213pub enum ToolDefinition {
214 #[allow(dead_code)]
215 Function { function: FunctionDefinition },
216}
217
218#[derive(Clone, Debug, Serialize, Deserialize)]
219pub struct FunctionDefinition {
220 pub name: String,
221 pub description: Option<String>,
222 pub parameters: Option<Value>,
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226#[serde(tag = "role", rename_all = "lowercase")]
227pub enum RequestMessage {
228 Assistant {
229 content: Option<String>,
230 #[serde(default, skip_serializing_if = "Vec::is_empty")]
231 tool_calls: Vec<ToolCall>,
232 },
233 User {
234 content: String,
235 },
236 System {
237 content: String,
238 },
239 Tool {
240 content: String,
241 tool_call_id: String,
242 },
243}
244
245#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
246pub struct ToolCall {
247 pub id: String,
248 #[serde(flatten)]
249 pub content: ToolCallContent,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253#[serde(tag = "type", rename_all = "lowercase")]
254pub enum ToolCallContent {
255 Function { function: FunctionContent },
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259pub struct FunctionContent {
260 pub name: String,
261 pub arguments: String,
262}
263
264#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
265pub struct ResponseMessageDelta {
266 pub role: Option<Role>,
267 pub content: Option<String>,
268 #[serde(default, skip_serializing_if = "is_none_or_empty")]
269 pub tool_calls: Option<Vec<ToolCallChunk>>,
270}
271
272#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
273pub struct ToolCallChunk {
274 pub index: usize,
275 pub id: Option<String>,
276
277 // There is also an optional `type` field that would determine if a
278 // function is there. Sometimes this streams in with the `function` before
279 // it streams in the `type`
280 pub function: Option<FunctionChunk>,
281}
282
283#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
284pub struct FunctionChunk {
285 pub name: Option<String>,
286 pub arguments: Option<String>,
287}
288
289#[derive(Serialize, Deserialize, Debug)]
290pub struct Usage {
291 pub prompt_tokens: u32,
292 pub completion_tokens: u32,
293 pub total_tokens: u32,
294}
295
296#[derive(Serialize, Deserialize, Debug)]
297pub struct ChoiceDelta {
298 pub index: u32,
299 pub delta: ResponseMessageDelta,
300 pub finish_reason: Option<String>,
301}
302
303#[derive(Serialize, Deserialize, Debug)]
304#[serde(untagged)]
305pub enum ResponseStreamResult {
306 Ok(ResponseStreamEvent),
307 Err { error: String },
308}
309
310#[derive(Serialize, Deserialize, Debug)]
311pub struct ResponseStreamEvent {
312 pub created: u32,
313 pub model: String,
314 pub choices: Vec<ChoiceDelta>,
315 pub usage: Option<Usage>,
316}
317
318#[derive(Serialize, Deserialize, Debug)]
319pub struct CompletionResponse {
320 pub id: String,
321 pub object: String,
322 pub created: u64,
323 pub model: String,
324 pub choices: Vec<CompletionChoice>,
325 pub usage: Usage,
326}
327
328#[derive(Serialize, Deserialize, Debug)]
329pub struct CompletionChoice {
330 pub text: String,
331}
332
333#[derive(Serialize, Deserialize, Debug)]
334pub struct Response {
335 pub id: String,
336 pub object: String,
337 pub created: u64,
338 pub model: String,
339 pub choices: Vec<Choice>,
340 pub usage: Usage,
341}
342
343#[derive(Serialize, Deserialize, Debug)]
344pub struct Choice {
345 pub index: u32,
346 pub message: RequestMessage,
347 pub finish_reason: Option<String>,
348}
349
350pub async fn complete(
351 client: &dyn HttpClient,
352 api_url: &str,
353 api_key: &str,
354 request: Request,
355) -> Result<Response> {
356 let uri = format!("{api_url}/chat/completions");
357 let request_builder = HttpRequest::builder()
358 .method(Method::POST)
359 .uri(uri)
360 .header("Content-Type", "application/json")
361 .header("Authorization", format!("Bearer {}", api_key));
362
363 let mut request_body = request;
364 request_body.stream = false;
365
366 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
367 let mut response = client.send(request).await?;
368
369 if response.status().is_success() {
370 let mut body = String::new();
371 response.body_mut().read_to_string(&mut body).await?;
372 let response: Response = serde_json::from_str(&body)?;
373 Ok(response)
374 } else {
375 let mut body = String::new();
376 response.body_mut().read_to_string(&mut body).await?;
377
378 #[derive(Deserialize)]
379 struct OpenAiResponse {
380 error: OpenAiError,
381 }
382
383 #[derive(Deserialize)]
384 struct OpenAiError {
385 message: String,
386 }
387
388 match serde_json::from_str::<OpenAiResponse>(&body) {
389 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
390 "Failed to connect to OpenAI API: {}",
391 response.error.message,
392 )),
393
394 _ => Err(anyhow!(
395 "Failed to connect to OpenAI API: {} {}",
396 response.status(),
397 body,
398 )),
399 }
400 }
401}
402
403pub async fn complete_text(
404 client: &dyn HttpClient,
405 api_url: &str,
406 api_key: &str,
407 request: CompletionRequest,
408) -> Result<CompletionResponse> {
409 let uri = format!("{api_url}/completions");
410 let request_builder = HttpRequest::builder()
411 .method(Method::POST)
412 .uri(uri)
413 .header("Content-Type", "application/json")
414 .header("Authorization", format!("Bearer {}", api_key));
415
416 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
417 let mut response = client.send(request).await?;
418
419 if response.status().is_success() {
420 let mut body = String::new();
421 response.body_mut().read_to_string(&mut body).await?;
422 let response = serde_json::from_str(&body)?;
423 Ok(response)
424 } else {
425 let mut body = String::new();
426 response.body_mut().read_to_string(&mut body).await?;
427
428 #[derive(Deserialize)]
429 struct OpenAiResponse {
430 error: OpenAiError,
431 }
432
433 #[derive(Deserialize)]
434 struct OpenAiError {
435 message: String,
436 }
437
438 match serde_json::from_str::<OpenAiResponse>(&body) {
439 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
440 "Failed to connect to OpenAI API: {}",
441 response.error.message,
442 )),
443
444 _ => Err(anyhow!(
445 "Failed to connect to OpenAI API: {} {}",
446 response.status(),
447 body,
448 )),
449 }
450 }
451}
452
453fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
454 ResponseStreamEvent {
455 created: response.created as u32,
456 model: response.model,
457 choices: response
458 .choices
459 .into_iter()
460 .map(|choice| ChoiceDelta {
461 index: choice.index,
462 delta: ResponseMessageDelta {
463 role: Some(match choice.message {
464 RequestMessage::Assistant { .. } => Role::Assistant,
465 RequestMessage::User { .. } => Role::User,
466 RequestMessage::System { .. } => Role::System,
467 RequestMessage::Tool { .. } => Role::Tool,
468 }),
469 content: match choice.message {
470 RequestMessage::Assistant { content, .. } => content,
471 RequestMessage::User { content } => Some(content),
472 RequestMessage::System { content } => Some(content),
473 RequestMessage::Tool { content, .. } => Some(content),
474 },
475 tool_calls: None,
476 },
477 finish_reason: choice.finish_reason,
478 })
479 .collect(),
480 usage: Some(response.usage),
481 }
482}
483
484pub async fn stream_completion(
485 client: &dyn HttpClient,
486 api_url: &str,
487 api_key: &str,
488 request: Request,
489) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
490 if request.model.starts_with("o1") {
491 let response = complete(client, api_url, api_key, request).await;
492 let response_stream_event = response.map(adapt_response_to_stream);
493 return Ok(stream::once(future::ready(response_stream_event)).boxed());
494 }
495
496 let uri = format!("{api_url}/chat/completions");
497 let request_builder = HttpRequest::builder()
498 .method(Method::POST)
499 .uri(uri)
500 .header("Content-Type", "application/json")
501 .header("Authorization", format!("Bearer {}", api_key));
502
503 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
504 let mut response = client.send(request).await?;
505 if response.status().is_success() {
506 let reader = BufReader::new(response.into_body());
507 Ok(reader
508 .lines()
509 .filter_map(|line| async move {
510 match line {
511 Ok(line) => {
512 let line = line.strip_prefix("data: ")?;
513 if line == "[DONE]" {
514 None
515 } else {
516 match serde_json::from_str(line) {
517 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
518 Ok(ResponseStreamResult::Err { error }) => {
519 Some(Err(anyhow!(error)))
520 }
521 Err(error) => Some(Err(anyhow!(error))),
522 }
523 }
524 }
525 Err(error) => Some(Err(anyhow!(error))),
526 }
527 })
528 .boxed())
529 } else {
530 let mut body = String::new();
531 response.body_mut().read_to_string(&mut body).await?;
532
533 #[derive(Deserialize)]
534 struct OpenAiResponse {
535 error: OpenAiError,
536 }
537
538 #[derive(Deserialize)]
539 struct OpenAiError {
540 message: String,
541 }
542
543 match serde_json::from_str::<OpenAiResponse>(&body) {
544 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
545 "Failed to connect to OpenAI API: {}",
546 response.error.message,
547 )),
548
549 _ => Err(anyhow!(
550 "Failed to connect to OpenAI API: {} {}",
551 response.status(),
552 body,
553 )),
554 }
555 }
556}
557
558#[derive(Copy, Clone, Serialize, Deserialize)]
559pub enum OpenAiEmbeddingModel {
560 #[serde(rename = "text-embedding-3-small")]
561 TextEmbedding3Small,
562 #[serde(rename = "text-embedding-3-large")]
563 TextEmbedding3Large,
564}
565
566#[derive(Serialize)]
567struct OpenAiEmbeddingRequest<'a> {
568 model: OpenAiEmbeddingModel,
569 input: Vec<&'a str>,
570}
571
572#[derive(Deserialize)]
573pub struct OpenAiEmbeddingResponse {
574 pub data: Vec<OpenAiEmbedding>,
575}
576
577#[derive(Deserialize)]
578pub struct OpenAiEmbedding {
579 pub embedding: Vec<f32>,
580}
581
582pub fn embed<'a>(
583 client: &dyn HttpClient,
584 api_url: &str,
585 api_key: &str,
586 model: OpenAiEmbeddingModel,
587 texts: impl IntoIterator<Item = &'a str>,
588) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
589 let uri = format!("{api_url}/embeddings");
590
591 let request = OpenAiEmbeddingRequest {
592 model,
593 input: texts.into_iter().collect(),
594 };
595 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
596 let request = HttpRequest::builder()
597 .method(Method::POST)
598 .uri(uri)
599 .header("Content-Type", "application/json")
600 .header("Authorization", format!("Bearer {}", api_key))
601 .body(body)
602 .map(|request| client.send(request));
603
604 async move {
605 let mut response = request?.await?;
606 let mut body = String::new();
607 response.body_mut().read_to_string(&mut body).await?;
608
609 if response.status().is_success() {
610 let response: OpenAiEmbeddingResponse =
611 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
612 Ok(response)
613 } else {
614 Err(anyhow!(
615 "error during embedding, status: {:?}, body: {:?}",
616 response.status(),
617 body
618 ))
619 }
620 }
621}
622
623pub async fn extract_tool_args_from_events(
624 tool_name: String,
625 mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
626) -> Result<impl Send + Stream<Item = Result<String>>> {
627 let mut tool_use_index = None;
628 let mut first_chunk = None;
629 while let Some(event) = events.next().await {
630 let call = event?.choices.into_iter().find_map(|choice| {
631 choice.delta.tool_calls?.into_iter().find_map(|call| {
632 if call.function.as_ref()?.name.as_deref()? == tool_name {
633 Some(call)
634 } else {
635 None
636 }
637 })
638 });
639 if let Some(call) = call {
640 tool_use_index = Some(call.index);
641 first_chunk = call.function.and_then(|func| func.arguments);
642 break;
643 }
644 }
645
646 let Some(tool_use_index) = tool_use_index else {
647 return Err(anyhow!("tool not used"));
648 };
649
650 Ok(events.filter_map(move |event| {
651 let result = match event {
652 Err(error) => Some(Err(error)),
653 Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
654 choice.delta.tool_calls?.into_iter().find_map(|call| {
655 if call.index == tool_use_index {
656 let func = call.function?;
657 let mut arguments = func.arguments?;
658 if let Some(mut first_chunk) = first_chunk.take() {
659 first_chunk.push_str(&arguments);
660 arguments = first_chunk
661 }
662 Some(Ok(arguments))
663 } else {
664 None
665 }
666 })
667 }),
668 };
669
670 async move { result }
671 }))
672}
673
674pub fn extract_text_from_events(
675 response: impl Stream<Item = Result<ResponseStreamEvent>>,
676) -> impl Stream<Item = Result<String>> {
677 response.filter_map(|response| async move {
678 match response {
679 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
680 Err(error) => Some(Err(error)),
681 }
682 })
683}