Commit de25e1deb6dc5ea3eac3668d0f05500763dfb3dc

Authored by Marko Tikvić
1 parent 0feac50590
Exists in master

removed redundant role struct

Showing 1 changed file with 6 additions and 11 deletions   Show diff stats
1 package webutility 1 package webutility
2 2
3 import ( 3 import (
4 "crypto/rand" 4 "crypto/rand"
5 "crypto/sha256" 5 "crypto/sha256"
6 "encoding/hex" 6 "encoding/hex"
7 "errors" 7 "errors"
8 "net/http" 8 "net/http"
9 "strings" 9 "strings"
10 "time" 10 "time"
11 11
12 "github.com/dgrijalva/jwt-go" 12 "github.com/dgrijalva/jwt-go"
13 ) 13 )
14 14
15 var appName = "webutility" 15 var appName = "webutility"
16 var secret = "webutility" 16 var secret = "webutility"
17 17
18 type Role struct {
19 Name string `json:"name"`
20 ID int64 `json:"id"`
21 }
22
23 // TokenClaims are JWT token claims. 18 // TokenClaims are JWT token claims.
24 type TokenClaims struct { 19 type TokenClaims struct {
25 // extending a struct 20 // extending a struct
26 jwt.StandardClaims 21 jwt.StandardClaims
27 22
28 // custom claims 23 // custom claims
29 Token string `json:"access_token"` 24 Token string `json:"access_token"`
30 TokenType string `json:"token_type"` 25 TokenType string `json:"token_type"`
31 Username string `json:"username"` 26 Username string `json:"username"`
32 Role string `json:"role"` 27 RoleName string `json:"role"`
33 RoleID int64 `json:"role_id"` 28 RoleID int64 `json:"role_id"`
34 ExpiresIn int64 `json:"expires_in"` 29 ExpiresIn int64 `json:"expires_in"`
35 } 30 }
36 31
37 func InitJWT(appName, secret string) { 32 func InitJWT(appName, secret string) {
38 appName = appName 33 appName = appName
39 secret = secret 34 secret = secret
40 } 35 }
41 36
42 // ValidateHash hashes pass and salt and returns comparison result with resultHash 37 // ValidateHash hashes pass and salt and returns comparison result with resultHash
43 func ValidateHash(pass, salt, resultHash string) (bool, error) { 38 func ValidateHash(pass, salt, resultHash string) (bool, error) {
44 hash, _, err := CreateHash(pass, salt) 39 hash, _, err := CreateHash(pass, salt)
45 if err != nil { 40 if err != nil {
46 return false, err 41 return false, err
47 } 42 }
48 res := hash == resultHash 43 res := hash == resultHash
49 return res, nil 44 return res, nil
50 } 45 }
51 46
52 // CreateHash hashes str using SHA256. 47 // CreateHash hashes str using SHA256.
53 // If the presalt parameter is not provided CreateHash will generate new salt string. 48 // If the presalt parameter is not provided CreateHash will generate new salt string.
54 // Returns hash and salt strings or an error if it fails. 49 // Returns hash and salt strings or an error if it fails.
55 func CreateHash(str, presalt string) (hash, salt string, err error) { 50 func CreateHash(str, presalt string) (hash, salt string, err error) {
56 // chech if message is presalted 51 // chech if message is presalted
57 if presalt == "" { 52 if presalt == "" {
58 salt, err = randomSalt() 53 salt, err = randomSalt()
59 if err != nil { 54 if err != nil {
60 return "", "", err 55 return "", "", err
61 } 56 }
62 } else { 57 } else {
63 salt = presalt 58 salt = presalt
64 } 59 }
65 60
66 // convert strings to raw byte slices 61 // convert strings to raw byte slices
67 rawstr := []byte(str) 62 rawstr := []byte(str)
68 rawsalt, err := hex.DecodeString(salt) 63 rawsalt, err := hex.DecodeString(salt)
69 if err != nil { 64 if err != nil {
70 return "", "", err 65 return "", "", err
71 } 66 }
72 67
73 rawdata := make([]byte, len(rawstr)+len(rawsalt)) 68 rawdata := make([]byte, len(rawstr)+len(rawsalt))
74 rawdata = append(rawdata, rawstr...) 69 rawdata = append(rawdata, rawstr...)
75 rawdata = append(rawdata, rawsalt...) 70 rawdata = append(rawdata, rawsalt...)
76 71
77 // hash message + salt 72 // hash message + salt
78 hasher := sha256.New() 73 hasher := sha256.New()
79 hasher.Write(rawdata) 74 hasher.Write(rawdata)
80 rawhash := hasher.Sum(nil) 75 rawhash := hasher.Sum(nil)
81 76
82 hash = hex.EncodeToString(rawhash) 77 hash = hex.EncodeToString(rawhash)
83 return hash, salt, nil 78 return hash, salt, nil
84 } 79 }
85 80
86 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. 81 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
87 // It returns an error if it fails. 82 // It returns an error if it fails.
88 func CreateAuthToken(username string, role Role) (TokenClaims, error) { 83 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) {
89 t0 := (time.Now()).Unix() 84 t0 := (time.Now()).Unix()
90 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() 85 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix()
91 claims := TokenClaims{ 86 claims := TokenClaims{
92 TokenType: "Bearer", 87 TokenType: "Bearer",
93 Username: username, 88 Username: username,
94 Role: role.Name, 89 RoleName: roleName,
95 RoleID: role.ID, 90 RoleID: roleID,
96 ExpiresIn: t1 - t0, 91 ExpiresIn: t1 - t0,
97 } 92 }
98 // initialize jwt.StandardClaims fields (anonymous struct) 93 // initialize jwt.StandardClaims fields (anonymous struct)
99 claims.IssuedAt = t0 94 claims.IssuedAt = t0
100 claims.ExpiresAt = t1 95 claims.ExpiresAt = t1
101 claims.Issuer = appName 96 claims.Issuer = appName
102 97
103 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 98 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
104 token, err := jwtToken.SignedString([]byte(secret)) 99 token, err := jwtToken.SignedString([]byte(secret))
105 if err != nil { 100 if err != nil {
106 return TokenClaims{}, err 101 return TokenClaims{}, err
107 } 102 }
108 claims.Token = token 103 claims.Token = token
109 return claims, nil 104 return claims, nil
110 } 105 }
111 106
112 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. 107 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date.
113 // It returns an error if it fails. 108 // It returns an error if it fails.
114 func RefreshAuthToken(tok string) (TokenClaims, error) { 109 func RefreshAuthToken(tok string) (TokenClaims, error) {
115 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) 110 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc)
116 if err != nil { 111 if err != nil {
117 if validation, ok := err.(*jwt.ValidationError); ok { 112 if validation, ok := err.(*jwt.ValidationError); ok {
118 // don't return error if token is expired 113 // don't return error if token is expired
119 // just extend it 114 // just extend it
120 if !(validation.Errors&jwt.ValidationErrorExpired != 0) { 115 if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
121 return TokenClaims{}, err 116 return TokenClaims{}, err
122 } 117 }
123 } else { 118 } else {
124 return TokenClaims{}, err 119 return TokenClaims{}, err
125 } 120 }
126 } 121 }
127 122
128 // type assertion 123 // type assertion
129 claims, ok := token.Claims.(*TokenClaims) 124 claims, ok := token.Claims.(*TokenClaims)
130 if !ok { 125 if !ok {
131 return TokenClaims{}, errors.New("token is not valid") 126 return TokenClaims{}, errors.New("token is not valid")
132 } 127 }
133 128
134 // extend token expiration date 129 // extend token expiration date
135 return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID}) 130 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID)
136 } 131 }
137 132
138 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { 133 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) {
139 // validate token and check expiration date 134 // validate token and check expiration date
140 claims, err := GetTokenClaims(req) 135 claims, err := GetTokenClaims(req)
141 if err != nil { 136 if err != nil {
142 return claims, err 137 return claims, err
143 } 138 }
144 139
145 if roles == "" { 140 if roles == "" {
146 return claims, nil 141 return claims, nil
147 } 142 }
148 143
149 // check if token has expired 144 // check if token has expired
150 if claims.ExpiresAt < (time.Now()).Unix() { 145 if claims.ExpiresAt < (time.Now()).Unix() {
151 return claims, errors.New("token has expired") 146 return claims, errors.New("token has expired")
152 } 147 }
153 148
154 if roles == "*" { 149 if roles == "*" {
155 return claims, nil 150 return claims, nil
156 } 151 }
157 152
158 parts := strings.Split(roles, ",") 153 parts := strings.Split(roles, ",")
159 for i, _ := range parts { 154 for i, _ := range parts {
160 r := strings.Trim(parts[i], " ") 155 r := strings.Trim(parts[i], " ")
161 if claims.Role == r { 156 if claims.RoleName == r {
162 return claims, nil 157 return claims, nil
163 } 158 }
164 } 159 }
165 160
166 return claims, nil 161 return claims, nil
167 } 162 }
168 163
169 // GetTokenClaims extracts JWT claims from Authorization header of req. 164 // GetTokenClaims extracts JWT claims from Authorization header of req.
170 // Returns token claims or an error. 165 // Returns token claims or an error.
171 func GetTokenClaims(req *http.Request) (*TokenClaims, error) { 166 func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
172 // check for and strip 'Bearer' prefix 167 // check for and strip 'Bearer' prefix
173 var tokstr string 168 var tokstr string
174 authHead := req.Header.Get("Authorization") 169 authHead := req.Header.Get("Authorization")
175 if ok := strings.HasPrefix(authHead, "Bearer "); ok { 170 if ok := strings.HasPrefix(authHead, "Bearer "); ok {
176 tokstr = strings.TrimPrefix(authHead, "Bearer ") 171 tokstr = strings.TrimPrefix(authHead, "Bearer ")
177 } else { 172 } else {
178 return &TokenClaims{}, errors.New("authorization header in incomplete") 173 return &TokenClaims{}, errors.New("authorization header in incomplete")
179 } 174 }
180 175
181 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) 176 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
182 if err != nil { 177 if err != nil {
183 return &TokenClaims{}, err 178 return &TokenClaims{}, err
184 } 179 }
185 180
186 // type assertion 181 // type assertion
187 claims, ok := token.Claims.(*TokenClaims) 182 claims, ok := token.Claims.(*TokenClaims)
188 if !ok || !token.Valid { 183 if !ok || !token.Valid {
189 return &TokenClaims{}, errors.New("token is not valid") 184 return &TokenClaims{}, errors.New("token is not valid")
190 } 185 }
191 186
192 return claims, nil 187 return claims, nil
193 } 188 }
194 189
195 // randomSalt returns a string of 32 random characters. 190 // randomSalt returns a string of 32 random characters.
196 const saltSize = 32 191 const saltSize = 32
197 192
198 func randomSalt() (s string, err error) { 193 func randomSalt() (s string, err error) {
199 rawsalt := make([]byte, saltSize) 194 rawsalt := make([]byte, saltSize)
200 195
201 _, err = rand.Read(rawsalt) 196 _, err = rand.Read(rawsalt)
202 if err != nil { 197 if err != nil {
203 return "", err 198 return "", err
204 } 199 }
205 200
206 s = hex.EncodeToString(rawsalt) 201 s = hex.EncodeToString(rawsalt)
207 return s, nil 202 return s, nil
208 } 203 }
209 204
210 // secretFunc returns byte slice of API secret keyword. 205 // secretFunc returns byte slice of API secret keyword.
211 func secretFunc(token *jwt.Token) (interface{}, error) { 206 func secretFunc(token *jwt.Token) (interface{}, error) {
212 return []byte(secret), nil 207 return []byte(secret), nil
213 } 208 }
214 209