auth.go 5.65 KB
package webutility

import (
	"crypto/rand"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"net/http"
	"strings"
	"time"

	"github.com/dgrijalva/jwt-go"
)

var _issuer = "webutility"
var _secret = "webutility"

// TokenClaims are JWT token claims.
type TokenClaims struct {
	// extending a struct
	jwt.StandardClaims

	// custom claims
	Token     string `json:"access_token"`
	TokenType string `json:"token_type"`
	Username  string `json:"username"`
	RoleName  string `json:"role"`
	RoleID    int64  `json:"role_id"`
	ExpiresIn int64  `json:"expires_in"`
}

// InitJWT ...
func InitJWT(issuer, secret string) {
	_issuer = issuer
	_secret = secret
}

// ValidateHash hashes pass and salt and returns comparison result with resultHash
func ValidateHash(pass, salt, resultHash string) (bool, error) {
	hash, _, err := CreateHash(pass, salt)
	if err != nil {
		return false, err
	}
	res := hash == resultHash
	return res, nil
}

// CreateHash hashes str using SHA256.
// If the presalt parameter is not provided CreateHash will generate new salt string.
// Returns hash and salt strings or an error if it fails.
func CreateHash(str, presalt string) (hash, salt string, err error) {
	// chech if message is presalted
	if presalt == "" {
		salt, err = randomSalt()
		if err != nil {
			return "", "", err
		}
	} else {
		salt = presalt
	}

	// convert strings to raw byte slices
	rawstr := []byte(str)
	rawsalt, err := hex.DecodeString(salt)
	if err != nil {
		return "", "", err
	}

	rawdata := make([]byte, len(rawstr)+len(rawsalt))
	rawdata = append(rawdata, rawstr...)
	rawdata = append(rawdata, rawsalt...)

	// hash message + salt
	hasher := sha256.New()
	hasher.Write(rawdata)
	rawhash := hasher.Sum(nil)

	hash = hex.EncodeToString(rawhash)
	return hash, salt, nil
}

// CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
// It returns an error if it fails.
func CreateAuthToken(username string, roleName string, roleID int64) (TokenClaims, error) {
	t0 := (time.Now()).Unix()
	t1 := (time.Now().Add(time.Hour * 24 * 7)).Unix()
	claims := TokenClaims{
		TokenType: "Bearer",
		Username:  username,
		RoleName:  roleName,
		RoleID:    roleID,
		ExpiresIn: t1 - t0,
	}
	// initialize jwt.StandardClaims fields (anonymous struct)
	claims.IssuedAt = t0
	claims.ExpiresAt = t1
	claims.Issuer = _issuer

	jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	token, err := jwtToken.SignedString([]byte(_secret))
	if err != nil {
		return TokenClaims{}, err
	}
	claims.Token = token
	return claims, nil
}

// RefreshAuthToken returns new JWT token with same claims contained in tok but with prolonged expiration date.
// It returns an error if it fails.
func RefreshAuthToken(tok string) (TokenClaims, error) {
	token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc)
	if err != nil {
		if validation, ok := err.(*jwt.ValidationError); ok {
			// don't return error if token is expired, just extend it
			if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
				return TokenClaims{}, err
			}
		} else {
			return TokenClaims{}, err
		}
	}

	// type assertion
	claims, ok := token.Claims.(*TokenClaims)
	if !ok {
		return TokenClaims{}, errors.New("token is not valid")
	}

	// extend token expiration date
	return CreateAuthToken(claims.Username, claims.RoleName, claims.RoleID)
}

// AuthCheck ...
func AuthCheck(req *http.Request, roles string) (*TokenClaims, error) {
	// validate token and check expiration date
	claims, err := GetTokenClaims(req)
	if err != nil {
		return claims, err
	}

	if roles == "" {
		return claims, nil
	}

	// check if token has expired
	if claims.ExpiresAt < (time.Now()).Unix() {
		return claims, errors.New("token has expired")
	}

	if roles == "*" {
		return claims, nil
	}

	parts := strings.Split(roles, ",")
	for i := range parts {
		r := strings.Trim(parts[i], " ")
		if claims.RoleName == r {
			return claims, nil
		}
	}

	return claims, errors.New("unauthorized role access")
}

// GetTokenClaims extracts JWT claims from Authorization header of req.
// Returns token claims or an error.
func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
	// check for and strip 'Bearer' prefix
	var tokstr string
	authHead := req.Header.Get("Authorization")
	if ok := strings.HasPrefix(authHead, "Bearer "); ok {
		tokstr = strings.TrimPrefix(authHead, "Bearer ")
	} else {
		return &TokenClaims{}, errors.New("authorization header is incomplete")
	}

	token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
	if err != nil {
		return &TokenClaims{}, err
	}

	// type assertion
	claims, ok := token.Claims.(*TokenClaims)
	if !ok || !token.Valid {
		return &TokenClaims{}, errors.New("token is not valid")
	}

	return claims, nil
}

func DecodeJWT(secret, token string) (*TokenClaims, error) {
	secretfunc := func(*jwt.Token) (interface{}, error) {
		return []byte(secret), nil
	}

	tok, err := jwt.ParseWithClaims(token, &TokenClaims{}, secretfunc)
	if err != nil {
		if validation, ok := err.(*jwt.ValidationError); ok {
			// don't return error if token is expired
			if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
				return nil, err
			}
		} else {
			return nil, err
		}
	}

	// type assertion
	claims, ok := tok.Claims.(*TokenClaims)
	if !ok {
		return &TokenClaims{}, errors.New("token is not valid")
	}

	return claims, nil
}

// randomSalt returns a string of 32 random characters.
func randomSalt() (s string, err error) {
	const saltSize = 32

	rawsalt := make([]byte, saltSize)

	_, err = rand.Read(rawsalt)
	if err != nil {
		return "", err
	}

	s = hex.EncodeToString(rawsalt)
	return s, nil
}

// secretFunc returns byte slice of API secret keyword.
func secretFunc(token *jwt.Token) (interface{}, error) {
	return []byte(_secret), nil
}