server.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/rpc"
  9	"os"
 10	"os/user"
 11	"path/filepath"
 12	"runtime"
 13	"strings"
 14	"sync/atomic"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/app"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20
 21	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
 22)
 23
 24// ErrServerClosed is returned when the server is closed.
 25var ErrServerClosed = fmt.Errorf("server closed")
 26
 27// InstanceState represents the state of a running [app.App] instance.
 28type InstanceState uint8
 29
 30const (
 31	// InstanceStateCreated indicates that the instance has been created but not yet started.
 32	InstanceStateCreated InstanceState = iota
 33	// InstanceStateStarted indicates that the instance is currently running.
 34	InstanceStateStarted
 35	// InstanceStateStopped indicates that the instance has been stopped.
 36	InstanceStateStopped
 37)
 38
 39// Instance represents a running [app.App] instance with its associated
 40// resources and state.
 41type Instance struct {
 42	*app.App
 43	State InstanceState
 44	id    string
 45	path  string
 46}
 47
 48// ID returns the unique identifier of the instance.
 49func (i *Instance) ID() string {
 50	return i.id
 51}
 52
 53// Path returns the filesystem path associated with the instance.
 54func (i *Instance) Path() string {
 55	return i.path
 56}
 57
 58// DefaultAddr returns the default address path for the Crush server based on
 59// the operating system.
 60func DefaultAddr() string {
 61	sock := "crush.sock"
 62	user, err := user.Current()
 63	if err == nil && user.Uid != "" {
 64		sock = fmt.Sprintf("crush-%s.sock", user.Uid)
 65	}
 66	if runtime.GOOS == "windows" {
 67		return fmt.Sprintf(`\\.\pipe\%s`, sock)
 68	}
 69	return filepath.Join(os.TempDir(), sock)
 70}
 71
 72// Server represents a Crush server instance bound to a specific address.
 73type Server struct {
 74	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 75	Addr string
 76
 77	// instances is a map of running applications managed by the server.
 78	instances *csync.Map[string, *Instance]
 79	// listeners is the network listener for the server.
 80	listeners *csync.Map[*net.Listener, struct{}]
 81	cfg       *config.Config
 82	logger    *slog.Logger
 83
 84	shutdown atomic.Bool
 85}
 86
 87// DefaultServer returns a new [Server] instance with the default address.
 88func DefaultServer(cfg *config.Config) *Server {
 89	return NewServer(cfg, "unix", DefaultAddr())
 90}
 91
 92// NewServer is a helper to create a new [Server] instance with the given
 93// address. On Windows, if the address is not a "tcp" address, it will be
 94// converted to a named pipe format.
 95func NewServer(cfg *config.Config, network, address string) *Server {
 96	if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") &&
 97		!strings.HasPrefix(address, `\\.\pipe\`) {
 98		// On Windows, convert to named pipe format if not TCP
 99		// (e.g., "mypipe" -> "\\.\pipe\mypipe")
100		address = fmt.Sprintf(`\\.\pipe\%s`, address)
101	}
102
103	s := new(Server)
104	s.Addr = address
105	s.cfg = cfg
106	s.instances = csync.NewMap[string, *Instance]()
107	rpc.Register(&ServerProto{s})
108	return s
109}
110
111// Serve accepts incoming connections on the listener.
112func (s *Server) Serve(ln net.Listener) error {
113	if s.listeners == nil {
114		s.listeners = csync.NewMap[*net.Listener, struct{}]()
115	}
116	s.listeners.Set(&ln, struct{}{})
117
118	var tempDelay time.Duration // how long to sleep on accept failure
119	for {
120		conn, err := ln.Accept()
121		if err != nil {
122			if s.shuttingDown() {
123				return ErrServerClosed
124			}
125			if ne, ok := err.(net.Error); ok && ne.Temporary() {
126				if tempDelay == 0 {
127					tempDelay = 5 * time.Millisecond
128				} else {
129					tempDelay *= 2
130				}
131				if max := 1 * time.Second; tempDelay > max {
132					tempDelay = max
133				}
134				time.Sleep(tempDelay)
135				continue
136			}
137			return fmt.Errorf("failed to accept connection: %w", err)
138		}
139		go s.handleConn(conn)
140	}
141}
142
143// ListenAndServe starts the server and begins accepting connections.
144func (s *Server) ListenAndServe() error {
145	ln, err := listen("unix", s.Addr)
146	if err != nil {
147		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
148	}
149	return s.Serve(ln)
150}
151
152// Close force close all listeners and connections.
153func (s *Server) Close() error {
154	s.shutdown.Store(true)
155	var firstErr error
156	for k := range s.listeners.Seq2() {
157		if err := (*k).Close(); err != nil && firstErr == nil {
158			firstErr = err
159		}
160		s.listeners.Del(k)
161	}
162	return firstErr
163}
164
165// Shutdown gracefully shuts down the server without interrupting active
166// connections. It stops accepting new connections and waits for existing
167// connections to finish.
168func (s *Server) Shutdown(ctx context.Context) error {
169	// TODO: implement graceful shutdown
170	return s.Close()
171}
172
173func (s *Server) handleConn(conn net.Conn) {
174	s.info("accepted connection from %s", conn.RemoteAddr())
175	msgpackrpc.ServeConn(conn)
176	// var req rpc.Request
177	// codec := msgpackrpc.NewServerCodec(conn)
178	// if err := codec.ReadRequestHeader(&req); err != nil {
179	// 	s.error("failed to read request header: %v", err)
180	// }
181	// rpc.ServeCodec(codec)
182}
183
184func (s *Server) shuttingDown() bool {
185	return s.shutdown.Load()
186}
187
188func (s *Server) info(msg string, args ...any) {
189	if s.logger != nil {
190		s.logger.Info(msg, args...)
191	}
192}
193
194func (s *Server) debug(msg string, args ...any) {
195	if s.logger != nil {
196		s.logger.Debug(msg, args...)
197	}
198}
199
200func (s *Server) error(msg string, args ...any) {
201	if s.logger != nil {
202		s.logger.Error(msg, args...)
203	}
204}
205
206func (s *Server) warn(msg string, args ...any) {
207	if s.logger != nil {
208		s.logger.Warn(msg, args...)
209	}
210}