assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod model_selector;
  6mod prompt_library;
  7mod prompts;
  8mod saved_conversation;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12
 13pub use assistant_panel::AssistantPanel;
 14
 15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
 16use assistant_slash_command::SlashCommandRegistry;
 17use client::{proto, Client};
 18use command_palette_hooks::CommandPaletteFilter;
 19pub(crate) use completion_provider::*;
 20use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 21pub(crate) use model_selector::*;
 22pub(crate) use saved_conversation::*;
 23use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 24use serde::{Deserialize, Serialize};
 25use settings::{Settings, SettingsStore};
 26use slash_command::{
 27    active_command, default_command, fetch_command, file_command, project_command, prompt_command,
 28    rustdoc_command, search_command, tabs_command,
 29};
 30use std::{
 31    fmt::{self, Display},
 32    sync::Arc,
 33};
 34use util::paths::EMBEDDINGS_DIR;
 35
 36actions!(
 37    assistant,
 38    [
 39        Assist,
 40        Split,
 41        CycleMessageRole,
 42        QuoteSelection,
 43        ToggleFocus,
 44        ResetKey,
 45        InlineAssist,
 46        InsertActivePrompt,
 47        ToggleHistory,
 48        ApplyEdit,
 49        ConfirmCommand,
 50        ToggleModelSelector
 51    ]
 52);
 53
 54#[derive(
 55    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 56)]
 57struct MessageId(usize);
 58
 59#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 60#[serde(rename_all = "lowercase")]
 61pub enum Role {
 62    User,
 63    Assistant,
 64    System,
 65}
 66
 67impl Role {
 68    pub fn cycle(&mut self) {
 69        *self = match self {
 70            Role::User => Role::Assistant,
 71            Role::Assistant => Role::System,
 72            Role::System => Role::User,
 73        }
 74    }
 75}
 76
 77impl Display for Role {
 78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 79        match self {
 80            Role::User => write!(f, "user"),
 81            Role::Assistant => write!(f, "assistant"),
 82            Role::System => write!(f, "system"),
 83        }
 84    }
 85}
 86
 87#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 88pub enum LanguageModel {
 89    Cloud(CloudModel),
 90    OpenAi(OpenAiModel),
 91    Anthropic(AnthropicModel),
 92}
 93
 94impl Default for LanguageModel {
 95    fn default() -> Self {
 96        LanguageModel::Cloud(CloudModel::default())
 97    }
 98}
 99
