1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package x509tools
18
19 import (
20 "crypto/ecdsa"
21 "crypto/elliptic"
22 "encoding/asn1"
23 "errors"
24 "fmt"
25 "math/big"
26 "strconv"
27 "strings"
28 )
29
30
31 type CurveDefinition struct {
32 Bits uint
33 Curve elliptic.Curve
34 Oid asn1.ObjectIdentifier
35 }
36
37 var DefinedCurves = []CurveDefinition{
38 {256, elliptic.P256(), asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7}},
39 {384, elliptic.P384(), asn1.ObjectIdentifier{1, 3, 132, 0, 34}},
40 {521, elliptic.P521(), asn1.ObjectIdentifier{1, 3, 132, 0, 35}},
41 }
42
43
44 func (def *CurveDefinition) ToDer() []byte {
45 der, err := asn1.Marshal(def.Oid)
46 if err != nil {
47 panic(err)
48 }
49 return der
50 }
51
52
53 func SupportedCurves() string {
54 curves := make([]string, len(DefinedCurves))
55 for i, def := range DefinedCurves {
56 curves[i] = strconv.FormatUint(uint64(def.Bits), 10)
57 }
58 return strings.Join(curves, ", ")
59 }
60
61
62 func CurveByOid(oid asn1.ObjectIdentifier) (*CurveDefinition, error) {
63 for _, def := range DefinedCurves {
64 if oid.Equal(def.Oid) {
65 return &def, nil
66 }
67 }
68 return nil, fmt.Errorf("Unsupported ECDSA curve with OID: %s\nSupported curves: %s", oid, SupportedCurves())
69 }
70
71
72 func CurveByOidString(oidstr string) (*CurveDefinition, error) {
73 parts := strings.Split(oidstr, ".")
74 oid := make(asn1.ObjectIdentifier, 0, len(parts))
75 for _, n := range parts {
76 v, err := strconv.Atoi(n)
77 if err != nil {
78 return nil, errors.New("invalid OID")
79 }
80 oid = append(oid, v)
81 }
82 return CurveByOid(oid)
83 }
84
85
86 func CurveByDer(der []byte) (*CurveDefinition, error) {
87 var oid asn1.ObjectIdentifier
88 _, err := asn1.Unmarshal(der, &oid)
89 if err != nil {
90 return nil, err
91 }
92 return CurveByOid(oid)
93 }
94
95
96 func CurveByCurve(curve elliptic.Curve) (*CurveDefinition, error) {
97 for _, def := range DefinedCurves {
98 if curve == def.Curve {
99 return &def, nil
100 }
101 }
102 return nil, fmt.Errorf("Unsupported ECDSA curve: %v\nSupported curves: %s", curve, SupportedCurves())
103 }
104
105
106 func CurveByBits(bits uint) (*CurveDefinition, error) {
107 for _, def := range DefinedCurves {
108 if bits == def.Bits {
109 return &def, nil
110 }
111 }
112 return nil, fmt.Errorf("Unsupported ECDSA curve: %v\nSupported curves: %s", bits, SupportedCurves())
113 }
114
115
116
117 func DerToPoint(curve elliptic.Curve, der []byte) (*big.Int, *big.Int) {
118 var blob []byte
119 switch der[0] {
120 case asn1.TagOctetString:
121 _, err := asn1.Unmarshal(der, &blob)
122 if err != nil {
123 return nil, nil
124 }
125 case asn1.TagBitString:
126 var bits asn1.BitString
127 _, err := asn1.Unmarshal(der, &bits)
128 if err != nil {
129 return nil, nil
130 }
131 blob = bits.Bytes
132 default:
133 return nil, nil
134 }
135 return elliptic.Unmarshal(curve, blob)
136 }
137
138 func PointToDer(pub *ecdsa.PublicKey) []byte {
139 blob := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
140 der, err := asn1.Marshal(blob)
141 if err != nil {
142 return nil
143 }
144 return der
145 }
146
147
148 type EcdsaSignature struct {
149 R, S *big.Int
150 }
151
152
153 func UnmarshalEcdsaSignature(der []byte) (sig EcdsaSignature, err error) {
154 der, err = asn1.Unmarshal(der, &sig)
155 if err != nil || len(der) != 0 {
156 err = errors.New("invalid ECDSA signature")
157 }
158 return
159 }
160
161
162 func UnpackEcdsaSignature(packed []byte) (sig EcdsaSignature, err error) {
163 byteLen := len(packed) / 2
164 if len(packed) != byteLen*2 {
165 err = errors.New("ecdsa signature is incorrect size")
166 } else {
167 sig.R = new(big.Int).SetBytes(packed[:byteLen])
168 sig.S = new(big.Int).SetBytes(packed[byteLen:])
169 }
170 return
171 }
172
173
174 func (sig EcdsaSignature) Marshal() []byte {
175 ret, _ := asn1.Marshal(sig)
176 return ret
177 }
178
179
180 func (sig EcdsaSignature) Pack() []byte {
181 rbytes := sig.R.Bytes()
182 sbytes := sig.S.Bytes()
183 byteLen := len(rbytes)
184 if len(sbytes) > byteLen {
185 byteLen = len(sbytes)
186 }
187 ret := make([]byte, byteLen*2)
188 copy(ret[byteLen-len(rbytes):], rbytes)
189 copy(ret[2*byteLen-len(sbytes):], sbytes)
190 return ret
191 }
192
View as plain text