1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod streaming_diff;
8
9pub use assistant_panel::AssistantPanel;
10use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
11use chrono::{DateTime, Local};
12use client::{proto, Client};
13pub(crate) use completion_provider::*;
14use gpui::{actions, AppContext, SharedString};
15pub(crate) use saved_conversation::*;
16use serde::{Deserialize, Serialize};
17use settings::Settings;
18use std::{
19 fmt::{self, Display},
20 sync::Arc,
21};
22
23actions!(
24 assistant,
25 [
26 NewConversation,
27 Assist,
28 Split,
29 CycleMessageRole,
30 QuoteSelection,
31 ToggleFocus,
32 ResetKey,
33 InlineAssist,
34 ToggleIncludeConversation,
35 ]
36);
37
38#[derive(
39 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
40)]
41struct MessageId(usize);
42
43#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
44#[serde(rename_all = "lowercase")]
45pub enum Role {
46 User,
47 Assistant,
48 System,
49}
50
51impl Role {
52 pub fn cycle(&mut self) {
53 *self = match self {
54 Role::User => Role::Assistant,
55 Role::Assistant => Role::System,
56 Role::System => Role::User,
57 }
58 }
59}
60
61impl Display for Role {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 Role::User => write!(f, "user"),
65 Role::Assistant => write!(f, "assistant"),
66 Role::System => write!(f, "system"),
67 }
68 }
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
72pub enum LanguageModel {
73 ZedDotDev(ZedDotDevModel),
74 OpenAi(OpenAiModel),
75}
76
77impl Default for LanguageModel {
78 fn default() -> Self {
79 LanguageModel::ZedDotDev(ZedDotDevModel::default())
80 }
81}
82
83impl LanguageModel {
84 pub fn telemetry_id(&self) -> String {
85 match self {
86 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
87 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
88 }
89 }
90
91 pub fn display_name(&self) -> String {
92 match self {
93 LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
94 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
95 }
96 }
97
98 pub fn max_token_count(&self) -> usize {
99 match self {
100 LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()),
101 LanguageModel::ZedDotDev(model) => match model {
102 ZedDotDevModel::GptThreePointFiveTurbo
103 | ZedDotDevModel::GptFour
104 | ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()),
105 ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model.
106 },
107 }
108 }
109
110 pub fn id(&self) -> &str {
111 match self {
112 LanguageModel::OpenAi(model) => model.id(),
113 LanguageModel::ZedDotDev(model) => model.id(),
114 }
115 }
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119pub struct LanguageModelRequestMessage {
120 pub role: Role,
121 pub content: String,
122}
123
124impl LanguageModelRequestMessage {
125 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
126 proto::LanguageModelRequestMessage {
127 role: match self.role {
128 Role::User => proto::LanguageModelRole::LanguageModelUser,
129 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
130 Role::System => proto::LanguageModelRole::LanguageModelSystem,
131 } as i32,
132 content: self.content.clone(),
133 }
134 }
135}
136
137#[derive(Debug, Default, Serialize)]
138pub struct LanguageModelRequest {
139 pub model: LanguageModel,
140 pub messages: Vec<LanguageModelRequestMessage>,
141 pub stop: Vec<String>,
142 pub temperature: f32,
143}
144
145impl LanguageModelRequest {
146 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
147 proto::CompleteWithLanguageModel {
148 model: self.model.id().to_string(),
149 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
150 stop: self.stop.clone(),
151 temperature: self.temperature,
152 }
153 }
154}
155
156#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
157pub struct LanguageModelResponseMessage {
158 pub role: Option<Role>,
159 pub content: Option<String>,
160}
161
162#[derive(Deserialize, Debug)]
163pub struct LanguageModelUsage {
164 pub prompt_tokens: u32,
165 pub completion_tokens: u32,
166 pub total_tokens: u32,
167}
168
169#[derive(Deserialize, Debug)]
170pub struct LanguageModelChoiceDelta {
171 pub index: u32,
172 pub delta: LanguageModelResponseMessage,
173 pub finish_reason: Option<String>,
174}
175
176#[derive(Clone, Debug, Serialize, Deserialize)]
177struct MessageMetadata {
178 role: Role,
179 sent_at: DateTime<Local>,
180 status: MessageStatus,
181}
182
183#[derive(Clone, Debug, Serialize, Deserialize)]
184enum MessageStatus {
185 Pending,
186 Done,
187 Error(SharedString),
188}
189
190pub fn init(client: Arc<Client>, cx: &mut AppContext) {
191 AssistantSettings::register(cx);
192 completion_provider::init(client, cx);
193 assistant_panel::init(cx);
194}
195
196#[cfg(test)]
197#[ctor::ctor]
198fn init_logger() {
199 if std::env::var("RUST_LOG").is_ok() {
200 env_logger::init();
201 }
202}