100impl LanguageModel {
101    pub fn telemetry_id(&self) -> String {
102        match self {
103            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
104            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
105            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
106        }
107    }
108
109    pub fn display_name(&self) -> String {
110        match self {
111            LanguageModel::OpenAi(model) => model.display_name().into(),
112            LanguageModel::Anthropic(model) => model.display_name().into(),
113            LanguageModel::Cloud(model) => model.display_name().into(),
114        }
115    }
116
117    pub fn max_token_count(&self) -> usize {
118        match self {
119            LanguageModel::OpenAi(model) => model.max_token_count(),
120            LanguageModel::Anthropic(model) => model.max_token_count(),
121            LanguageModel::Cloud(model) => model.max_token_count(),
122        }
123    }
124
125    pub fn id(&self) -> &str {
126        match self {
127            LanguageModel::OpenAi(model) => model.id(),
128            LanguageModel::Anthropic(model) => model.id(),
129            LanguageModel::Cloud(model) => model.id(),
130        }
131    }
132}
133
134#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
135pub struct LanguageModelRequestMessage {
136    pub role: Role,
137    pub content: String,
138}
139
140impl LanguageModelRequestMessage {
141    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
142        proto::LanguageModelRequestMessage {
143            role: match self.role {
144                Role::User => proto::LanguageModelRole::LanguageModelUser,
145                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
146                Role::System => proto::LanguageModelRole::LanguageModelSystem,
147            } as i32,
148            content: self.content.clone(),
149            tool_calls: Vec::new(),
150            tool_call_id: None,
151        }
152    }
153}
154
155#[derive(Debug, Default, Serialize)]
156pub struct LanguageModelRequest {
157    pub model: LanguageModel,
158    pub messages: Vec<LanguageModelRequestMessage>,
159    pub stop: Vec<String>,
160    pub temperature: f32,
161}
162
163impl LanguageModelRequest {
164    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
165        proto::CompleteWithLanguageModel {
166            model: self.model.id().to_string(),
167            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
168            stop: self.stop.clone(),
169            temperature: self.temperature,
170            tool_choice: None,
171            tools: Vec::new(),
172        }
173    }
174
175    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
176    pub fn preprocess(&mut self) {
177        match &self.model {
178            LanguageModel::OpenAi(_) => {}
179            LanguageModel::Anthropic(_) => {}
180            LanguageModel::Cloud(model) => match model {
181                CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
182                    preprocess_anthropic_request(self);
183                }
184                _ => {}
185            },
186        }
187    }
188}
189
190#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
191pub struct LanguageModelResponseMessage {
192    pub role: Option<Role>,
193    pub content: Option<String>,
194}
195
196#[derive(Deserialize, Debug)]
197pub struct LanguageModelUsage {
198    pub prompt_tokens: u32,
199    pub completion_tokens: u32,
200    pub total_tokens: u32,
201}
202
203#[derive(Deserialize, Debug)]
204pub struct LanguageModelChoiceDelta {
205    pub index: u32,
206    pub delta: LanguageModelResponseMessage,
207    pub finish_reason: Option<String>,
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize)]
211struct MessageMetadata {
212    role: Role,
213    status: MessageStatus,
214}
215
216#[derive(Clone, Debug, Serialize, Deserialize)]
217enum MessageStatus {
218    Pending,
219    Done,
220    Error(SharedString),
221}
222
223/// The state pertaining to the Assistant.
224#[derive(Default)]
225struct Assistant {
226    /// Whether the Assistant is enabled.
227    enabled: bool,
228}
229
230impl Global for Assistant {}
231
232impl Assistant {
233    const NAMESPACE: &'static str = "assistant";
234
235    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
236        if self.enabled == enabled {
237            return;
238        }
239
240        self.enabled = enabled;
241
242        if !enabled {
243            CommandPaletteFilter::update_global(cx, |filter, _cx| {
244                filter.hide_namespace(Self::NAMESPACE);
245            });
246
247            return;
248        }
249
250        CommandPaletteFilter::update_global(cx, |filter, _cx| {
251            filter.show_namespace(Self::NAMESPACE);
252        });
253    }
254}
255
256pub fn init(client: Arc<Client>, cx: &mut AppContext) {
257    cx.set_global(Assistant::default());
258    AssistantSettings::register(cx);
259
260    cx.spawn(|mut cx| {
261        let client = client.clone();
262        async move {
263            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
264            let semantic_index = SemanticIndex::new(
265                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
266                Arc::new(embedding_provider),
267                &mut cx,
268            )
269            .await?;
270            cx.update(|cx| cx.set_global(semantic_index))
271        }
272    })
273    .detach();
274
275    prompt_library::init(cx);
276    completion_provider::init(client, cx);
277    assistant_slash_command::init(cx);
278    register_slash_commands(cx);
279    assistant_panel::init(cx);
280
281    CommandPaletteFilter::update_global(cx, |filter, _cx| {
282        filter.hide_namespace(Assistant::NAMESPACE);
283    });
284    Assistant::update_global(cx, |assistant, cx| {
285        let settings = AssistantSettings::get_global(cx);
286
287        assistant.set_enabled(settings.enabled, cx);
288    });
289    cx.observe_global::<SettingsStore>(|cx| {
290        Assistant::update_global(cx, |assistant, cx| {
291            let settings = AssistantSettings::get_global(cx);
292            assistant.set_enabled(settings.enabled, cx);
293        });
294    })
295    .detach();
296}
297
298fn register_slash_commands(cx: &mut AppContext) {
299    let slash_command_registry = SlashCommandRegistry::global(cx);
300    slash_command_registry.register_command(file_command::FileSlashCommand, true);
301    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
302    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
303    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
304    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
305    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
306    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
307    slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
308    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
309}
310
311#[cfg(test)]
312#[ctor::ctor]
313fn init_logger() {
314    if std::env::var("RUST_LOG").is_ok() {
315        env_logger::init();
316    }
317}