assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod completion_provider;
  4mod inline_assistant;
  5mod model_selector;
  6mod prompt_library;
  7mod prompts;
  8mod saved_conversation;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12
 13pub use assistant_panel::AssistantPanel;
 14
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub(crate) use completion_provider::*;
 20use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 21pub(crate) use inline_assistant::*;
 22pub(crate) use model_selector::*;
 23pub(crate) use saved_conversation::*;
 24use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 25use serde::{Deserialize, Serialize};
 26use settings::{Settings, SettingsStore};
 27use slash_command::{
 28    active_command, default_command, fetch_command, file_command, project_command, prompt_command,
 29    rustdoc_command, search_command, tabs_command,
 30};
 31use std::{
 32    fmt::{self, Display},
 33    sync::Arc,
 34};
 35pub(crate) use streaming_diff::*;
 36use util::paths::EMBEDDINGS_DIR;
 37
 38actions!(
 39    assistant,
 40    [
 41        Assist,
 42        Split,
 43        CycleMessageRole,
 44        QuoteSelection,
 45        ToggleFocus,
 46        ResetKey,
 47        InlineAssist,
 48        InsertActivePrompt,
 49        ToggleHistory,
 50        ApplyEdit,
 51        ConfirmCommand,
 52        ToggleModelSelector
 53    ]
 54);
 55
 56#[derive(
 57    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 58)]
 59struct MessageId(usize);
 60
 61#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 62#[serde(rename_all = "lowercase")]
 63pub enum Role {
 64    User,
 65    Assistant,
 66    System,
 67}
 68
 69impl Role {
 70    pub fn cycle(&mut self) {
 71        *self = match self {
 72            Role::User => Role::Assistant,
 73            Role::Assistant => Role::System,
 74            Role::System => Role::User,
 75        }
 76    }
 77}
 78
 79impl Display for Role {
 80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 81        match self {
 82            Role::User => write!(f, "user"),
 83            Role::Assistant => write!(f, "assistant"),
 84            Role::System => write!(f, "system"),
 85        }
 86    }
 87}
 88
 89#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 90pub enum LanguageModel {
 91    Cloud(CloudModel),
 92    OpenAi(OpenAiModel),
 93    Anthropic(AnthropicModel),
 94}
 95
 96impl Default for LanguageModel {
 97    fn default() -> Self {
 98        LanguageModel::Cloud(CloudModel::default())
 99    }
100}
101
102impl LanguageModel {
103    pub fn telemetry_id(&self) -> String {
104        match self {
105            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
106            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
107            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
108        }
109    }
110
111    pub fn display_name(&self) -> String {
112        match self {
113            LanguageModel::OpenAi(model) => model.display_name().into(),
114            LanguageModel::Anthropic(model) => model.display_name().into(),
115            LanguageModel::Cloud(model) => model.display_name().into(),
116        }
117    }
118
119    pub fn max_token_count(&self) -> usize {
120        match self {
121            LanguageModel::OpenAi(model) => model.max_token_count(),
122            LanguageModel::Anthropic(model) => model.max_token_count(),
123            LanguageModel::Cloud(model) => model.max_token_count(),
124        }
125    }
126
127    pub fn id(&self) -> &str {
128        match self {
129            LanguageModel::OpenAi(model) => model.id(),
130            LanguageModel::Anthropic(model) => model.id(),
131            LanguageModel::Cloud(model) => model.id(),
132        }
133    }
134}
135
136#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
137pub struct LanguageModelRequestMessage {
138    pub role: Role,
139    pub content: String,
140}
141
142impl LanguageModelRequestMessage {
143    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
144        proto::LanguageModelRequestMessage {
145            role: match self.role {
146                Role::User => proto::LanguageModelRole::LanguageModelUser,
147                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
148                Role::System => proto::LanguageModelRole::LanguageModelSystem,
149            } as i32,
150            content: self.content.clone(),
151            tool_calls: Vec::new(),
152            tool_call_id: None,
153        }
154    }
155}
156
157#[derive(Debug, Default, Serialize)]
158pub struct LanguageModelRequest {
159    pub model: LanguageModel,
160    pub messages: Vec<LanguageModelRequestMessage>,
161    pub stop: Vec<String>,
162    pub temperature: f32,
163}
164
165impl LanguageModelRequest {
166    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
167        proto::CompleteWithLanguageModel {
168            model: self.model.id().to_string(),
169            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
170            stop: self.stop.clone(),
171            temperature: self.temperature,
172            tool_choice: None,
173            tools: Vec::new(),
174        }
175    }
176
177    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
178    pub fn preprocess(&mut self) {
179        match &self.model {
180            LanguageModel::OpenAi(_) => {}
181            LanguageModel::Anthropic(_) => {}
182            LanguageModel::Cloud(model) => match model {
183                CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
184                    preprocess_anthropic_request(self);
185                }
186                _ => {}
187            },
188        }
189    }
190}
191
192#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
193pub struct LanguageModelResponseMessage {
194    pub role: Option<Role>,
195    pub content: Option<String>,
196}
197
198#[derive(Deserialize, Debug)]
199pub struct LanguageModelUsage {
200    pub prompt_tokens: u32,
201    pub completion_tokens: u32,
202    pub total_tokens: u32,
203}
204
205#[derive(Deserialize, Debug)]
206pub struct LanguageModelChoiceDelta {
207    pub index: u32,
208    pub delta: LanguageModelResponseMessage,
209    pub finish_reason: Option<String>,
210}
211
212#[derive(Clone, Debug, Serialize, Deserialize)]
213struct MessageMetadata {
214    role: Role,
215    status: MessageStatus,
216}
217
218#[derive(Clone, Debug, Serialize, Deserialize)]
219enum MessageStatus {
220    Pending,
221    Done,
222    Error(SharedString),
223}
224
225/// The state pertaining to the Assistant.
226#[derive(Default)]
227struct Assistant {
228    /// Whether the Assistant is enabled.
229    enabled: bool,
230}
231
232impl Global for Assistant {}
233
234impl Assistant {
235    const NAMESPACE: &'static str = "assistant";
236
237    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
238        if self.enabled == enabled {
239            return;
240        }
241
242        self.enabled = enabled;
243
244        if !enabled {
245            CommandPaletteFilter::update_global(cx, |filter, _cx| {
246                filter.hide_namespace(Self::NAMESPACE);
247            });
248
249            return;
250        }
251
252        CommandPaletteFilter::update_global(cx, |filter, _cx| {
253            filter.show_namespace(Self::NAMESPACE);
254        });
255    }
256}
257
258pub fn init(client: Arc<Client>, cx: &mut AppContext) {
259    cx.set_global(Assistant::default());
260    AssistantSettings::register(cx);
261
262    cx.spawn(|mut cx| {
263        let client = client.clone();
264        async move {
265            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
266            let semantic_index = SemanticIndex::new(
267                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
268                Arc::new(embedding_provider),
269                &mut cx,
270            )
271            .await?;
272            cx.update(|cx| cx.set_global(semantic_index))
273        }
274    })
275    .detach();
276
277    prompt_library::init(cx);
278    completion_provider::init(client.clone(), cx);
279    assistant_slash_command::init(cx);
280    register_slash_commands(cx);
281    assistant_panel::init(cx);
282    inline_assistant::init(client.telemetry().clone(), cx);
283
284    CommandPaletteFilter::update_global(cx, |filter, _cx| {
285        filter.hide_namespace(Assistant::NAMESPACE);
286    });
287    Assistant::update_global(cx, |assistant, cx| {
288        let settings = AssistantSettings::get_global(cx);
289
290        assistant.set_enabled(settings.enabled, cx);
291    });
292    cx.observe_global::<SettingsStore>(|cx| {
293        Assistant::update_global(cx, |assistant, cx| {
294            let settings = AssistantSettings::get_global(cx);
295            assistant.set_enabled(settings.enabled, cx);
296        });
297    })
298    .detach();
299}
300
301fn register_slash_commands(cx: &mut AppContext) {
302    let slash_command_registry = SlashCommandRegistry::global(cx);
303    slash_command_registry.register_command(file_command::FileSlashCommand, true);
304    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
305    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
306    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
307    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
308    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
309    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
310    slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
311    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
312}
313
314#[cfg(test)]
315#[ctor::ctor]
316fn init_logger() {
317    if std::env::var("RUST_LOG").is_ok() {
318        env_logger::init();
319    }
320}