1package cmd
2
3import (
4 "bytes"
5 "context"
6 "database/sql"
7 _ "embed"
8 "encoding/json"
9 "fmt"
10 "html/template"
11 "os"
12 "os/user"
13 "path/filepath"
14 "strings"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/db"
19 "github.com/charmbracelet/crush/internal/event"
20 "github.com/pkg/browser"
21 "github.com/spf13/cobra"
22)
23
24//go:embed stats/index.html
25var statsTemplate string
26
27//go:embed stats/index.css
28var statsCSS string
29
30//go:embed stats/index.js
31var statsJS string
32
33//go:embed stats/header.svg
34var headerSVG string
35
36//go:embed stats/heartbit.svg
37var heartbitSVG string
38
39//go:embed stats/footer.svg
40var footerSVG string
41
42var statsCmd = &cobra.Command{
43 Use: "stats",
44 Short: "Show usage statistics",
45 Long: "Generate and display usage statistics including token usage, costs, and activity patterns",
46 RunE: runStats,
47}
48
49// Day names for day of week statistics.
50var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
51
52// Stats holds all the statistics data.
53type Stats struct {
54 GeneratedAt time.Time `json:"generated_at"`
55 Total TotalStats `json:"total"`
56 UsageByDay []DailyUsage `json:"usage_by_day"`
57 UsageByModel []ModelUsage `json:"usage_by_model"`
58 UsageByHour []HourlyUsage `json:"usage_by_hour"`
59 UsageByDayOfWeek []DayOfWeekUsage `json:"usage_by_day_of_week"`
60 RecentActivity []DailyActivity `json:"recent_activity"`
61 AvgResponseTimeMs float64 `json:"avg_response_time_ms"`
62 ToolUsage []ToolUsage `json:"tool_usage"`
63 HourDayHeatmap []HourDayHeatmapPt `json:"hour_day_heatmap"`
64}
65
66type TotalStats struct {
67 TotalSessions int64 `json:"total_sessions"`
68 TotalPromptTokens int64 `json:"total_prompt_tokens"`
69 TotalCompletionTokens int64 `json:"total_completion_tokens"`
70 TotalTokens int64 `json:"total_tokens"`
71 TotalCost float64 `json:"total_cost"`
72 TotalMessages int64 `json:"total_messages"`
73 AvgTokensPerSession float64 `json:"avg_tokens_per_session"`
74 AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
75}
76
77type DailyUsage struct {
78 Day string `json:"day"`
79 PromptTokens int64 `json:"prompt_tokens"`
80 CompletionTokens int64 `json:"completion_tokens"`
81 TotalTokens int64 `json:"total_tokens"`
82 Cost float64 `json:"cost"`
83 SessionCount int64 `json:"session_count"`
84}
85
86type ModelUsage struct {
87 Model string `json:"model"`
88 Provider string `json:"provider"`
89 MessageCount int64 `json:"message_count"`
90}
91
92type HourlyUsage struct {
93 Hour int `json:"hour"`
94 SessionCount int64 `json:"session_count"`
95}
96
97type DayOfWeekUsage struct {
98 DayOfWeek int `json:"day_of_week"`
99 DayName string `json:"day_name"`
100 SessionCount int64 `json:"session_count"`
101 PromptTokens int64 `json:"prompt_tokens"`
102 CompletionTokens int64 `json:"completion_tokens"`
103}
104
105type DailyActivity struct {
106 Day string `json:"day"`
107 SessionCount int64 `json:"session_count"`
108 TotalTokens int64 `json:"total_tokens"`
109 Cost float64 `json:"cost"`
110}
111
112type ToolUsage struct {
113 ToolName string `json:"tool_name"`
114 CallCount int64 `json:"call_count"`
115}
116
117type HourDayHeatmapPt struct {
118 DayOfWeek int `json:"day_of_week"`
119 Hour int `json:"hour"`
120 SessionCount int64 `json:"session_count"`
121}
122
123func runStats(cmd *cobra.Command, _ []string) error {
124 dataDir, _ := cmd.Flags().GetString("data-dir")
125 ctx := cmd.Context()
126
127 cfg, err := config.Init("", dataDir, false)
128 if err != nil {
129 return fmt.Errorf("failed to initialize config: %w", err)
130 }
131 if dataDir == "" {
132 dataDir = cfg.Config().Options.DataDirectory
133 }
134 if shouldEnableMetrics(cfg.Config()) {
135 event.Init()
136 }
137
138 event.StatsViewed()
139
140 conn, err := db.Connect(ctx, dataDir)
141 if err != nil {
142 return fmt.Errorf("failed to connect to database: %w", err)
143 }
144 defer conn.Close()
145
146 stats, err := gatherStats(ctx, conn)
147 if err != nil {
148 return fmt.Errorf("failed to gather stats: %w", err)
149 }
150
151 if stats.Total.TotalSessions == 0 {
152 return fmt.Errorf("no data available: no sessions found in database")
153 }
154
155 currentUser, err := user.Current()
156 if err != nil {
157 return fmt.Errorf("failed to get current user: %w", err)
158 }
159 username := currentUser.Username
160 project, err := os.Getwd()
161 if err != nil {
162 return fmt.Errorf("failed to get current directory: %w", err)
163 }
164 project = strings.Replace(project, currentUser.HomeDir, "~", 1)
165
166 htmlPath := filepath.Join(dataDir, "stats/index.html")
167 if err := generateHTML(stats, project, username, htmlPath); err != nil {
168 return fmt.Errorf("failed to generate HTML: %w", err)
169 }
170
171 fmt.Printf("Stats generated: %s\n", htmlPath)
172
173 if err := browser.OpenFile(htmlPath); err != nil {
174 fmt.Printf("Could not open browser: %v\n", err)
175 fmt.Println("Please open the file manually.")
176 }
177
178 return nil
179}
180
181func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
182 queries := db.New(conn)
183
184 stats := &Stats{
185 GeneratedAt: time.Now(),
186 }
187
188 // Total stats.
189 total, err := queries.GetTotalStats(ctx)
190 if err != nil {
191 return nil, fmt.Errorf("get total stats: %w", err)
192 }
193 stats.Total = TotalStats{
194 TotalSessions: total.TotalSessions,
195 TotalPromptTokens: toInt64(total.TotalPromptTokens),
196 TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
197 TotalTokens: toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
198 TotalCost: toFloat64(total.TotalCost),
199 TotalMessages: toInt64(total.TotalMessages),
200 AvgTokensPerSession: toFloat64(total.AvgTokensPerSession),
201 AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
202 }
203
204 // Usage by day.
205 dailyUsage, err := queries.GetUsageByDay(ctx)
206 if err != nil {
207 return nil, fmt.Errorf("get usage by day: %w", err)
208 }
209 for _, d := range dailyUsage {
210 prompt := nullFloat64ToInt64(d.PromptTokens)
211 completion := nullFloat64ToInt64(d.CompletionTokens)
212 stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
213 Day: fmt.Sprintf("%v", d.Day),
214 PromptTokens: prompt,
215 CompletionTokens: completion,
216 TotalTokens: prompt + completion,
217 Cost: d.Cost.Float64,
218 SessionCount: d.SessionCount,
219 })
220 }
221
222 // Usage by model.
223 modelUsage, err := queries.GetUsageByModel(ctx)
224 if err != nil {
225 return nil, fmt.Errorf("get usage by model: %w", err)
226 }
227 for _, m := range modelUsage {
228 stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
229 Model: m.Model,
230 Provider: m.Provider,
231 MessageCount: m.MessageCount,
232 })
233 }
234
235 // Usage by hour.
236 hourlyUsage, err := queries.GetUsageByHour(ctx)
237 if err != nil {
238 return nil, fmt.Errorf("get usage by hour: %w", err)
239 }
240 for _, h := range hourlyUsage {
241 stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
242 Hour: int(h.Hour),
243 SessionCount: h.SessionCount,
244 })
245 }
246
247 // Usage by day of week.
248 dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
249 if err != nil {
250 return nil, fmt.Errorf("get usage by day of week: %w", err)
251 }
252 for _, d := range dowUsage {
253 stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
254 DayOfWeek: int(d.DayOfWeek),
255 DayName: dayNames[int(d.DayOfWeek)],
256 SessionCount: d.SessionCount,
257 PromptTokens: nullFloat64ToInt64(d.PromptTokens),
258 CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
259 })
260 }
261
262 // Recent activity (last 30 days).
263 recent, err := queries.GetRecentActivity(ctx)
264 if err != nil {
265 return nil, fmt.Errorf("get recent activity: %w", err)
266 }
267 for _, r := range recent {
268 stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
269 Day: fmt.Sprintf("%v", r.Day),
270 SessionCount: r.SessionCount,
271 TotalTokens: nullFloat64ToInt64(r.TotalTokens),
272 Cost: r.Cost.Float64,
273 })
274 }
275
276 // Average response time.
277 avgResp, err := queries.GetAverageResponseTime(ctx)
278 if err != nil {
279 return nil, fmt.Errorf("get average response time: %w", err)
280 }
281 stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
282
283 // Tool usage.
284 toolUsage, err := queries.GetToolUsage(ctx)
285 if err != nil {
286 return nil, fmt.Errorf("get tool usage: %w", err)
287 }
288 for _, t := range toolUsage {
289 if name, ok := t.ToolName.(string); ok && name != "" {
290 stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
291 ToolName: name,
292 CallCount: t.CallCount,
293 })
294 }
295 }
296
297 // Hour/day heatmap.
298 heatmap, err := queries.GetHourDayHeatmap(ctx)
299 if err != nil {
300 return nil, fmt.Errorf("get hour day heatmap: %w", err)
301 }
302 for _, h := range heatmap {
303 stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
304 DayOfWeek: int(h.DayOfWeek),
305 Hour: int(h.Hour),
306 SessionCount: h.SessionCount,
307 })
308 }
309
310 return stats, nil
311}
312
313func toInt64(v any) int64 {
314 switch val := v.(type) {
315 case int64:
316 return val
317 case float64:
318 return int64(val)
319 case int:
320 return int64(val)
321 default:
322 return 0
323 }
324}
325
326func toFloat64(v any) float64 {
327 switch val := v.(type) {
328 case float64:
329 return val
330 case int64:
331 return float64(val)
332 case int:
333 return float64(val)
334 default:
335 return 0
336 }
337}
338
339func nullFloat64ToInt64(n sql.NullFloat64) int64 {
340 if n.Valid {
341 return int64(n.Float64)
342 }
343 return 0
344}
345
346func generateHTML(stats *Stats, projName, username, path string) error {
347 statsJSON, err := json.Marshal(stats)
348 if err != nil {
349 return err
350 }
351
352 tmpl, err := template.New("stats").Parse(statsTemplate)
353 if err != nil {
354 return fmt.Errorf("parse template: %w", err)
355 }
356
357 data := struct {
358 StatsJSON template.JS
359 CSS template.CSS
360 JS template.JS
361 Header template.HTML
362 Heartbit template.HTML
363 Footer template.HTML
364 GeneratedAt string
365 ProjectName string
366 Username string
367 }{
368 StatsJSON: template.JS(statsJSON),
369 CSS: template.CSS(statsCSS),
370 JS: template.JS(statsJS),
371 Header: template.HTML(headerSVG),
372 Heartbit: template.HTML(heartbitSVG),
373 Footer: template.HTML(footerSVG),
374 GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
375 ProjectName: projName,
376 Username: username,
377 }
378
379 var buf bytes.Buffer
380 if err := tmpl.Execute(&buf, data); err != nil {
381 return fmt.Errorf("execute template: %w", err)
382 }
383
384 // Ensure parent directory exists.
385 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
386 return fmt.Errorf("create directory: %w", err)
387 }
388
389 return os.WriteFile(path, buf.Bytes(), 0o644)
390}