Commit fbf92700f4c7403d1086f3e97373a686081cfe34

Authored by Marko Tikvić
1 parent 2b62f61cc8
Exists in master

renamed package variables

Showing 1 changed file with 8 additions and 8 deletions   Show diff stats
1 package webutility 1 package webutility
2 2
3 import ( 3 import (
4 "crypto/rand" 4 "crypto/rand"
5 "crypto/sha256" 5 "crypto/sha256"
6 "encoding/hex" 6 "encoding/hex"
7 "errors" 7 "errors"
8 "net/http" 8 "net/http"
9 "strings" 9 "strings"
10 "time" 10 "time"
11 11
12 "github.com/dgrijalva/jwt-go" 12 "github.com/dgrijalva/jwt-go"
13 ) 13 )
14 14
15 var appName = "webutility" 15 var _issuer = "webutility"
16 var secret = "webutility" 16 var _secret = "webutility"
17 17
18 // TokenClaims are JWT token claims. 18 // TokenClaims are JWT token claims.
19 type TokenClaims struct { 19 type TokenClaims struct {
20 // extending a struct 20 // extending a struct
21 jwt.StandardClaims 21 jwt.StandardClaims
22 22
23 // custom claims 23 // custom claims
24 Token string `json:"access_token"` 24 Token string `json:"access_token"`
25 TokenType string `json:"token_type"` 25 TokenType string `json:"token_type"`
26 Username string `json:"username"` 26 Username string `json:"username"`
27 RoleName string `json:"role"` 27 RoleName string `json:"role"`
28 RoleID int64 `json:"role_id"` 28 RoleID int64 `json:"role_id"`
29 ExpiresIn int64 `json:"expires_in"` 29 ExpiresIn int64 `json:"expires_in"`
30 } 30 }
31 31
32 func InitJWT(appName, secret string) { 32 func InitJWT(issuer, secret string) {
33 appName = appName 33 _issuer = issuer
34 secret = secret 34 _secret = secret
35 } 35 }
36 36
37 // ValidateHash hashes pass and salt and returns comparison result with resultHash 37 // ValidateHash hashes pass and salt and returns comparison result with resultHash
38 func ValidateHash(pass, salt, resultHash string) (bool, error) { 38 func ValidateHash(pass, salt, resultHash string) (bool, error) {
39 hash, _, err := CreateHash(pass, salt) 39 hash, _, err := CreateHash(pass, salt)
40 if err != nil { 40 if err != nil {
41 return false, err 41 return false, err
42 } 42 }
43 res := hash == resultHash 43 res := hash == resultHash
44 return res, nil 44 return res, nil
45 } 45 }
46 46
47 // CreateHash hashes str using SHA256. 47 // CreateHash hashes str using SHA256.
48 // If the presalt parameter is not provided CreateHash will generate new salt string. 48 // If the presalt parameter is not provided CreateHash will generate new salt string.
49 // Returns hash and salt strings or an error if it fails. 49 // Returns hash and salt strings or an error if it fails.
50 func CreateHash(str, presalt string) (hash, salt string, err error) { 50 func CreateHash(str, presalt string) (hash, salt string, err error) {
51 // chech if message is presalted 51 // chech if message is presalted
52 if presalt == "" { 52 if presalt == "" {
53 salt, err = randomSalt() 53 salt, err = randomSalt()
54 if err != nil { 54 if err != nil {
55 return "", "", err 55 return "", "", err
56 } 56 }
57 } else { 57 } else {
58 salt = presalt 58 salt = presalt
59 } 59 }
60 60
61 // convert strings to raw byte slices 61 // convert strings to raw byte slices
62 rawstr := []byte(str) 62 rawstr := []byte(str)
63 rawsalt, err := hex.DecodeString(salt) 63 rawsalt, err := hex.DecodeString(salt)
64 if err != nil { 64 if err != nil {
65 return "", "", err 65 return "", "", err
66 } 66 }
67 67
68 rawdata := make([]byte, len(rawstr)+len(rawsalt)) 68 rawdata := make([]byte, len(rawstr)+len(rawsalt))
69 rawdata = append(rawdata, rawstr...) 69 rawdata = append(rawdata, rawstr...)
70 rawdata = append(rawdata, rawsalt...) 70 rawdata = append(rawdata, rawsalt...)
71 71
72 // hash message + salt 72 // hash message + salt
73 hasher := sha256.New() 73 hasher := sha256.New()
74 hasher.Write(rawdata) 74 hasher.Write(rawdata)
75 rawhash := hasher.Sum(nil) 75 rawhash := hasher.Sum(nil)
76 76
77 hash = hex.EncodeToString(rawhash) 77 hash = hex.EncodeToString(rawhash)
78 return hash, salt, nil 78 return hash, salt, nil
79 } 79 }
80 80
81 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims. 81 // CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
82 // It returns an error if it fails. 82 // It returns an error if it fails.
83 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) { 83 func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) {
84 t0 := (time.Now()).Unix() 84 t0 := (time.Now()).Unix()
85 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix() 85 t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix()
86 claims := TokenClaims{ 86 claims := TokenClaims{
87 TokenType: "Bearer", 87 TokenType: "Bearer",
88 Username: username, 88 Username: username,
89 RoleName: roleName, 89 RoleName: roleName,
90 RoleID: roleID, 90 RoleID: roleID,
91 ExpiresIn: t1 - t0, 91 ExpiresIn: t1 - t0,
92 } 92 }
93 // initialize jwt.StandardClaims fields (anonymous struct) 93 // initialize jwt.StandardClaims fields (anonymous struct)
94 claims.IssuedAt = t0 94 claims.IssuedAt = t0
95 claims.ExpiresAt = t1 95 claims.ExpiresAt = t1
96 claims.Issuer = appName 96 claims.Issuer = _issuer
97 97
98 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 98 jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
99 token, err := jwtToken.SignedString([]byte(secret)) 99 token, err := jwtToken.SignedString([]byte(_secret))
100 if err != nil { 100 if err != nil {
101 return TokenClaims{}, err 101 return TokenClaims{}, err
102 } 102 }
103 claims.Token = token 103 claims.Token = token
104 return claims, nil 104 return claims, nil
105 } 105 }
106 106
107 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date. 107 // RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date.
108 // It returns an error if it fails. 108 // It returns an error if it fails.
109 func RefreshAuthToken(tok string) (TokenClaims, error) { 109 func RefreshAuthToken(tok string) (TokenClaims, error) {
110 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc) 110 token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc)
111 if err != nil { 111 if err != nil {
112 if validation, ok := err.(*jwt.ValidationError); ok { 112 if validation, ok := err.(*jwt.ValidationError); ok {
113 // don't return error if token is expired 113 // don't return error if token is expired
114 // just extend it 114 // just extend it
115 if !(validation.Errors&jwt.ValidationErrorExpired != 0) { 115 if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
116 return TokenClaims{}, err 116 return TokenClaims{}, err
117 } 117 }
118 } else { 118 } else {
119 return TokenClaims{}, err 119 return TokenClaims{}, err
120 } 120 }
121 } 121 }
122 122
123 // type assertion 123 // type assertion
124 claims, ok := token.Claims.(*TokenClaims) 124 claims, ok := token.Claims.(*TokenClaims)
125 if !ok { 125 if !ok {
126 return TokenClaims{}, errors.New("token is not valid") 126 return TokenClaims{}, errors.New("token is not valid")
127 } 127 }
128 128
129 // extend token expiration date 129 // extend token expiration date
130 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID) 130 return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID)
131 } 131 }
132 132
133 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) { 133 func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) {
134 // validate token and check expiration date 134 // validate token and check expiration date
135 claims, err := GetTokenClaims(req) 135 claims, err := GetTokenClaims(req)
136 if err != nil { 136 if err != nil {
137 return claims, err 137 return claims, err
138 } 138 }
139 139
140 if roles == "" { 140 if roles == "" {
141 return claims, nil 141 return claims, nil
142 } 142 }
143 143
144 // check if token has expired 144 // check if token has expired
145 if claims.ExpiresAt < (time.Now()).Unix() { 145 if claims.ExpiresAt < (time.Now()).Unix() {
146 return claims, errors.New("token has expired") 146 return claims, errors.New("token has expired")
147 } 147 }
148 148
149 if roles == "*" { 149 if roles == "*" {
150 return claims, nil 150 return claims, nil
151 } 151 }
152 152
153 parts := strings.Split(roles, ",") 153 parts := strings.Split(roles, ",")
154 for i, _ := range parts { 154 for i, _ := range parts {
155 r := strings.Trim(parts[i], " ") 155 r := strings.Trim(parts[i], " ")
156 if claims.RoleName == r { 156 if claims.RoleName == r {
157 return claims, nil 157 return claims, nil
158 } 158 }
159 } 159 }
160 160
161 return claims, nil 161 return claims, nil
162 } 162 }
163 163
164 // GetTokenClaims extracts JWT claims from Authorization header of req. 164 // GetTokenClaims extracts JWT claims from Authorization header of req.
165 // Returns token claims or an error. 165 // Returns token claims or an error.
166 func GetTokenClaims(req *http.Request) (*TokenClaims, error) { 166 func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
167 // check for and strip 'Bearer' prefix 167 // check for and strip 'Bearer' prefix
168 var tokstr string 168 var tokstr string
169 authHead := req.Header.Get("Authorization") 169 authHead := req.Header.Get("Authorization")
170 if ok := strings.HasPrefix(authHead, "Bearer "); ok { 170 if ok := strings.HasPrefix(authHead, "Bearer "); ok {
171 tokstr = strings.TrimPrefix(authHead, "Bearer ") 171 tokstr = strings.TrimPrefix(authHead, "Bearer ")
172 } else { 172 } else {
173 return &TokenClaims{}, errors.New("authorization header in incomplete") 173 return &TokenClaims{}, errors.New("authorization header in incomplete")
174 } 174 }
175 175
176 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) 176 token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
177 if err != nil { 177 if err != nil {
178 return &TokenClaims{}, err 178 return &TokenClaims{}, err
179 } 179 }
180 180
181 // type assertion 181 // type assertion
182 claims, ok := token.Claims.(*TokenClaims) 182 claims, ok := token.Claims.(*TokenClaims)
183 if !ok || !token.Valid { 183 if !ok || !token.Valid {
184 return &TokenClaims{}, errors.New("token is not valid") 184 return &TokenClaims{}, errors.New("token is not valid")
185 } 185 }
186 186
187 return claims, nil 187 return claims, nil
188 } 188 }
189 189
190 // randomSalt returns a string of 32 random characters. 190 // randomSalt returns a string of 32 random characters.
191 const saltSize = 32 191 const saltSize = 32
192 192
193 func randomSalt() (s string, err error) { 193 func randomSalt() (s string, err error) {
194 rawsalt := make([]byte, saltSize) 194 rawsalt := make([]byte, saltSize)
195 195
196 _, err = rand.Read(rawsalt) 196 _, err = rand.Read(rawsalt)
197 if err != nil { 197 if err != nil {
198 return "", err 198 return "", err
199 } 199 }
200 200
201 s = hex.EncodeToString(rawsalt) 201 s = hex.EncodeToString(rawsalt)
202 return s, nil 202 return s, nil
203 } 203 }
204 204
205 // secretFunc returns byte slice of API secret keyword. 205 // secretFunc returns byte slice of API secret keyword.
206 func secretFunc(token *jwt.Token) (interface{}, error) { 206 func secretFunc(token *jwt.Token) (interface{}, error) {
207 return []byte(secret), nil 207 return []byte(_secret), nil
208 } 208 }
209 209