assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod search;
  8mod slash_command;
  9mod streaming_diff;
 10
 11pub use assistant_panel::AssistantPanel;
 12
 13use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 14use client::{proto, Client};
 15use command_palette_hooks::CommandPaletteFilter;
 16pub(crate) use completion_provider::*;
 17use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 18pub(crate) use saved_conversation::*;
 19use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 20use serde::{Deserialize, Serialize};
 21use settings::{Settings, SettingsStore};
 22use std::{
 23    fmt::{self, Display},
 24    sync::Arc,
 25};
 26use util::paths::EMBEDDINGS_DIR;
 27
 28actions!(
 29    assistant,
 30    [
 31        Assist,
 32        Split,
 33        CycleMessageRole,
 34        QuoteSelection,
 35        ToggleFocus,
 36        ResetKey,
 37        InlineAssist,
 38        InsertActivePrompt,
 39        ToggleIncludeConversation,
 40        ToggleHistory,
 41        ApplyEdit,
 42        ConfirmCommand
 43    ]
 44);
 45
 46#[derive(
 47    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 48)]
 49struct MessageId(usize);
 50
 51#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 52#[serde(rename_all = "lowercase")]
 53pub enum Role {
 54    User,
 55    Assistant,
 56    System,
 57}
 58
 59impl Role {
 60    pub fn cycle(&mut self) {
 61        *self = match self {
 62            Role::User => Role::Assistant,
 63            Role::Assistant => Role::System,
 64            Role::System => Role::User,
 65        }
 66    }
 67}
 68
 69impl Display for Role {
 70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 71        match self {
 72            Role::User => write!(f, "user"),
 73            Role::Assistant => write!(f, "assistant"),
 74            Role::System => write!(f, "system"),
 75        }
 76    }
 77}
 78
 79#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 80pub enum LanguageModel {
 81    ZedDotDev(ZedDotDevModel),
 82    OpenAi(OpenAiModel),
 83    Anthropic(AnthropicModel),
 84}
 85
 86impl Default for LanguageModel {
 87    fn default() -> Self {
 88        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 89    }
 90}
 91
 92impl LanguageModel {
 93    pub fn telemetry_id(&self) -> String {
 94        match self {
 95            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 96            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 97            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 98        }
 99    }
100
101    pub fn display_name(&self) -> String {
102        match self {
103            LanguageModel::OpenAi(model) => model.display_name().into(),
104            LanguageModel::Anthropic(model) => model.display_name().into(),
105            LanguageModel::ZedDotDev(model) => model.display_name().into(),
106        }
107    }
108
109    pub fn max_token_count(&self) -> usize {
110        match self {
111            LanguageModel::OpenAi(model) => model.max_token_count(),
112            LanguageModel::Anthropic(model) => model.max_token_count(),
113            LanguageModel::ZedDotDev(model) => model.max_token_count(),
114        }
115    }
116
117    pub fn id(&self) -> &str {
118        match self {
119            LanguageModel::OpenAi(model) => model.id(),
120            LanguageModel::Anthropic(model) => model.id(),
121            LanguageModel::ZedDotDev(model) => model.id(),
122        }
123    }
124}
125
126#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
127pub struct LanguageModelRequestMessage {
128    pub role: Role,
129    pub content: String,
130}
131
132impl LanguageModelRequestMessage {
133    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
134        proto::LanguageModelRequestMessage {
135            role: match self.role {
136                Role::User => proto::LanguageModelRole::LanguageModelUser,
137                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
138                Role::System => proto::LanguageModelRole::LanguageModelSystem,
139            } as i32,
140            content: self.content.clone(),
141            tool_calls: Vec::new(),
142            tool_call_id: None,
143        }
144    }
145}
146
147#[derive(Debug, Default, Serialize)]
148pub struct LanguageModelRequest {
149    pub model: LanguageModel,
150    pub messages: Vec<LanguageModelRequestMessage>,
151    pub stop: Vec<String>,
152    pub temperature: f32,
153}
154
155impl LanguageModelRequest {
156    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
157        proto::CompleteWithLanguageModel {
158            model: self.model.id().to_string(),
159            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
160            stop: self.stop.clone(),
161            temperature: self.temperature,
162            tool_choice: None,
163            tools: Vec::new(),
164        }
165    }
166}
167
168#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
169pub struct LanguageModelResponseMessage {
170    pub role: Option<Role>,
171    pub content: Option<String>,
172}
173
174#[derive(Deserialize, Debug)]
175pub struct LanguageModelUsage {
176    pub prompt_tokens: u32,
177    pub completion_tokens: u32,
178    pub total_tokens: u32,
179}
180
181#[derive(Deserialize, Debug)]
182pub struct LanguageModelChoiceDelta {
183    pub index: u32,
184    pub delta: LanguageModelResponseMessage,
185    pub finish_reason: Option<String>,
186}
187
188#[derive(Clone, Debug, Serialize, Deserialize)]
189struct MessageMetadata {
190    role: Role,
191    status: MessageStatus,
192}
193
194#[derive(Clone, Debug, Serialize, Deserialize)]
195enum MessageStatus {
196    Pending,
197    Done,
198    Error(SharedString),
199}
200
201/// The state pertaining to the Assistant.
202#[derive(Default)]
203struct Assistant {
204    /// Whether the Assistant is enabled.
205    enabled: bool,
206}
207
208impl Global for Assistant {}
209
210impl Assistant {
211    const NAMESPACE: &'static str = "assistant";
212
213    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
214        if self.enabled == enabled {
215            return;
216        }
217
218        self.enabled = enabled;
219
220        if !enabled {
221            CommandPaletteFilter::update_global(cx, |filter, _cx| {
222                filter.hide_namespace(Self::NAMESPACE);
223            });
224
225            return;
226        }
227
228        CommandPaletteFilter::update_global(cx, |filter, _cx| {
229            filter.show_namespace(Self::NAMESPACE);
230        });
231    }
232}
233
234pub fn init(client: Arc<Client>, cx: &mut AppContext) {
235    cx.set_global(Assistant::default());
236    AssistantSettings::register(cx);
237
238    cx.spawn(|mut cx| {
239        let client = client.clone();
240        async move {
241            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
242            let semantic_index = SemanticIndex::new(
243                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
244                Arc::new(embedding_provider),
245                &mut cx,
246            )
247            .await?;
248            cx.update(|cx| cx.set_global(semantic_index))
249        }
250    })
251    .detach();
252    completion_provider::init(client, cx);
253    assistant_slash_command::init(cx);
254    assistant_panel::init(cx);
255
256    CommandPaletteFilter::update_global(cx, |filter, _cx| {
257        filter.hide_namespace(Assistant::NAMESPACE);
258    });
259    Assistant::update_global(cx, |assistant, cx| {
260        let settings = AssistantSettings::get_global(cx);
261
262        assistant.set_enabled(settings.enabled, cx);
263    });
264    cx.observe_global::<SettingsStore>(|cx| {
265        Assistant::update_global(cx, |assistant, cx| {
266            let settings = AssistantSettings::get_global(cx);
267
268            assistant.set_enabled(settings.enabled, cx);
269        });
270    })
271    .detach();
272}
273
274#[cfg(test)]
275#[ctor::ctor]
276fn init_logger() {
277    if std::env::var("RUST_LOG").is_ok() {
278        env_logger::init();
279    }
280}