1mod ambient_context;
2pub mod assistant_panel;
3pub mod assistant_settings;
4mod codegen;
5mod completion_provider;
6mod prompts;
7mod saved_conversation;
8mod streaming_diff;
9
10pub use assistant_panel::AssistantPanel;
11use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
12use client::{proto, Client};
13use command_palette_hooks::CommandPaletteFilter;
14pub(crate) use completion_provider::*;
15use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
16pub(crate) use saved_conversation::*;
17use serde::{Deserialize, Serialize};
18use settings::{Settings, SettingsStore};
19use std::{
20 fmt::{self, Display},
21 sync::Arc,
22};
23
24actions!(
25 assistant,
26 [
27 Assist,
28 Split,
29 CycleMessageRole,
30 QuoteSelection,
31 ToggleFocus,
32 ResetKey,
33 InlineAssist,
34 ToggleIncludeConversation,
35 ToggleHistory,
36 ]
37);
38
39#[derive(
40 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
41)]
42struct MessageId(usize);
43
44#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
45#[serde(rename_all = "lowercase")]
46pub enum Role {
47 User,
48 Assistant,
49 System,
50}
51
52impl Role {
53 pub fn cycle(&mut self) {
54 *self = match self {
55 Role::User => Role::Assistant,
56 Role::Assistant => Role::System,
57 Role::System => Role::User,
58 }
59 }
60}
61
62impl Display for Role {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 Role::User => write!(f, "user"),
66 Role::Assistant => write!(f, "assistant"),
67 Role::System => write!(f, "system"),
68 }
69 }
70}
71
72#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
73pub enum LanguageModel {
74 ZedDotDev(ZedDotDevModel),
75 OpenAi(OpenAiModel),
76 Anthropic(AnthropicModel),
77}
78
79impl Default for LanguageModel {
80 fn default() -> Self {
81 LanguageModel::ZedDotDev(ZedDotDevModel::default())
82 }
83}
84
85impl LanguageModel {
86 pub fn telemetry_id(&self) -> String {
87 match self {
88 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
89 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
90 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
91 }
92 }
93
94 pub fn display_name(&self) -> String {
95 match self {
96 LanguageModel::OpenAi(model) => model.display_name().into(),
97 LanguageModel::Anthropic(model) => model.display_name().into(),
98 LanguageModel::ZedDotDev(model) => model.display_name().into(),
99 }
100 }
101
102 pub fn max_token_count(&self) -> usize {
103 match self {
104 LanguageModel::OpenAi(model) => model.max_token_count(),
105 LanguageModel::Anthropic(model) => model.max_token_count(),
106 LanguageModel::ZedDotDev(model) => model.max_token_count(),
107 }
108 }
109
110 pub fn id(&self) -> &str {
111 match self {
112 LanguageModel::OpenAi(model) => model.id(),
113 LanguageModel::Anthropic(model) => model.id(),
114 LanguageModel::ZedDotDev(model) => model.id(),
115 }
116 }
117}
118
119#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
120pub struct LanguageModelRequestMessage {
121 pub role: Role,
122 pub content: String,
123}
124
125impl LanguageModelRequestMessage {
126 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
127 proto::LanguageModelRequestMessage {
128 role: match self.role {
129 Role::User => proto::LanguageModelRole::LanguageModelUser,
130 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
131 Role::System => proto::LanguageModelRole::LanguageModelSystem,
132 } as i32,
133 content: self.content.clone(),
134 tool_calls: Vec::new(),
135 tool_call_id: None,
136 }
137 }
138}
139
140#[derive(Debug, Default, Serialize)]
141pub struct LanguageModelRequest {
142 pub model: LanguageModel,
143 pub messages: Vec<LanguageModelRequestMessage>,
144 pub stop: Vec<String>,
145 pub temperature: f32,
146}
147
148impl LanguageModelRequest {
149 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
150 proto::CompleteWithLanguageModel {
151 model: self.model.id().to_string(),
152 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
153 stop: self.stop.clone(),
154 temperature: self.temperature,
155 tool_choice: None,
156 tools: Vec::new(),
157 }
158 }
159}
160
161#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
162pub struct LanguageModelResponseMessage {
163 pub role: Option<Role>,
164 pub content: Option<String>,
165}
166
167#[derive(Deserialize, Debug)]
168pub struct LanguageModelUsage {
169 pub prompt_tokens: u32,
170 pub completion_tokens: u32,
171 pub total_tokens: u32,
172}
173
174#[derive(Deserialize, Debug)]
175pub struct LanguageModelChoiceDelta {
176 pub index: u32,
177 pub delta: LanguageModelResponseMessage,
178 pub finish_reason: Option<String>,
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize)]
182struct MessageMetadata {
183 role: Role,
184 status: MessageStatus,
185}
186
187#[derive(Clone, Debug, Serialize, Deserialize)]
188enum MessageStatus {
189 Pending,
190 Done,
191 Error(SharedString),
192}
193
194/// The state pertaining to the Assistant.
195#[derive(Default)]
196struct Assistant {
197 /// Whether the Assistant is enabled.
198 enabled: bool,
199}
200
201impl Global for Assistant {}
202
203impl Assistant {
204 const NAMESPACE: &'static str = "assistant";
205
206 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
207 if self.enabled == enabled {
208 return;
209 }
210
211 self.enabled = enabled;
212
213 if !enabled {
214 CommandPaletteFilter::update_global(cx, |filter, _cx| {
215 filter.hide_namespace(Self::NAMESPACE);
216 });
217
218 return;
219 }
220
221 CommandPaletteFilter::update_global(cx, |filter, _cx| {
222 filter.show_namespace(Self::NAMESPACE);
223 });
224 }
225}
226
227pub fn init(client: Arc<Client>, cx: &mut AppContext) {
228 cx.set_global(Assistant::default());
229 AssistantSettings::register(cx);
230 completion_provider::init(client, cx);
231 assistant_panel::init(cx);
232
233 CommandPaletteFilter::update_global(cx, |filter, _cx| {
234 filter.hide_namespace(Assistant::NAMESPACE);
235 });
236 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
237 let settings = AssistantSettings::get_global(cx);
238
239 assistant.set_enabled(settings.enabled, cx);
240 });
241 cx.observe_global::<SettingsStore>(|cx| {
242 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
243 let settings = AssistantSettings::get_global(cx);
244
245 assistant.set_enabled(settings.enabled, cx);
246 });
247 })
248 .detach();
249}
250
251#[cfg(test)]
252#[ctor::ctor]
253fn init_logger() {
254 if std::env::var("RUST_LOG").is_ok() {
255 env_logger::init();
256 }
257}