Commit 052f8a3a63bf8c52b72b5e20ffd0a5d22372e5ab

Authored by Marko Tikvić
1 parent 61efd58cdb
Exists in master and in 1 other branch v2

token validation extended

Showing 1 changed file with 10 additions and 29 deletions   Show diff stats
1 package webutility 1 package webutility
2 2
3 import ( 3 import (
4 "crypto/rand" 4 "crypto/rand"
5 "crypto/sha256" 5 "crypto/sha256"
6 "encoding/hex" 6 "encoding/hex"
7 "errors" 7 "errors"
8 "net/http" 8 "net/http"
9 "strings" 9 "strings"
10 "time" 10 "time"
11 11
12 "github.com/dgrijalva/jwt-go" 12 "github.com/dgrijalva/jwt-go"
13 ) 13 )
14 14
15 const OneDay = time.Hour * 24 15 const OneDay = time.Hour * 24
16 const OneWeek = OneDay * 7 16 const OneWeek = OneDay * 7
17 const saltSize = 32 17 const saltSize = 32
18 const appName = "korisnicki-centar" 18 const appName = "korisnicki-centar"
19 const secret = "korisnicki-centar-api" 19 const secret = "korisnicki-centar-api"
20 20
21 type Role struct { 21 type Role struct {
22 Name string `json:"name"` 22 Name string `json:"name"`
23 ID uint32 `json:"id"` 23 ID uint32 `json:"id"`
24 } 24 }
25 25
26 // TokenClaims are JWT token claims. 26 // TokenClaims are JWT token claims.
27 type TokenClaims struct { 27 type TokenClaims struct {
28 Token string `json:"access_token"` 28 Token string `json:"access_token"`
29 TokenType string `json:"token_type"` 29 TokenType string `json:"token_type"`
30 Username string `json:"username"` 30 Username string `json:"username"`
31 Role string `json:"role"` 31 Role string `json:"role"`
32 RoleID uint32 `json:"role_id"` 32 RoleID uint32 `json:"role_id"`
33 ExpiresIn int64 `json:"expires_in"` 33 ExpiresIn int64 `json:"expires_in"`
34 34
35 // extending a struct 35 // extending a struct
36 jwt.StandardClaims 36 jwt.StandardClaims
37 } 37 }
38 38
39 // CredentialsStruct is an instace of username/password values. 39 // CredentialsStruct is an instace of username/password values.
40 type CredentialsStruct struct { 40 type CredentialsStruct struct {
41 Username string `json:"username"` 41 Username string `json:"username"`
42 Password string `json:"password"` 42 Password string `json:"password"`
43 RoleID uint32 `json:"roleID"` 43 RoleID uint32 `json:"roleID"`
44 } 44 }
45 45
46 // ValidateCredentials hashes pass and salt and returns comparison result with resultHash 46 // ValidateCredentials hashes pass and salt and returns comparison result with resultHash
47 func ValidateCredentials(pass, salt, resultHash string) bool { 47 func ValidateCredentials(pass, salt, resultHash string) bool {
48 hash, _, err := CreateHash(pass, salt) 48 hash, _, err := CreateHash(pass, salt)
49 if err != nil { 49 if err != nil {
50 return false 50 return false
51 } 51 }
52 return hash == resultHash 52 return hash == resultHash
53 } 53 }
54 54
55 // CreateHash hashes str using SHA256. 55 // CreateHash hashes str using SHA256.
56 // If the presalt parameter is not provided CreateHash will generate new salt string. 56 // If the presalt parameter is not provided CreateHash will generate new salt string.
57 // Returns hash and salt strings or an error if it fails. 57 // Returns hash and salt strings or an error if it fails.
58 func CreateHash(str, presalt string) (hash, salt string, err error) { 58 func CreateHash(str, presalt string) (hash, salt string, err error) {
59 // chech if message is presalted 59 // chech if message is presalted
60 if presalt == "" { 60 if presalt == "" {
61 salt, err = randomSalt() 61 salt, err = randomSalt()
62 if err != nil { 62 if err != nil {
63 return "", "", err 63 return "", "", err
64 } 64 }
65 } else { 65 } else {
66 salt = presalt 66 salt = presalt
67 } 67 }
68 68
69 // convert strings to raw byte slices 69 // convert strings to raw byte slices
70 rawstr := []byte(str) 70 rawstr := []byte(str)
71 rawsalt, err := hex.DecodeString(salt) 71 rawsalt, err := hex.DecodeString(salt)
72 if err != nil { 72 if err != nil {
73 return "", "", err 73 return "", "", err
74 } 74 }
75 75
76 rawdata := make([]byte, len(rawstr)+len(rawsalt)) 76 rawdata := make([]byte, len(rawstr)+len(rawsalt))
77 rawdata = append(rawdata, rawstr...) 77 rawdata = append(rawdata, rawstr...)
78 rawdata = append(rawdata, rawsalt...) 78 rawdata = append(rawdata, rawsalt...)
79 79
80 // hash message + salt 80 // hash message + salt
81 hasher := sha256.New() 81 hasher := sha256.New()
82 hasher.Write(rawdata) 82 hasher.Write(rawdata)
83 rawhash := hasher.Sum(nil) 83 rawhash := hasher.Sum(nil)
84 84
85 hash = hex.EncodeToString(rawhash) 85 hash = hex.EncodeToString(rawhash)
86 return hash, salt, nil 86 return hash, salt, nil
87 } 87 }
88 88
89 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. 89 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
90 // It returns an error if it fails. 90 // It returns an error if it fails.
91 func CreateAuthToken(username string, role Role) (TokenClaims, error) { 91 func CreateAuthToken(username string, role Role) (TokenClaims, error) {
92 t0 := (time.Now()).Unix() 92 t0 := (time.Now()).Unix()
93 t1 := (time.Now().Add(OneWeek)).Unix() 93 t1 := (time.Now().Add(OneWeek)).Unix()
94 claims := TokenClaims{ 94 claims := TokenClaims{
95 TokenType: "Bearer", 95 TokenType: "Bearer",
96 Username: username, 96 Username: username,
97 Role: role.Name, 97 Role: role.Name,
98 RoleID: role.ID, 98 RoleID: role.ID,
99 ExpiresIn: t1 - t0, 99 ExpiresIn: t1 - t0,
100 } 100 }
101 // initialize jwt.StandardClaims fields (anonymous struct) 101 // initialize jwt.StandardClaims fields (anonymous struct)
102 claims.IssuedAt = t0 102 claims.IssuedAt = t0
103 claims.ExpiresAt = t1 103 claims.ExpiresAt = t1
104 claims.Issuer = appName 104 claims.Issuer = appName
105 105
106 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 106 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
107 token, err := jwtToken.SignedString([]byte(secret)) 107 token, err := jwtToken.SignedString([]byte(secret))
108 if err != nil { 108 if err != nil {
109 return TokenClaims{}, err 109 return TokenClaims{}, err
110 } 110 }
111 claims.Token = token 111 claims.Token = token
112 return claims, nil 112 return claims, nil
113 } 113 }
114 114
115 // RefreshAuthToken prolongs JWT token's expiration date for one week. 115 // RefreshAuthToken prolongs JWT token's expiration date for one week.
116 // It returns new JWT token or an error if it fails. 116 // It returns new JWT token or an error if it fails.
117 func RefreshAuthToken(req *http.Request) (TokenClaims, error) { 117 func RefreshAuthToken(req *http.Request) (TokenClaims, error) {
118 authHead := req.Header.Get("Authorization") 118 authHead := req.Header.Get("Authorization")
119 tokenstr := strings.TrimPrefix(authHead, "Bearer ") 119 tokenstr := strings.TrimPrefix(authHead, "Bearer ")
120 token, err := jwt.ParseWithClaims(tokenstr, &TokenClaims{}, secretFunc) 120 token, err := jwt.ParseWithClaims(tokenstr, &TokenClaims{}, secretFunc)
121 if err != nil { 121 if err != nil {
122 return TokenClaims{}, err 122 if validation, ok := err.(*jwt.ValidationError); ok {
123 // don't return error if token is expired
124 // just extend it
125 if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
126 return TokenClaims{}, err
127 }
128 } else {
129 return TokenClaims{}, err
130 }
123 } 131 }
124 132
125 // type assertion 133 // type assertion
126 claims, ok := token.Claims.(*TokenClaims) 134 claims, ok := token.Claims.(*TokenClaims)
127 if !ok || !token.Valid { 135 if !ok {
128 return TokenClaims{}, errors.New("token is not valid") 136 return TokenClaims{}, errors.New("token is not valid")
129 } 137 }
130 138
131 // extend token expiration date 139 // extend token expiration date
132 return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) 140 return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID})
133 } 141 }
134 142
135 // RbacCheck returns true if role that made HTTP request is authorized to 143 // RbacCheck returns true if role that made HTTP request is authorized to
136 // access the resource it is targeting. 144 // access the resource it is targeting.
137 // It exctracts user's role from the JWT token located in Authorization header of 145 // It exctracts user's role from the JWT token located in Authorization header of
138 // http.Request and then compares it with the list of supplied roles and returns 146 // http.Request and then compares it with the list of supplied roles and returns
139 // true if there's a match, if "*" is provided or if the authRoles is nil. 147 // true if there's a match, if "*" is provided or if the authRoles is nil.
140 // Otherwise it returns false. 148 // Otherwise it returns false.
141 func RbacCheck(req *http.Request, authRoles []string) bool { 149 func RbacCheck(req *http.Request, authRoles []string) bool {
142 if authRoles == nil { 150 if authRoles == nil {
143 return true 151 return true
144 } 152 }
145 153
146 // validate token and check expiration date 154 // validate token and check expiration date
147 claims, err := GetTokenClaims(req) 155 claims, err := GetTokenClaims(req)
148 if err != nil { 156 if err != nil {
149 return false 157 return false
150 } 158 }
151 // check if token has expired 159 // check if token has expired
152 if claims.ExpiresAt < (time.Now()).Unix() { 160 if claims.ExpiresAt < (time.Now()).Unix() {
153 return false 161 return false
154 } 162 }
155 163
156 // check if role extracted from token matches 164 // check if role extracted from token matches
157 // any of the provided (allowed) ones 165 // any of the provided (allowed) ones
158 for _, r := range authRoles { 166 for _, r := range authRoles {
159 if claims.Role == r || r == "*" { 167 if claims.Role == r || r == "*" {
160 return true 168 return true
161 } 169 }
162 } 170 }
163 171
164 return false 172 return false
165 } 173 }
166 174
167 // TODO
168 func AuthCheck(req *http.Request, authRoles []string) (*TokenClaims, error) {
169 if authRoles == nil {
170 return &TokenClaims{}, nil
171 }
172
173 // validate token and check expiration date
174 claims, err := GetTokenClaims(req)
175 if err != nil {
176 return &TokenClaims{}, err
177 }
178 // check if token has expired
179 if claims.ExpiresAt < (time.Now()).Unix() {
180 return &TokenClaims{}, errors.New("token has expired")
181 }
182
183 // check if role extracted from token matches
184 // any of the provided (allowed) ones
185 for _, r := range authRoles {
186 if claims.Role == r || r == "*" {
187 return claims, nil
188 }
189 }
190
191 return claims, errors.New("role is not authorized")
192 }
193
194 // GetTokenClaims extracts JWT claims from Authorization header of the request. 175 // GetTokenClaims extracts JWT claims from Authorization header of the request.
195 // Returns token claims or an error. 176 // Returns token claims or an error.
196 func GetTokenClaims(req *http.Request) (*TokenClaims, error) { 177 func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
197 // check for and strip 'Bearer' prefix 178 // check for and strip 'Bearer' prefix
198 var tokstr string 179 var tokstr string
199 authHead := req.Header.Get("Authorization") 180 authHead := req.Header.Get("Authorization")
200 if ok := strings.HasPrefix(authHead, "Bearer "); ok { 181 if ok := strings.HasPrefix(authHead, "Bearer "); ok {
201 tokstr = strings.TrimPrefix(authHead, "Bearer ") 182 tokstr = strings.TrimPrefix(authHead, "Bearer ")
202 } else { 183 } else {
203 return &TokenClaims{}, errors.New("authorization header in incomplete") 184 return &TokenClaims{}, errors.New("authorization header in incomplete")
204 } 185 }
205 186
206 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) 187 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
207 if err != nil { 188 if err != nil {
208 return &TokenClaims{}, err 189 return &TokenClaims{}, err
209 } 190 }
210 191
211 // type assertion 192 // type assertion
212 claims, ok := token.Claims.(*TokenClaims) 193 claims, ok := token.Claims.(*TokenClaims)
213 if !ok || !token.Valid { 194 if !ok || !token.Valid {
214 return &TokenClaims{}, errors.New("token is not valid") 195 return &TokenClaims{}, errors.New("token is not valid")
215 } 196 }
216 197
217 return claims, nil 198 return claims, nil
218 } 199 }
219 200
220 // randomSalt returns a string of random characters of 'saltSize' length. 201 // randomSalt returns a string of random characters of 'saltSize' length.
221 func randomSalt() (s string, err error) { 202 func randomSalt() (s string, err error) {
222 rawsalt := make([]byte, saltSize) 203 rawsalt := make([]byte, saltSize)
223 204
224 _, err = rand.Read(rawsalt) 205 _, err = rand.Read(rawsalt)
225 if err != nil { 206 if err != nil {
226 return "", err 207 return "", err
227 } 208 }
228 209
229 s = hex.EncodeToString(rawsalt) 210 s = hex.EncodeToString(rawsalt)