assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod streaming_diff;
  8
  9pub use assistant_panel::AssistantPanel;
 10use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 11use client::{proto, Client};
 12use command_palette_hooks::CommandPaletteFilter;
 13pub(crate) use completion_provider::*;
 14use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
 15pub(crate) use saved_conversation::*;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{
 19    fmt::{self, Display},
 20    sync::Arc,
 21};
 22
 23actions!(
 24    assistant,
 25    [
 26        Assist,
 27        Split,
 28        CycleMessageRole,
 29        QuoteSelection,
 30        ToggleFocus,
 31        ResetKey,
 32        InlineAssist,
 33        ToggleIncludeConversation,
 34        ToggleHistory,
 35    ]
 36);
 37
 38#[derive(
 39    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 40)]
 41struct MessageId(usize);
 42
 43#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 44#[serde(rename_all = "lowercase")]
 45pub enum Role {
 46    User,
 47    Assistant,
 48    System,
 49}
 50
 51impl Role {
 52    pub fn cycle(&mut self) {
 53        *self = match self {
 54            Role::User => Role::Assistant,
 55            Role::Assistant => Role::System,
 56            Role::System => Role::User,
 57        }
 58    }
 59}
 60
 61impl Display for Role {
 62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 63        match self {
 64            Role::User => write!(f, "user"),
 65            Role::Assistant => write!(f, "assistant"),
 66            Role::System => write!(f, "system"),
 67        }
 68    }
 69}
 70
 71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 72pub enum LanguageModel {
 73    ZedDotDev(ZedDotDevModel),
 74    OpenAi(OpenAiModel),
 75    Anthropic(AnthropicModel),
 76}
 77
 78impl Default for LanguageModel {
 79    fn default() -> Self {
 80        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 81    }
 82}
 83
 84impl LanguageModel {
 85    pub fn telemetry_id(&self) -> String {
 86        match self {
 87            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 88            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 89            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 90        }
 91    }
 92
 93    pub fn display_name(&self) -> String {
 94        match self {
 95            LanguageModel::OpenAi(model) => model.display_name().into(),
 96            LanguageModel::Anthropic(model) => model.display_name().into(),
 97            LanguageModel::ZedDotDev(model) => model.display_name().into(),
 98        }
 99    }
100
101    pub fn max_token_count(&self) -> usize {
102        match self {
103            LanguageModel::OpenAi(model) => model.max_token_count(),
104            LanguageModel::Anthropic(model) => model.max_token_count(),
105            LanguageModel::ZedDotDev(model) => model.max_token_count(),
106        }
107    }
108
109    pub fn id(&self) -> &str {
110        match self {
111            LanguageModel::OpenAi(model) => model.id(),
112            LanguageModel::Anthropic(model) => model.id(),
113            LanguageModel::ZedDotDev(model) => model.id(),
114        }
115    }
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119pub struct LanguageModelRequestMessage {
120    pub role: Role,
121    pub content: String,
122}
123
124impl LanguageModelRequestMessage {
125    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
126        proto::LanguageModelRequestMessage {
127            role: match self.role {
128                Role::User => proto::LanguageModelRole::LanguageModelUser,
129                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
130                Role::System => proto::LanguageModelRole::LanguageModelSystem,
131            } as i32,
132            content: self.content.clone(),
133            tool_calls: Vec::new(),
134            tool_call_id: None,
135        }
136    }
137}
138
139#[derive(Debug, Default, Serialize)]
140pub struct LanguageModelRequest {
141    pub model: LanguageModel,
142    pub messages: Vec<LanguageModelRequestMessage>,
143    pub stop: Vec<String>,
144    pub temperature: f32,
145}
146
147impl LanguageModelRequest {
148    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
149        proto::CompleteWithLanguageModel {
150            model: self.model.id().to_string(),
151            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
152            stop: self.stop.clone(),
153            temperature: self.temperature,
154            tool_choice: None,
155            tools: Vec::new(),
156        }
157    }
158}
159
160#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
161pub struct LanguageModelResponseMessage {
162    pub role: Option<Role>,
163    pub content: Option<String>,
164}
165
166#[derive(Deserialize, Debug)]
167pub struct LanguageModelUsage {
168    pub prompt_tokens: u32,
169    pub completion_tokens: u32,
170    pub total_tokens: u32,
171}
172
173#[derive(Deserialize, Debug)]
174pub struct LanguageModelChoiceDelta {
175    pub index: u32,
176    pub delta: LanguageModelResponseMessage,
177    pub finish_reason: Option<String>,
178}
179
180#[derive(Clone, Debug, Serialize, Deserialize)]
181struct MessageMetadata {
182    role: Role,
183    status: MessageStatus,
184}
185
186#[derive(Clone, Debug, Serialize, Deserialize)]
187enum MessageStatus {
188    Pending,
189    Done,
190    Error(SharedString),
191}
192
193/// The state pertaining to the Assistant.
194#[derive(Default)]
195struct Assistant {
196    /// Whether the Assistant is enabled.
197    enabled: bool,
198}
199
200impl Global for Assistant {}
201
202impl Assistant {
203    const NAMESPACE: &'static str = "assistant";
204
205    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
206        if self.enabled == enabled {
207            return;
208        }
209
210        self.enabled = enabled;
211
212        if !enabled {
213            CommandPaletteFilter::update_global(cx, |filter, _cx| {
214                filter.hide_namespace(Self::NAMESPACE);
215            });
216
217            return;
218        }
219
220        CommandPaletteFilter::update_global(cx, |filter, _cx| {
221            filter.show_namespace(Self::NAMESPACE);
222        });
223    }
224}
225
226pub fn init(client: Arc<Client>, cx: &mut AppContext) {
227    cx.set_global(Assistant::default());
228    AssistantSettings::register(cx);
229    completion_provider::init(client, cx);
230    assistant_panel::init(cx);
231
232    CommandPaletteFilter::update_global(cx, |filter, _cx| {
233        filter.hide_namespace(Assistant::NAMESPACE);
234    });
235    cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
236        let settings = AssistantSettings::get_global(cx);
237
238        assistant.set_enabled(settings.enabled, cx);
239    });
240    cx.observe_global::<SettingsStore>(|cx| {
241        cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
242            let settings = AssistantSettings::get_global(cx);
243
244            assistant.set_enabled(settings.enabled, cx);
245        });
246    })
247    .detach();
248}
249
250#[cfg(test)]
251#[ctor::ctor]
252fn init_logger() {
253    if std::env::var("RUST_LOG").is_ok() {
254        env_logger::init();
255    }
256}