1// Copyright (c) 2017, Daniel MartΓ <mvdan@mvdan.cc>
2// See LICENSE for licensing information
3
4package expand
5
6import (
7 "fmt"
8 "strconv"
9 "strings"
10
11 "mvdan.cc/sh/v3/syntax"
12)
13
14func Arithm(cfg *Config, expr syntax.ArithmExpr) (int, error) {
15 switch expr := expr.(type) {
16 case *syntax.Word:
17 str, err := Literal(cfg, expr)
18 if err != nil {
19 return 0, err
20 }
21 // recursively fetch vars
22 i := 0
23 for syntax.ValidName(str) {
24 val := cfg.envGet(str)
25 if val == "" {
26 break
27 }
28 if i++; i >= maxNameRefDepth {
29 break
30 }
31 str = val
32 }
33 // default to 0
34 return atoi(str), nil
35 case *syntax.ParenArithm:
36 return Arithm(cfg, expr.X)
37 case *syntax.UnaryArithm:
38 switch expr.Op {
39 case syntax.Inc, syntax.Dec:
40 name := expr.X.(*syntax.Word).Lit()
41 old := atoi(cfg.envGet(name))
42 val := old
43 if expr.Op == syntax.Inc {
44 val++
45 } else {
46 val--
47 }
48 if err := cfg.envSet(name, strconv.Itoa(val)); err != nil {
49 return 0, err
50 }
51 if expr.Post {
52 return old, nil
53 }
54 return val, nil
55 }
56 val, err := Arithm(cfg, expr.X)
57 if err != nil {
58 return 0, err
59 }
60 switch expr.Op {
61 case syntax.Not:
62 return oneIf(val == 0), nil
63 case syntax.BitNegation:
64 return ^val, nil
65 case syntax.Plus:
66 return val, nil
67 default: // syntax.Minus
68 return -val, nil
69 }
70 case *syntax.BinaryArithm:
71 switch expr.Op {
72 case syntax.Assgn, syntax.AddAssgn, syntax.SubAssgn,
73 syntax.MulAssgn, syntax.QuoAssgn, syntax.RemAssgn,
74 syntax.AndAssgn, syntax.OrAssgn, syntax.XorAssgn,
75 syntax.ShlAssgn, syntax.ShrAssgn:
76 return cfg.assgnArit(expr)
77 case syntax.TernQuest: // TernColon can't happen here
78 cond, err := Arithm(cfg, expr.X)
79 if err != nil {
80 return 0, err
81 }
82 b2 := expr.Y.(*syntax.BinaryArithm) // must have Op==TernColon
83 if cond == 1 {
84 return Arithm(cfg, b2.X)
85 }
86 return Arithm(cfg, b2.Y)
87 }
88 left, err := Arithm(cfg, expr.X)
89 if err != nil {
90 return 0, err
91 }
92 right, err := Arithm(cfg, expr.Y)
93 if err != nil {
94 return 0, err
95 }
96 return binArit(expr.Op, left, right)
97 default:
98 panic(fmt.Sprintf("unexpected arithm expr: %T", expr))
99 }
100}
101
102func oneIf(b bool) int {
103 if b {
104 return 1
105 }
106 return 0
107}
108
109// atoi is like [strconv.Atoi], but it ignores errors and trims whitespace.
110func atoi(s string) int {
111 s = strings.TrimSpace(s)
112 n, _ := strconv.Atoi(s)
113 return n
114}
115
116func (cfg *Config) assgnArit(b *syntax.BinaryArithm) (int, error) {
117 name := b.X.(*syntax.Word).Lit()
118 val := atoi(cfg.envGet(name))
119 arg, err := Arithm(cfg, b.Y)
120 if err != nil {
121 return 0, err
122 }
123 switch b.Op {
124 case syntax.Assgn:
125 val = arg
126 case syntax.AddAssgn:
127 val += arg
128 case syntax.SubAssgn:
129 val -= arg
130 case syntax.MulAssgn:
131 val *= arg
132 case syntax.QuoAssgn:
133 if arg == 0 {
134 return 0, fmt.Errorf("division by zero")
135 }
136 val /= arg
137 case syntax.RemAssgn:
138 if arg == 0 {
139 return 0, fmt.Errorf("division by zero")
140 }
141 val %= arg
142 case syntax.AndAssgn:
143 val &= arg
144 case syntax.OrAssgn:
145 val |= arg
146 case syntax.XorAssgn:
147 val ^= arg
148 case syntax.ShlAssgn:
149 val <<= uint(arg)
150 case syntax.ShrAssgn:
151 val >>= uint(arg)
152 }
153 if err := cfg.envSet(name, strconv.Itoa(val)); err != nil {
154 return 0, err
155 }
156 return val, nil
157}
158
159func intPow(a, b int) int {
160 p := 1
161 for b > 0 {
162 if b&1 != 0 {
163 p *= a
164 }
165 b >>= 1
166 a *= a
167 }
168 return p
169}
170
171func binArit(op syntax.BinAritOperator, x, y int) (int, error) {
172 switch op {
173 case syntax.Add:
174 return x + y, nil
175 case syntax.Sub:
176 return x - y, nil
177 case syntax.Mul:
178 return x * y, nil
179 case syntax.Quo:
180 if y == 0 {
181 return 0, fmt.Errorf("division by zero")
182 }
183 return x / y, nil
184 case syntax.Rem:
185 if y == 0 {
186 return 0, fmt.Errorf("division by zero")
187 }
188 return x % y, nil
189 case syntax.Pow:
190 return intPow(x, y), nil
191 case syntax.Eql:
192 return oneIf(x == y), nil
193 case syntax.Gtr:
194 return oneIf(x > y), nil
195 case syntax.Lss:
196 return oneIf(x < y), nil
197 case syntax.Neq:
198 return oneIf(x != y), nil
199 case syntax.Leq:
200 return oneIf(x <= y), nil
201 case syntax.Geq:
202 return oneIf(x >= y), nil
203 case syntax.And:
204 return x & y, nil
205 case syntax.Or:
206 return x | y, nil
207 case syntax.Xor:
208 return x ^ y, nil
209 case syntax.Shr:
210 return x >> uint(y), nil
211 case syntax.Shl:
212 return x << uint(y), nil
213 case syntax.AndArit:
214 return oneIf(x != 0 && y != 0), nil
215 case syntax.OrArit:
216 return oneIf(x != 0 || y != 0), nil
217 default: // syntax.Comma
218 // x is executed but its result discarded
219 return y, nil
220 }
221}