xsrftoken: add custom timeout support for valid func
Added new function 'ValidFor' with custom token timeout support.
Function 'Valid' will use default token timeout.
Fixes golang/go#41438
Change-Id: I5cf0388aeed7ca34edcb0d3493c3e79c8ce19938
GitHub-Last-Rev: 3e3b5817964aebf5b804ceec6694ed500f439c1e
GitHub-Pull-Request: golang/net#86
Reviewed-on: https://go-review.googlesource.com/c/net/+/260317
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Trust: Filippo Valsorda <filippo@golang.org>
diff --git a/xsrftoken/xsrf.go b/xsrftoken/xsrf.go
index 4f66adf..3ca5d5b 100644
--- a/xsrftoken/xsrf.go
+++ b/xsrftoken/xsrf.go
@@ -54,12 +54,19 @@
}
// Valid reports whether a token is a valid, unexpired token returned by Generate.
+// The token is considered to be expired and invalid if it is older than the default Timeout.
func Valid(token, key, userID, actionID string) bool {
- return validTokenAtTime(token, key, userID, actionID, time.Now())
+ return validTokenAtTime(token, key, userID, actionID, time.Now(), Timeout)
+}
+
+// ValidFor reports whether a token is a valid, unexpired token returned by Generate.
+// The token is considered to be expired and invalid if it is older than the timeout duration.
+func ValidFor(token, key, userID, actionID string, timeout time.Duration) bool {
+ return validTokenAtTime(token, key, userID, actionID, time.Now(), timeout)
}
// validTokenAtTime reports whether a token is valid at the given time.
-func validTokenAtTime(token, key, userID, actionID string, now time.Time) bool {
+func validTokenAtTime(token, key, userID, actionID string, now time.Time, timeout time.Duration) bool {
if len(key) == 0 {
panic("zero length xsrf secret key")
}
@@ -75,7 +82,7 @@
issueTime := time.Unix(0, millis*1e6)
// Check that the token is not expired.
- if now.Sub(issueTime) >= Timeout {
+ if now.Sub(issueTime) >= timeout {
return false
}
diff --git a/xsrftoken/xsrf_test.go b/xsrftoken/xsrf_test.go
index fc0a48a..60ff84a 100644
--- a/xsrftoken/xsrf_test.go
+++ b/xsrftoken/xsrf_test.go
@@ -23,13 +23,22 @@
func TestValidToken(t *testing.T) {
tok := generateTokenAtTime(key, userID, actionID, now)
- if !validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow) {
+ if !validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow, Timeout) {
t.Error("One second later: Expected token to be valid")
}
- if !validTokenAtTime(tok, key, userID, actionID, now.Add(Timeout-1*time.Nanosecond)) {
+ if !validTokenAtTime(tok, key, userID, actionID, now.Add(Timeout-1*time.Nanosecond), Timeout) {
t.Error("Just before timeout: Expected token to be valid")
}
- if !validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute+1*time.Millisecond)) {
+ if !validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute+1*time.Millisecond), Timeout) {
+ t.Error("One minute in the past: Expected token to be valid")
+ }
+ if !validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow, time.Hour) {
+ t.Error("One second later: Expected token to be valid")
+ }
+ if !validTokenAtTime(tok, key, userID, actionID, now.Add(time.Minute-1*time.Nanosecond), time.Minute) {
+ t.Error("Just before timeout: Expected token to be valid")
+ }
+ if !validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute+1*time.Millisecond), time.Hour) {
t.Error("One minute in the past: Expected token to be valid")
}
}
@@ -69,17 +78,19 @@
invalidTokenTests := []struct {
name, key, userID, actionID string
t time.Time
+ timeout time.Duration
}{
- {"Bad key", "foobar", userID, actionID, oneMinuteFromNow},
- {"Bad userID", key, "foobar", actionID, oneMinuteFromNow},
- {"Bad actionID", key, userID, "foobar", oneMinuteFromNow},
- {"Expired", key, userID, actionID, now.Add(Timeout + 1*time.Millisecond)},
- {"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute)},
+ {"Bad key", "foobar", userID, actionID, oneMinuteFromNow, Timeout},
+ {"Bad userID", key, "foobar", actionID, oneMinuteFromNow, Timeout},
+ {"Bad actionID", key, userID, "foobar", oneMinuteFromNow, Timeout},
+ {"Expired", key, userID, actionID, now.Add(Timeout + 1*time.Millisecond), Timeout},
+ {"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute), Timeout},
+ {"Expired with 1 minute timeout", key, userID, actionID, now.Add(time.Minute + 1*time.Millisecond), time.Minute},
}
tok := generateTokenAtTime(key, userID, actionID, now)
for _, itt := range invalidTokenTests {
- if validTokenAtTime(tok, itt.key, itt.userID, itt.actionID, itt.t) {
+ if validTokenAtTime(tok, itt.key, itt.userID, itt.actionID, itt.t, itt.timeout) {
t.Errorf("%v: Expected token to be invalid", itt.name)
}
}
@@ -98,7 +109,7 @@
}
for _, bdt := range badDataTests {
- if validTokenAtTime(bdt.tok, key, userID, actionID, oneMinuteFromNow) {
+ if validTokenAtTime(bdt.tok, key, userID, actionID, oneMinuteFromNow, Timeout) {
t.Errorf("%v: Expected token to be invalid", bdt.name)
}
}