vercel.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, future::Future};
  7use strum::EnumIter;
  8
  9pub const VERCEL_API_URL: &str = "https://api.v0.dev/v1";
 10
 11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 12    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 13}
 14
 15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 16#[serde(rename_all = "lowercase")]
 17pub enum Role {
 18    User,
 19    Assistant,
 20    System,
 21    Tool,
 22}
 23
 24impl TryFrom<String> for Role {
 25    type Error = anyhow::Error;
 26
 27    fn try_from(value: String) -> Result<Self> {
 28        match value.as_str() {
 29            "user" => Ok(Self::User),
 30            "assistant" => Ok(Self::Assistant),
 31            "system" => Ok(Self::System),
 32            "tool" => Ok(Self::Tool),
 33            _ => anyhow::bail!("invalid role '{value}'"),
 34        }
 35    }
 36}
 37
 38impl From<Role> for String {
 39    fn from(val: Role) -> Self {
 40        match val {
 41            Role::User => "user".to_owned(),
 42            Role::Assistant => "assistant".to_owned(),
 43            Role::System => "system".to_owned(),
 44            Role::Tool => "tool".to_owned(),
 45        }
 46    }
 47}
 48
 49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 51pub enum Model {
 52    #[serde(rename = "v-0")]
 53    #[default]
 54    VZero,
 55
 56    #[serde(rename = "custom")]
 57    Custom {
 58        name: String,
 59        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 60        display_name: Option<String>,
 61        max_tokens: u64,
 62        max_output_tokens: Option<u64>,
 63        max_completion_tokens: Option<u64>,
 64    },
 65}
 66
 67impl Model {
 68    pub fn default_fast() -> Self {
 69        Self::VZero
 70    }
 71
 72    pub fn from_id(id: &str) -> Result<Self> {
 73        match id {
 74            "v-0" => Ok(Self::VZero),
 75            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
 76        }
 77    }
 78
 79    pub fn id(&self) -> &str {
 80        match self {
 81            Self::VZero => "v-0",
 82            Self::Custom { name, .. } => name,
 83        }
 84    }
 85
 86    pub fn display_name(&self) -> &str {
 87        match self {
 88            Self::VZero => "Vercel v0",
 89            Self::Custom {
 90                name, display_name, ..
 91            } => display_name.as_ref().unwrap_or(name),
 92        }
 93    }
 94
 95    pub fn max_token_count(&self) -> u64 {
 96        match self {
 97            Self::VZero => 128_000,
 98            Self::Custom { max_tokens, .. } => *max_tokens,
 99        }
100    }
101
102    pub fn max_output_tokens(&self) -> Option<u64> {
103        match self {
104            Self::Custom {
105                max_output_tokens, ..
106            } => *max_output_tokens,
107            Self::VZero => Some(32_768),
108        }
109    }
110
111    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
112    ///
113    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
114    pub fn supports_parallel_tool_calls(&self) -> bool {
115        match self {
116            Self::VZero => true,
117            Model::Custom { .. } => false,
118        }
119    }
120}
121
122#[derive(Debug, Serialize, Deserialize)]
123pub struct Request {
124    pub model: String,
125    pub messages: Vec<RequestMessage>,
126    pub stream: bool,
127    #[serde(default, skip_serializing_if = "Option::is_none")]
128    pub max_completion_tokens: Option<u64>,
129    #[serde(default, skip_serializing_if = "Vec::is_empty")]
130    pub stop: Vec<String>,
131    pub temperature: f32,
132    #[serde(default, skip_serializing_if = "Option::is_none")]
133    pub tool_choice: Option<ToolChoice>,
134    /// Whether to enable parallel function calling during tool use.
135    #[serde(default, skip_serializing_if = "Option::is_none")]
136    pub parallel_tool_calls: Option<bool>,
137    #[serde(default, skip_serializing_if = "Vec::is_empty")]
138    pub tools: Vec<ToolDefinition>,
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142#[serde(untagged)]
143pub enum ToolChoice {
144    Auto,
145    Required,
146    None,
147    Other(ToolDefinition),
148}
149
150#[derive(Clone, Deserialize, Serialize, Debug)]
151#[serde(tag = "type", rename_all = "snake_case")]
152pub enum ToolDefinition {
153    #[allow(dead_code)]
154    Function { function: FunctionDefinition },
155}
156
157#[derive(Clone, Debug, Serialize, Deserialize)]
158pub struct FunctionDefinition {
159    pub name: String,
160    pub description: Option<String>,
161    pub parameters: Option<Value>,
162}
163
164#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
165#[serde(tag = "role", rename_all = "lowercase")]
166pub enum RequestMessage {
167    Assistant {
168        content: Option<MessageContent>,
169        #[serde(default, skip_serializing_if = "Vec::is_empty")]
170        tool_calls: Vec<ToolCall>,
171    },
172    User {
173        content: MessageContent,
174    },
175    System {
176        content: MessageContent,
177    },
178    Tool {
179        content: MessageContent,
180        tool_call_id: String,
181    },
182}
183
184#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
185#[serde(untagged)]
186pub enum MessageContent {
187    Plain(String),
188    Multipart(Vec<MessagePart>),
189}
190
191impl MessageContent {
192    pub fn empty() -> Self {
193        MessageContent::Multipart(vec![])
194    }
195
196    pub fn push_part(&mut self, part: MessagePart) {
197        match self {
198            MessageContent::Plain(text) => {
199                *self =
200                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
201            }
202            MessageContent::Multipart(parts) if parts.is_empty() => match part {
203                MessagePart::Text { text } => *self = MessageContent::Plain(text),
204                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
205            },
206            MessageContent::Multipart(parts) => parts.push(part),
207        }
208    }
209}
210
211impl From<Vec<MessagePart>> for MessageContent {
212    fn from(mut parts: Vec<MessagePart>) -> Self {
213        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
214            MessageContent::Plain(std::mem::take(text))
215        } else {
216            MessageContent::Multipart(parts)
217        }
218    }
219}
220
221#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
222#[serde(tag = "type")]
223pub enum MessagePart {
224    #[serde(rename = "text")]
225    Text { text: String },
226    #[serde(rename = "image_url")]
227    Image { image_url: ImageUrl },
228}
229
230#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
231pub struct ImageUrl {
232    pub url: String,
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub detail: Option<String>,
235}
236
237#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
238pub struct ToolCall {
239    pub id: String,
240    #[serde(flatten)]
241    pub content: ToolCallContent,
242}
243
244#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
245#[serde(tag = "type", rename_all = "lowercase")]
246pub enum ToolCallContent {
247    Function { function: FunctionContent },
248}
249
250#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
251pub struct FunctionContent {
252    pub name: String,
253    pub arguments: String,
254}
255
256#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
257pub struct ResponseMessageDelta {
258    pub role: Option<Role>,
259    pub content: Option<String>,
260    #[serde(default, skip_serializing_if = "is_none_or_empty")]
261    pub tool_calls: Option<Vec<ToolCallChunk>>,
262}
263
264#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
265pub struct ToolCallChunk {
266    pub index: usize,
267    pub id: Option<String>,
268
269    // There is also an optional `type` field that would determine if a
270    // function is there. Sometimes this streams in with the `function` before
271    // it streams in the `type`
272    pub function: Option<FunctionChunk>,
273}
274
275#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
276pub struct FunctionChunk {
277    pub name: Option<String>,
278    pub arguments: Option<String>,
279}
280
281#[derive(Serialize, Deserialize, Debug)]
282pub struct Usage {
283    pub prompt_tokens: u32,
284    pub completion_tokens: u32,
285    pub total_tokens: u32,
286}
287
288#[derive(Serialize, Deserialize, Debug)]
289pub struct ChoiceDelta {
290    pub index: u32,
291    pub delta: ResponseMessageDelta,
292    pub finish_reason: Option<String>,
293}
294
295#[derive(Serialize, Deserialize, Debug)]
296#[serde(untagged)]
297pub enum ResponseStreamResult {
298    Ok(ResponseStreamEvent),
299    Err { error: String },
300}
301
302#[derive(Serialize, Deserialize, Debug)]
303pub struct ResponseStreamEvent {
304    pub model: String,
305    pub choices: Vec<ChoiceDelta>,
306    pub usage: Option<Usage>,
307}
308
309pub async fn stream_completion(
310    client: &dyn HttpClient,
311    api_url: &str,
312    api_key: &str,
313    request: Request,
314) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
315    let uri = format!("{api_url}/chat/completions");
316    let request_builder = HttpRequest::builder()
317        .method(Method::POST)
318        .uri(uri)
319        .header("Content-Type", "application/json")
320        .header("Authorization", format!("Bearer {}", api_key));
321
322    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
323    let mut response = client.send(request).await?;
324    if response.status().is_success() {
325        let reader = BufReader::new(response.into_body());
326        Ok(reader
327            .lines()
328            .filter_map(|line| async move {
329                match line {
330                    Ok(line) => {
331                        let line = line.strip_prefix("data: ")?;
332                        if line == "[DONE]" {
333                            None
334                        } else {
335                            match serde_json::from_str(line) {
336                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
337                                Ok(ResponseStreamResult::Err { error }) => {
338                                    Some(Err(anyhow!(error)))
339                                }
340                                Err(error) => Some(Err(anyhow!(error))),
341                            }
342                        }
343                    }
344                    Err(error) => Some(Err(anyhow!(error))),
345                }
346            })
347            .boxed())
348    } else {
349        let mut body = String::new();
350        response.body_mut().read_to_string(&mut body).await?;
351
352        #[derive(Deserialize)]
353        struct VercelResponse {
354            error: VercelError,
355        }
356
357        #[derive(Deserialize)]
358        struct VercelError {
359            message: String,
360        }
361
362        match serde_json::from_str::<VercelResponse>(&body) {
363            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
364                "Failed to connect to Vercel API: {}",
365                response.error.message,
366            )),
367
368            _ => anyhow::bail!(
369                "Failed to connect to Vercel API: {} {}",
370                response.status(),
371                body,
372            ),
373        }
374    }
375}
376
377#[derive(Copy, Clone, Serialize, Deserialize)]
378pub enum VercelEmbeddingModel {
379    #[serde(rename = "text-embedding-3-small")]
380    TextEmbedding3Small,
381    #[serde(rename = "text-embedding-3-large")]
382    TextEmbedding3Large,
383}
384
385#[derive(Serialize)]
386struct VercelEmbeddingRequest<'a> {
387    model: VercelEmbeddingModel,
388    input: Vec<&'a str>,
389}
390
391#[derive(Deserialize)]
392pub struct VercelEmbeddingResponse {
393    pub data: Vec<VercelEmbedding>,
394}
395
396#[derive(Deserialize)]
397pub struct VercelEmbedding {
398    pub embedding: Vec<f32>,
399}
400
401pub fn embed<'a>(
402    client: &dyn HttpClient,
403    api_url: &str,
404    api_key: &str,
405    model: VercelEmbeddingModel,
406    texts: impl IntoIterator<Item = &'a str>,
407) -> impl 'static + Future<Output = Result<VercelEmbeddingResponse>> {
408    let uri = format!("{api_url}/embeddings");
409
410    let request = VercelEmbeddingRequest {
411        model,
412        input: texts.into_iter().collect(),
413    };
414    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
415    let request = HttpRequest::builder()
416        .method(Method::POST)
417        .uri(uri)
418        .header("Content-Type", "application/json")
419        .header("Authorization", format!("Bearer {}", api_key))
420        .body(body)
421        .map(|request| client.send(request));
422
423    async move {
424        let mut response = request?.await?;
425        let mut body = String::new();
426        response.body_mut().read_to_string(&mut body).await?;
427
428        anyhow::ensure!(
429            response.status().is_success(),
430            "error during embedding, status: {:?}, body: {:?}",
431            response.status(),
432            body
433        );
434        let response: VercelEmbeddingResponse =
435            serde_json::from_str(&body).context("failed to parse Vercel embedding response")?;
436        Ok(response)
437    }
438}