assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod streaming_diff;
  8
  9pub use assistant_panel::AssistantPanel;
 10use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
 11use chrono::{DateTime, Local};
 12use client::{proto, Client};
 13pub(crate) use completion_provider::*;
 14use gpui::{actions, AppContext, SharedString};
 15pub(crate) use saved_conversation::*;
 16use serde::{Deserialize, Serialize};
 17use settings::Settings;
 18use std::{
 19    fmt::{self, Display},
 20    sync::Arc,
 21};
 22
 23actions!(
 24    assistant,
 25    [
 26        NewConversation,
 27        Assist,
 28        Split,
 29        CycleMessageRole,
 30        QuoteSelection,
 31        ToggleFocus,
 32        ResetKey,
 33        InlineAssist,
 34        ToggleIncludeConversation,
 35    ]
 36);
 37
 38#[derive(
 39    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 40)]
 41struct MessageId(usize);
 42
 43#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 44#[serde(rename_all = "lowercase")]
 45pub enum Role {
 46    User,
 47    Assistant,
 48    System,
 49}
 50
 51impl Role {
 52    pub fn cycle(&mut self) {
 53        *self = match self {
 54            Role::User => Role::Assistant,
 55            Role::Assistant => Role::System,
 56            Role::System => Role::User,
 57        }
 58    }
 59}
 60
 61impl Display for Role {
 62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 63        match self {
 64            Role::User => write!(f, "user"),
 65            Role::Assistant => write!(f, "assistant"),
 66            Role::System => write!(f, "system"),
 67        }
 68    }
 69}
 70
 71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 72pub enum LanguageModel {
 73    ZedDotDev(ZedDotDevModel),
 74    OpenAi(OpenAiModel),
 75}
 76
 77impl Default for LanguageModel {
 78    fn default() -> Self {
 79        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 80    }
 81}
 82
 83impl LanguageModel {
 84    pub fn telemetry_id(&self) -> String {
 85        match self {
 86            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 87            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 88        }
 89    }
 90
 91    pub fn display_name(&self) -> String {
 92        match self {
 93            LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
 94            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
 95        }
 96    }
 97
 98    pub fn max_token_count(&self) -> usize {
 99        match self {
100            LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()),
101            LanguageModel::ZedDotDev(model) => match model {
102                ZedDotDevModel::GptThreePointFiveTurbo
103                | ZedDotDevModel::GptFour
104                | ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()),
105                ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model.
106            },
107        }
108    }
109
110    pub fn id(&self) -> &str {
111        match self {
112            LanguageModel::OpenAi(model) => model.id(),
113            LanguageModel::ZedDotDev(model) => model.id(),
114        }
115    }
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119pub struct LanguageModelRequestMessage {
120    pub role: Role,
121    pub content: String,
122}
123
124impl LanguageModelRequestMessage {
125    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
126        proto::LanguageModelRequestMessage {
127            role: match self.role {
128                Role::User => proto::LanguageModelRole::LanguageModelUser,
129                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
130                Role::System => proto::LanguageModelRole::LanguageModelSystem,
131            } as i32,
132            content: self.content.clone(),
133        }
134    }
135}
136
137#[derive(Debug, Default, Serialize)]
138pub struct LanguageModelRequest {
139    pub model: LanguageModel,
140    pub messages: Vec<LanguageModelRequestMessage>,
141    pub stop: Vec<String>,
142    pub temperature: f32,
143}
144
145impl LanguageModelRequest {
146    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
147        proto::CompleteWithLanguageModel {
148            model: self.model.id().to_string(),
149            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
150            stop: self.stop.clone(),
151            temperature: self.temperature,
152        }
153    }
154}
155
156#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
157pub struct LanguageModelResponseMessage {
158    pub role: Option<Role>,
159    pub content: Option<String>,
160}
161
162#[derive(Deserialize, Debug)]
163pub struct LanguageModelUsage {
164    pub prompt_tokens: u32,
165    pub completion_tokens: u32,
166    pub total_tokens: u32,
167}
168
169#[derive(Deserialize, Debug)]
170pub struct LanguageModelChoiceDelta {
171    pub index: u32,
172    pub delta: LanguageModelResponseMessage,
173    pub finish_reason: Option<String>,
174}
175
176#[derive(Clone, Debug, Serialize, Deserialize)]
177struct MessageMetadata {
178    role: Role,
179    sent_at: DateTime<Local>,
180    status: MessageStatus,
181}
182
183#[derive(Clone, Debug, Serialize, Deserialize)]
184enum MessageStatus {
185    Pending,
186    Done,
187    Error(SharedString),
188}
189
190pub fn init(client: Arc<Client>, cx: &mut AppContext) {
191    AssistantSettings::register(cx);
192    completion_provider::init(client, cx);
193    assistant_panel::init(cx);
194}
195
196#[cfg(test)]
197#[ctor::ctor]
198fn init_logger() {
199    if std::env::var("RUST_LOG").is_ok() {
200        env_logger::init();
201    }
202}