cloud_model.rs

 1use schemars::JsonSchema;
 2use serde::{Deserialize, Serialize};
 3use strum::EnumIter;
 4
 5#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 6#[serde(tag = "provider", rename_all = "lowercase")]
 7pub enum CloudModel {
 8    Anthropic(anthropic::Model),
 9    OpenAi(open_ai::Model),
10    Google(google_ai::Model),
11    Zed(ZedModel),
12}
13
14#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
15pub enum ZedModel {
16    #[serde(rename = "qwen2-7b-instruct")]
17    Qwen2_7bInstruct,
18}
19
20impl ZedModel {
21    pub fn id(&self) -> &str {
22        match self {
23            ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
24        }
25    }
26
27    pub fn display_name(&self) -> &str {
28        match self {
29            ZedModel::Qwen2_7bInstruct => "Qwen2 7B Instruct",
30        }
31    }
32
33    pub fn max_token_count(&self) -> usize {
34        match self {
35            ZedModel::Qwen2_7bInstruct => 28000,
36        }
37    }
38}
39
40impl Default for CloudModel {
41    fn default() -> Self {
42        Self::Anthropic(anthropic::Model::default())
43    }
44}
45
46impl CloudModel {
47    pub fn id(&self) -> &str {
48        match self {
49            CloudModel::Anthropic(model) => model.id(),
50            CloudModel::OpenAi(model) => model.id(),
51            CloudModel::Google(model) => model.id(),
52            CloudModel::Zed(model) => model.id(),
53        }
54    }
55
56    pub fn display_name(&self) -> &str {
57        match self {
58            CloudModel::Anthropic(model) => model.display_name(),
59            CloudModel::OpenAi(model) => model.display_name(),
60            CloudModel::Google(model) => model.display_name(),
61            CloudModel::Zed(model) => model.display_name(),
62        }
63    }
64
65    pub fn max_token_count(&self) -> usize {
66        match self {
67            CloudModel::Anthropic(model) => model.max_token_count(),
68            CloudModel::OpenAi(model) => model.max_token_count(),
69            CloudModel::Google(model) => model.max_token_count(),
70            CloudModel::Zed(model) => model.max_token_count(),
71        }
72    }
73}