assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context;
  5pub mod context_store;
  6mod inline_assistant;
  7mod model_selector;
  8mod prompt_library;
  9mod prompts;
 10mod slash_command;
 11mod streaming_diff;
 12mod terminal_inline_assistant;
 13
 14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub use completion_provider::*;
 20pub use context::*;
 21pub use context_store::*;
 22use fs::Fs;
 23use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 24use indexed_docs::IndexedDocsRegistry;
 25pub(crate) use inline_assistant::*;
 26pub(crate) use model_selector::*;
 27use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 28use serde::{Deserialize, Serialize};
 29use settings::{Settings, SettingsStore};
 30use slash_command::{
 31    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 32    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 33    tabs_command, term_command,
 34};
 35use std::{
 36    fmt::{self, Display},
 37    sync::Arc,
 38};
 39pub(crate) use streaming_diff::*;
 40
 41actions!(
 42    assistant,
 43    [
 44        Assist,
 45        Split,
 46        CycleMessageRole,
 47        QuoteSelection,
 48        InsertIntoEditor,
 49        ToggleFocus,
 50        ResetKey,
 51        InlineAssist,
 52        InsertActivePrompt,
 53        DeployHistory,
 54        DeployPromptLibrary,
 55        ConfirmCommand,
 56        ToggleModelSelector,
 57        DebugEditSteps
 58    ]
 59);
 60
 61#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 62pub struct MessageId(clock::Lamport);
 63
 64impl MessageId {
 65    pub fn as_u64(self) -> u64 {
 66        self.0.as_u64()
 67    }
 68}
 69
 70#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 71#[serde(rename_all = "lowercase")]
 72pub enum Role {
 73    User,
 74    Assistant,
 75    System,
 76}
 77
 78impl Role {
 79    pub fn from_proto(role: i32) -> Role {
 80        match proto::LanguageModelRole::from_i32(role) {
 81            Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
 82            Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
 83            Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
 84            Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
 85            None => Role::User,
 86        }
 87    }
 88
 89    pub fn to_proto(&self) -> proto::LanguageModelRole {
 90        match self {
 91            Role::User => proto::LanguageModelRole::LanguageModelUser,
 92            Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
 93            Role::System => proto::LanguageModelRole::LanguageModelSystem,
 94        }
 95    }
 96
 97    pub fn cycle(self) -> Role {
 98        match self {
 99            Role::User => Role::Assistant,
100            Role::Assistant => Role::System,
101            Role::System => Role::User,
102        }
103    }
104}
105
106impl Display for Role {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
108        match self {
109            Role::User => write!(f, "user"),
110            Role::Assistant => write!(f, "assistant"),
111            Role::System => write!(f, "system"),
112        }
113    }
114}
115
116#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
117pub enum LanguageModel {
118    Cloud(CloudModel),
119    OpenAi(OpenAiModel),
120    Anthropic(AnthropicModel),
121    Ollama(OllamaModel),
122}
123
124impl Default for LanguageModel {
125    fn default() -> Self {
126        LanguageModel::Cloud(CloudModel::default())
127    }
128}
129
130impl LanguageModel {
131    pub fn telemetry_id(&self) -> String {
132        match self {
133            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
134            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
135            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
136            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
137        }
138    }
139
140    pub fn display_name(&self) -> String {
141        match self {
142            LanguageModel::OpenAi(model) => model.display_name().into(),
143            LanguageModel::Anthropic(model) => model.display_name().into(),
144            LanguageModel::Cloud(model) => model.display_name().into(),
145            LanguageModel::Ollama(model) => model.display_name().into(),
146        }
147    }
148
149    pub fn max_token_count(&self) -> usize {
150        match self {
151            LanguageModel::OpenAi(model) => model.max_token_count(),
152            LanguageModel::Anthropic(model) => model.max_token_count(),
153            LanguageModel::Cloud(model) => model.max_token_count(),
154            LanguageModel::Ollama(model) => model.max_token_count(),
155        }
156    }
157
158    pub fn id(&self) -> &str {
159        match self {
160            LanguageModel::OpenAi(model) => model.id(),
161            LanguageModel::Anthropic(model) => model.id(),
162            LanguageModel::Cloud(model) => model.id(),
163            LanguageModel::Ollama(model) => model.id(),
164        }
165    }
166}
167
168#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
169pub struct LanguageModelRequestMessage {
170    pub role: Role,
171    pub content: String,
172}
173
174impl LanguageModelRequestMessage {
175    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
176        proto::LanguageModelRequestMessage {
177            role: self.role.to_proto() as i32,
178            content: self.content.clone(),
179            tool_calls: Vec::new(),
180            tool_call_id: None,
181        }
182    }
183}
184
185#[derive(Debug, Default, Serialize, Deserialize)]
186pub struct LanguageModelRequest {
187    pub model: LanguageModel,
188    pub messages: Vec<LanguageModelRequestMessage>,
189    pub stop: Vec<String>,
190    pub temperature: f32,
191}
192
193impl LanguageModelRequest {
194    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
195        proto::CompleteWithLanguageModel {
196            model: self.model.id().to_string(),
197            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
198            stop: self.stop.clone(),
199            temperature: self.temperature,
200            tool_choice: None,
201            tools: Vec::new(),
202        }
203    }
204
205    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
206    pub fn preprocess(&mut self) {
207        match &self.model {
208            LanguageModel::OpenAi(_) => {}
209            LanguageModel::Anthropic(_) => {}
210            LanguageModel::Ollama(_) => {}
211            LanguageModel::Cloud(model) => match model {
212                CloudModel::Claude3Opus
213                | CloudModel::Claude3Sonnet
214                | CloudModel::Claude3Haiku
215                | CloudModel::Claude3_5Sonnet => {
216                    preprocess_anthropic_request(self);
217                }
218                _ => {}
219            },
220        }
221    }
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225pub struct LanguageModelResponseMessage {
226    pub role: Option<Role>,
227    pub content: Option<String>,
228}
229
230#[derive(Deserialize, Debug)]
231pub struct LanguageModelUsage {
232    pub prompt_tokens: u32,
233    pub completion_tokens: u32,
234    pub total_tokens: u32,
235}
236
237#[derive(Deserialize, Debug)]
238pub struct LanguageModelChoiceDelta {
239    pub index: u32,
240    pub delta: LanguageModelResponseMessage,
241    pub finish_reason: Option<String>,
242}
243
244#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
245pub enum MessageStatus {
246    Pending,
247    Done,
248    Error(SharedString),
249}
250
251impl MessageStatus {
252    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
253        match status.variant {
254            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
255            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
256            Some(proto::context_message_status::Variant::Error(error)) => {
257                MessageStatus::Error(error.message.into())
258            }
259            None => MessageStatus::Pending,
260        }
261    }
262
263    pub fn to_proto(&self) -> proto::ContextMessageStatus {
264        match self {
265            MessageStatus::Pending => proto::ContextMessageStatus {
266                variant: Some(proto::context_message_status::Variant::Pending(
267                    proto::context_message_status::Pending {},
268                )),
269            },
270            MessageStatus::Done => proto::ContextMessageStatus {
271                variant: Some(proto::context_message_status::Variant::Done(
272                    proto::context_message_status::Done {},
273                )),
274            },
275            MessageStatus::Error(message) => proto::ContextMessageStatus {
276                variant: Some(proto::context_message_status::Variant::Error(
277                    proto::context_message_status::Error {
278                        message: message.to_string(),
279                    },
280                )),
281            },
282        }
283    }
284}
285
286/// The state pertaining to the Assistant.
287#[derive(Default)]
288struct Assistant {
289    /// Whether the Assistant is enabled.
290    enabled: bool,
291}
292
293impl Global for Assistant {}
294
295impl Assistant {
296    const NAMESPACE: &'static str = "assistant";
297
298    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
299        if self.enabled == enabled {
300            return;
301        }
302
303        self.enabled = enabled;
304
305        if !enabled {
306            CommandPaletteFilter::update_global(cx, |filter, _cx| {
307                filter.hide_namespace(Self::NAMESPACE);
308            });
309
310            return;
311        }
312
313        CommandPaletteFilter::update_global(cx, |filter, _cx| {
314            filter.show_namespace(Self::NAMESPACE);
315        });
316    }
317}
318
319pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
320    cx.set_global(Assistant::default());
321    AssistantSettings::register(cx);
322
323    cx.spawn(|mut cx| {
324        let client = client.clone();
325        async move {
326            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
327            let semantic_index = SemanticIndex::new(
328                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
329                Arc::new(embedding_provider),
330                &mut cx,
331            )
332            .await?;
333            cx.update(|cx| cx.set_global(semantic_index))
334        }
335    })
336    .detach();
337
338    context_store::init(&client);
339    prompt_library::init(cx);
340    completion_provider::init(client.clone(), cx);
341    assistant_slash_command::init(cx);
342    register_slash_commands(cx);
343    assistant_panel::init(cx);
344    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
345    terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
346    IndexedDocsRegistry::init_global(cx);
347
348    CommandPaletteFilter::update_global(cx, |filter, _cx| {
349        filter.hide_namespace(Assistant::NAMESPACE);
350    });
351    Assistant::update_global(cx, |assistant, cx| {
352        let settings = AssistantSettings::get_global(cx);
353
354        assistant.set_enabled(settings.enabled, cx);
355    });
356    cx.observe_global::<SettingsStore>(|cx| {
357        Assistant::update_global(cx, |assistant, cx| {
358            let settings = AssistantSettings::get_global(cx);
359            assistant.set_enabled(settings.enabled, cx);
360        });
361    })
362    .detach();
363}
364
365fn register_slash_commands(cx: &mut AppContext) {
366    let slash_command_registry = SlashCommandRegistry::global(cx);
367    slash_command_registry.register_command(file_command::FileSlashCommand, true);
368    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
369    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
370    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
371    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
372    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
373    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
374    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
375    slash_command_registry.register_command(term_command::TermSlashCommand, true);
376    slash_command_registry.register_command(now_command::NowSlashCommand, true);
377    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
378    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
379    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
380}
381
382pub fn humanize_token_count(count: usize) -> String {
383    match count {
384        0..=999 => count.to_string(),
385        1000..=9999 => {
386            let thousands = count / 1000;
387            let hundreds = (count % 1000 + 50) / 100;
388            if hundreds == 0 {
389                format!("{}k", thousands)
390            } else if hundreds == 10 {
391                format!("{}k", thousands + 1)
392            } else {
393                format!("{}.{}k", thousands, hundreds)
394            }
395        }
396        _ => format!("{}k", (count + 500) / 1000),
397    }
398}
399
400#[cfg(test)]
401#[ctor::ctor]
402fn init_logger() {
403    if std::env::var("RUST_LOG").is_ok() {
404        env_logger::init();
405    }
406}