Commit 6faf94f857296273d979d4adbbec99f3cbb1fcbe

Authored by Marko Tikvić
1 parent 11933054ac
Exists in master

Fixed auth check to return error if no role has been recognized

Showing 1 changed file with 2 additions and 3 deletions   Show diff stats
1 package webutility 1 package webutility
2 2
3 import ( 3 import (
4 "crypto/rand" 4 "crypto/rand"
5 "crypto/sha256" 5 "crypto/sha256"
6 "encoding/hex" 6 "encoding/hex"
7 "errors" 7 "errors"
8 "net/http" 8 "net/http"
9 "strings" 9 "strings"
10 "time" 10 "time"
11 11
12 "github.com/dgrijalva/jwt-go" 12 "github.com/dgrijalva/jwt-go"
13 ) 13 )
14 14
15 var _issuer = "webutility" 15 var _issuer = "webutility"
16 var _secret = "webutility" 16 var _secret = "webutility"
17 17
18 // TokenClaims are JWT token claims. 18 // TokenClaims are JWT token claims.
19 type TokenClaims struct { 19 type TokenClaims struct {
20 // extending a struct 20 // extending a struct
21 jwt.StandardClaims 21 jwt.StandardClaims
22 22
23 // custom claims 23 // custom claims
24 Token string `json:"access_token"` 24 Token string `json:"access_token"`
25 TokenType string `json:"token_type"` 25 TokenType string `json:"token_type"`
26 Username string `json:"username"` 26 Username string `json:"username"`
27 RoleName string `json:"role"` 27 RoleName string `json:"role"`
28 RoleID int64 `json:"role_id"` 28 RoleID int64 `json:"role_id"`
29 ExpiresIn int64 `json:"expires_in"` 29 ExpiresIn int64 `json:"expires_in"`
30 } 30 }
31 31
32 // InitJWT ... 32 // InitJWT ...
33 func InitJWT(issuer, secret string) { 33 func InitJWT(issuer, secret string) {
34 _issuer = issuer 34 _issuer = issuer
35 _secret = secret 35 _secret = secret
36 } 36 }
37 37
38 // ValidateHash hashes pass and salt and returns comparison result with resultHash 38 // ValidateHash hashes pass and salt and returns comparison result with resultHash
39 func ValidateHash(pass, salt, resultHash string) (bool, error) { 39 func ValidateHash(pass, salt, resultHash string) (bool, error) {
40 hash, _, err := CreateHash(pass, salt) 40 hash, _, err := CreateHash(pass, salt)
41 if err != nil { 41 if err != nil {
42 return false, err 42 return false, err
43 } 43 }
44 res := hash == resultHash 44 res := hash == resultHash
45 return res, nil 45 return res, nil
46 } 46 }
47 47
48 // CreateHash hashes str using SHA256. 48 // CreateHash hashes str using SHA256.
49 // If the presalt parameter is not provided CreateHash will generate new salt string. 49 // If the presalt parameter is not provided CreateHash will generate new salt string.
50 // Returns hash and salt strings or an error if it fails. 50 // Returns hash and salt strings or an error if it fails.
51 func CreateHash(str, presalt string) (hash, salt string, err error) { 51 func CreateHash(str, presalt string) (hash, salt string, err error) {
52 // chech if message is presalted 52 // chech if message is presalted
53 if presalt == "" { 53 if presalt == "" {
54 salt, err = randomSalt() 54 salt, err = randomSalt()
55 if err != nil { 55 if err != nil {
56 return "", "", err 56 return "", "", err
57 } 57 }
58 } else { 58 } else {
59 salt = presalt 59 salt = presalt
60 } 60 }
61 61
62 // convert strings to raw byte slices 62 // convert strings to raw byte slices
63 rawstr := []byte(str) 63 rawstr := []byte(str)
64 rawsalt, err := hex.DecodeString(salt) 64 rawsalt, err := hex.DecodeString(salt)
65 if err != nil { 65 if err != nil {
66 return "", "", err 66 return "", "", err
67 } 67 }
68 68
69 rawdata := make([]byte, len(rawstr)+len(rawsalt)) 69 rawdata := make([]byte, len(rawstr)+len(rawsalt))
70 rawdata = append(rawdata, rawstr...) 70 rawdata = append(rawdata, rawstr...)
71 rawdata = append(rawdata, rawsalt...) 71 rawdata = append(rawdata, rawsalt...)
72 72
73 // hash message + salt 73 // hash message + salt
74 hasher := sha256.New() 74 hasher := sha256.New()
75 hasher.Write(rawdata) 75 hasher.Write(rawdata)
76 rawhash := hasher.Sum(nil) 76 rawhash := hasher.Sum(nil)
77 77
78 hash = hex.EncodeToString(rawhash) 78 hash = hex.EncodeToString(rawhash)
79 return hash, salt, nil 79 return hash, salt, nil
80 } 80 }
81 81
82 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. 82 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
83 // It returns an error if it fails. 83 // It returns an error if it fails.
84 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) { 84 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) {
85 t0 := (time.Now()).Unix() 85 t0 := (time.Now()).Unix()
86 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() 86 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix()
87 claims := TokenClaims{ 87 claims := TokenClaims{
88 TokenType: "Bearer", 88 TokenType: "Bearer",
89 Username: username, 89 Username: username,
90 RoleName: roleName, 90 RoleName: roleName,
91 RoleID: roleID, 91 RoleID: roleID,
92 ExpiresIn: t1 - t0, 92 ExpiresIn: t1 - t0,
93 } 93 }
94 // initialize jwt.StandardClaims fields (anonymous struct) 94 // initialize jwt.StandardClaims fields (anonymous struct)
95 claims.IssuedAt = t0 95 claims.IssuedAt = t0
96 claims.ExpiresAt = t1 96 claims.ExpiresAt = t1
97 claims.Issuer = _issuer 97 claims.Issuer = _issuer
98 98
99 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 99 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
100 token, err := jwtToken.SignedString([]byte(_secret)) 100 token, err := jwtToken.SignedString([]byte(_secret))
101 if err != nil { 101 if err != nil {
102 return TokenClaims{}, err 102 return TokenClaims{}, err
103 } 103 }
104 claims.Token = token 104 claims.Token = token
105 return claims, nil 105 return claims, nil
106 } 106 }
107 107
108 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. 108 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date.
109 // It returns an error if it fails. 109 // It returns an error if it fails.
110 func RefreshAuthToken(tok string) (TokenClaims, error) { 110 func RefreshAuthToken(tok string) (TokenClaims, error) {
111 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) 111 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc)
112 if err != nil { 112 if err != nil {
113 if validation, ok := err.(*jwt.ValidationError); ok { 113 if validation, ok := err.(*jwt.ValidationError); ok {
114 // don't return error if token is expired 114 // don't return error if token is expired, just extend it
115 // just extend it
116 if !(validation.Errors&jwt.ValidationErrorExpired != 0) { 115 if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
117 return TokenClaims{}, err 116 return TokenClaims{}, err
118 } 117 }
119 } else { 118 } else {
120 return TokenClaims{}, err 119 return TokenClaims{}, err
121 } 120 }
122 } 121 }
123 122
124 // type assertion 123 // type assertion
125 claims, ok := token.Claims.(*TokenClaims) 124 claims, ok := token.Claims.(*TokenClaims)
126 if !ok { 125 if !ok {
127 return TokenClaims{}, errors.New("token is not valid") 126 return TokenClaims{}, errors.New("token is not valid")
128 } 127 }
129 128
130 // extend token expiration date 129 // extend token expiration date
131 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID) 130 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID)
132 } 131 }
133 132
134 // AuthCheck ... 133 // AuthCheck ...
135 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { 134 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) {
136 // validate token and check expiration date 135 // validate token and check expiration date
137 claims, err := GetTokenClaims(req) 136 claims, err := GetTokenClaims(req)
138 if err != nil { 137 if err != nil {
139 return claims, err 138 return claims, err
140 } 139 }
141 140
142 if roles == "" { 141 if roles == "" {
143 return claims, nil 142 return claims, nil
144 } 143 }
145 144
146 // check if token has expired 145 // check if token has expired
147 if claims.ExpiresAt < (time.Now()).Unix() { 146 if claims.ExpiresAt < (time.Now()).Unix() {
148 return claims, errors.New("token has expired") 147 return claims, errors.New("token has expired")
149 } 148 }
150 149
151 if roles == "*" { 150 if roles == "*" {
152 return claims, nil 151 return claims, nil
153 } 152 }
154 153
155 parts := strings.Split(roles, ",") 154 parts := strings.Split(roles, ",")
156 for i := range parts { 155 for i := range parts {
157 r := strings.Trim(parts[i], " ") 156 r := strings.Trim(parts[i], " ")
158 if claims.RoleName == r { 157 if claims.RoleName == r {
159 return claims, nil 158 return claims, nil
160 } 159 }
161 } 160 }
162 161
163 return claims, nil 162 return claims, errors.New("unauthorized role access")
164 } 163 }
165 164
166 // GetTokenClaims extracts JWT claims from Authorization header of req. 165 // GetTokenClaims extracts JWT claims from Authorization header of req.
167 // Returns token claims or an error. 166 // Returns token claims or an error.
168 func GetTokenClaims(req *http.Request) (*TokenClaims, error) { 167 func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
169 // check for and strip 'Bearer' prefix 168 // check for and strip 'Bearer' prefix
170 var tokstr string 169 var tokstr string
171 authHead := req.Header.Get("Authorization") 170 authHead := req.Header.Get("Authorization")
172 if ok := strings.HasPrefix(authHead, "Bearer "); ok { 171 if ok := strings.HasPrefix(authHead, "Bearer "); ok {
173 tokstr = strings.TrimPrefix(authHead, "Bearer ") 172 tokstr = strings.TrimPrefix(authHead, "Bearer ")
174 } else { 173 } else {
175 return &TokenClaims{}, errors.New("authorization header in incomplete") 174 return &TokenClaims{}, errors.New("authorization header in incomplete")
176 } 175 }
177 176
178 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) 177 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
179 if err != nil { 178 if err != nil {
180 return &TokenClaims{}, err 179 return &TokenClaims{}, err
181 } 180 }
182 181
183 // type assertion 182 // type assertion
184 claims, ok := token.Claims.(*TokenClaims) 183 claims, ok := token.Claims.(*TokenClaims)
185 if !ok || !token.Valid { 184 if !ok || !token.Valid {
186 return &TokenClaims{}, errors.New("token is not valid") 185 return &TokenClaims{}, errors.New("token is not valid")
187 } 186 }
188 187
189 return claims, nil 188 return claims, nil
190 } 189 }
191 190
192 // randomSalt returns a string of 32 random characters. 191 // randomSalt returns a string of 32 random characters.
193 func randomSalt() (s string, err error) { 192 func randomSalt() (s string, err error) {
194 const saltSize = 32 193 const saltSize = 32
195 194
196 rawsalt := make([]byte, saltSize) 195 rawsalt := make([]byte, saltSize)
197 196
198 _, err = rand.Read(rawsalt) 197 _, err = rand.Read(rawsalt)
199 if err != nil { 198 if err != nil {
200 return "", err 199 return "", err
201 } 200 }
202 201
203 s = hex.EncodeToString(rawsalt) 202 s = hex.EncodeToString(rawsalt)
204 return s, nil 203 return s, nil
205 } 204 }
206 205
207 // secretFunc returns byte slice of API secret keyword. 206 // secretFunc returns byte slice of API secret keyword.
208 func secretFunc(token *jwt.Token) (interface{}, error) { 207 func secretFunc(token *jwt.Token) (interface{}, error) {
209 return []byte(_secret), nil 208 return []byte(_secret), nil
210 } 209 }
211 210