assistant.rs

  1mod ambient_context;
  2pub mod assistant_panel;
  3pub mod assistant_settings;
  4mod codegen;
  5mod completion_provider;
  6mod omit_ranges;
  7mod prompts;
  8mod saved_conversation;
  9mod search;
 10mod slash_command;
 11mod streaming_diff;
 12
 13use ambient_context::AmbientContextSnapshot;
 14pub use assistant_panel::AssistantPanel;
 15use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 16use client::{proto, Client};
 17use command_palette_hooks::CommandPaletteFilter;
 18pub(crate) use completion_provider::*;
 19use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 20pub(crate) use prompts::prompt_library::*;
 21pub(crate) use saved_conversation::*;
 22use serde::{Deserialize, Serialize};
 23use settings::{Settings, SettingsStore};
 24use std::{
 25    fmt::{self, Display},
 26    sync::Arc,
 27};
 28
 29actions!(
 30    assistant,
 31    [
 32        Assist,
 33        Split,
 34        CycleMessageRole,
 35        QuoteSelection,
 36        ToggleFocus,
 37        ResetKey,
 38        InlineAssist,
 39        InsertActivePrompt,
 40        ToggleIncludeConversation,
 41        ToggleHistory,
 42        ApplyEdit
 43    ]
 44);
 45
 46#[derive(
 47    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 48)]
 49struct MessageId(usize);
 50
 51#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 52#[serde(rename_all = "lowercase")]
 53pub enum Role {
 54    User,
 55    Assistant,
 56    System,
 57}
 58
 59impl Role {
 60    pub fn cycle(&mut self) {
 61        *self = match self {
 62            Role::User => Role::Assistant,
 63            Role::Assistant => Role::System,
 64            Role::System => Role::User,
 65        }
 66    }
 67}
 68
 69impl Display for Role {
 70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 71        match self {
 72            Role::User => write!(f, "user"),
 73            Role::Assistant => write!(f, "assistant"),
 74            Role::System => write!(f, "system"),
 75        }
 76    }
 77}
 78
 79#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 80pub enum LanguageModel {
 81    ZedDotDev(ZedDotDevModel),
 82    OpenAi(OpenAiModel),
 83    Anthropic(AnthropicModel),
 84}
 85
 86impl Default for LanguageModel {
 87    fn default() -> Self {
 88        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 89    }
 90}
 91
 92impl LanguageModel {
 93    pub fn telemetry_id(&self) -> String {
 94        match self {
 95            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 96            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 97            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 98        }
 99    }
100
101    pub fn display_name(&self) -> String {
102        match self {
103            LanguageModel::OpenAi(model) => model.display_name().into(),
104            LanguageModel::Anthropic(model) => model.display_name().into(),
105            LanguageModel::ZedDotDev(model) => model.display_name().into(),
106        }
107    }
108
109    pub fn max_token_count(&self) -> usize {
110        match self {
111            LanguageModel::OpenAi(model) => model.max_token_count(),
112            LanguageModel::Anthropic(model) => model.max_token_count(),
113            LanguageModel::ZedDotDev(model) => model.max_token_count(),
114        }
115    }
116
117    pub fn id(&self) -> &str {
118        match self {
119            LanguageModel::OpenAi(model) => model.id(),
120            LanguageModel::Anthropic(model) => model.id(),
121            LanguageModel::ZedDotDev(model) => model.id(),
122        }
123    }
124}
125
126#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
127pub struct LanguageModelRequestMessage {
128    pub role: Role,
129    pub content: String,
130}
131
132impl LanguageModelRequestMessage {
133    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
134        proto::LanguageModelRequestMessage {
135            role: match self.role {
136                Role::User => proto::LanguageModelRole::LanguageModelUser,
137                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
138                Role::System => proto::LanguageModelRole::LanguageModelSystem,
139            } as i32,
140            content: self.content.clone(),
141            tool_calls: Vec::new(),
142            tool_call_id: None,
143        }
144    }
145}
146
147#[derive(Debug, Default, Serialize)]
148pub struct LanguageModelRequest {
149    pub model: LanguageModel,
150    pub messages: Vec<LanguageModelRequestMessage>,
151    pub stop: Vec<String>,
152    pub temperature: f32,
153}
154
155impl LanguageModelRequest {
156    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
157        proto::CompleteWithLanguageModel {
158            model: self.model.id().to_string(),
159            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
160            stop: self.stop.clone(),
161            temperature: self.temperature,
162            tool_choice: None,
163            tools: Vec::new(),
164        }
165    }
166}
167
168#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
169pub struct LanguageModelResponseMessage {
170    pub role: Option<Role>,
171    pub content: Option<String>,
172}
173
174#[derive(Deserialize, Debug)]
175pub struct LanguageModelUsage {
176    pub prompt_tokens: u32,
177    pub completion_tokens: u32,
178    pub total_tokens: u32,
179}
180
181#[derive(Deserialize, Debug)]
182pub struct LanguageModelChoiceDelta {
183    pub index: u32,
184    pub delta: LanguageModelResponseMessage,
185    pub finish_reason: Option<String>,
186}
187
188#[derive(Clone, Debug, Serialize, Deserialize)]
189struct MessageMetadata {
190    role: Role,
191    status: MessageStatus,
192    // todo!("delete this")
193    #[serde(skip)]
194    ambient_context: AmbientContextSnapshot,
195}
196
197#[derive(Clone, Debug, Serialize, Deserialize)]
198enum MessageStatus {
199    Pending,
200    Done,
201    Error(SharedString),
202}
203
204/// The state pertaining to the Assistant.
205#[derive(Default)]
206struct Assistant {
207    /// Whether the Assistant is enabled.
208    enabled: bool,
209}
210
211impl Global for Assistant {}
212
213impl Assistant {
214    const NAMESPACE: &'static str = "assistant";
215
216    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
217        if self.enabled == enabled {
218            return;
219        }
220
221        self.enabled = enabled;
222
223        if !enabled {
224            CommandPaletteFilter::update_global(cx, |filter, _cx| {
225                filter.hide_namespace(Self::NAMESPACE);
226            });
227
228            return;
229        }
230
231        CommandPaletteFilter::update_global(cx, |filter, _cx| {
232            filter.show_namespace(Self::NAMESPACE);
233        });
234    }
235}
236
237pub fn init(client: Arc<Client>, cx: &mut AppContext) {
238    cx.set_global(Assistant::default());
239    AssistantSettings::register(cx);
240    completion_provider::init(client, cx);
241    assistant_panel::init(cx);
242
243    CommandPaletteFilter::update_global(cx, |filter, _cx| {
244        filter.hide_namespace(Assistant::NAMESPACE);
245    });
246    Assistant::update_global(cx, |assistant, cx| {
247        let settings = AssistantSettings::get_global(cx);
248
249        assistant.set_enabled(settings.enabled, cx);
250    });
251    cx.observe_global::<SettingsStore>(|cx| {
252        Assistant::update_global(cx, |assistant, cx| {
253            let settings = AssistantSettings::get_global(cx);
254
255            assistant.set_enabled(settings.enabled, cx);
256        });
257    })
258    .detach();
259}
260
261#[cfg(test)]
262#[ctor::ctor]
263fn init_logger() {
264    if std::env::var("RUST_LOG").is_ok() {
265        env_logger::init();
266    }
267}