assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod search;
  8mod slash_command;
  9mod streaming_diff;
 10
 11pub use assistant_panel::AssistantPanel;
 12
 13use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 14use client::{proto, Client};
 15use command_palette_hooks::CommandPaletteFilter;
 16pub(crate) use completion_provider::*;
 17use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 18pub(crate) use saved_conversation::*;
 19use serde::{Deserialize, Serialize};
 20use settings::{Settings, SettingsStore};
 21use std::{
 22    fmt::{self, Display},
 23    sync::Arc,
 24};
 25
 26actions!(
 27    assistant,
 28    [
 29        Assist,
 30        Split,
 31        CycleMessageRole,
 32        QuoteSelection,
 33        ToggleFocus,
 34        ResetKey,
 35        InlineAssist,
 36        InsertActivePrompt,
 37        ToggleIncludeConversation,
 38        ToggleHistory,
 39        ApplyEdit,
 40        ConfirmCommand
 41    ]
 42);
 43
 44#[derive(
 45    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 46)]
 47struct MessageId(usize);
 48
 49#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 50#[serde(rename_all = "lowercase")]
 51pub enum Role {
 52    User,
 53    Assistant,
 54    System,
 55}
 56
 57impl Role {
 58    pub fn cycle(&mut self) {
 59        *self = match self {
 60            Role::User => Role::Assistant,
 61            Role::Assistant => Role::System,
 62            Role::System => Role::User,
 63        }
 64    }
 65}
 66
 67impl Display for Role {
 68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 69        match self {
 70            Role::User => write!(f, "user"),
 71            Role::Assistant => write!(f, "assistant"),
 72            Role::System => write!(f, "system"),
 73        }
 74    }
 75}
 76
 77#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 78pub enum LanguageModel {
 79    ZedDotDev(ZedDotDevModel),
 80    OpenAi(OpenAiModel),
 81    Anthropic(AnthropicModel),
 82}
 83
 84impl Default for LanguageModel {
 85    fn default() -> Self {
 86        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 87    }
 88}
 89
 90impl LanguageModel {
 91    pub fn telemetry_id(&self) -> String {
 92        match self {
 93            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 94            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 95            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 96        }
 97    }
 98
 99    pub fn display_name(&self) -> String {
100        match self {
101            LanguageModel::OpenAi(model) => model.display_name().into(),
102            LanguageModel::Anthropic(model) => model.display_name().into(),
103            LanguageModel::ZedDotDev(model) => model.display_name().into(),
104        }
105    }
106
107    pub fn max_token_count(&self) -> usize {
108        match self {
109            LanguageModel::OpenAi(model) => model.max_token_count(),
110            LanguageModel::Anthropic(model) => model.max_token_count(),
111            LanguageModel::ZedDotDev(model) => model.max_token_count(),
112        }
113    }
114
115    pub fn id(&self) -> &str {
116        match self {
117            LanguageModel::OpenAi(model) => model.id(),
118            LanguageModel::Anthropic(model) => model.id(),
119            LanguageModel::ZedDotDev(model) => model.id(),
120        }
121    }
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
125pub struct LanguageModelRequestMessage {
126    pub role: Role,
127    pub content: String,
128}
129
130impl LanguageModelRequestMessage {
131    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
132        proto::LanguageModelRequestMessage {
133            role: match self.role {
134                Role::User => proto::LanguageModelRole::LanguageModelUser,
135                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
136                Role::System => proto::LanguageModelRole::LanguageModelSystem,
137            } as i32,
138            content: self.content.clone(),
139            tool_calls: Vec::new(),
140            tool_call_id: None,
141        }
142    }
143}
144
145#[derive(Debug, Default, Serialize)]
146pub struct LanguageModelRequest {
147    pub model: LanguageModel,
148    pub messages: Vec<LanguageModelRequestMessage>,
149    pub stop: Vec<String>,
150    pub temperature: f32,
151}
152
153impl LanguageModelRequest {
154    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
155        proto::CompleteWithLanguageModel {
156            model: self.model.id().to_string(),
157            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
158            stop: self.stop.clone(),
159            temperature: self.temperature,
160            tool_choice: None,
161            tools: Vec::new(),
162        }
163    }
164}
165
166#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
167pub struct LanguageModelResponseMessage {
168    pub role: Option<Role>,
169    pub content: Option<String>,
170}
171
172#[derive(Deserialize, Debug)]
173pub struct LanguageModelUsage {
174    pub prompt_tokens: u32,
175    pub completion_tokens: u32,
176    pub total_tokens: u32,
177}
178
179#[derive(Deserialize, Debug)]
180pub struct LanguageModelChoiceDelta {
181    pub index: u32,
182    pub delta: LanguageModelResponseMessage,
183    pub finish_reason: Option<String>,
184}
185
186#[derive(Clone, Debug, Serialize, Deserialize)]
187struct MessageMetadata {
188    role: Role,
189    status: MessageStatus,
190}
191
192#[derive(Clone, Debug, Serialize, Deserialize)]
193enum MessageStatus {
194    Pending,
195    Done,
196    Error(SharedString),
197}
198
199/// The state pertaining to the Assistant.
200#[derive(Default)]
201struct Assistant {
202    /// Whether the Assistant is enabled.
203    enabled: bool,
204}
205
206impl Global for Assistant {}
207
208impl Assistant {
209    const NAMESPACE: &'static str = "assistant";
210
211    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
212        if self.enabled == enabled {
213            return;
214        }
215
216        self.enabled = enabled;
217
218        if !enabled {
219            CommandPaletteFilter::update_global(cx, |filter, _cx| {
220                filter.hide_namespace(Self::NAMESPACE);
221            });
222
223            return;
224        }
225
226        CommandPaletteFilter::update_global(cx, |filter, _cx| {
227            filter.show_namespace(Self::NAMESPACE);
228        });
229    }
230}
231
232pub fn init(client: Arc<Client>, cx: &mut AppContext) {
233    cx.set_global(Assistant::default());
234    AssistantSettings::register(cx);
235    completion_provider::init(client, cx);
236    assistant_slash_command::init(cx);
237    assistant_panel::init(cx);
238
239    CommandPaletteFilter::update_global(cx, |filter, _cx| {
240        filter.hide_namespace(Assistant::NAMESPACE);
241    });
242    Assistant::update_global(cx, |assistant, cx| {
243        let settings = AssistantSettings::get_global(cx);
244
245        assistant.set_enabled(settings.enabled, cx);
246    });
247    cx.observe_global::<SettingsStore>(|cx| {
248        Assistant::update_global(cx, |assistant, cx| {
249            let settings = AssistantSettings::get_global(cx);
250
251            assistant.set_enabled(settings.enabled, cx);
252        });
253    })
254    .detach();
255}
256
257#[cfg(test)]
258#[ctor::ctor]
259fn init_logger() {
260    if std::env::var("RUST_LOG").is_ok() {
261        env_logger::init();
262    }
263}