assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod model_selector;
  6mod prompt_library;
  7mod prompts;
  8mod saved_conversation;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12
 13pub use assistant_panel::AssistantPanel;
 14
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub(crate) use completion_provider::*;
 20use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 21pub(crate) use model_selector::*;
 22use prompt_library::PromptStore;
 23pub(crate) use saved_conversation::*;
 24use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 25use serde::{Deserialize, Serialize};
 26use settings::{Settings, SettingsStore};
 27use slash_command::{
 28    active_command, file_command, project_command, prompt_command, rustdoc_command, search_command,
 29    tabs_command,
 30};
 31use std::{
 32    fmt::{self, Display},
 33    sync::Arc,
 34};
 35use util::paths::EMBEDDINGS_DIR;
 36
 37actions!(
 38    assistant,
 39    [
 40        Assist,
 41        Split,
 42        CycleMessageRole,
 43        QuoteSelection,
 44        ToggleFocus,
 45        ResetKey,
 46        InlineAssist,
 47        InsertActivePrompt,
 48        ToggleHistory,
 49        ApplyEdit,
 50        ConfirmCommand,
 51        ToggleModelSelector
 52    ]
 53);
 54
 55#[derive(
 56    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 57)]
 58struct MessageId(usize);
 59
 60#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 61#[serde(rename_all = "lowercase")]
 62pub enum Role {
 63    User,
 64    Assistant,
 65    System,
 66}
 67
 68impl Role {
 69    pub fn cycle(&mut self) {
 70        *self = match self {
 71            Role::User => Role::Assistant,
 72            Role::Assistant => Role::System,
 73            Role::System => Role::User,
 74        }
 75    }
 76}
 77
 78impl Display for Role {
 79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 80        match self {
 81            Role::User => write!(f, "user"),
 82            Role::Assistant => write!(f, "assistant"),
 83            Role::System => write!(f, "system"),
 84        }
 85    }
 86}
 87
 88#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 89pub enum LanguageModel {
 90    Cloud(CloudModel),
 91    OpenAi(OpenAiModel),
 92    Anthropic(AnthropicModel),
 93}
 94
 95impl Default for LanguageModel {
 96    fn default() -> Self {
 97        LanguageModel::Cloud(CloudModel::default())
 98    }
 99}
