assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12mod terminal_inline_assistant;
 13
 14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub(crate) use completion_provider::*;
 20pub(crate) use context_store::*;
 21use fs::Fs;
 22use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 23use indexed_docs::IndexedDocsRegistry;
 24pub(crate) use inline_assistant::*;
 25pub(crate) use model_selector::*;
 26use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 27use serde::{Deserialize, Serialize};
 28use settings::{Settings, SettingsStore};
 29use slash_command::{
 30    active_command, default_command, diagnostics_command, fetch_command, file_command, now_command,
 31    project_command, prompt_command, rustdoc_command, search_command, tabs_command, term_command,
 32};
 33use std::{
 34    fmt::{self, Display},
 35    sync::Arc,
 36};
 37pub(crate) use streaming_diff::*;
 38
 39actions!(
 40    assistant,
 41    [
 42        Assist,
 43        Split,
 44        CycleMessageRole,
 45        QuoteSelection,
 46        InsertIntoEditor,
 47        ToggleFocus,
 48        ResetKey,
 49        InlineAssist,
 50        InsertActivePrompt,
 51        ToggleHistory,
 52        ApplyEdit,
 53        ConfirmCommand,
 54        ToggleModelSelector
 55    ]
 56);
 57
 58#[derive(
 59    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 60)]
 61struct MessageId(usize);
 62
 63#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 64#[serde(rename_all = "lowercase")]
 65pub enum Role {
 66    User,
 67    Assistant,
 68    System,
 69}
 70
 71impl Role {
 72    pub fn cycle(&mut self) {
 73        *self = match self {
 74            Role::User => Role::Assistant,
 75            Role::Assistant => Role::System,
 76            Role::System => Role::User,
 77        }
 78    }
 79}
 80
 81impl Display for Role {
 82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 83        match self {
 84            Role::User => write!(f, "user"),
 85            Role::Assistant => write!(f, "assistant"),
 86            Role::System => write!(f, "system"),
 87        }
 88    }
 89}
 90
 91#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 92pub enum LanguageModel {
 93    Cloud(CloudModel),
 94    OpenAi(OpenAiModel),
 95    Anthropic(AnthropicModel),
 96    Ollama(OllamaModel),
 97}
 98
 99impl Default for LanguageModel {
100    fn default() -> Self {
101        LanguageModel::Cloud(CloudModel::default())
102    }
103}
104
105impl LanguageModel {
106    pub fn telemetry_id(&self) -> String {
107        match self {
108            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
109            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
110            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
111            LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
112        }
113    }
114
115    pub fn display_name(&self) -> String {
116        match self {
117            LanguageModel::OpenAi(model) => model.display_name().into(),
118            LanguageModel::Anthropic(model) => model.display_name().into(),
119            LanguageModel::Cloud(model) => model.display_name().into(),
120            LanguageModel::Ollama(model) => model.display_name().into(),
121        }
122    }
123
124    pub fn max_token_count(&self) -> usize {
125        match self {
126            LanguageModel::OpenAi(model) => model.max_token_count(),
127            LanguageModel::Anthropic(model) => model.max_token_count(),
128            LanguageModel::Cloud(model) => model.max_token_count(),
129            LanguageModel::Ollama(model) => model.max_token_count(),
130        }
131    }
132
133    pub fn id(&self) -> &str {
134        match self {
135            LanguageModel::OpenAi(model) => model.id(),
136            LanguageModel::Anthropic(model) => model.id(),
137            LanguageModel::Cloud(model) => model.id(),
138            LanguageModel::Ollama(model) => model.id(),
139        }
140    }
141}
142
143#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
144pub struct LanguageModelRequestMessage {
145    pub role: Role,
146    pub content: String,
147}
148
149impl LanguageModelRequestMessage {
150    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
151        proto::LanguageModelRequestMessage {
152            role: match self.role {
153                Role::User => proto::LanguageModelRole::LanguageModelUser,
154                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
155                Role::System => proto::LanguageModelRole::LanguageModelSystem,
156            } as i32,
157            content: self.content.clone(),
158            tool_calls: Vec::new(),
159            tool_call_id: None,
160        }
161    }
162}
163
164#[derive(Debug, Default, Serialize)]
165pub struct LanguageModelRequest {
166    pub model: LanguageModel,
167    pub messages: Vec<LanguageModelRequestMessage>,
168    pub stop: Vec<String>,
169    pub temperature: f32,
170}
171
172impl LanguageModelRequest {
173    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
174        proto::CompleteWithLanguageModel {
175            model: self.model.id().to_string(),
176            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
177            stop: self.stop.clone(),
178            temperature: self.temperature,
179            tool_choice: None,
180            tools: Vec::new(),
181        }
182    }
183
184    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
185    pub fn preprocess(&mut self) {
186        match &self.model {
187            LanguageModel::OpenAi(_) => {}
188            LanguageModel::Anthropic(_) => {}
189            LanguageModel::Ollama(_) => {}
190            LanguageModel::Cloud(model) => match model {
191                CloudModel::Claude3Opus
192                | CloudModel::Claude3Sonnet
193                | CloudModel::Claude3Haiku
194                | CloudModel::Claude3_5Sonnet => {
195                    preprocess_anthropic_request(self);
196                }
197                _ => {}
198            },
199        }
200    }
201}
202
203#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
204pub struct LanguageModelResponseMessage {
205    pub role: Option<Role>,
206    pub content: Option<String>,
207}
208
209#[derive(Deserialize, Debug)]
210pub struct LanguageModelUsage {
211    pub prompt_tokens: u32,
212    pub completion_tokens: u32,
213    pub total_tokens: u32,
214}
215
216#[derive(Deserialize, Debug)]
217pub struct LanguageModelChoiceDelta {
218    pub index: u32,
219    pub delta: LanguageModelResponseMessage,
220    pub finish_reason: Option<String>,
221}
222
223#[derive(Clone, Debug, Serialize, Deserialize)]
224struct MessageMetadata {
225    role: Role,
226    status: MessageStatus,
227}
228
229#[derive(Clone, Debug, Serialize, Deserialize)]
230enum MessageStatus {
231    Pending,
232    Done,
233    Error(SharedString),
234}
235
236/// The state pertaining to the Assistant.
237#[derive(Default)]
238struct Assistant {
239    /// Whether the Assistant is enabled.
240    enabled: bool,
241}
242
243impl Global for Assistant {}
244
245impl Assistant {
246    const NAMESPACE: &'static str = "assistant";
247
248    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
249        if self.enabled == enabled {
250            return;
251        }
252
253        self.enabled = enabled;
254
255        if !enabled {
256            CommandPaletteFilter::update_global(cx, |filter, _cx| {
257                filter.hide_namespace(Self::NAMESPACE);
258            });
259
260            return;
261        }
262
263        CommandPaletteFilter::update_global(cx, |filter, _cx| {
264            filter.show_namespace(Self::NAMESPACE);
265        });
266    }
267}
268
269pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
270    cx.set_global(Assistant::default());
271    AssistantSettings::register(cx);
272
273    cx.spawn(|mut cx| {
274        let client = client.clone();
275        async move {
276            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
277            let semantic_index = SemanticIndex::new(
278                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
279                Arc::new(embedding_provider),
280                &mut cx,
281            )
282            .await?;
283            cx.update(|cx| cx.set_global(semantic_index))
284        }
285    })
286    .detach();
287
288    prompt_library::init(cx);
289    completion_provider::init(client.clone(), cx);
290    assistant_slash_command::init(cx);
291    register_slash_commands(cx);
292    assistant_panel::init(cx);
293    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
294    terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
295    IndexedDocsRegistry::init_global(cx);
296
297    CommandPaletteFilter::update_global(cx, |filter, _cx| {
298        filter.hide_namespace(Assistant::NAMESPACE);
299    });
300    Assistant::update_global(cx, |assistant, cx| {
301        let settings = AssistantSettings::get_global(cx);
302
303        assistant.set_enabled(settings.enabled, cx);
304    });
305    cx.observe_global::<SettingsStore>(|cx| {
306        Assistant::update_global(cx, |assistant, cx| {
307            let settings = AssistantSettings::get_global(cx);
308            assistant.set_enabled(settings.enabled, cx);
309        });
310    })
311    .detach();
312}
313
314fn register_slash_commands(cx: &mut AppContext) {
315    let slash_command_registry = SlashCommandRegistry::global(cx);
316    slash_command_registry.register_command(file_command::FileSlashCommand, true);
317    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
318    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
319    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
320    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
321    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
322    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
323    slash_command_registry.register_command(term_command::TermSlashCommand, true);
324    slash_command_registry.register_command(now_command::NowSlashCommand, true);
325    slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
326    slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
327    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
328}
329
330pub fn humanize_token_count(count: usize) -> String {
331    match count {
332        0..=999 => count.to_string(),
333        1000..=9999 => {
334            let thousands = count / 1000;
335            let hundreds = (count % 1000 + 50) / 100;
336            if hundreds == 0 {
337                format!("{}k", thousands)
338            } else if hundreds == 10 {
339                format!("{}k", thousands + 1)
340            } else {
341                format!("{}.{}k", thousands, hundreds)
342            }
343        }
344        _ => format!("{}k", (count + 500) / 1000),
345    }
346}
347
348#[cfg(test)]
349#[ctor::ctor]
350fn init_logger() {
351    if std::env::var("RUST_LOG").is_ok() {
352        env_logger::init();
353    }
354}