Commit 0feac505905b9435723c3a058453e61a59cd7dc7

Authored by Marko Tikvić
1 parent 31a4e13027
Exists in master

renamed ValidateCredentials to ValidateHash

Showing 1 changed file with 2 additions and 2 deletions   Show diff stats
1 package webutility 1 package webutility
2 2
3 import ( 3 import (
4 "crypto/rand" 4 "crypto/rand"
5 "crypto/sha256" 5 "crypto/sha256"
6 "encoding/hex" 6 "encoding/hex"
7 "errors" 7 "errors"
8 "net/http" 8 "net/http"
9 "strings" 9 "strings"
10 "time" 10 "time"
11 11
12 "github.com/dgrijalva/jwt-go" 12 "github.com/dgrijalva/jwt-go"
13 ) 13 )
14 14
15 var appName = "webutility" 15 var appName = "webutility"
16 var secret = "webutility" 16 var secret = "webutility"
17 17
18 type Role struct { 18 type Role struct {
19 Name string `json:"name"` 19 Name string `json:"name"`
20 ID int64 `json:"id"` 20 ID int64 `json:"id"`
21 } 21 }
22 22
23 // TokenClaims are JWT token claims. 23 // TokenClaims are JWT token claims.
24 type TokenClaims struct { 24 type TokenClaims struct {
25 // extending a struct 25 // extending a struct
26 jwt.StandardClaims 26 jwt.StandardClaims
27 27
28 // custom claims 28 // custom claims
29 Token string `json:"access_token"` 29 Token string `json:"access_token"`
30 TokenType string `json:"token_type"` 30 TokenType string `json:"token_type"`
31 Username string `json:"username"` 31 Username string `json:"username"`
32 Role string `json:"role"` 32 Role string `json:"role"`
33 RoleID int64 `json:"role_id"` 33 RoleID int64 `json:"role_id"`
34 ExpiresIn int64 `json:"expires_in"` 34 ExpiresIn int64 `json:"expires_in"`
35 } 35 }
36 36
37 func InitJWT(appName, secret string) { 37 func InitJWT(appName, secret string) {
38 appName = appName 38 appName = appName
39 secret = secret 39 secret = secret
40 } 40 }
41 41
42 // ValidateCredentials hashes pass and salt and returns comparison result with resultHash 42 // ValidateHash hashes pass and salt and returns comparison result with resultHash
43 func ValidateCredentials(pass, salt, resultHash string) (bool, error) { 43 func ValidateHash(pass, salt, resultHash string) (bool, error) {
44 hash, _, err := CreateHash(pass, salt) 44 hash, _, err := CreateHash(pass, salt)
45 if err != nil { 45 if err != nil {
46 return false, err 46 return false, err
47 } 47 }
48 res := hash == resultHash 48 res := hash == resultHash
49 return res, nil 49 return res, nil
50 } 50 }
51 51
52 // CreateHash hashes str using SHA256. 52 // CreateHash hashes str using SHA256.
53 // If the presalt parameter is not provided CreateHash will generate new salt string. 53 // If the presalt parameter is not provided CreateHash will generate new salt string.
54 // Returns hash and salt strings or an error if it fails. 54 // Returns hash and salt strings or an error if it fails.
55 func CreateHash(str, presalt string) (hash, salt string, err error) { 55 func CreateHash(str, presalt string) (hash, salt string, err error) {
56 // chech if message is presalted 56 // chech if message is presalted
57 if presalt == "" { 57 if presalt == "" {
58 salt, err = randomSalt() 58 salt, err = randomSalt()
59 if err != nil { 59 if err != nil {
60 return "", "", err 60 return "", "", err
61 } 61 }
62 } else { 62 } else {
63 salt = presalt 63 salt = presalt
64 } 64 }
65 65
66 // convert strings to raw byte slices 66 // convert strings to raw byte slices
67 rawstr := []byte(str) 67 rawstr := []byte(str)
68 rawsalt, err := hex.DecodeString(salt) 68 rawsalt, err := hex.DecodeString(salt)
69 if err != nil { 69 if err != nil {
70 return "", "", err 70 return "", "", err
71 } 71 }
72 72
73 rawdata := make([]byte, len(rawstr)+len(rawsalt)) 73 rawdata := make([]byte, len(rawstr)+len(rawsalt))
74 rawdata = append(rawdata, rawstr...) 74 rawdata = append(rawdata, rawstr...)
75 rawdata = append(rawdata, rawsalt...) 75 rawdata = append(rawdata, rawsalt...)
76 76
77 // hash message + salt 77 // hash message + salt
78 hasher := sha256.New() 78 hasher := sha256.New()
79 hasher.Write(rawdata) 79 hasher.Write(rawdata)
80 rawhash := hasher.Sum(nil) 80 rawhash := hasher.Sum(nil)
81 81
82 hash = hex.EncodeToString(rawhash) 82 hash = hex.EncodeToString(rawhash)
83 return hash, salt, nil 83 return hash, salt, nil
84 } 84 }
85 85
86 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. 86 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
87 // It returns an error if it fails. 87 // It returns an error if it fails.
88 func CreateAuthToken(username string, role Role) (TokenClaims, error) { 88 func CreateAuthToken(username string, role Role) (TokenClaims, error) {
89 t0 := (time.Now()).Unix() 89 t0 := (time.Now()).Unix()
90 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() 90 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix()
91 claims := TokenClaims{ 91 claims := TokenClaims{
92 TokenType: "Bearer", 92 TokenType: "Bearer",
93 Username: username, 93 Username: username,
94 Role: role.Name, 94 Role: role.Name,
95 RoleID: role.ID, 95 RoleID: role.ID,
96 ExpiresIn: t1 - t0, 96 ExpiresIn: t1 - t0,
97 } 97 }
98 // initialize jwt.StandardClaims fields (anonymous struct) 98 // initialize jwt.StandardClaims fields (anonymous struct)
99 claims.IssuedAt = t0 99 claims.IssuedAt = t0
100 claims.ExpiresAt = t1 100 claims.ExpiresAt = t1
101 claims.Issuer = appName 101 claims.Issuer = appName
102 102
103 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 103 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
104 token, err := jwtToken.SignedString([]byte(secret)) 104 token, err := jwtToken.SignedString([]byte(secret))
105 if err != nil { 105 if err != nil {
106 return TokenClaims{}, err 106 return TokenClaims{}, err
107 } 107 }
108 claims.Token = token 108 claims.Token = token
109 return claims, nil 109 return claims, nil
110 } 110 }
111 111
112 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. 112 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date.
113 // It returns an error if it fails. 113 // It returns an error if it fails.
114 func RefreshAuthToken(tok string) (TokenClaims, error) { 114 func RefreshAuthToken(tok string) (TokenClaims, error) {
115 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) 115 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc)
116 if err != nil { 116 if err != nil {
117 if validation, ok := err.(*jwt.ValidationError); ok { 117 if validation, ok := err.(*jwt.ValidationError); ok {
118 // don't return error if token is expired 118 // don't return error if token is expired
119 // just extend it 119 // just extend it
120 if !(validation.Errors&jwt.ValidationErrorExpired != 0) { 120 if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
121 return TokenClaims{}, err 121 return TokenClaims{}, err
122 } 122 }
123 } else { 123 } else {
124 return TokenClaims{}, err 124 return TokenClaims{}, err
125 } 125 }
126 } 126 }
127 127
128 // type assertion 128 // type assertion
129 claims, ok := token.Claims.(*TokenClaims) 129 claims, ok := token.Claims.(*TokenClaims)
130 if !ok { 130 if !ok {
131 return TokenClaims{}, errors.New("token is not valid") 131 return TokenClaims{}, errors.New("token is not valid")
132 } 132 }
133 133
134 // extend token expiration date 134 // extend token expiration date
135 return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) 135 return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID})
136 } 136 }
137 137
138 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { 138 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) {
139 // validate token and check expiration date 139 // validate token and check expiration date
140 claims, err := GetTokenClaims(req) 140 claims, err := GetTokenClaims(req)
141 if err != nil { 141 if err != nil {
142 return claims, err 142 return claims, err
143 } 143 }
144 144
145 if roles == "" { 145 if roles == "" {
146 return claims, nil 146 return claims, nil
147 } 147 }
148 148
149 // check if token has expired 149 // check if token has expired
150 if claims.ExpiresAt < (time.Now()).Unix() { 150 if claims.ExpiresAt < (time.Now()).Unix() {
151 return claims, errors.New("token has expired") 151 return claims, errors.New("token has expired")
152 } 152 }
153 153
154 if roles == "*" { 154 if roles == "*" {
155 return claims, nil 155 return claims, nil
156 } 156 }
157 157
158 parts := strings.Split(roles, ",") 158 parts := strings.Split(roles, ",")
159 for i, _ := range parts { 159 for i, _ := range parts {
160 r := strings.Trim(parts[i], " ") 160 r := strings.Trim(parts[i], " ")
161 if claims.Role == r { 161 if claims.Role == r {
162 return claims, nil 162 return claims, nil
163 } 163 }
164 } 164 }
165 165
166 return claims, nil 166 return claims, nil
167 } 167 }
168 168
169 // GetTokenClaims extracts JWT claims from Authorization header of req. 169 // GetTokenClaims extracts JWT claims from Authorization header of req.
170 // Returns token claims or an error. 170 // Returns token claims or an error.
171 func GetTokenClaims(req *http.Request) (*TokenClaims, error) { 171 func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
172 // check for and strip 'Bearer' prefix 172 // check for and strip 'Bearer' prefix
173 var tokstr string 173 var tokstr string
174 authHead := req.Header.Get("Authorization") 174 authHead := req.Header.Get("Authorization")
175 if ok := strings.HasPrefix(authHead, "Bearer "); ok { 175 if ok := strings.HasPrefix(authHead, "Bearer "); ok {
176 tokstr = strings.TrimPrefix(authHead, "Bearer ") 176 tokstr = strings.TrimPrefix(authHead, "Bearer ")
177 } else { 177 } else {
178 return &TokenClaims{}, errors.New("authorization header in incomplete") 178 return &TokenClaims{}, errors.New("authorization header in incomplete")
179 } 179 }
180 180
181 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) 181 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
182 if err != nil { 182 if err != nil {
183 return &TokenClaims{}, err 183 return &TokenClaims{}, err
184 } 184 }
185 185
186 // type assertion 186 // type assertion
187 claims, ok := token.Claims.(*TokenClaims) 187 claims, ok := token.Claims.(*TokenClaims)
188 if !ok || !token.Valid { 188 if !ok || !token.Valid {
189 return &TokenClaims{}, errors.New("token is not valid") 189 return &TokenClaims{}, errors.New("token is not valid")
190 } 190 }
191 191
192 return claims, nil 192 return claims, nil
193 } 193 }
194 194
195 // randomSalt returns a string of 32 random characters. 195 // randomSalt returns a string of 32 random characters.
196 const saltSize = 32 196 const saltSize = 32
197 197
198 func randomSalt() (s string, err error) { 198 func randomSalt() (s string, err error) {
199 rawsalt := make([]byte, saltSize) 199 rawsalt := make([]byte, saltSize)
200 200
201 _, err = rand.Read(rawsalt) 201 _, err = rand.Read(rawsalt)
202 if err != nil { 202 if err != nil {
203 return "", err 203 return "", err
204 } 204 }
205 205
206 s = hex.EncodeToString(rawsalt) 206 s = hex.EncodeToString(rawsalt)
207 return s, nil 207 return s, nil
208 } 208 }
209 209
210 // secretFunc returns byte slice of API secret keyword. 210 // secretFunc returns byte slice of API secret keyword.
211 func secretFunc(token *jwt.Token) (interface{}, error) { 211 func secretFunc(token *jwt.Token) (interface{}, error) {
212 return []byte(secret), nil 212 return []byte(secret), nil
213 } 213 }
214 214