arith.go

  1// Copyright (c) 2017, Daniel MartΓ­ <mvdan@mvdan.cc>
  2// See LICENSE for licensing information
  3
  4package expand
  5
  6import (
  7	"fmt"
  8	"strconv"
  9	"strings"
 10
 11	"mvdan.cc/sh/v3/syntax"
 12)
 13
 14func Arithm(cfg *Config, expr syntax.ArithmExpr) (int, error) {
 15	switch expr := expr.(type) {
 16	case *syntax.Word:
 17		str, err := Literal(cfg, expr)
 18		if err != nil {
 19			return 0, err
 20		}
 21		// recursively fetch vars
 22		i := 0
 23		for syntax.ValidName(str) {
 24			val := cfg.envGet(str)
 25			if val == "" {
 26				break
 27			}
 28			if i++; i >= maxNameRefDepth {
 29				break
 30			}
 31			str = val
 32		}
 33		// default to 0
 34		return atoi(str), nil
 35	case *syntax.ParenArithm:
 36		return Arithm(cfg, expr.X)
 37	case *syntax.UnaryArithm:
 38		switch expr.Op {
 39		case syntax.Inc, syntax.Dec:
 40			name := expr.X.(*syntax.Word).Lit()
 41			old := atoi(cfg.envGet(name))
 42			val := old
 43			if expr.Op == syntax.Inc {
 44				val++
 45			} else {
 46				val--
 47			}
 48			if err := cfg.envSet(name, strconv.Itoa(val)); err != nil {
 49				return 0, err
 50			}
 51			if expr.Post {
 52				return old, nil
 53			}
 54			return val, nil
 55		}
 56		val, err := Arithm(cfg, expr.X)
 57		if err != nil {
 58			return 0, err
 59		}
 60		switch expr.Op {
 61		case syntax.Not:
 62			return oneIf(val == 0), nil
 63		case syntax.BitNegation:
 64			return ^val, nil
 65		case syntax.Plus:
 66			return val, nil
 67		default: // syntax.Minus
 68			return -val, nil
 69		}
 70	case *syntax.BinaryArithm:
 71		switch expr.Op {
 72		case syntax.Assgn, syntax.AddAssgn, syntax.SubAssgn,
 73			syntax.MulAssgn, syntax.QuoAssgn, syntax.RemAssgn,
 74			syntax.AndAssgn, syntax.OrAssgn, syntax.XorAssgn,
 75			syntax.ShlAssgn, syntax.ShrAssgn:
 76			return cfg.assgnArit(expr)
 77		case syntax.TernQuest: // TernColon can't happen here
 78			cond, err := Arithm(cfg, expr.X)
 79			if err != nil {
 80				return 0, err
 81			}
 82			b2 := expr.Y.(*syntax.BinaryArithm) // must have Op==TernColon
 83			if cond == 1 {
 84				return Arithm(cfg, b2.X)
 85			}
 86			return Arithm(cfg, b2.Y)
 87		}
 88		left, err := Arithm(cfg, expr.X)
 89		if err != nil {
 90			return 0, err
 91		}
 92		right, err := Arithm(cfg, expr.Y)
 93		if err != nil {
 94			return 0, err
 95		}
 96		return binArit(expr.Op, left, right)
 97	default:
 98		panic(fmt.Sprintf("unexpected arithm expr: %T", expr))
 99	}
100}
101
102func oneIf(b bool) int {
103	if b {
104		return 1
105	}
106	return 0
107}
108
109// atoi is like [strconv.Atoi], but it ignores errors and trims whitespace.
110func atoi(s string) int {
111	s = strings.TrimSpace(s)
112	n, _ := strconv.Atoi(s)
113	return n
114}
115
116func (cfg *Config) assgnArit(b *syntax.BinaryArithm) (int, error) {
117	name := b.X.(*syntax.Word).Lit()
118	val := atoi(cfg.envGet(name))
119	arg, err := Arithm(cfg, b.Y)
120	if err != nil {
121		return 0, err
122	}
123	switch b.Op {
124	case syntax.Assgn:
125		val = arg
126	case syntax.AddAssgn:
127		val += arg
128	case syntax.SubAssgn:
129		val -= arg
130	case syntax.MulAssgn:
131		val *= arg
132	case syntax.QuoAssgn:
133		if arg == 0 {
134			return 0, fmt.Errorf("division by zero")
135		}
136		val /= arg
137	case syntax.RemAssgn:
138		if arg == 0 {
139			return 0, fmt.Errorf("division by zero")
140		}
141		val %= arg
142	case syntax.AndAssgn:
143		val &= arg
144	case syntax.OrAssgn:
145		val |= arg
146	case syntax.XorAssgn:
147		val ^= arg
148	case syntax.ShlAssgn:
149		val <<= uint(arg)
150	case syntax.ShrAssgn:
151		val >>= uint(arg)
152	}
153	if err := cfg.envSet(name, strconv.Itoa(val)); err != nil {
154		return 0, err
155	}
156	return val, nil
157}
158
159func intPow(a, b int) int {
160	p := 1
161	for b > 0 {
162		if b&1 != 0 {
163			p *= a
164		}
165		b >>= 1
166		a *= a
167	}
168	return p
169}
170
171func binArit(op syntax.BinAritOperator, x, y int) (int, error) {
172	switch op {
173	case syntax.Add:
174		return x + y, nil
175	case syntax.Sub:
176		return x - y, nil
177	case syntax.Mul:
178		return x * y, nil
179	case syntax.Quo:
180		if y == 0 {
181			return 0, fmt.Errorf("division by zero")
182		}
183		return x / y, nil
184	case syntax.Rem:
185		if y == 0 {
186			return 0, fmt.Errorf("division by zero")
187		}
188		return x % y, nil
189	case syntax.Pow:
190		return intPow(x, y), nil
191	case syntax.Eql:
192		return oneIf(x == y), nil
193	case syntax.Gtr:
194		return oneIf(x > y), nil
195	case syntax.Lss:
196		return oneIf(x < y), nil
197	case syntax.Neq:
198		return oneIf(x != y), nil
199	case syntax.Leq:
200		return oneIf(x <= y), nil
201	case syntax.Geq:
202		return oneIf(x >= y), nil
203	case syntax.And:
204		return x & y, nil
205	case syntax.Or:
206		return x | y, nil
207	case syntax.Xor:
208		return x ^ y, nil
209	case syntax.Shr:
210		return x >> uint(y), nil
211	case syntax.Shl:
212		return x << uint(y), nil
213	case syntax.AndArit:
214		return oneIf(x != 0 && y != 0), nil
215	case syntax.OrArit:
216		return oneIf(x != 0 || y != 0), nil
217	default: // syntax.Comma
218		// x is executed but its result discarded
219		return y, nil
220	}
221}