diff --git a/hsweb-authorization/hsweb-authorization-api/src/main/java/org/hswebframework/web/authorization/AuthenticationHolder.java b/hsweb-authorization/hsweb-authorization-api/src/main/java/org/hswebframework/web/authorization/AuthenticationHolder.java index ece7aac99..b70f99a16 100644 --- a/hsweb-authorization/hsweb-authorization-api/src/main/java/org/hswebframework/web/authorization/AuthenticationHolder.java +++ b/hsweb-authorization/hsweb-authorization-api/src/main/java/org/hswebframework/web/authorization/AuthenticationHolder.java @@ -20,6 +20,13 @@ package org.hswebframework.web.authorization; import org.hswebframework.web.ThreadLocalUtils; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Function; + /** * 权限获取器,用于静态方式获取当前登录用户的权限信息. * 例如: @@ -36,22 +43,33 @@ import org.hswebframework.web.ThreadLocalUtils; * @since 3.0 */ public final class AuthenticationHolder { - private static AuthenticationSupplier supplier; + private static final List suppliers = new ArrayList<>(); - public static final String CURRENT_USER_ID_KEY = Authentication.class.getName() + "_current_id"; + private static final String CURRENT_USER_ID_KEY = Authentication.class.getName() + "_current_id"; + + private static final ReadWriteLock lock = new ReentrantReadWriteLock(); + + private static Authentication get(Function function) { + lock.readLock().lock(); + try { + return suppliers.stream() + .map(function) + .filter(Objects::nonNull) + .findFirst().orElse(null); + } finally { + lock.readLock().unlock(); + } + } /** * @return 当前登录的用户权限信息 */ public static Authentication get() { - if (null == supplier) { - throw new UnsupportedOperationException("supplier is null!"); - } String currentId = ThreadLocalUtils.get(CURRENT_USER_ID_KEY); if (currentId != null) { - return supplier.get(currentId); + return get(currentId); } - return supplier.get(); + return get(AuthenticationSupplier::get); } /** @@ -61,10 +79,7 @@ public final class AuthenticationHolder { * @return 权限信息 */ public static Authentication get(String userId) { - if (null == supplier) { - throw new UnsupportedOperationException("supplier is null!"); - } - return supplier.get(userId); + return get(supplier -> supplier.get(userId)); } /** @@ -72,9 +87,13 @@ public final class AuthenticationHolder { * * @param supplier */ - public static void setSupplier(AuthenticationSupplier supplier) { - if (null == AuthenticationHolder.supplier) - AuthenticationHolder.supplier = supplier; + public static void addSupplier(AuthenticationSupplier supplier) { + lock.writeLock().lock(); + try { + suppliers.add(supplier); + } finally { + lock.writeLock().unlock(); + } } public static void setCureentUserId(String id) {