1 package util
2
3 import (
4 "bytes"
5 "crypto/hmac"
6 "crypto/sha256"
7 "crypto/sha512"
8 "encoding/hex"
9 "encoding/json"
10 "fmt"
11 "hash"
12 "io"
13 "os"
14 "path"
15 "path/filepath"
16 "strconv"
17 "strings"
18
19 "github.com/theupdateframework/go-tuf/data"
20 )
21
22 type ErrWrongLength struct {
23 Expected int64
24 Actual int64
25 }
26
27 func (e ErrWrongLength) Error() string {
28 return fmt.Sprintf("wrong length, expected %d got %d", e.Expected, e.Actual)
29 }
30
31 type ErrWrongVersion struct {
32 Expected int64
33 Actual int64
34 }
35
36 func (e ErrWrongVersion) Error() string {
37 return fmt.Sprintf("wrong version, expected %d got %d", e.Expected, e.Actual)
38 }
39
40 type ErrWrongHash struct {
41 Type string
42 Expected data.HexBytes
43 Actual data.HexBytes
44 }
45
46 func (e ErrWrongHash) Error() string {
47 return fmt.Sprintf("wrong %s hash, expected %s got %s", e.Type, hex.EncodeToString(e.Expected), hex.EncodeToString(e.Actual))
48 }
49
50 type ErrNoCommonHash struct {
51 Expected data.Hashes
52 Actual data.Hashes
53 }
54
55 func (e ErrNoCommonHash) Error() string {
56 types := func(a data.Hashes) []string {
57 t := make([]string, 0, len(a))
58 for typ := range a {
59 t = append(t, typ)
60 }
61 return t
62 }
63 return fmt.Sprintf("no common hash function, expected one of %s, got %s", types(e.Expected), types(e.Actual))
64 }
65
66 type ErrUnknownHashAlgorithm struct {
67 Name string
68 }
69
70 func (e ErrUnknownHashAlgorithm) Error() string {
71 return fmt.Sprintf("unknown hash algorithm: %s", e.Name)
72 }
73
74 type PassphraseFunc func(role string, confirm bool, change bool) ([]byte, error)
75
76 func FileMetaEqual(actual data.FileMeta, expected data.FileMeta) error {
77 if actual.Length != expected.Length {
78 return ErrWrongLength{expected.Length, actual.Length}
79 }
80
81 if err := hashEqual(actual.Hashes, expected.Hashes); err != nil {
82 return err
83 }
84
85 return nil
86 }
87
88 func BytesMatchLenAndHashes(fetched []byte, length int64, hashes data.Hashes) error {
89 flen := int64(len(fetched))
90 if length != 0 && flen != length {
91 return ErrWrongLength{length, flen}
92 }
93
94 for alg, expected := range hashes {
95 var h hash.Hash
96 switch alg {
97 case "sha256":
98 h = sha256.New()
99 case "sha512":
100 h = sha512.New()
101 default:
102 return ErrUnknownHashAlgorithm{alg}
103 }
104 h.Write(fetched)
105 hash := h.Sum(nil)
106 if !hmac.Equal(hash, expected) {
107 return ErrWrongHash{alg, expected, hash}
108 }
109 }
110
111 return nil
112 }
113
114 func hashEqual(actual data.Hashes, expected data.Hashes) error {
115 hashChecked := false
116 for typ, hash := range expected {
117 if h, ok := actual[typ]; ok {
118 hashChecked = true
119 if !hmac.Equal(h, hash) {
120 return ErrWrongHash{typ, hash, h}
121 }
122 }
123 }
124 if !hashChecked {
125 return ErrNoCommonHash{expected, actual}
126 }
127 return nil
128 }
129
130 func VersionEqual(actual int64, expected int64) error {
131 if actual != expected {
132 return ErrWrongVersion{expected, actual}
133 }
134 return nil
135 }
136
137 func SnapshotFileMetaEqual(actual data.SnapshotFileMeta, expected data.SnapshotFileMeta) error {
138
139
140
141
142
143 if expected.Length != 0 && actual.Length != expected.Length {
144 return ErrWrongLength{expected.Length, actual.Length}
145 }
146
147 if len(expected.Hashes) != 0 {
148 if err := hashEqual(actual.Hashes, expected.Hashes); err != nil {
149 return err
150 }
151 }
152
153 if err := VersionEqual(actual.Version, expected.Version); err != nil {
154 return err
155 }
156
157 return nil
158 }
159
160 func TargetFileMetaEqual(actual data.TargetFileMeta, expected data.TargetFileMeta) error {
161 return FileMetaEqual(actual.FileMeta, expected.FileMeta)
162 }
163
164 func TimestampFileMetaEqual(actual data.TimestampFileMeta, expected data.TimestampFileMeta) error {
165
166
167 if expected.Length != 0 && actual.Length != expected.Length {
168 return ErrWrongLength{expected.Length, actual.Length}
169 }
170
171 if len(expected.Hashes) != 0 {
172 if err := hashEqual(actual.Hashes, expected.Hashes); err != nil {
173 return err
174 }
175 }
176
177 if err := VersionEqual(actual.Version, expected.Version); err != nil {
178 return err
179 }
180
181 return nil
182 }
183
184 const defaultHashAlgorithm = "sha512"
185
186 func GenerateFileMeta(r io.Reader, hashAlgorithms ...string) (data.FileMeta, error) {
187 if len(hashAlgorithms) == 0 {
188 hashAlgorithms = []string{defaultHashAlgorithm}
189 }
190 hashes := make(map[string]hash.Hash, len(hashAlgorithms))
191 for _, hashAlgorithm := range hashAlgorithms {
192 var h hash.Hash
193 switch hashAlgorithm {
194 case "sha256":
195 h = sha256.New()
196 case "sha512":
197 h = sha512.New()
198 default:
199 return data.FileMeta{}, ErrUnknownHashAlgorithm{hashAlgorithm}
200 }
201 hashes[hashAlgorithm] = h
202 r = io.TeeReader(r, h)
203 }
204 n, err := io.Copy(io.Discard, r)
205 if err != nil {
206 return data.FileMeta{}, err
207 }
208 m := data.FileMeta{Length: n, Hashes: make(data.Hashes, len(hashes))}
209 for hashAlgorithm, h := range hashes {
210 m.Hashes[hashAlgorithm] = h.Sum(nil)
211 }
212 return m, nil
213 }
214
215 type versionedMeta struct {
216 Version int64 `json:"version"`
217 }
218
219 func generateVersionedFileMeta(r io.Reader, hashAlgorithms ...string) (data.FileMeta, int64, error) {
220 b, err := io.ReadAll(r)
221 if err != nil {
222 return data.FileMeta{}, 0, err
223 }
224
225 m, err := GenerateFileMeta(bytes.NewReader(b), hashAlgorithms...)
226 if err != nil {
227 return data.FileMeta{}, 0, err
228 }
229
230 s := data.Signed{}
231 if err := json.Unmarshal(b, &s); err != nil {
232 return data.FileMeta{}, 0, err
233 }
234
235 vm := versionedMeta{}
236 if err := json.Unmarshal(s.Signed, &vm); err != nil {
237 return data.FileMeta{}, 0, err
238 }
239
240 return m, vm.Version, nil
241 }
242
243 func GenerateSnapshotFileMeta(r io.Reader, hashAlgorithms ...string) (data.SnapshotFileMeta, error) {
244 m, v, err := generateVersionedFileMeta(r, hashAlgorithms...)
245 if err != nil {
246 return data.SnapshotFileMeta{}, err
247 }
248 return data.SnapshotFileMeta{
249 Length: m.Length,
250 Hashes: m.Hashes,
251 Version: v,
252 }, nil
253 }
254
255 func GenerateTargetFileMeta(r io.Reader, hashAlgorithms ...string) (data.TargetFileMeta, error) {
256 m, err := GenerateFileMeta(r, hashAlgorithms...)
257 if err != nil {
258 return data.TargetFileMeta{}, err
259 }
260 return data.TargetFileMeta{
261 FileMeta: m,
262 }, nil
263 }
264
265 func GenerateTimestampFileMeta(r io.Reader, hashAlgorithms ...string) (data.TimestampFileMeta, error) {
266 m, v, err := generateVersionedFileMeta(r, hashAlgorithms...)
267 if err != nil {
268 return data.TimestampFileMeta{}, err
269 }
270 return data.TimestampFileMeta{
271 Length: m.Length,
272 Hashes: m.Hashes,
273 Version: v,
274 }, nil
275 }
276
277 func NormalizeTarget(p string) string {
278
279
280
281
282
283
284
285 return strings.TrimPrefix(path.Join("/", p), "/")
286 }
287
288 func VersionedPath(p string, version int64) string {
289 return path.Join(path.Dir(p), strconv.FormatInt(version, 10)+"."+path.Base(p))
290 }
291
292 func HashedPaths(p string, hashes data.Hashes) []string {
293 paths := make([]string, 0, len(hashes))
294 for _, hash := range hashes {
295 hashedPath := path.Join(path.Dir(p), hash.String()+"."+path.Base(p))
296 paths = append(paths, hashedPath)
297 }
298 return paths
299 }
300
301 func AtomicallyWriteFile(filename string, data []byte, perm os.FileMode) error {
302 dir, name := filepath.Split(filename)
303 f, err := os.CreateTemp(dir, name)
304 if err != nil {
305 return err
306 }
307
308 _, err = f.Write(data)
309 if err != nil {
310 f.Close()
311 os.Remove(f.Name())
312 return err
313 }
314
315 if err = f.Chmod(perm); err != nil {
316 f.Close()
317 os.Remove(f.Name())
318 return err
319 }
320
321 if err := f.Close(); err != nil {
322 os.Remove(f.Name())
323 return err
324 }
325
326 if err := os.Rename(f.Name(), filename); err != nil {
327 os.Remove(f.Name())
328 return err
329 }
330
331 return nil
332 }
333
View as plain text