Commit 984b5e55a47715dac9dfb0dca873d8461ad39410

Authored by Marko Tikvić
1 parent 0dd8dda340
Exists in master

DecodeJWT

Showing 1 changed file with 26 additions and 0 deletions   Show diff stats
1 package webutility 1 package webutility
2 2
3 import ( 3 import (
4 "crypto/rand" 4 "crypto/rand"
5 "crypto/sha256" 5 "crypto/sha256"
6 "encoding/hex" 6 "encoding/hex"
7 "errors" 7 "errors"
8 "net/http" 8 "net/http"
9 "strings" 9 "strings"
10 "time" 10 "time"
11 11
12 "github.com/dgrijalva/jwt-go" 12 "github.com/dgrijalva/jwt-go"
13 ) 13 )
14 14
15 var _issuer = "webutility" 15 var _issuer = "webutility"
16 var _secret = "webutility" 16 var _secret = "webutility"
17 17
18 // TokenClaims are JWT token claims. 18 // TokenClaims are JWT token claims.
19 type TokenClaims struct { 19 type TokenClaims struct {
20 // extending a struct 20 // extending a struct
21 jwt.StandardClaims 21 jwt.StandardClaims
22 22
23 // custom claims 23 // custom claims
24 Token string `json:"access_token"` 24 Token string `json:"access_token"`
25 TokenType string `json:"token_type"` 25 TokenType string `json:"token_type"`
26 Username string `json:"username"` 26 Username string `json:"username"`
27 RoleName string `json:"role"` 27 RoleName string `json:"role"`
28 RoleID int64 `json:"role_id"` 28 RoleID int64 `json:"role_id"`
29 ExpiresIn int64 `json:"expires_in"` 29 ExpiresIn int64 `json:"expires_in"`
30 } 30 }
31 31
32 // InitJWT ... 32 // InitJWT ...
33 func InitJWT(issuer, secret string) { 33 func InitJWT(issuer, secret string) {
34 _issuer = issuer 34 _issuer = issuer
35 _secret = secret 35 _secret = secret
36 } 36 }
37 37
38 // ValidateHash hashes pass and salt and returns comparison result with resultHash 38 // ValidateHash hashes pass and salt and returns comparison result with resultHash
39 func ValidateHash(pass, salt, resultHash string) (bool, error) { 39 func ValidateHash(pass, salt, resultHash string) (bool, error) {
40 hash, _, err := CreateHash(pass, salt) 40 hash, _, err := CreateHash(pass, salt)
41 if err != nil { 41 if err != nil {
42 return false, err 42 return false, err
43 } 43 }
44 res := hash == resultHash 44 res := hash == resultHash
45 return res, nil 45 return res, nil
46 } 46 }
47 47
48 // CreateHash hashes str using SHA256. 48 // CreateHash hashes str using SHA256.
49 // If the presalt parameter is not provided CreateHash will generate new salt string. 49 // If the presalt parameter is not provided CreateHash will generate new salt string.
50 // Returns hash and salt strings or an error if it fails. 50 // Returns hash and salt strings or an error if it fails.
51 func CreateHash(str, presalt string) (hash, salt string, err error) { 51 func CreateHash(str, presalt string) (hash, salt string, err error) {
52 // chech if message is presalted 52 // chech if message is presalted
53 if presalt == "" { 53 if presalt == "" {
54 salt, err = randomSalt() 54 salt, err = randomSalt()
55 if err != nil { 55 if err != nil {
56 return "", "", err 56 return "", "", err
57 } 57 }
58 } else { 58 } else {
59 salt = presalt 59 salt = presalt
60 } 60 }
61 61
62 // convert strings to raw byte slices 62 // convert strings to raw byte slices
63 rawstr := []byte(str) 63 rawstr := []byte(str)
64 rawsalt, err := hex.DecodeString(salt) 64 rawsalt, err := hex.DecodeString(salt)
65 if err != nil { 65 if err != nil {
66 return "", "", err 66 return "", "", err
67 } 67 }
68 68
69 rawdata := make([]byte, len(rawstr)+len(rawsalt)) 69 rawdata := make([]byte, len(rawstr)+len(rawsalt))
70 rawdata = append(rawdata, rawstr...) 70 rawdata = append(rawdata, rawstr...)
71 rawdata = append(rawdata, rawsalt...) 71 rawdata = append(rawdata, rawsalt...)
72 72
73 // hash message + salt 73 // hash message + salt
74 hasher := sha256.New() 74 hasher := sha256.New()
75 hasher.Write(rawdata) 75 hasher.Write(rawdata)
76 rawhash := hasher.Sum(nil) 76 rawhash := hasher.Sum(nil)
77 77
78 hash = hex.EncodeToString(rawhash) 78 hash = hex.EncodeToString(rawhash)
79 return hash, salt, nil 79 return hash, salt, nil
80 } 80 }
81 81
82 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. 82 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
83 // It returns an error if it fails. 83 // It returns an error if it fails.
84 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) { 84 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) {
85 t0 := (time.Now()).Unix() 85 t0 := (time.Now()).Unix()
86 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() 86 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix()
87 claims := TokenClaims{ 87 claims := TokenClaims{
88 TokenType: "Bearer", 88 TokenType: "Bearer",
89 Username: username, 89 Username: username,
90 RoleName: roleName, 90 RoleName: roleName,
91 RoleID: roleID, 91 RoleID: roleID,
92 ExpiresIn: t1 - t0, 92 ExpiresIn: t1 - t0,
93 } 93 }
94 // initialize jwt.StandardClaims fields (anonymous struct) 94 // initialize jwt.StandardClaims fields (anonymous struct)
95 claims.IssuedAt = t0 95 claims.IssuedAt = t0
96 claims.ExpiresAt = t1 96 claims.ExpiresAt = t1
97 claims.Issuer = _issuer 97 claims.Issuer = _issuer
98 98
99 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 99 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
100 token, err := jwtToken.SignedString([]byte(_secret)) 100 token, err := jwtToken.SignedString([]byte(_secret))
101 if err != nil { 101 if err != nil {
102 return TokenClaims{}, err 102 return TokenClaims{}, err
103 } 103 }
104 claims.Token = token 104 claims.Token = token
105 return claims, nil 105 return claims, nil
106 } 106 }
107 107
108 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. 108 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date.
109 // It returns an error if it fails. 109 // It returns an error if it fails.
110 func RefreshAuthToken(tok string) (TokenClaims, error) { 110 func RefreshAuthToken(tok string) (TokenClaims, error) {
111 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) 111 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc)
112 if err != nil { 112 if err != nil {
113 if validation, ok := err.(*jwt.ValidationError); ok { 113 if validation, ok := err.(*jwt.ValidationError); ok {
114 // don't return error if token is expired, just extend it 114 // don't return error if token is expired, just extend it
115 if !(validation.Errors&jwt.ValidationErrorExpired != 0) { 115 if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
116 return TokenClaims{}, err 116 return TokenClaims{}, err
117 } 117 }
118 } else { 118 } else {
119 return TokenClaims{}, err 119 return TokenClaims{}, err
120 } 120 }
121 } 121 }
122 122
123 // type assertion 123 // type assertion
124 claims, ok := token.Claims.(*TokenClaims) 124 claims, ok := token.Claims.(*TokenClaims)
125 if !ok { 125 if !ok {
126 return TokenClaims{}, errors.New("token is not valid") 126 return TokenClaims{}, errors.New("token is not valid")
127 } 127 }
128 128
129 // extend token expiration date 129 // extend token expiration date
130 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID) 130 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID)
131 } 131 }
132 132
133 // AuthCheck ... 133 // AuthCheck ...
134 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { 134 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) {
135 // validate token and check expiration date 135 // validate token and check expiration date
136 claims, err := GetTokenClaims(req) 136 claims, err := GetTokenClaims(req)
137 if err != nil { 137 if err != nil {
138 return claims, err 138 return claims, err
139 } 139 }
140 140
141 if roles == "" { 141 if roles == "" {
142 return claims, nil 142 return claims, nil
143 } 143 }
144 144
145 // check if token has expired 145 // check if token has expired
146 if claims.ExpiresAt < (time.Now()).Unix() { 146 if claims.ExpiresAt < (time.Now()).Unix() {
147 return claims, errors.New("token has expired") 147 return claims, errors.New("token has expired")
148 } 148 }
149 149
150 if roles == "*" { 150 if roles == "*" {
151 return claims, nil 151 return claims, nil
152 } 152 }
153 153
154 parts := strings.Split(roles, ",") 154 parts := strings.Split(roles, ",")
155 for i := range parts { 155 for i := range parts {
156 r := strings.Trim(parts[i], " ") 156 r := strings.Trim(parts[i], " ")
157 if claims.RoleName == r { 157 if claims.RoleName == r {
158 return claims, nil 158 return claims, nil
159 } 159 }
160 } 160 }
161 161
162 return claims, errors.New("unauthorized role access") 162 return claims, errors.New("unauthorized role access")
163 } 163 }
164 164
165 // GetTokenClaims extracts JWT claims from Authorization header of req. 165 // GetTokenClaims extracts JWT claims from Authorization header of req.
166 // Returns token claims or an error. 166 // Returns token claims or an error.
167 func GetTokenClaims(req *http.Request) (*TokenClaims, error) { 167 func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
168 // check for and strip 'Bearer' prefix 168 // check for and strip 'Bearer' prefix
169 var tokstr string 169 var tokstr string
170 authHead := req.Header.Get("Authorization") 170 authHead := req.Header.Get("Authorization")
171 if ok := strings.HasPrefix(authHead, "Bearer "); ok { 171 if ok := strings.HasPrefix(authHead, "Bearer "); ok {
172 tokstr = strings.TrimPrefix(authHead, "Bearer ") 172 tokstr = strings.TrimPrefix(authHead, "Bearer ")
173 } else { 173 } else {
174 return &TokenClaims{}, errors.New("authorization header is incomplete") 174 return &TokenClaims{}, errors.New("authorization header is incomplete")
175 } 175 }
176 176
177 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) 177 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
178 if err != nil { 178 if err != nil {
179 return &TokenClaims{}, err 179 return &TokenClaims{}, err
180 } 180 }
181 181
182 // type assertion 182 // type assertion
183 claims, ok := token.Claims.(*TokenClaims) 183 claims, ok := token.Claims.(*TokenClaims)
184 if !ok || !token.Valid { 184 if !ok || !token.Valid {
185 return &TokenClaims{}, errors.New("token is not valid") 185 return &TokenClaims{}, errors.New("token is not valid")
186 } 186 }
187 187
188 return claims, nil 188 return claims, nil
189 } 189 }
190 190
191 func DecodeJWT(secret, token string) (*TokenClaims, error) {
192 secretfunc := func(*jwt.Token) (interface{}, error) {
193 return []byte(secret), nil
194 }
195
196 tok, err := jwt.ParseWithClaims(token, &TokenClaims{}, secretfunc)
197 if err != nil {
198 if validation, ok := err.(*jwt.ValidationError); ok {
199 // don't return error if token is expired
200 if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
201 return nil, err
202 }
203 } else {
204 return nil, err
205 }
206 }
207
208 // type assertion
209 claims, ok := tok.Claims.(*TokenClaims)
210 if !ok {
211 return &TokenClaims{}, errors.New("token is not valid")
212 }
213
214 return claims, nil
215 }
216
191 // randomSalt returns a string of 32 random characters. 217 // randomSalt returns a string of 32 random characters.
192 func randomSalt() (s string, err error) { 218 func randomSalt() (s string, err error) {
193 const saltSize = 32 219 const saltSize = 32
194 220
195 rawsalt := make([]byte, saltSize) 221 rawsalt := make([]byte, saltSize)
196 222
197 _, err = rand.Read(rawsalt) 223 _, err = rand.Read(rawsalt)
198 if err != nil { 224 if err != nil {
199 return "", err 225 return "", err
200 } 226 }
201 227
202 s = hex.EncodeToString(rawsalt) 228 s = hex.EncodeToString(rawsalt)
203 return s, nil 229 return s, nil
204 } 230 }
205 231
206 // secretFunc returns byte slice of API secret keyword. 232 // secretFunc returns byte slice of API secret keyword.
207 func secretFunc(token *jwt.Token) (interface{}, error) { 233 func secretFunc(token *jwt.Token) (interface{}, error) {
208 return []byte(_secret), nil 234 return []byte(_secret), nil
209 } 235 }
210 236