diff --git a/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/ScopePredicate.java b/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/ScopePredicate.java new file mode 100644 index 000000000..ca55e43da --- /dev/null +++ b/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/ScopePredicate.java @@ -0,0 +1,10 @@ +package org.hswebframework.web.oauth2.server; + +import java.util.function.BiPredicate; + +@FunctionalInterface +public interface ScopePredicate extends BiPredicate { + + boolean test(String permission, String... actions); + +} diff --git a/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/code/DefaultAuthorizationCodeGranter.java b/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/code/DefaultAuthorizationCodeGranter.java index cd21aba29..a74c13be4 100644 --- a/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/code/DefaultAuthorizationCodeGranter.java +++ b/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/code/DefaultAuthorizationCodeGranter.java @@ -2,24 +2,22 @@ package org.hswebframework.web.oauth2.server.code; import lombok.AllArgsConstructor; import org.hswebframework.web.authorization.Authentication; -import org.hswebframework.web.authorization.Permission; import org.hswebframework.web.id.IDGenerator; import org.hswebframework.web.oauth2.ErrorType; import org.hswebframework.web.oauth2.OAuth2Exception; import org.hswebframework.web.oauth2.server.AccessToken; import org.hswebframework.web.oauth2.server.AccessTokenManager; import org.hswebframework.web.oauth2.server.OAuth2Client; +import org.hswebframework.web.oauth2.server.ScopePredicate; +import org.hswebframework.web.oauth2.server.utils.OAuth2ScopeUtils; import org.springframework.data.redis.connection.ReactiveRedisConnectionFactory; import org.springframework.data.redis.core.ReactiveRedisOperations; import org.springframework.data.redis.core.ReactiveRedisTemplate; import org.springframework.data.redis.serializer.RedisSerializationContext; import org.springframework.data.redis.serializer.RedisSerializer; -import org.springframework.util.StringUtils; import reactor.core.publisher.Mono; import java.time.Duration; -import java.util.*; -import java.util.function.BiPredicate; @AllArgsConstructor public class DefaultAuthorizationCodeGranter implements AuthorizationCodeGranter { @@ -49,9 +47,10 @@ public class DefaultAuthorizationCodeGranter implements AuthorizationCodeGranter request.getParameter("scope").map(String::valueOf).ifPresent(codeCache::setScope); codeCache.setCode(code); codeCache.setClientId(client.getClientId()); - codeCache.setAuthentication(authentication.copy(createPredicate(codeCache.getScope()), dimension -> true)); + ScopePredicate permissionPredicate = OAuth2ScopeUtils.createScopePredicate(codeCache.getScope()); + + codeCache.setAuthentication(authentication.copy((permission, action) -> permissionPredicate.test(permission.getId(), action), dimension -> true)); - createPredicate(codeCache.getScope()); return redis .opsForValue() @@ -59,24 +58,6 @@ public class DefaultAuthorizationCodeGranter implements AuthorizationCodeGranter .thenReturn(new AuthorizationCodeResponse(code)); } - static BiPredicate createPredicate(String scopeStr) { - if (StringUtils.isEmpty(scopeStr)) { - return ((permission, s) -> false); - } - String[] scopes = scopeStr.split("[ ,\n]"); - Map> actions = new HashMap<>(); - for (String scope : scopes) { - String[] permissions = scope.split("[:]"); - String per = permissions[0]; - Set acts = actions.computeIfAbsent(per, k -> new HashSet<>()); - acts.addAll(Arrays.asList(permissions).subList(1, permissions.length)); - } - - return ((permission, action) -> Optional - .ofNullable(actions.get(permission.getId())) - .map(acts -> acts.contains(action)) - .orElse(false)); - } private String getRedisKey(String code) { return "oauth2-code:" + code; diff --git a/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/utils/OAuth2ScopeUtils.java b/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/utils/OAuth2ScopeUtils.java new file mode 100644 index 000000000..4ac30fff0 --- /dev/null +++ b/hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/utils/OAuth2ScopeUtils.java @@ -0,0 +1,32 @@ +package org.hswebframework.web.oauth2.server.utils; + +import org.hswebframework.web.oauth2.server.ScopePredicate; +import org.springframework.util.StringUtils; + +import java.util.*; + +/** + * @author zhouhao + * @since 4.0.8 + */ +public class OAuth2ScopeUtils { + + public static ScopePredicate createScopePredicate(String scopeStr) { + if (StringUtils.isEmpty(scopeStr)) { + return ((permission, action) -> false); + } + String[] scopes = scopeStr.split("[ ,\n]"); + Map> actions = new HashMap<>(); + for (String scope : scopes) { + String[] permissions = scope.split("[:]"); + String per = permissions[0]; + Set acts = actions.computeIfAbsent(per, k -> new HashSet<>()); + acts.addAll(Arrays.asList(permissions).subList(1, permissions.length)); + } + + return ((permission, action) -> Optional + .ofNullable(actions.get(permission)) + .map(acts -> action.length == 0 || acts.containsAll(Arrays.asList(action))) + .orElse(false)); + } +} diff --git a/hsweb-authorization/hsweb-authorization-oauth2/src/test/java/org/hswebframework/web/oauth2/server/code/DefaultAuthorizationCodeGranterTest.java b/hsweb-authorization/hsweb-authorization-oauth2/src/test/java/org/hswebframework/web/oauth2/server/code/DefaultAuthorizationCodeGranterTest.java index 6fe99d62e..4f9fdaacc 100644 --- a/hsweb-authorization/hsweb-authorization-oauth2/src/test/java/org/hswebframework/web/oauth2/server/code/DefaultAuthorizationCodeGranterTest.java +++ b/hsweb-authorization/hsweb-authorization-oauth2/src/test/java/org/hswebframework/web/oauth2/server/code/DefaultAuthorizationCodeGranterTest.java @@ -16,33 +16,6 @@ import static org.junit.Assert.*; public class DefaultAuthorizationCodeGranterTest { - - @Test - public void testPermission() { - BiPredicate predicate = DefaultAuthorizationCodeGranter.createPredicate("user:info device:query"); - - { - SimplePermission permission=new SimplePermission(); - permission.setId("user"); - permission.setActions(Collections.singleton("info")); - - - assertTrue(predicate.test(permission,"info")); - assertFalse(predicate.test(permission,"info2")); - } - - { - SimplePermission permission=new SimplePermission(); - permission.setId("device"); - permission.setActions(Collections.singleton("query")); - - - assertTrue(predicate.test(permission,"query")); - assertFalse(predicate.test(permission,"query2")); - } - - } - @Test public void testRequestToken() { diff --git a/hsweb-authorization/hsweb-authorization-oauth2/src/test/java/org/hswebframework/web/oauth2/server/utils/OAuth2ScopeUtilsTest.java b/hsweb-authorization/hsweb-authorization-oauth2/src/test/java/org/hswebframework/web/oauth2/server/utils/OAuth2ScopeUtilsTest.java new file mode 100644 index 000000000..83cb9787a --- /dev/null +++ b/hsweb-authorization/hsweb-authorization-oauth2/src/test/java/org/hswebframework/web/oauth2/server/utils/OAuth2ScopeUtilsTest.java @@ -0,0 +1,35 @@ +package org.hswebframework.web.oauth2.server.utils; + +import org.hswebframework.web.oauth2.server.ScopePredicate; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class OAuth2ScopeUtilsTest { + + + @Test + public void testEmpty() { + ScopePredicate predicate = OAuth2ScopeUtils.createScopePredicate(null); + assertFalse(predicate.test("basic")); + } + + @Test + public void testScope() { + ScopePredicate predicate = OAuth2ScopeUtils.createScopePredicate("basic user:info device:query"); + + assertTrue(predicate.test("basic")); + { + + assertTrue(predicate.test("user", "info")); + assertFalse(predicate.test("user", "info2")); + } + + { + assertTrue(predicate.test("device", "query")); + assertFalse(predicate.test("device", "query2")); + } + + } +} \ No newline at end of file