stats.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"database/sql"
  7	_ "embed"
  8	"encoding/json"
  9	"fmt"
 10	"html/template"
 11	"os"
 12	"os/user"
 13	"path/filepath"
 14	"strings"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/db"
 19	"github.com/charmbracelet/crush/internal/event"
 20	"github.com/pkg/browser"
 21	"github.com/spf13/cobra"
 22)
 23
 24//go:embed stats/index.html
 25var statsTemplate string
 26
 27//go:embed stats/index.css
 28var statsCSS string
 29
 30//go:embed stats/index.js
 31var statsJS string
 32
 33//go:embed stats/header.svg
 34var headerSVG string
 35
 36//go:embed stats/heartbit.svg
 37var heartbitSVG string
 38
 39//go:embed stats/footer.svg
 40var footerSVG string
 41
 42var statsCmd = &cobra.Command{
 43	Use:   "stats",
 44	Short: "Show usage statistics",
 45	Long:  "Generate and display usage statistics including token usage, costs, and activity patterns",
 46	RunE:  runStats,
 47}
 48
 49// Day names for day of week statistics.
 50var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
 51
 52// Stats holds all the statistics data.
 53type Stats struct {
 54	GeneratedAt       time.Time          `json:"generated_at"`
 55	Total             TotalStats         `json:"total"`
 56	UsageByDay        []DailyUsage       `json:"usage_by_day"`
 57	UsageByModel      []ModelUsage       `json:"usage_by_model"`
 58	UsageByHour       []HourlyUsage      `json:"usage_by_hour"`
 59	UsageByDayOfWeek  []DayOfWeekUsage   `json:"usage_by_day_of_week"`
 60	RecentActivity    []DailyActivity    `json:"recent_activity"`
 61	AvgResponseTimeMs float64            `json:"avg_response_time_ms"`
 62	ToolUsage         []ToolUsage        `json:"tool_usage"`
 63	HourDayHeatmap    []HourDayHeatmapPt `json:"hour_day_heatmap"`
 64}
 65
 66type TotalStats struct {
 67	TotalSessions         int64   `json:"total_sessions"`
 68	TotalPromptTokens     int64   `json:"total_prompt_tokens"`
 69	TotalCompletionTokens int64   `json:"total_completion_tokens"`
 70	TotalTokens           int64   `json:"total_tokens"`
 71	TotalCost             float64 `json:"total_cost"`
 72	TotalMessages         int64   `json:"total_messages"`
 73	AvgTokensPerSession   float64 `json:"avg_tokens_per_session"`
 74	AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
 75}
 76
 77type DailyUsage struct {
 78	Day              string  `json:"day"`
 79	PromptTokens     int64   `json:"prompt_tokens"`
 80	CompletionTokens int64   `json:"completion_tokens"`
 81	TotalTokens      int64   `json:"total_tokens"`
 82	Cost             float64 `json:"cost"`
 83	SessionCount     int64   `json:"session_count"`
 84}
 85
 86type ModelUsage struct {
 87	Model        string `json:"model"`
 88	Provider     string `json:"provider"`
 89	MessageCount int64  `json:"message_count"`
 90}
 91
 92type HourlyUsage struct {
 93	Hour         int   `json:"hour"`
 94	SessionCount int64 `json:"session_count"`
 95}
 96
 97type DayOfWeekUsage struct {
 98	DayOfWeek        int    `json:"day_of_week"`
 99	DayName          string `json:"day_name"`
100	SessionCount     int64  `json:"session_count"`
101	PromptTokens     int64  `json:"prompt_tokens"`
102	CompletionTokens int64  `json:"completion_tokens"`
103}
104
105type DailyActivity struct {
106	Day          string  `json:"day"`
107	SessionCount int64   `json:"session_count"`
108	TotalTokens  int64   `json:"total_tokens"`
109	Cost         float64 `json:"cost"`
110}
111
112type ToolUsage struct {
113	ToolName  string `json:"tool_name"`
114	CallCount int64  `json:"call_count"`
115}
116
117type HourDayHeatmapPt struct {
118	DayOfWeek    int   `json:"day_of_week"`
119	Hour         int   `json:"hour"`
120	SessionCount int64 `json:"session_count"`
121}
122
123func runStats(cmd *cobra.Command, _ []string) error {
124	event.StatsViewed()
125
126	dataDir, _ := cmd.Flags().GetString("data-dir")
127	ctx := cmd.Context()
128
129	if dataDir == "" {
130		cfg, err := config.Init("", "", false)
131		if err != nil {
132			return fmt.Errorf("failed to initialize config: %w", err)
133		}
134		dataDir = cfg.Config().Options.DataDirectory
135	}
136
137	conn, err := db.Connect(ctx, dataDir)
138	if err != nil {
139		return fmt.Errorf("failed to connect to database: %w", err)
140	}
141	defer conn.Close()
142
143	stats, err := gatherStats(ctx, conn)
144	if err != nil {
145		return fmt.Errorf("failed to gather stats: %w", err)
146	}
147
148	if stats.Total.TotalSessions == 0 {
149		return fmt.Errorf("no data available: no sessions found in database")
150	}
151
152	currentUser, err := user.Current()
153	if err != nil {
154		return fmt.Errorf("failed to get current user: %w", err)
155	}
156	username := currentUser.Username
157	project, err := os.Getwd()
158	if err != nil {
159		return fmt.Errorf("failed to get current directory: %w", err)
160	}
161	project = strings.Replace(project, currentUser.HomeDir, "~", 1)
162
163	htmlPath := filepath.Join(dataDir, "stats/index.html")
164	if err := generateHTML(stats, project, username, htmlPath); err != nil {
165		return fmt.Errorf("failed to generate HTML: %w", err)
166	}
167
168	fmt.Printf("Stats generated: %s\n", htmlPath)
169
170	if err := browser.OpenFile(htmlPath); err != nil {
171		fmt.Printf("Could not open browser: %v\n", err)
172		fmt.Println("Please open the file manually.")
173	}
174
175	return nil
176}
177
178func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
179	queries := db.New(conn)
180
181	stats := &Stats{
182		GeneratedAt: time.Now(),
183	}
184
185	// Total stats.
186	total, err := queries.GetTotalStats(ctx)
187	if err != nil {
188		return nil, fmt.Errorf("get total stats: %w", err)
189	}
190	stats.Total = TotalStats{
191		TotalSessions:         total.TotalSessions,
192		TotalPromptTokens:     toInt64(total.TotalPromptTokens),
193		TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
194		TotalTokens:           toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
195		TotalCost:             toFloat64(total.TotalCost),
196		TotalMessages:         toInt64(total.TotalMessages),
197		AvgTokensPerSession:   toFloat64(total.AvgTokensPerSession),
198		AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
199	}
200
201	// Usage by day.
202	dailyUsage, err := queries.GetUsageByDay(ctx)
203	if err != nil {
204		return nil, fmt.Errorf("get usage by day: %w", err)
205	}
206	for _, d := range dailyUsage {
207		prompt := nullFloat64ToInt64(d.PromptTokens)
208		completion := nullFloat64ToInt64(d.CompletionTokens)
209		stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
210			Day:              fmt.Sprintf("%v", d.Day),
211			PromptTokens:     prompt,
212			CompletionTokens: completion,
213			TotalTokens:      prompt + completion,
214			Cost:             d.Cost.Float64,
215			SessionCount:     d.SessionCount,
216		})
217	}
218
219	// Usage by model.
220	modelUsage, err := queries.GetUsageByModel(ctx)
221	if err != nil {
222		return nil, fmt.Errorf("get usage by model: %w", err)
223	}
224	for _, m := range modelUsage {
225		stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
226			Model:        m.Model,
227			Provider:     m.Provider,
228			MessageCount: m.MessageCount,
229		})
230	}
231
232	// Usage by hour.
233	hourlyUsage, err := queries.GetUsageByHour(ctx)
234	if err != nil {
235		return nil, fmt.Errorf("get usage by hour: %w", err)
236	}
237	for _, h := range hourlyUsage {
238		stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
239			Hour:         int(h.Hour),
240			SessionCount: h.SessionCount,
241		})
242	}
243
244	// Usage by day of week.
245	dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
246	if err != nil {
247		return nil, fmt.Errorf("get usage by day of week: %w", err)
248	}
249	for _, d := range dowUsage {
250		stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
251			DayOfWeek:        int(d.DayOfWeek),
252			DayName:          dayNames[int(d.DayOfWeek)],
253			SessionCount:     d.SessionCount,
254			PromptTokens:     nullFloat64ToInt64(d.PromptTokens),
255			CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
256		})
257	}
258
259	// Recent activity (last 30 days).
260	recent, err := queries.GetRecentActivity(ctx)
261	if err != nil {
262		return nil, fmt.Errorf("get recent activity: %w", err)
263	}
264	for _, r := range recent {
265		stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
266			Day:          fmt.Sprintf("%v", r.Day),
267			SessionCount: r.SessionCount,
268			TotalTokens:  nullFloat64ToInt64(r.TotalTokens),
269			Cost:         r.Cost.Float64,
270		})
271	}
272
273	// Average response time.
274	avgResp, err := queries.GetAverageResponseTime(ctx)
275	if err != nil {
276		return nil, fmt.Errorf("get average response time: %w", err)
277	}
278	stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
279
280	// Tool usage.
281	toolUsage, err := queries.GetToolUsage(ctx)
282	if err != nil {
283		return nil, fmt.Errorf("get tool usage: %w", err)
284	}
285	for _, t := range toolUsage {
286		if name, ok := t.ToolName.(string); ok && name != "" {
287			stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
288				ToolName:  name,
289				CallCount: t.CallCount,
290			})
291		}
292	}
293
294	// Hour/day heatmap.
295	heatmap, err := queries.GetHourDayHeatmap(ctx)
296	if err != nil {
297		return nil, fmt.Errorf("get hour day heatmap: %w", err)
298	}
299	for _, h := range heatmap {
300		stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
301			DayOfWeek:    int(h.DayOfWeek),
302			Hour:         int(h.Hour),
303			SessionCount: h.SessionCount,
304		})
305	}
306
307	return stats, nil
308}
309
310func toInt64(v any) int64 {
311	switch val := v.(type) {
312	case int64:
313		return val
314	case float64:
315		return int64(val)
316	case int:
317		return int64(val)
318	default:
319		return 0
320	}
321}
322
323func toFloat64(v any) float64 {
324	switch val := v.(type) {
325	case float64:
326		return val
327	case int64:
328		return float64(val)
329	case int:
330		return float64(val)
331	default:
332		return 0
333	}
334}
335
336func nullFloat64ToInt64(n sql.NullFloat64) int64 {
337	if n.Valid {
338		return int64(n.Float64)
339	}
340	return 0
341}
342
343func generateHTML(stats *Stats, projName, username, path string) error {
344	statsJSON, err := json.Marshal(stats)
345	if err != nil {
346		return err
347	}
348
349	tmpl, err := template.New("stats").Parse(statsTemplate)
350	if err != nil {
351		return fmt.Errorf("parse template: %w", err)
352	}
353
354	data := struct {
355		StatsJSON   template.JS
356		CSS         template.CSS
357		JS          template.JS
358		Header      template.HTML
359		Heartbit    template.HTML
360		Footer      template.HTML
361		GeneratedAt string
362		ProjectName string
363		Username    string
364	}{
365		StatsJSON:   template.JS(statsJSON),
366		CSS:         template.CSS(statsCSS),
367		JS:          template.JS(statsJS),
368		Header:      template.HTML(headerSVG),
369		Heartbit:    template.HTML(heartbitSVG),
370		Footer:      template.HTML(footerSVG),
371		GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
372		ProjectName: projName,
373		Username:    username,
374	}
375
376	var buf bytes.Buffer
377	if err := tmpl.Execute(&buf, data); err != nil {
378		return fmt.Errorf("execute template: %w", err)
379	}
380
381	// Ensure parent directory exists.
382	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
383		return fmt.Errorf("create directory: %w", err)
384	}
385
386	return os.WriteFile(path, buf.Bytes(), 0o644)
387}