100
101impl LanguageModel {
102    pub fn telemetry_id(&self) -> String {
103        match self {
104            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
105            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
106            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
107        }
108    }
109
110    pub fn display_name(&self) -> String {
111        match self {
112            LanguageModel::OpenAi(model) => model.display_name().into(),
113            LanguageModel::Anthropic(model) => model.display_name().into(),
114            LanguageModel::Cloud(model) => model.display_name().into(),
115        }
116    }
117
118    pub fn max_token_count(&self) -> usize {
119        match self {
120            LanguageModel::OpenAi(model) => model.max_token_count(),
121            LanguageModel::Anthropic(model) => model.max_token_count(),
122            LanguageModel::Cloud(model) => model.max_token_count(),
123        }
124    }
125
126    pub fn id(&self) -> &str {
127        match self {
128            LanguageModel::OpenAi(model) => model.id(),
129            LanguageModel::Anthropic(model) => model.id(),
130            LanguageModel::Cloud(model) => model.id(),
131        }
132    }
133}
134
135#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
136pub struct LanguageModelRequestMessage {
137    pub role: Role,
138    pub content: String,
139}
140
141impl LanguageModelRequestMessage {
142    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
143        proto::LanguageModelRequestMessage {
144            role: match self.role {
145                Role::User => proto::LanguageModelRole::LanguageModelUser,
146                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
147                Role::System => proto::LanguageModelRole::LanguageModelSystem,
148            } as i32,
149            content: self.content.clone(),
150            tool_calls: Vec::new(),
151            tool_call_id: None,
152        }
153    }
154}
155
156#[derive(Debug, Default, Serialize)]
157pub struct LanguageModelRequest {
158    pub model: LanguageModel,
159    pub messages: Vec<LanguageModelRequestMessage>,
160    pub stop: Vec<String>,
161    pub temperature: f32,
162}
163
164impl LanguageModelRequest {
165    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
166        proto::CompleteWithLanguageModel {
167            model: self.model.id().to_string(),
168            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
169            stop: self.stop.clone(),
170            temperature: self.temperature,
171            tool_choice: None,
172            tools: Vec::new(),
173        }
174    }
175
176    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
177    pub fn preprocess(&mut self) {
178        match &self.model {
179            LanguageModel::OpenAi(_) => {}
180            LanguageModel::Anthropic(_) => {}
181            LanguageModel::Cloud(model) => match model {
182                CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
183                    preprocess_anthropic_request(self);
184                }
185                _ => {}
186            },
187        }
188    }
189}
190
191#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
192pub struct LanguageModelResponseMessage {
193    pub role: Option<Role>,
194    pub content: Option<String>,
195}
196
197#[derive(Deserialize, Debug)]
198pub struct LanguageModelUsage {
199    pub prompt_tokens: u32,
200    pub completion_tokens: u32,
201    pub total_tokens: u32,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct LanguageModelChoiceDelta {
206    pub index: u32,
207    pub delta: LanguageModelResponseMessage,
208    pub finish_reason: Option<String>,
209}
210
211#[derive(Clone, Debug, Serialize, Deserialize)]
212struct MessageMetadata {
213    role: Role,
214    status: MessageStatus,
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize)]
218enum MessageStatus {
219    Pending,
220    Done,
221    Error(SharedString),
222}
223
224/// The state pertaining to the Assistant.
225#[derive(Default)]
226struct Assistant {
227    /// Whether the Assistant is enabled.
228    enabled: bool,
229}
230
231impl Global for Assistant {}
232
233impl Assistant {
234    const NAMESPACE: &'static str = "assistant";
235
236    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
237        if self.enabled == enabled {
238            return;
239        }
240
241        self.enabled = enabled;
242
243        if !enabled {
244            CommandPaletteFilter::update_global(cx, |filter, _cx| {
245                filter.hide_namespace(Self::NAMESPACE);
246            });
247
248            return;
249        }
250
251        CommandPaletteFilter::update_global(cx, |filter, _cx| {
252            filter.show_namespace(Self::NAMESPACE);
253        });
254    }
255}
256
257pub fn init(client: Arc<Client>, cx: &mut AppContext) {
258    cx.set_global(Assistant::default());
259    AssistantSettings::register(cx);
260
261    cx.spawn(|mut cx| {
262        let client = client.clone();
263        async move {
264            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
265            let semantic_index = SemanticIndex::new(
266                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
267                Arc::new(embedding_provider),
268                &mut cx,
269            )
270            .await?;
271            cx.update(|cx| cx.set_global(semantic_index))
272        }
273    })
274    .detach();
275
276    prompt_library::init(cx);
277    completion_provider::init(client, cx);
278    assistant_slash_command::init(cx);
279    register_slash_commands(cx);
280    assistant_panel::init(cx);
281
282    CommandPaletteFilter::update_global(cx, |filter, _cx| {
283        filter.hide_namespace(Assistant::NAMESPACE);
284    });
285    Assistant::update_global(cx, |assistant, cx| {
286        let settings = AssistantSettings::get_global(cx);
287
288        assistant.set_enabled(settings.enabled, cx);
289    });
290    cx.observe_global::<SettingsStore>(|cx| {
291        Assistant::update_global(cx, |assistant, cx| {
292            let settings = AssistantSettings::get_global(cx);
293            assistant.set_enabled(settings.enabled, cx);
294        });
295    })
296    .detach();
297}
298
299fn register_slash_commands(cx: &mut AppContext) {
300    let slash_command_registry = SlashCommandRegistry::global(cx);
301    slash_command_registry.register_command(file_command::FileSlashCommand, true);
302    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
303    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
304    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
305    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
306    slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
307
308    let store = PromptStore::global(cx);
309    cx.background_executor()
310        .spawn(async move {
311            let store = store.await?;
312            slash_command_registry
313                .register_command(prompt_command::PromptSlashCommand::new(store), true);
314            anyhow::Ok(())
315        })
316        .detach_and_log_err(cx);
317}
318
319#[cfg(test)]
320#[ctor::ctor]
321fn init_logger() {
322    if std::env::var("RUST_LOG").is_ok() {
323        env_logger::init();
324    }
325}