1package server
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net"
8 "net/rpc"
9 "os"
10 "os/user"
11 "path/filepath"
12 "runtime"
13 "strings"
14 "sync/atomic"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/app"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20
21 msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
22)
23
24// ErrServerClosed is returned when the server is closed.
25var ErrServerClosed = fmt.Errorf("server closed")
26
27// InstanceState represents the state of a running [app.App] instance.
28type InstanceState uint8
29
30const (
31 // InstanceStateCreated indicates that the instance has been created but not yet started.
32 InstanceStateCreated InstanceState = iota
33 // InstanceStateStarted indicates that the instance is currently running.
34 InstanceStateStarted
35 // InstanceStateStopped indicates that the instance has been stopped.
36 InstanceStateStopped
37)
38
39// Instance represents a running [app.App] instance with its associated
40// resources and state.
41type Instance struct {
42 *app.App
43 State InstanceState
44 id string
45 path string
46}
47
48// ID returns the unique identifier of the instance.
49func (i *Instance) ID() string {
50 return i.id
51}
52
53// Path returns the filesystem path associated with the instance.
54func (i *Instance) Path() string {
55 return i.path
56}
57
58// DefaultAddr returns the default address path for the Crush server based on
59// the operating system.
60func DefaultAddr() string {
61 sock := "crush.sock"
62 user, err := user.Current()
63 if err == nil && user.Uid != "" {
64 sock = fmt.Sprintf("crush-%s.sock", user.Uid)
65 }
66 if runtime.GOOS == "windows" {
67 return fmt.Sprintf(`\\.\pipe\%s`, sock)
68 }
69 return filepath.Join(os.TempDir(), sock)
70}
71
72// Server represents a Crush server instance bound to a specific address.
73type Server struct {
74 // Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
75 Addr string
76
77 // instances is a map of running applications managed by the server.
78 instances *csync.Map[string, *Instance]
79 // listeners is the network listener for the server.
80 listeners *csync.Map[*net.Listener, struct{}]
81 cfg *config.Config
82 logger *slog.Logger
83
84 shutdown atomic.Bool
85}
86
87// SetLogger sets the logger for the server.
88func (s *Server) SetLogger(logger *slog.Logger) {
89 s.logger = logger
90}
91
92// DefaultServer returns a new [Server] instance with the default address.
93func DefaultServer(cfg *config.Config) *Server {
94 return NewServer(cfg, "unix", DefaultAddr())
95}
96
97// NewServer is a helper to create a new [Server] instance with the given
98// address. On Windows, if the address is not a "tcp" address, it will be
99// converted to a named pipe format.
100func NewServer(cfg *config.Config, network, address string) *Server {
101 if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") &&
102 !strings.HasPrefix(address, `\\.\pipe\`) {
103 // On Windows, convert to named pipe format if not TCP
104 // (e.g., "mypipe" -> "\\.\pipe\mypipe")
105 address = fmt.Sprintf(`\\.\pipe\%s`, address)
106 }
107
108 s := new(Server)
109 s.Addr = address
110 s.cfg = cfg
111 s.instances = csync.NewMap[string, *Instance]()
112 rpc.Register(&ServerProto{s})
113 return s
114}
115
116// Serve accepts incoming connections on the listener.
117func (s *Server) Serve(ln net.Listener) error {
118 if s.listeners == nil {
119 s.listeners = csync.NewMap[*net.Listener, struct{}]()
120 }
121 s.listeners.Set(&ln, struct{}{})
122
123 var tempDelay time.Duration // how long to sleep on accept failure
124 for {
125 conn, err := ln.Accept()
126 if err != nil {
127 if s.shuttingDown() {
128 return ErrServerClosed
129 }
130 if ne, ok := err.(net.Error); ok && ne.Temporary() {
131 if tempDelay == 0 {
132 tempDelay = 5 * time.Millisecond
133 } else {
134 tempDelay *= 2
135 }
136 if max := 1 * time.Second; tempDelay > max {
137 tempDelay = max
138 }
139 time.Sleep(tempDelay)
140 continue
141 }
142 return fmt.Errorf("failed to accept connection: %w", err)
143 }
144 go s.handleConn(conn)
145 }
146}
147
148// ListenAndServe starts the server and begins accepting connections.
149func (s *Server) ListenAndServe() error {
150 ln, err := listen("unix", s.Addr)
151 if err != nil {
152 return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
153 }
154 return s.Serve(ln)
155}
156
157// Close force close all listeners and connections.
158func (s *Server) Close() error {
159 s.shutdown.Store(true)
160 var firstErr error
161 for k := range s.listeners.Seq2() {
162 if err := (*k).Close(); err != nil && firstErr == nil {
163 firstErr = err
164 }
165 s.listeners.Del(k)
166 }
167 return firstErr
168}
169
170// Shutdown gracefully shuts down the server without interrupting active
171// connections. It stops accepting new connections and waits for existing
172// connections to finish.
173func (s *Server) Shutdown(ctx context.Context) error {
174 // TODO: implement graceful shutdown
175 return s.Close()
176}
177
178func (s *Server) handleConn(conn net.Conn) {
179 s.info("accepted connection", "remote_addr", conn.LocalAddr())
180 codec := &ServerCodec{
181 MsgpackCodec: msgpackrpc.NewCodec(true, true, conn),
182 logger: s.logger.With(
183 slog.String("remote_addr", conn.RemoteAddr().String()),
184 slog.String("local_addr", conn.LocalAddr().String()),
185 ),
186 }
187 rpc.ServeCodec(codec)
188}
189
190func (s *Server) shuttingDown() bool {
191 return s.shutdown.Load()
192}
193
194func (s *Server) info(msg string, args ...any) {
195 if s.logger != nil {
196 s.logger.Info(msg, args...)
197 }
198}
199
200func (s *Server) debug(msg string, args ...any) {
201 if s.logger != nil {
202 s.logger.Debug(msg, args...)
203 }
204}
205
206func (s *Server) error(msg string, args ...any) {
207 if s.logger != nil {
208 s.logger.Error(msg, args...)
209 }
210}
211
212func (s *Server) warn(msg string, args ...any) {
213 if s.logger != nil {
214 s.logger.Warn(msg, args...)
215 }
216}