assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod streaming_diff;
  8
  9pub use assistant_panel::AssistantPanel;
 10use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
 11use chrono::{DateTime, Local};
 12use client::{proto, Client};
 13use command_palette_hooks::CommandPaletteFilter;
 14pub(crate) use completion_provider::*;
 15use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
 16pub(crate) use saved_conversation::*;
 17use serde::{Deserialize, Serialize};
 18use settings::{Settings, SettingsStore};
 19use std::{
 20    fmt::{self, Display},
 21    sync::Arc,
 22};
 23
 24actions!(
 25    assistant,
 26    [
 27        NewConversation,
 28        Assist,
 29        Split,
 30        CycleMessageRole,
 31        QuoteSelection,
 32        ToggleFocus,
 33        ResetKey,
 34        InlineAssist,
 35        ToggleIncludeConversation,
 36    ]
 37);
 38
 39#[derive(
 40    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 41)]
 42struct MessageId(usize);
 43
 44#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 45#[serde(rename_all = "lowercase")]
 46pub enum Role {
 47    User,
 48    Assistant,
 49    System,
 50}
 51
 52impl Role {
 53    pub fn cycle(&mut self) {
 54        *self = match self {
 55            Role::User => Role::Assistant,
 56            Role::Assistant => Role::System,
 57            Role::System => Role::User,
 58        }
 59    }
 60}
 61
 62impl Display for Role {
 63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 64        match self {
 65            Role::User => write!(f, "user"),
 66            Role::Assistant => write!(f, "assistant"),
 67            Role::System => write!(f, "system"),
 68        }
 69    }
 70}
 71
 72#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 73pub enum LanguageModel {
 74    ZedDotDev(ZedDotDevModel),
 75    OpenAi(OpenAiModel),
 76}
 77
 78impl Default for LanguageModel {
 79    fn default() -> Self {
 80        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 81    }
 82}
 83
 84impl LanguageModel {
 85    pub fn telemetry_id(&self) -> String {
 86        match self {
 87            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 88            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 89        }
 90    }
 91
 92    pub fn display_name(&self) -> String {
 93        match self {
 94            LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
 95            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
 96        }
 97    }
 98
 99    pub fn max_token_count(&self) -> usize {
100        match self {
101            LanguageModel::OpenAi(model) => model.max_token_count(),
102            LanguageModel::ZedDotDev(model) => model.max_token_count(),
103        }
104    }
105
106    pub fn id(&self) -> &str {
107        match self {
108            LanguageModel::OpenAi(model) => model.id(),
109            LanguageModel::ZedDotDev(model) => model.id(),
110        }
111    }
112}
113
114#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
115pub struct LanguageModelRequestMessage {
116    pub role: Role,
117    pub content: String,
118}
119
120impl LanguageModelRequestMessage {
121    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
122        proto::LanguageModelRequestMessage {
123            role: match self.role {
124                Role::User => proto::LanguageModelRole::LanguageModelUser,
125                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
126                Role::System => proto::LanguageModelRole::LanguageModelSystem,
127            } as i32,
128            content: self.content.clone(),
129        }
130    }
131}
132
133#[derive(Debug, Default, Serialize)]
134pub struct LanguageModelRequest {
135    pub model: LanguageModel,
136    pub messages: Vec<LanguageModelRequestMessage>,
137    pub stop: Vec<String>,
138    pub temperature: f32,
139}
140
141impl LanguageModelRequest {
142    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
143        proto::CompleteWithLanguageModel {
144            model: self.model.id().to_string(),
145            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
146            stop: self.stop.clone(),
147            temperature: self.temperature,
148        }
149    }
150}
151
152#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
153pub struct LanguageModelResponseMessage {
154    pub role: Option<Role>,
155    pub content: Option<String>,
156}
157
158#[derive(Deserialize, Debug)]
159pub struct LanguageModelUsage {
160    pub prompt_tokens: u32,
161    pub completion_tokens: u32,
162    pub total_tokens: u32,
163}
164
165#[derive(Deserialize, Debug)]
166pub struct LanguageModelChoiceDelta {
167    pub index: u32,
168    pub delta: LanguageModelResponseMessage,
169    pub finish_reason: Option<String>,
170}
171
172#[derive(Clone, Debug, Serialize, Deserialize)]
173struct MessageMetadata {
174    role: Role,
175    sent_at: DateTime<Local>,
176    status: MessageStatus,
177}
178
179#[derive(Clone, Debug, Serialize, Deserialize)]
180enum MessageStatus {
181    Pending,
182    Done,
183    Error(SharedString),
184}
185
186/// The state pertaining to the Assistant.
187#[derive(Default)]
188struct Assistant {
189    /// Whether the Assistant is enabled.
190    enabled: bool,
191}
192
193impl Global for Assistant {}
194
195impl Assistant {
196    const NAMESPACE: &'static str = "assistant";
197
198    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
199        if self.enabled == enabled {
200            return;
201        }
202
203        self.enabled = enabled;
204
205        if !enabled {
206            CommandPaletteFilter::update_global(cx, |filter, _cx| {
207                filter.hide_namespace(Self::NAMESPACE);
208            });
209
210            return;
211        }
212
213        CommandPaletteFilter::update_global(cx, |filter, _cx| {
214            filter.show_namespace(Self::NAMESPACE);
215        });
216    }
217}
218
219pub fn init(client: Arc<Client>, cx: &mut AppContext) {
220    cx.set_global(Assistant::default());
221    AssistantSettings::register(cx);
222    completion_provider::init(client, cx);
223    assistant_panel::init(cx);
224
225    CommandPaletteFilter::update_global(cx, |filter, _cx| {
226        filter.hide_namespace(Assistant::NAMESPACE);
227    });
228    cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
229        let settings = AssistantSettings::get_global(cx);
230
231        assistant.set_enabled(settings.enabled, cx);
232    });
233    cx.observe_global::<SettingsStore>(|cx| {
234        cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
235            let settings = AssistantSettings::get_global(cx);
236
237            assistant.set_enabled(settings.enabled, cx);
238        });
239    })
240    .detach();
241}
242
243#[cfg(test)]
244#[ctor::ctor]
245fn init_logger() {
246    if std::env::var("RUST_LOG").is_ok() {
247        env_logger::init();
248    }
249}