assistant.rs

  1mod ambient_context;
  2pub mod assistant_panel;
  3pub mod assistant_settings;
  4mod codegen;
  5mod completion_provider;
  6mod prompt_library;
  7mod prompts;
  8mod saved_conversation;
  9mod streaming_diff;
 10
 11pub use assistant_panel::AssistantPanel;
 12use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 13use client::{proto, Client};
 14use command_palette_hooks::CommandPaletteFilter;
 15pub(crate) use completion_provider::*;
 16use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 17pub(crate) use saved_conversation::*;
 18use serde::{Deserialize, Serialize};
 19use settings::{Settings, SettingsStore};
 20use std::{
 21    fmt::{self, Display},
 22    sync::Arc,
 23};
 24
 25actions!(
 26    assistant,
 27    [
 28        Assist,
 29        Split,
 30        CycleMessageRole,
 31        QuoteSelection,
 32        ToggleFocus,
 33        ResetKey,
 34        InlineAssist,
 35        InsertActivePrompt,
 36        ToggleIncludeConversation,
 37        ToggleHistory,
 38    ]
 39);
 40
 41#[derive(
 42    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 43)]
 44struct MessageId(usize);
 45
 46#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 47#[serde(rename_all = "lowercase")]
 48pub enum Role {
 49    User,
 50    Assistant,
 51    System,
 52}
 53
 54impl Role {
 55    pub fn cycle(&mut self) {
 56        *self = match self {
 57            Role::User => Role::Assistant,
 58            Role::Assistant => Role::System,
 59            Role::System => Role::User,
 60        }
 61    }
 62}
 63
 64impl Display for Role {
 65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 66        match self {
 67            Role::User => write!(f, "user"),
 68            Role::Assistant => write!(f, "assistant"),
 69            Role::System => write!(f, "system"),
 70        }
 71    }
 72}
 73
 74#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 75pub enum LanguageModel {
 76    ZedDotDev(ZedDotDevModel),
 77    OpenAi(OpenAiModel),
 78    Anthropic(AnthropicModel),
 79}
 80
 81impl Default for LanguageModel {
 82    fn default() -> Self {
 83        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 84    }
 85}
 86
 87impl LanguageModel {
 88    pub fn telemetry_id(&self) -> String {
 89        match self {
 90            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 91            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 92            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 93        }
 94    }
 95
 96    pub fn display_name(&self) -> String {
 97        match self {
 98            LanguageModel::OpenAi(model) => model.display_name().into(),
 99            LanguageModel::Anthropic(model) => model.display_name().into(),
100            LanguageModel::ZedDotDev(model) => model.display_name().into(),
101        }
102    }
103
104    pub fn max_token_count(&self) -> usize {
105        match self {
106            LanguageModel::OpenAi(model) => model.max_token_count(),
107            LanguageModel::Anthropic(model) => model.max_token_count(),
108            LanguageModel::ZedDotDev(model) => model.max_token_count(),
109        }
110    }
111
112    pub fn id(&self) -> &str {
113        match self {
114            LanguageModel::OpenAi(model) => model.id(),
115            LanguageModel::Anthropic(model) => model.id(),
116            LanguageModel::ZedDotDev(model) => model.id(),
117        }
118    }
119}
120
121#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
122pub struct LanguageModelRequestMessage {
123    pub role: Role,
124    pub content: String,
125}
126
127impl LanguageModelRequestMessage {
128    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
129        proto::LanguageModelRequestMessage {
130            role: match self.role {
131                Role::User => proto::LanguageModelRole::LanguageModelUser,
132                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
133                Role::System => proto::LanguageModelRole::LanguageModelSystem,
134            } as i32,
135            content: self.content.clone(),
136            tool_calls: Vec::new(),
137            tool_call_id: None,
138        }
139    }
140}
141
142#[derive(Debug, Default, Serialize)]
143pub struct LanguageModelRequest {
144    pub model: LanguageModel,
145    pub messages: Vec<LanguageModelRequestMessage>,
146    pub stop: Vec<String>,
147    pub temperature: f32,
148}
149
150impl LanguageModelRequest {
151    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
152        proto::CompleteWithLanguageModel {
153            model: self.model.id().to_string(),
154            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
155            stop: self.stop.clone(),
156            temperature: self.temperature,
157            tool_choice: None,
158            tools: Vec::new(),
159        }
160    }
161}
162
163#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
164pub struct LanguageModelResponseMessage {
165    pub role: Option<Role>,
166    pub content: Option<String>,
167}
168
169#[derive(Deserialize, Debug)]
170pub struct LanguageModelUsage {
171    pub prompt_tokens: u32,
172    pub completion_tokens: u32,
173    pub total_tokens: u32,
174}
175
176#[derive(Deserialize, Debug)]
177pub struct LanguageModelChoiceDelta {
178    pub index: u32,
179    pub delta: LanguageModelResponseMessage,
180    pub finish_reason: Option<String>,
181}
182
183#[derive(Clone, Debug, Serialize, Deserialize)]
184struct MessageMetadata {
185    role: Role,
186    status: MessageStatus,
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190enum MessageStatus {
191    Pending,
192    Done,
193    Error(SharedString),
194}
195
196/// The state pertaining to the Assistant.
197#[derive(Default)]
198struct Assistant {
199    /// Whether the Assistant is enabled.
200    enabled: bool,
201}
202
203impl Global for Assistant {}
204
205impl Assistant {
206    const NAMESPACE: &'static str = "assistant";
207
208    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
209        if self.enabled == enabled {
210            return;
211        }
212
213        self.enabled = enabled;
214
215        if !enabled {
216            CommandPaletteFilter::update_global(cx, |filter, _cx| {
217                filter.hide_namespace(Self::NAMESPACE);
218            });
219
220            return;
221        }
222
223        CommandPaletteFilter::update_global(cx, |filter, _cx| {
224            filter.show_namespace(Self::NAMESPACE);
225        });
226    }
227}
228
229pub fn init(client: Arc<Client>, cx: &mut AppContext) {
230    cx.set_global(Assistant::default());
231    AssistantSettings::register(cx);
232    completion_provider::init(client, cx);
233    assistant_panel::init(cx);
234
235    CommandPaletteFilter::update_global(cx, |filter, _cx| {
236        filter.hide_namespace(Assistant::NAMESPACE);
237    });
238    Assistant::update_global(cx, |assistant, cx| {
239        let settings = AssistantSettings::get_global(cx);
240
241        assistant.set_enabled(settings.enabled, cx);
242    });
243    cx.observe_global::<SettingsStore>(|cx| {
244        Assistant::update_global(cx, |assistant, cx| {
245            let settings = AssistantSettings::get_global(cx);
246
247            assistant.set_enabled(settings.enabled, cx);
248        });
249    })
250    .detach();
251}
252
253#[cfg(test)]
254#[ctor::ctor]
255fn init_logger() {
256    if std::env::var("RUST_LOG").is_ok() {
257        env_logger::init();
258    }
259}