refactor: change forward agent group association from single to multiple

- Change ForwardAgent.GroupID (single) to GroupIDs (array) to support
  multi-group membership
- Add migration script 047 to convert group_id column to group_ids array
- Update repository layer with array-based queries using ANY operator
- Update ResourceGroupRepository to handle array-based forward agent
  group associations
- Update all related use cases and DTOs to work with group arrays
- Fix payment repository to use explicit column selection
- Update handlers and tests to reflect new multi-group structure
This commit is contained in:
orris-inc
2026-02-03 14:01:26 +08:00
parent a645917430
commit c4dc968301
28 changed files with 789 additions and 351 deletions

View File

@@ -16,9 +16,9 @@ type ForwardAgentDTO struct {
TunnelAddress string `json:"tunnel_address,omitempty"` // IP or hostname only (no port), configure if agent may serve as relay/exit in any rule
Status string `json:"status"`
Remark string `json:"remark"`
GroupSID *string `json:"group_id,omitempty"` // Resource group SID this agent belongs to
AgentVersion string `json:"agent_version"` // Agent software version (e.g., "1.2.3"), extracted from system_status for easy display
HasUpdate bool `json:"has_update"` // True if a newer version is available
GroupSIDs []string `json:"group_sids,omitempty"` // Resource group SIDs this agent belongs to
AgentVersion string `json:"agent_version"` // Agent software version (e.g., "1.2.3"), extracted from system_status for easy display
HasUpdate bool `json:"has_update"` // True if a newer version is available
AllowedPortRange string `json:"allowed_port_range,omitempty"`
BlockedProtocols []string `json:"blocked_protocols,omitempty"` // Protocols blocked by this agent
SortOrder int `json:"sort_order"` // Custom sort order for UI display
@@ -28,11 +28,12 @@ type ForwardAgentDTO struct {
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
SystemStatus *AgentStatusDTO `json:"system_status,omitempty"`
internalGroupIDs []uint `json:"-"` // internal resource group IDs for lookup
}
// ToForwardAgentDTO converts a domain forward agent to DTO.
// groupInfo is optional, can be nil if group information is not available.
func ToForwardAgentDTO(agent *forward.ForwardAgent, groupInfo *GroupInfo) *ForwardAgentDTO {
func ToForwardAgentDTO(agent *forward.ForwardAgent) *ForwardAgentDTO {
if agent == nil {
return nil
}
@@ -58,25 +59,52 @@ func ToForwardAgentDTO(agent *forward.ForwardAgent, groupInfo *GroupInfo) *Forwa
LastSeenAt: agent.LastSeenAt(),
CreatedAt: agent.CreatedAt().Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: agent.UpdatedAt().Format("2006-01-02T15:04:05Z07:00"),
}
if groupInfo != nil {
dto.GroupSID = &groupInfo.SID
internalGroupIDs: agent.GroupIDs(),
}
return dto
}
// ToForwardAgentDTOs converts a slice of domain forward agents to DTOs.
// groupInfoMap is optional, can be nil if group information is not available.
func ToForwardAgentDTOs(agents []*forward.ForwardAgent, groupInfoMap GroupInfoMap) []*ForwardAgentDTO {
func ToForwardAgentDTOs(agents []*forward.ForwardAgent) []*ForwardAgentDTO {
dtos := make([]*ForwardAgentDTO, len(agents))
for i, agent := range agents {
var groupInfo *GroupInfo
if groupInfoMap != nil && agent.GroupID() != nil {
groupInfo = groupInfoMap[*agent.GroupID()]
}
dtos[i] = ToForwardAgentDTO(agent, groupInfo)
dtos[i] = ToForwardAgentDTO(agent)
}
return dtos
}
// PopulateGroupSIDs fills in the group SIDs field using the SID map.
func (d *ForwardAgentDTO) PopulateGroupSIDs(groupMap GroupSIDMap) {
if len(d.internalGroupIDs) == 0 {
return
}
d.GroupSIDs = make([]string, 0, len(d.internalGroupIDs))
for _, groupID := range d.internalGroupIDs {
if sid, ok := groupMap[groupID]; ok && sid != "" {
d.GroupSIDs = append(d.GroupSIDs, sid)
}
}
}
// InternalGroupIDs returns the internal resource group IDs for repository lookups.
func (d *ForwardAgentDTO) InternalGroupIDs() []uint {
return d.internalGroupIDs
}
// CollectAgentGroupIDs collects unique resource group IDs from agent DTOs for batch lookup.
func CollectAgentGroupIDs(dtos []*ForwardAgentDTO) []uint {
idSet := make(map[uint]struct{})
for _, dto := range dtos {
for _, groupID := range dto.internalGroupIDs {
if groupID != 0 {
idSet[groupID] = struct{}{}
}
}
}
ids := make([]uint, 0, len(idSet))
for id := range idSet {
ids = append(ids, id)
}
return ids
}

View File

@@ -10,13 +10,20 @@ import (
// UserForwardAgentDTO represents a forward agent from user's perspective.
// This DTO hides sensitive fields like token_hash and api_token.
type UserForwardAgentDTO struct {
ID string `json:"id"` // Stripe-style prefixed ID (e.g., "fa_xK9mP2vL3nQ")
Name string `json:"name"` // Agent name
PublicAddress string `json:"public_address,omitempty"` // Public address for client connections
Status string `json:"status"` // enabled or disabled
GroupSID string `json:"group_id,omitempty"` // Resource group SID (e.g., "rg_xK9mP2vL3nQ")
GroupName string `json:"group_name,omitempty"` // Resource group name for display
CreatedAt time.Time `json:"created_at"`
ID string `json:"id"` // Stripe-style prefixed ID (e.g., "fa_xK9mP2vL3nQ")
Name string `json:"name"` // Agent name
PublicAddress string `json:"public_address,omitempty"` // Public address for client connections
Status string `json:"status"` // enabled or disabled
Groups []UserGroupInfoDTO `json:"groups,omitempty"` // Resource groups this agent belongs to
CreatedAt time.Time `json:"created_at"`
internalGroupIDs []uint `json:"-"` // internal resource group IDs for lookup
}
// UserGroupInfoDTO represents group info for user-facing DTOs.
type UserGroupInfoDTO struct {
SID string `json:"id"` // Resource group SID (e.g., "rg_xK9mP2vL3nQ")
Name string `json:"name"` // Resource group name for display
}
// GroupInfo holds resource group information for populating DTOs.
@@ -29,36 +36,66 @@ type GroupInfo struct {
type GroupInfoMap map[uint]*GroupInfo
// ToUserForwardAgentDTO converts a domain forward agent to user-facing DTO.
func ToUserForwardAgentDTO(agent *forward.ForwardAgent, groupInfo *GroupInfo) *UserForwardAgentDTO {
func ToUserForwardAgentDTO(agent *forward.ForwardAgent) *UserForwardAgentDTO {
if agent == nil {
return nil
}
dto := &UserForwardAgentDTO{
ID: agent.SID(),
Name: agent.Name(),
PublicAddress: agent.PublicAddress(),
Status: string(agent.Status()),
CreatedAt: agent.CreatedAt(),
}
if groupInfo != nil {
dto.GroupSID = groupInfo.SID
dto.GroupName = groupInfo.Name
ID: agent.SID(),
Name: agent.Name(),
PublicAddress: agent.PublicAddress(),
Status: string(agent.Status()),
CreatedAt: agent.CreatedAt(),
internalGroupIDs: agent.GroupIDs(),
}
return dto
}
// ToUserForwardAgentDTOs converts a slice of domain forward agents to user-facing DTOs.
func ToUserForwardAgentDTOs(agents []*forward.ForwardAgent, groupInfoMap GroupInfoMap) []*UserForwardAgentDTO {
func ToUserForwardAgentDTOs(agents []*forward.ForwardAgent) []*UserForwardAgentDTO {
dtos := make([]*UserForwardAgentDTO, len(agents))
for i, agent := range agents {
var groupInfo *GroupInfo
if agent.GroupID() != nil {
groupInfo = groupInfoMap[*agent.GroupID()]
}
dtos[i] = ToUserForwardAgentDTO(agent, groupInfo)
dtos[i] = ToUserForwardAgentDTO(agent)
}
return dtos
}
// PopulateGroups fills in the groups field using the group info map.
func (d *UserForwardAgentDTO) PopulateGroups(groupInfoMap GroupInfoMap) {
if len(d.internalGroupIDs) == 0 {
return
}
d.Groups = make([]UserGroupInfoDTO, 0, len(d.internalGroupIDs))
for _, groupID := range d.internalGroupIDs {
if info, ok := groupInfoMap[groupID]; ok && info != nil {
d.Groups = append(d.Groups, UserGroupInfoDTO{
SID: info.SID,
Name: info.Name,
})
}
}
}
// InternalGroupIDs returns the internal resource group IDs for repository lookups.
func (d *UserForwardAgentDTO) InternalGroupIDs() []uint {
return d.internalGroupIDs
}
// CollectUserAgentGroupIDs collects unique resource group IDs from user agent DTOs for batch lookup.
func CollectUserAgentGroupIDs(dtos []*UserForwardAgentDTO) []uint {
idSet := make(map[uint]struct{})
for _, dto := range dtos {
for _, groupID := range dto.internalGroupIDs {
if groupID != 0 {
idSet[groupID] = struct{}{}
}
}
}
ids := make([]uint, 0, len(idSet))
for id := range idSet {
ids = append(ids, id)
}
return ids
}

View File

@@ -18,7 +18,7 @@ type CreateForwardAgentCommand struct {
PublicAddress string
TunnelAddress string
Remark string
GroupSID string // Resource group SID to associate with (empty means no association)
GroupSIDs []string // Resource group SIDs to associate with (empty means no association)
AllowedPortRange string // Port range string (e.g., "80,443,8000-9000"), empty means all ports allowed
BlockedProtocols []string // Protocols to block (e.g., ["socks5", "http_connect"]), empty means no blocking
SortOrder *int // Custom sort order for UI display (nil: use default 0, non-nil: set explicitly)
@@ -124,18 +124,46 @@ func (uc *CreateForwardAgentUseCase) Execute(ctx context.Context, cmd CreateForw
agent.UpdateSortOrder(*cmd.SortOrder)
}
// Handle GroupSID (resolve SID to internal ID)
if cmd.GroupSID != "" {
group, err := uc.resourceGroupRepo.GetBySID(ctx, cmd.GroupSID)
if err != nil {
uc.logger.Errorw("failed to get resource group by SID", "group_sid", cmd.GroupSID, "error", err)
return nil, errors.NewNotFoundError("resource group", cmd.GroupSID)
// Handle GroupSIDs (resolve SIDs to internal IDs)
// Limit to 10 groups to prevent DoS attacks
const maxGroupSIDs = 10
if len(cmd.GroupSIDs) > maxGroupSIDs {
return nil, errors.NewValidationError(fmt.Sprintf("too many group_sids, maximum allowed is %d", maxGroupSIDs))
}
if len(cmd.GroupSIDs) > 0 {
// Deduplicate and filter empty SIDs
uniqueSIDs := make([]string, 0, len(cmd.GroupSIDs))
seenSIDs := make(map[string]struct{}, len(cmd.GroupSIDs))
for _, sid := range cmd.GroupSIDs {
if sid == "" {
continue
}
if _, exists := seenSIDs[sid]; exists {
continue
}
seenSIDs[sid] = struct{}{}
uniqueSIDs = append(uniqueSIDs, sid)
}
if group == nil {
return nil, errors.NewNotFoundError("resource group", cmd.GroupSID)
if len(uniqueSIDs) > 0 {
// Batch fetch all groups to avoid N+1 queries
groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, uniqueSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get resource groups", "error", err)
return nil, fmt.Errorf("failed to get resource groups: %w", err)
}
// Resolve SIDs to internal IDs
resolvedIDs := make([]uint, 0, len(uniqueSIDs))
for _, sid := range uniqueSIDs {
group, ok := groupMap[sid]
if !ok || group == nil {
return nil, errors.NewNotFoundError("resource group", sid)
}
resolvedIDs = append(resolvedIDs, group.ID())
}
agent.SetGroupIDs(resolvedIDs)
}
groupID := group.ID()
agent.SetGroupID(&groupID)
}
// Persist

View File

@@ -143,6 +143,33 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa
cmd.ListenPort, agent.AllowedPortRange().String()))
}
// Collect all agent SIDs that need to be fetched (to avoid N+1 queries)
allAgentSIDs := make([]string, 0)
if cmd.ExitAgentShortID != "" {
allAgentSIDs = append(allAgentSIDs, cmd.ExitAgentShortID)
}
for _, input := range cmd.ExitAgents {
allAgentSIDs = append(allAgentSIDs, input.AgentSID)
}
allAgentSIDs = append(allAgentSIDs, cmd.ChainAgentShortIDs...)
for shortID := range cmd.ChainPortConfig {
allAgentSIDs = append(allAgentSIDs, shortID)
}
// Batch fetch all agents
var agentMap map[string]*forward.ForwardAgent
if len(allAgentSIDs) > 0 {
agents, err := uc.agentRepo.GetBySIDs(ctx, allAgentSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get agents", "error", err)
return nil, fmt.Errorf("failed to get agents: %w", err)
}
agentMap = make(map[string]*forward.ForwardAgent, len(agents))
for _, a := range agents {
agentMap[a.SID()] = a
}
}
// Resolve exit agent configuration (single exitAgentID OR multiple exitAgents)
var exitAgentID uint
var exitAgents []vo.AgentWeight
@@ -150,12 +177,8 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa
return nil, errors.NewValidationError("exit_agent_id and exit_agents are mutually exclusive")
}
if cmd.ExitAgentShortID != "" {
exitAgent, err := uc.agentRepo.GetBySID(ctx, cmd.ExitAgentShortID)
if err != nil {
uc.logger.Errorw("failed to get exit agent", "exit_agent_short_id", cmd.ExitAgentShortID, "error", err)
return nil, fmt.Errorf("failed to validate exit agent: %w", err)
}
if exitAgent == nil {
exitAgent, ok := agentMap[cmd.ExitAgentShortID]
if !ok || exitAgent == nil {
return nil, errors.NewNotFoundError("exit forward agent", cmd.ExitAgentShortID)
}
exitAgentID = exitAgent.ID()
@@ -163,12 +186,8 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa
// Resolve multiple exit agents with weights
exitAgents = make([]vo.AgentWeight, 0, len(cmd.ExitAgents))
for _, input := range cmd.ExitAgents {
exitAgent, err := uc.agentRepo.GetBySID(ctx, input.AgentSID)
if err != nil {
uc.logger.Errorw("failed to get exit agent", "exit_agent_sid", input.AgentSID, "error", err)
return nil, fmt.Errorf("failed to validate exit agent: %w", err)
}
if exitAgent == nil {
exitAgent, ok := agentMap[input.AgentSID]
if !ok || exitAgent == nil {
return nil, errors.NewNotFoundError("exit forward agent", input.AgentSID)
}
// Use provided weight or default
@@ -193,12 +212,8 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa
if len(cmd.ChainAgentShortIDs) > 0 {
chainAgentIDs = make([]uint, len(cmd.ChainAgentShortIDs))
for i, shortID := range cmd.ChainAgentShortIDs {
chainAgent, err := uc.agentRepo.GetBySID(ctx, shortID)
if err != nil {
uc.logger.Errorw("failed to get chain agent", "chain_agent_short_id", shortID, "error", err)
return nil, fmt.Errorf("failed to validate chain agent: %w", err)
}
if chainAgent == nil {
chainAgent, ok := agentMap[shortID]
if !ok || chainAgent == nil {
return nil, errors.NewNotFoundError("chain forward agent", shortID)
}
chainAgentIDs[i] = chainAgent.ID()
@@ -212,12 +227,8 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa
if len(cmd.ChainPortConfig) > 0 {
chainPortConfig = make(map[uint]uint16, len(cmd.ChainPortConfig))
for shortID, port := range cmd.ChainPortConfig {
chainAgent, err := uc.agentRepo.GetBySID(ctx, shortID)
if err != nil {
uc.logger.Errorw("failed to get chain agent for port config", "chain_agent_short_id", shortID, "error", err)
return nil, fmt.Errorf("failed to validate chain agent in chain_port_config: %w", err)
}
if chainAgent == nil {
chainAgent, ok := agentMap[shortID]
if !ok || chainAgent == nil {
return nil, errors.NewNotFoundError("chain forward agent in chain_port_config", shortID)
}
// Validate port against chain agent's allowed port range
@@ -299,29 +310,47 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa
// Resolve GroupSIDs to internal IDs and validate plan types (if provided)
var groupIDs []uint
if len(cmd.GroupSIDs) > 0 {
groupIDs = make([]uint, 0, len(cmd.GroupSIDs))
// Validate SID formats first
for _, groupSID := range cmd.GroupSIDs {
// Validate the SID format (rg_xxx)
if err := id.ValidatePrefix(groupSID, id.PrefixResourceGroup); err != nil {
return nil, errors.NewValidationError(fmt.Sprintf("invalid resource group ID format: %s", groupSID))
}
}
group, err := uc.resourceGroupRepo.GetBySID(ctx, groupSID)
if err != nil {
uc.logger.Errorw("failed to get resource group", "group_sid", groupSID, "error", err)
return nil, fmt.Errorf("failed to validate resource group: %w", err)
}
if group == nil {
// Batch fetch all groups to avoid N+1 queries
groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, cmd.GroupSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get resource groups", "error", err)
return nil, fmt.Errorf("failed to get resource groups: %w", err)
}
// Collect plan IDs for batch fetch
planIDs := make([]uint, 0, len(cmd.GroupSIDs))
for _, groupSID := range cmd.GroupSIDs {
group, ok := groupMap[groupSID]
if !ok || group == nil {
return nil, errors.NewNotFoundError("resource group", groupSID)
}
planIDs = append(planIDs, group.PlanID())
}
// Verify the plan type supports forward rules binding (node and hybrid only, not forward)
plan, err := uc.planRepo.GetByID(ctx, group.PlanID())
if err != nil {
uc.logger.Errorw("failed to get plan for resource group", "plan_id", group.PlanID(), "error", err)
return nil, fmt.Errorf("failed to validate resource group plan: %w", err)
}
if plan == nil {
// Batch fetch all plans
plans, err := uc.planRepo.GetByIDs(ctx, planIDs)
if err != nil {
uc.logger.Errorw("failed to batch get plans", "error", err)
return nil, fmt.Errorf("failed to get plans: %w", err)
}
planMap := make(map[uint]*subscription.Plan, len(plans))
for _, p := range plans {
planMap[p.ID()] = p
}
// Validate each group and build groupIDs
groupIDs = make([]uint, 0, len(cmd.GroupSIDs))
for _, groupSID := range cmd.GroupSIDs {
group := groupMap[groupSID]
plan, ok := planMap[group.PlanID()]
if !ok || plan == nil {
return nil, fmt.Errorf("plan not found for resource group %s", groupSID)
}
if plan.PlanType().IsForward() {
@@ -332,7 +361,6 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa
return nil, errors.NewValidationError(
fmt.Sprintf("resource group %s belongs to a forward plan and cannot bind forward rules", groupSID))
}
groupIDs = append(groupIDs, group.ID())
}
}
@@ -658,28 +686,47 @@ func (uc *CreateForwardRuleUseCase) executeExternalRule(ctx context.Context, cmd
// Resolve GroupSIDs to internal IDs and validate plan types (if provided)
var groupIDs []uint
if len(cmd.GroupSIDs) > 0 {
groupIDs = make([]uint, 0, len(cmd.GroupSIDs))
// Validate SID formats first
for _, groupSID := range cmd.GroupSIDs {
if err := id.ValidatePrefix(groupSID, id.PrefixResourceGroup); err != nil {
return nil, errors.NewValidationError(fmt.Sprintf("invalid resource group ID format: %s", groupSID))
}
}
group, err := uc.resourceGroupRepo.GetBySID(ctx, groupSID)
if err != nil {
uc.logger.Errorw("failed to get resource group", "group_sid", groupSID, "error", err)
return nil, fmt.Errorf("failed to validate resource group: %w", err)
}
if group == nil {
// Batch fetch all groups to avoid N+1 queries
groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, cmd.GroupSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get resource groups", "error", err)
return nil, fmt.Errorf("failed to get resource groups: %w", err)
}
// Collect plan IDs for batch fetch
planIDs := make([]uint, 0, len(cmd.GroupSIDs))
for _, groupSID := range cmd.GroupSIDs {
group, ok := groupMap[groupSID]
if !ok || group == nil {
return nil, errors.NewNotFoundError("resource group", groupSID)
}
planIDs = append(planIDs, group.PlanID())
}
// Verify the plan type supports forward rules binding
plan, err := uc.planRepo.GetByID(ctx, group.PlanID())
if err != nil {
uc.logger.Errorw("failed to get plan for resource group", "plan_id", group.PlanID(), "error", err)
return nil, fmt.Errorf("failed to validate resource group plan: %w", err)
}
if plan == nil {
// Batch fetch all plans
plans, err := uc.planRepo.GetByIDs(ctx, planIDs)
if err != nil {
uc.logger.Errorw("failed to batch get plans", "error", err)
return nil, fmt.Errorf("failed to get plans: %w", err)
}
planMap := make(map[uint]*subscription.Plan, len(plans))
for _, p := range plans {
planMap[p.ID()] = p
}
// Validate each group and build groupIDs
groupIDs = make([]uint, 0, len(cmd.GroupSIDs))
for _, groupSID := range cmd.GroupSIDs {
group := groupMap[groupSID]
plan, ok := planMap[group.PlanID()]
if !ok || plan == nil {
return nil, fmt.Errorf("plan not found for resource group %s", groupSID)
}
if plan.PlanType().IsForward() {
@@ -690,7 +737,6 @@ func (uc *CreateForwardRuleUseCase) executeExternalRule(ctx context.Context, cmd
return nil, errors.NewValidationError(
fmt.Sprintf("resource group %s belongs to a forward plan and cannot bind forward rules", groupSID))
}
groupIDs = append(groupIDs, group.ID())
}
}

View File

@@ -97,40 +97,30 @@ func (uc *ListForwardAgentsUseCase) Execute(ctx context.Context, query ListForwa
pages++
}
// Collect unique group IDs for batch query
groupIDSet := make(map[uint]struct{})
for _, agent := range agents {
if agent.GroupID() != nil {
groupIDSet[*agent.GroupID()] = struct{}{}
}
}
// Convert to DTOs
dtos := dto.ToForwardAgentDTOs(agents)
// Query resource groups and build GroupInfoMap
var groupInfoMap dto.GroupInfoMap
if len(groupIDSet) > 0 && uc.resourceGroupRepo != nil {
groupIDs := make([]uint, 0, len(groupIDSet))
for id := range groupIDSet {
groupIDs = append(groupIDs, id)
}
// Collect unique group IDs from DTOs for batch query
groupIDs := dto.CollectAgentGroupIDs(dtos)
// Query resource groups and populate GroupSIDs
if len(groupIDs) > 0 && uc.resourceGroupRepo != nil {
groups, err := uc.resourceGroupRepo.GetByIDs(ctx, groupIDs)
if err != nil {
uc.logger.Warnw("failed to get resource groups, continuing without group info",
"error", err,
)
} else {
groupInfoMap = make(dto.GroupInfoMap, len(groups))
groupSIDMap := make(dto.GroupSIDMap, len(groups))
for _, group := range groups {
groupInfoMap[group.ID()] = &dto.GroupInfo{
SID: group.SID(),
Name: group.Name(),
}
groupSIDMap[group.ID()] = group.SID()
}
for _, d := range dtos {
d.PopulateGroupSIDs(groupSIDMap)
}
}
}
dtos := dto.ToForwardAgentDTOs(agents, groupInfoMap)
// Collect agent IDs for batch status query and create ID mapping
agentIDs := make([]uint, 0, len(agents))
idToIndexMap := make(map[uint]int, len(agents))

View File

@@ -136,20 +136,21 @@ func (uc *ListUserForwardAgentsUseCase) Execute(ctx context.Context, query ListU
}, nil
}
// Step 4: Get active resource groups for these plans
// Step 4: Get active resource groups for these plans (batch query to avoid N+1)
groupsByPlan, err := uc.resourceGroupRepo.GetByPlanIDs(ctx, forwardPlanIDs)
if err != nil {
// Log warning but continue with empty result (consistent with original behavior)
uc.logger.Warnw("failed to batch get resource groups for plans",
"user_id", query.UserID,
"error", err,
)
groupsByPlan = make(map[uint][]*resource.ResourceGroup)
}
groupIDs := make([]uint, 0)
groupInfoMap := make(dto.GroupInfoMap)
for _, planID := range forwardPlanIDs {
groups, err := uc.resourceGroupRepo.GetByPlanID(ctx, planID)
if err != nil {
uc.logger.Warnw("failed to get resource groups for plan",
"plan_id", planID,
"error", err,
)
continue
}
for _, groups := range groupsByPlan {
for _, group := range groups {
if group.IsActive() {
groupIDs = append(groupIDs, group.ID())
@@ -197,8 +198,11 @@ func (uc *ListUserForwardAgentsUseCase) Execute(ctx context.Context, query ListU
pages++
}
// Step 6: Convert to user-facing DTOs
dtos := dto.ToUserForwardAgentDTOs(agents, groupInfoMap)
// Step 6: Convert to user-facing DTOs and populate groups
dtos := dto.ToUserForwardAgentDTOs(agents)
for _, d := range dtos {
d.PopulateGroups(groupInfoMap)
}
uc.logger.Infow("user forward agents listed successfully",
"user_id", query.UserID,

View File

@@ -18,7 +18,7 @@ type UpdateForwardAgentCommand struct {
PublicAddress *string
TunnelAddress *string
Remark *string
GroupSID *string // Resource group SID (empty string to remove association)
GroupSIDs []string // Resource group SIDs (empty slice to remove all associations)
AllowedPortRange *string // nil: no update, empty string: clear (allow all), non-empty: set new range
BlockedProtocols *[]string // nil: no update, empty slice: clear (allow all), non-empty: set new protocols
SortOrder *int // nil: no update, non-nil: set new sort order
@@ -98,23 +98,49 @@ func (uc *UpdateForwardAgentUseCase) Execute(ctx context.Context, cmd UpdateForw
}
}
// Handle GroupSID update (resolve SID to internal ID)
if cmd.GroupSID != nil {
if *cmd.GroupSID == "" {
// Empty string means remove the association
agent.SetGroupID(nil)
// Handle GroupSIDs update (resolve SIDs to internal IDs)
// Limit to 10 groups to prevent DoS attacks
const maxGroupSIDs = 10
if len(cmd.GroupSIDs) > maxGroupSIDs {
return errors.NewValidationError(fmt.Sprintf("too many group_sids, maximum allowed is %d", maxGroupSIDs))
}
// Note: We always update if GroupSIDs is provided (even empty slice means clear all)
if cmd.GroupSIDs != nil {
if len(cmd.GroupSIDs) == 0 {
// Empty slice means remove all associations
agent.SetGroupIDs(nil)
} else {
// Resolve group SID to internal ID
group, err := uc.resourceGroupRepo.GetBySID(ctx, *cmd.GroupSID)
// Deduplicate and filter empty SIDs
uniqueSIDs := make([]string, 0, len(cmd.GroupSIDs))
seenSIDs := make(map[string]struct{}, len(cmd.GroupSIDs))
for _, sid := range cmd.GroupSIDs {
if sid == "" {
continue
}
if _, exists := seenSIDs[sid]; exists {
continue
}
seenSIDs[sid] = struct{}{}
uniqueSIDs = append(uniqueSIDs, sid)
}
// Batch fetch all groups to avoid N+1 queries
groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, uniqueSIDs)
if err != nil {
uc.logger.Errorw("failed to get resource group by SID", "group_sid", *cmd.GroupSID, "error", err)
return errors.NewNotFoundError("resource group", *cmd.GroupSID)
uc.logger.Errorw("failed to batch get resource groups", "error", err)
return fmt.Errorf("failed to get resource groups: %w", err)
}
if group == nil {
return errors.NewNotFoundError("resource group", *cmd.GroupSID)
// Resolve SIDs to internal IDs
resolvedIDs := make([]uint, 0, len(uniqueSIDs))
for _, sid := range uniqueSIDs {
group, ok := groupMap[sid]
if !ok || group == nil {
return errors.NewNotFoundError("resource group", sid)
}
resolvedIDs = append(resolvedIDs, group.ID())
}
groupID := group.ID()
agent.SetGroupID(&groupID)
agent.SetGroupIDs(resolvedIDs)
}
}

View File

@@ -96,6 +96,36 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa
originalExitAgentIDs := rule.GetAllExitAgentIDs() // Includes both single and multiple exit agents
originalChainAgentIDs := rule.ChainAgentIDs()
// Collect all agent SIDs that need to be fetched (to avoid N+1 queries)
allAgentSIDs := make([]string, 0)
if cmd.AgentShortID != nil {
allAgentSIDs = append(allAgentSIDs, *cmd.AgentShortID)
}
if cmd.ExitAgentShortID != nil {
allAgentSIDs = append(allAgentSIDs, *cmd.ExitAgentShortID)
}
for _, input := range cmd.ExitAgents {
allAgentSIDs = append(allAgentSIDs, input.AgentSID)
}
allAgentSIDs = append(allAgentSIDs, cmd.ChainAgentShortIDs...)
for shortID := range cmd.ChainPortConfig {
allAgentSIDs = append(allAgentSIDs, shortID)
}
// Batch fetch all agents to avoid N+1 queries
var agentMap map[string]*forward.ForwardAgent
if len(allAgentSIDs) > 0 {
agents, err := uc.agentRepo.GetBySIDs(ctx, allAgentSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get agents", "error", err)
return fmt.Errorf("failed to get agents: %w", err)
}
agentMap = make(map[string]*forward.ForwardAgent, len(agents))
for _, a := range agents {
agentMap[a.SID()] = a
}
}
// Update fields
if cmd.Name != nil {
if err := rule.UpdateName(*cmd.Name); err != nil {
@@ -105,12 +135,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa
// Update entry agent ID
if cmd.AgentShortID != nil {
agent, err := uc.agentRepo.GetBySID(ctx, *cmd.AgentShortID)
if err != nil {
uc.logger.Errorw("failed to get agent", "agent_short_id", *cmd.AgentShortID, "error", err)
return fmt.Errorf("failed to validate agent: %w", err)
}
if agent == nil {
agent, ok := agentMap[*cmd.AgentShortID]
if !ok || agent == nil {
return errors.NewNotFoundError("forward agent", *cmd.AgentShortID)
}
@@ -160,12 +186,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa
// Update single exit agent ID (for entry type rules)
if cmd.ExitAgentShortID != nil {
exitAgent, err := uc.agentRepo.GetBySID(ctx, *cmd.ExitAgentShortID)
if err != nil {
uc.logger.Errorw("failed to get exit agent", "exit_agent_short_id", *cmd.ExitAgentShortID, "error", err)
return fmt.Errorf("failed to validate exit agent: %w", err)
}
if exitAgent == nil {
exitAgent, ok := agentMap[*cmd.ExitAgentShortID]
if !ok || exitAgent == nil {
return errors.NewNotFoundError("exit forward agent", *cmd.ExitAgentShortID)
}
@@ -185,12 +207,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa
if len(cmd.ExitAgents) > 0 {
exitAgents := make([]vo.AgentWeight, 0, len(cmd.ExitAgents))
for _, input := range cmd.ExitAgents {
exitAgent, err := uc.agentRepo.GetBySID(ctx, input.AgentSID)
if err != nil {
uc.logger.Errorw("failed to get exit agent", "exit_agent_sid", input.AgentSID, "error", err)
return fmt.Errorf("failed to validate exit agent: %w", err)
}
if exitAgent == nil {
exitAgent, ok := agentMap[input.AgentSID]
if !ok || exitAgent == nil {
return errors.NewNotFoundError("exit forward agent", input.AgentSID)
}
@@ -226,12 +244,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa
if cmd.ChainAgentShortIDs != nil {
chainAgentIDs := make([]uint, len(cmd.ChainAgentShortIDs))
for i, shortID := range cmd.ChainAgentShortIDs {
chainAgent, err := uc.agentRepo.GetBySID(ctx, shortID)
if err != nil {
uc.logger.Errorw("failed to get chain agent", "chain_agent_short_id", shortID, "error", err)
return fmt.Errorf("failed to validate chain agent: %w", err)
}
if chainAgent == nil {
chainAgent, ok := agentMap[shortID]
if !ok || chainAgent == nil {
return errors.NewNotFoundError("chain forward agent", shortID)
}
@@ -256,12 +270,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa
oldChainPortConfig := rule.ChainPortConfig()
chainPortConfig := make(map[uint]uint16, len(cmd.ChainPortConfig))
for shortID, port := range cmd.ChainPortConfig {
chainAgent, err := uc.agentRepo.GetBySID(ctx, shortID)
if err != nil {
uc.logger.Errorw("failed to get chain agent for port config", "chain_agent_short_id", shortID, "error", err)
return fmt.Errorf("failed to validate chain agent in chain_port_config: %w", err)
}
if chainAgent == nil {
chainAgent, ok := agentMap[shortID]
if !ok || chainAgent == nil {
return errors.NewNotFoundError("chain forward agent in chain_port_config", shortID)
}
@@ -474,29 +484,47 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa
if cmd.GroupSIDs != nil {
var groupIDs []uint
if len(*cmd.GroupSIDs) > 0 {
groupIDs = make([]uint, 0, len(*cmd.GroupSIDs))
// Validate SID formats first
for _, groupSID := range *cmd.GroupSIDs {
// Validate the SID format (rg_xxx)
if err := id.ValidatePrefix(groupSID, id.PrefixResourceGroup); err != nil {
return errors.NewValidationError(fmt.Sprintf("invalid resource group ID format: %s", groupSID))
}
}
group, err := uc.resourceGroupRepo.GetBySID(ctx, groupSID)
if err != nil {
uc.logger.Errorw("failed to get resource group", "group_sid", groupSID, "error", err)
return fmt.Errorf("failed to validate resource group: %w", err)
}
if group == nil {
// Batch fetch all groups to avoid N+1 queries
groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, *cmd.GroupSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get resource groups", "error", err)
return fmt.Errorf("failed to get resource groups: %w", err)
}
// Collect plan IDs for batch fetch
planIDs := make([]uint, 0, len(*cmd.GroupSIDs))
for _, groupSID := range *cmd.GroupSIDs {
group, ok := groupMap[groupSID]
if !ok || group == nil {
return errors.NewNotFoundError("resource group", groupSID)
}
planIDs = append(planIDs, group.PlanID())
}
// Verify the plan type supports forward rules binding (node and hybrid only, not forward)
plan, err := uc.planRepo.GetByID(ctx, group.PlanID())
if err != nil {
uc.logger.Errorw("failed to get plan for resource group", "plan_id", group.PlanID(), "error", err)
return fmt.Errorf("failed to validate resource group plan: %w", err)
}
if plan == nil {
// Batch fetch all plans to avoid N+1 queries
plans, err := uc.planRepo.GetByIDs(ctx, planIDs)
if err != nil {
uc.logger.Errorw("failed to batch get plans", "error", err)
return fmt.Errorf("failed to get plans: %w", err)
}
planMap := make(map[uint]*subscription.Plan, len(plans))
for _, p := range plans {
planMap[p.ID()] = p
}
// Validate each group and build groupIDs
groupIDs = make([]uint, 0, len(*cmd.GroupSIDs))
for _, groupSID := range *cmd.GroupSIDs {
group := groupMap[groupSID]
plan, ok := planMap[group.PlanID()]
if !ok || plan == nil {
return fmt.Errorf("plan not found for resource group %s", groupSID)
}
if plan.PlanType().IsForward() {
@@ -507,7 +535,6 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa
return errors.NewValidationError(
fmt.Sprintf("resource group %s belongs to a forward plan and cannot bind forward rules", groupSID))
}
groupIDs = append(groupIDs, group.ID())
}
}
@@ -632,14 +659,15 @@ func (uc *UpdateForwardRuleUseCase) getAccessibleGroupIDs(ctx context.Context, u
return nil, nil
}
// Step 4: Get active resource groups for these plans
// Step 4: Get active resource groups for these plans (batch query to avoid N+1)
groupsByPlan, err := uc.resourceGroupRepo.GetByPlanIDs(ctx, forwardPlanIDs)
if err != nil {
uc.logger.Warnw("failed to batch get resource groups for plans", "error", err)
return nil, nil
}
groupIDs := make([]uint, 0)
for _, planID := range forwardPlanIDs {
groups, err := uc.resourceGroupRepo.GetByPlanID(ctx, planID)
if err != nil {
uc.logger.Warnw("failed to get resource groups for plan", "plan_id", planID, "error", err)
continue
}
for _, groups := range groupsByPlan {
for _, group := range groups {
if group.IsActive() {
groupIDs = append(groupIDs, group.ID())
@@ -652,6 +680,7 @@ func (uc *UpdateForwardRuleUseCase) getAccessibleGroupIDs(ctx context.Context, u
// validateUserAgentAccess checks if the user has access to the specified agent.
// Returns nil if access is allowed, or an error if access is denied.
// Access is granted if any of the agent's group IDs is in the user's accessible group IDs.
func (uc *UpdateForwardRuleUseCase) validateUserAgentAccess(ctx context.Context, userID uint, agent *forward.ForwardAgent) error {
// Get user's accessible group IDs
accessibleGroupIDs, err := uc.getAccessibleGroupIDs(ctx, userID)
@@ -659,9 +688,9 @@ func (uc *UpdateForwardRuleUseCase) validateUserAgentAccess(ctx context.Context,
return fmt.Errorf("failed to get accessible groups: %w", err)
}
// Check if agent's group ID is in the accessible list
agentGroupID := agent.GroupID()
if agentGroupID == nil {
// Check if any of agent's group IDs is in the accessible list
agentGroupIDs := agent.GroupIDs()
if len(agentGroupIDs) == 0 {
// Agent has no group assigned, deny access for user endpoints
uc.logger.Warnw("user attempted to access agent without group",
"user_id", userID,
@@ -669,11 +698,20 @@ func (uc *UpdateForwardRuleUseCase) validateUserAgentAccess(ctx context.Context,
return errors.NewForbiddenError("agent is not accessible to user")
}
if !containsUint(accessibleGroupIDs, *agentGroupID) {
// Check if there's any intersection between agent's groups and user's accessible groups
hasAccess := false
for _, agentGroupID := range agentGroupIDs {
if containsUint(accessibleGroupIDs, agentGroupID) {
hasAccess = true
break
}
}
if !hasAccess {
uc.logger.Warnw("user attempted to access unauthorized agent",
"user_id", userID,
"agent_sid", agent.SID(),
"agent_group_id", *agentGroupID,
"agent_group_ids", agentGroupIDs,
"accessible_groups", accessibleGroupIDs)
return errors.NewForbiddenError("user does not have access to this agent")
}

View File

@@ -378,30 +378,37 @@ func (uc *CreateNodeUseCase) Execute(ctx context.Context, cmd CreateNodeCommand)
// Handle GroupSIDs (resolve SIDs to internal IDs)
if len(cmd.GroupSIDs) > 0 {
// Deduplicate and filter empty SIDs
uniqueSIDs := make([]string, 0, len(cmd.GroupSIDs))
seenSIDs := make(map[string]struct{}, len(cmd.GroupSIDs))
resolvedIDs := make([]uint, 0, len(cmd.GroupSIDs))
for _, sid := range cmd.GroupSIDs {
// Skip empty strings
if sid == "" {
continue
}
// Skip duplicate SIDs
if _, exists := seenSIDs[sid]; exists {
continue
}
seenSIDs[sid] = struct{}{}
group, err := uc.resourceGroupRepo.GetBySID(ctx, sid)
if err != nil {
uc.logger.Errorw("failed to get resource group by SID", "group_sid", sid, "error", err)
return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid))
}
if group == nil {
return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid))
}
resolvedIDs = append(resolvedIDs, group.ID())
uniqueSIDs = append(uniqueSIDs, sid)
}
if len(resolvedIDs) > 0 {
if len(uniqueSIDs) > 0 {
// Batch fetch all groups to avoid N+1 queries
groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, uniqueSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get resource groups", "error", err)
return nil, fmt.Errorf("failed to get resource groups: %w", err)
}
// Resolve SIDs to internal IDs
resolvedIDs := make([]uint, 0, len(uniqueSIDs))
for _, sid := range uniqueSIDs {
group, ok := groupMap[sid]
if !ok || group == nil {
return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid))
}
resolvedIDs = append(resolvedIDs, group.ID())
}
nodeEntity.SetGroupIDs(resolvedIDs)
}
}

View File

@@ -69,15 +69,26 @@ func (uc *GetUserNodeUsageUseCase) Execute(ctx context.Context, query GetUserNod
maxNodeLimit := 0
hasUnlimitedNodes := false
// Batch fetch all plans to avoid N+1 queries
planIDs := make([]uint, 0, len(subscriptions))
for _, sub := range subscriptions {
planIDs = append(planIDs, sub.PlanID())
}
plans, err := uc.planRepo.GetByIDs(ctx, planIDs)
if err != nil {
uc.logger.Errorw("failed to batch fetch plans", "user_id", query.UserID, "error", err)
return nil, fmt.Errorf("failed to get plans: %w", err)
}
planMap := make(map[uint]*subscription.Plan, len(plans))
for _, plan := range plans {
planMap[plan.ID()] = plan
}
// Find the highest limit among all active subscriptions
for _, sub := range subscriptions {
plan, err := uc.planRepo.GetByID(ctx, sub.PlanID())
if err != nil {
uc.logger.Warnw("failed to get plan for subscription", "subscription_id", sub.ID(), "plan_id", sub.PlanID(), "error", err)
continue
}
if plan == nil {
plan, ok := planMap[sub.PlanID()]
if !ok {
uc.logger.Warnw("plan not found for subscription", "subscription_id", sub.ID(), "plan_id", sub.PlanID())
continue
}

View File

@@ -186,33 +186,37 @@ func (uc *UpdateNodeUseCase) Execute(ctx context.Context, cmd UpdateNodeCommand)
// Empty slice means remove all group associations
existingNode.SetGroupIDs(nil)
} else {
// Resolve each group SID to internal ID (with deduplication)
// Deduplicate and filter empty SIDs
uniqueSIDs := make([]string, 0, len(cmd.GroupSIDs))
seenSIDs := make(map[string]struct{}, len(cmd.GroupSIDs))
resolvedIDs := make([]uint, 0, len(cmd.GroupSIDs))
for _, sid := range cmd.GroupSIDs {
// Skip empty strings
if sid == "" {
continue
}
// Skip duplicate SIDs
if _, exists := seenSIDs[sid]; exists {
continue
}
seenSIDs[sid] = struct{}{}
group, err := uc.resourceGroupRepo.GetBySID(ctx, sid)
if err != nil {
uc.logger.Errorw("failed to get resource group by SID", "group_sid", sid, "error", err)
return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid))
}
if group == nil {
return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid))
}
resolvedIDs = append(resolvedIDs, group.ID())
uniqueSIDs = append(uniqueSIDs, sid)
}
// Only update if we have valid SIDs after filtering
// (prevents accidental clear when all SIDs are empty strings)
if len(resolvedIDs) > 0 {
if len(uniqueSIDs) > 0 {
// Batch fetch all groups to avoid N+1 queries
groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, uniqueSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get resource groups", "error", err)
return nil, fmt.Errorf("failed to get resource groups: %w", err)
}
// Resolve SIDs to internal IDs
resolvedIDs := make([]uint, 0, len(uniqueSIDs))
for _, sid := range uniqueSIDs {
group, ok := groupMap[sid]
if !ok || group == nil {
return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid))
}
resolvedIDs = append(resolvedIDs, group.ID())
}
existingNode.SetGroupIDs(resolvedIDs)
}
}

View File

@@ -75,16 +75,25 @@ func (uc *CancelUnpaidSubscriptionsUseCase) Execute(ctx context.Context) (int, e
now := biztime.NowUTC()
cancelledCount := 0
// Batch fetch subscription IDs with pending payments to avoid N+1 queries
subscriptionIDs := make([]uint, 0, len(subscriptions))
for _, sub := range subscriptions {
subscriptionIDs = append(subscriptionIDs, sub.ID())
}
subsWithPendingPayments, err := uc.paymentRepo.GetSubscriptionIDsWithPendingPayments(ctx, subscriptionIDs)
if err != nil {
uc.logger.Errorw("failed to batch check pending payments", "error", err)
return 0, fmt.Errorf("failed to check pending payments: %w", err)
}
pendingPaymentSet := make(map[uint]bool, len(subsWithPendingPayments))
for _, subID := range subsWithPendingPayments {
pendingPaymentSet[subID] = true
}
for _, sub := range subscriptions {
// Check if there are any pending payments for this subscription first
// A new payment may have been created, skip cancellation
hasPending, err := uc.paymentRepo.HasPendingPaymentBySubscriptionID(ctx, sub.ID())
if err != nil {
uc.logger.Errorw("failed to check pending payments",
"subscription_id", sub.ID(),
"error", err)
continue
}
hasPending := pendingPaymentSet[sub.ID()]
if hasPending {
// New pending payment exists, skip cancellation

View File

@@ -42,6 +42,18 @@ func (uc *ExpirePaymentsUseCase) Execute(ctx context.Context) (int, error) {
uc.logger.Infow("processing expired payments", "count", len(expiredPayments))
// Batch fetch all subscriptions to avoid N+1 queries
subscriptionIDs := make([]uint, 0, len(expiredPayments))
for _, p := range expiredPayments {
subscriptionIDs = append(subscriptionIDs, p.SubscriptionID())
}
subscriptionMap, err := uc.subscriptionRepo.GetByIDs(ctx, subscriptionIDs)
if err != nil {
uc.logger.Warnw("failed to batch fetch subscriptions", "error", err)
// Continue with empty map, will log warnings for individual lookups
subscriptionMap = make(map[uint]*subscription.Subscription)
}
expiredCount := 0
for _, p := range expiredPayments {
if err := p.MarkAsExpired(); err != nil {
@@ -61,11 +73,19 @@ func (uc *ExpirePaymentsUseCase) Execute(ctx context.Context) (int, error) {
}
// Record payment expiration time on subscription for auto-cancel grace period
if err := uc.markSubscriptionPaymentExpired(ctx, p.SubscriptionID()); err != nil {
uc.logger.Warnw("failed to mark subscription payment expired",
"error", err,
sub, ok := subscriptionMap[p.SubscriptionID()]
if !ok || sub == nil {
uc.logger.Warnw("subscription not found for payment",
"payment_id", p.ID(),
"subscription_id", p.SubscriptionID())
// Continue processing other payments even if this fails
} else {
// Record the payment expiration time for grace period calculation
sub.SetMetadata("payment_expired_at", biztime.FormatMetadataTime(biztime.NowUTC()))
if err := uc.subscriptionRepo.Update(ctx, sub); err != nil {
uc.logger.Warnw("failed to update subscription payment_expired_at",
"error", err,
"subscription_id", p.SubscriptionID())
}
}
expiredCount++
@@ -81,24 +101,3 @@ func (uc *ExpirePaymentsUseCase) Execute(ctx context.Context) (int, error) {
return expiredCount, nil
}
// markSubscriptionPaymentExpired records the payment expiration time on the subscription
// This is used by the auto-cancel logic to determine the grace period
func (uc *ExpirePaymentsUseCase) markSubscriptionPaymentExpired(ctx context.Context, subscriptionID uint) error {
sub, err := uc.subscriptionRepo.GetByID(ctx, subscriptionID)
if err != nil {
return fmt.Errorf("failed to get subscription: %w", err)
}
if sub == nil {
return fmt.Errorf("subscription not found: %d", subscriptionID)
}
// Record the payment expiration time for grace period calculation
sub.SetMetadata("payment_expired_at", biztime.FormatMetadataTime(biztime.NowUTC()))
if err := uc.subscriptionRepo.Update(ctx, sub); err != nil {
return fmt.Errorf("failed to update subscription: %w", err)
}
return nil
}

View File

@@ -129,15 +129,23 @@ func (uc *ManageResourceGroupForwardAgentsUseCase) executeAddAgents(ctx context.
}
// Check if already in this group
if agent.GroupID() != nil && *agent.GroupID() == groupID {
currentGroupIDs := agent.GroupIDs()
alreadyInGroup := false
for _, gid := range currentGroupIDs {
if gid == groupID {
alreadyInGroup = true
break
}
}
if alreadyInGroup {
// Already in this group, count as success
result.Succeeded = append(result.Succeeded, agentSID)
continue
}
// Set group ID and update
gid := groupID
agent.SetGroupID(&gid)
// Add group ID to the list and update
newGroupIDs := append(currentGroupIDs, groupID)
agent.SetGroupIDs(newGroupIDs)
if err := uc.agentRepo.Update(ctx, agent); err != nil {
uc.logger.Errorw("failed to update forward agent", "error", err, "agent_sid", agentSID)
result.Failed = append(result.Failed, dto.BatchOperationErr{
@@ -236,7 +244,15 @@ func (uc *ManageResourceGroupForwardAgentsUseCase) executeRemoveAgents(ctx conte
}
// Check if the agent belongs to this group
if agent.GroupID() == nil || *agent.GroupID() != groupID {
currentGroupIDs := agent.GroupIDs()
foundIndex := -1
for i, gid := range currentGroupIDs {
if gid == groupID {
foundIndex = i
break
}
}
if foundIndex == -1 {
result.Failed = append(result.Failed, dto.BatchOperationErr{
ID: agentSID,
Reason: "forward agent does not belong to this group",
@@ -244,8 +260,14 @@ func (uc *ManageResourceGroupForwardAgentsUseCase) executeRemoveAgents(ctx conte
continue
}
// Remove group ID
agent.SetGroupID(nil)
// Remove group ID from the list
newGroupIDs := make([]uint, 0, len(currentGroupIDs)-1)
for i, gid := range currentGroupIDs {
if i != foundIndex {
newGroupIDs = append(newGroupIDs, gid)
}
}
agent.SetGroupIDs(newGroupIDs)
if err := uc.agentRepo.Update(ctx, agent); err != nil {
uc.logger.Errorw("failed to update forward agent", "error", err, "agent_sid", agentSID)
result.Failed = append(result.Failed, dto.BatchOperationErr{

View File

@@ -85,18 +85,22 @@ func (uc *ManageResourceGroupNodesUseCase) executeAddNodes(ctx context.Context,
Failed: make([]dto.BatchOperationErr, 0),
}
// Batch fetch all nodes to avoid N+1 queries
nodes, err := uc.nodeRepo.GetBySIDs(ctx, nodeSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get nodes", "error", err)
return nil, fmt.Errorf("failed to get nodes: %w", err)
}
// Build a map for quick lookup
nodeMap := make(map[string]*node.Node, len(nodes))
for _, n := range nodes {
nodeMap[n.SID()] = n
}
for _, nodeSID := range nodeSIDs {
// Get node by SID
n, err := uc.nodeRepo.GetBySID(ctx, nodeSID)
if err != nil {
uc.logger.Warnw("failed to get node", "error", err, "node_sid", nodeSID)
result.Failed = append(result.Failed, dto.BatchOperationErr{
ID: nodeSID,
Reason: "failed to get node",
})
continue
}
if n == nil {
n, ok := nodeMap[nodeSID]
if !ok {
result.Failed = append(result.Failed, dto.BatchOperationErr{
ID: nodeSID,
Reason: "node not found",
@@ -167,18 +171,22 @@ func (uc *ManageResourceGroupNodesUseCase) executeRemoveNodes(ctx context.Contex
Failed: make([]dto.BatchOperationErr, 0),
}
// Batch fetch all nodes to avoid N+1 queries
nodes, err := uc.nodeRepo.GetBySIDs(ctx, nodeSIDs)
if err != nil {
uc.logger.Errorw("failed to batch get nodes", "error", err)
return nil, fmt.Errorf("failed to get nodes: %w", err)
}
// Build a map for quick lookup
nodeMap := make(map[string]*node.Node, len(nodes))
for _, n := range nodes {
nodeMap[n.SID()] = n
}
for _, nodeSID := range nodeSIDs {
// Get node by SID
n, err := uc.nodeRepo.GetBySID(ctx, nodeSID)
if err != nil {
uc.logger.Warnw("failed to get node", "error", err, "node_sid", nodeSID)
result.Failed = append(result.Failed, dto.BatchOperationErr{
ID: nodeSID,
Reason: "failed to get node",
})
continue
}
if n == nil {
n, ok := nodeMap[nodeSID]
if !ok {
result.Failed = append(result.Failed, dto.BatchOperationErr{
ID: nodeSID,
Reason: "node not found",
@@ -265,13 +273,23 @@ func (uc *ManageResourceGroupNodesUseCase) executeListNodes(ctx context.Context,
for _, n := range nodes {
groupIDSet.AddAll(n.GroupIDs())
}
// Batch fetch group SIDs to avoid N+1 queries
groupIDToSID := make(map[uint]string)
groupIDToSID[groupID] = group.SID() // Current group is already loaded
otherGroupIDs := make([]uint, 0)
for _, gid := range groupIDSet.ToSlice() {
if gid != groupID {
g, err := uc.resourceGroupRepo.GetByID(ctx, gid)
if err == nil && g != nil {
groupIDToSID[gid] = g.SID()
otherGroupIDs = append(otherGroupIDs, gid)
}
}
if len(otherGroupIDs) > 0 {
sidMap, err := uc.resourceGroupRepo.GetSIDsByIDs(ctx, otherGroupIDs)
if err != nil {
uc.logger.Warnw("failed to batch get group SIDs", "error", err)
} else {
for gid, sid := range sidMap {
groupIDToSID[gid] = sid
}
}
}

View File

@@ -100,27 +100,46 @@ func (uc *ProcessReminderUseCase) processExpiringSubscriptions(ctx context.Conte
return 0, 1
}
// Collect unique expiringDays values to avoid N+1 queries
expiringDaysSet := make(map[int]struct{})
validBindings := make([]*telegram.TelegramBinding, 0, len(bindings))
for _, binding := range bindings {
if !binding.CanNotifyExpiring() {
continue
}
expiringDaysSet[binding.ExpiringDays()] = struct{}{}
validBindings = append(validBindings, binding)
}
// Find expiring subscriptions for this user
subs, err := uc.subscriptionRepo.FindExpiringSubscriptions(ctx, binding.ExpiringDays())
if len(validBindings) == 0 {
return 0, 0
}
// Batch fetch expiring subscriptions for each unique expiringDays value
// Key: expiringDays, Value: subscriptions grouped by userID
subscriptionsByDaysAndUser := make(map[int]map[uint][]*subscription.Subscription)
for days := range expiringDaysSet {
subs, err := uc.subscriptionRepo.FindExpiringSubscriptions(ctx, days)
if err != nil {
uc.logger.Errorw("failed to find expiring subscriptions", "user_id", binding.UserID(), "error", err)
uc.logger.Errorw("failed to find expiring subscriptions", "days", days, "error", err)
errors++
continue
}
// Filter to only this user's subscriptions
var userSubs []*subscription.Subscription
// Group subscriptions by userID
userSubsMap := make(map[uint][]*subscription.Subscription)
for _, sub := range subs {
if sub.UserID() == binding.UserID() {
userSubs = append(userSubs, sub)
}
userSubsMap[sub.UserID()] = append(userSubsMap[sub.UserID()], sub)
}
subscriptionsByDaysAndUser[days] = userSubsMap
}
for _, binding := range validBindings {
// Get pre-fetched subscriptions for this binding's expiringDays and userID
userSubsMap, ok := subscriptionsByDaysAndUser[binding.ExpiringDays()]
if !ok {
continue
}
userSubs := userSubsMap[binding.UserID()]
if len(userSubs) == 0 {
continue
}
@@ -185,9 +204,24 @@ func (uc *ProcessReminderUseCase) processTrafficUsage(ctx context.Context) (int,
planSubscriptions[sub.PlanID()] = append(planSubscriptions[sub.PlanID()], sub)
}
// Batch fetch all plans to avoid N+1 queries
planIDs := make([]uint, 0, len(planSubscriptions))
for planID := range planSubscriptions {
planIDs = append(planIDs, planID)
}
plans, err := uc.planRepo.GetByIDs(ctx, planIDs)
if err != nil {
uc.logger.Warnw("failed to batch fetch plans", "error", err)
continue
}
planMap := make(map[uint]*subscription.Plan, len(plans))
for _, plan := range plans {
planMap[plan.ID()] = plan
}
for planID, planSubs := range planSubscriptions {
plan, err := uc.planRepo.GetByID(ctx, planID)
if err != nil {
plan, ok := planMap[planID]
if !ok {
continue
}

View File

@@ -45,7 +45,7 @@ type ForwardAgent struct {
publicAddress string // optional public address for Entry to obtain Exit connection information
tunnelAddress string // IP or hostname only (no port), configure if agent may serve as relay/exit in any rule
remark string
groupID *uint // resource group ID
groupIDs []uint // resource group IDs
agentVersion string // agent software version (e.g., "1.2.3")
platform string // OS platform (linux, darwin, windows)
arch string // CPU architecture (amd64, arm64, arm, 386)
@@ -121,7 +121,7 @@ func ReconstructForwardAgent(
publicAddress string,
tunnelAddress string,
remark string,
groupID *uint,
groupIDs []uint,
agentVersion string,
platform string,
arch string,
@@ -172,7 +172,7 @@ func ReconstructForwardAgent(
publicAddress: publicAddress,
tunnelAddress: tunnelAddress,
remark: remark,
groupID: groupID,
groupIDs: groupIDs,
agentVersion: agentVersion,
platform: platform,
arch: arch,
@@ -236,14 +236,14 @@ func (a *ForwardAgent) GetEffectiveTunnelAddress() string {
return a.publicAddress
}
// GroupID returns the resource group ID
func (a *ForwardAgent) GroupID() *uint {
return a.groupID
// GroupIDs returns the resource group IDs
func (a *ForwardAgent) GroupIDs() []uint {
return a.groupIDs
}
// SetGroupID sets the resource group ID
func (a *ForwardAgent) SetGroupID(groupID *uint) {
a.groupID = groupID
// SetGroupIDs sets the resource group IDs
func (a *ForwardAgent) SetGroupIDs(groupIDs []uint) {
a.groupIDs = groupIDs
a.updatedAt = biztime.NowUTC()
}

View File

@@ -12,6 +12,9 @@ type PaymentRepository interface {
GetPendingBySubscriptionID(ctx context.Context, subscriptionID uint) (*Payment, error)
// HasPendingPaymentBySubscriptionID checks if there are any pending payments for a subscription
HasPendingPaymentBySubscriptionID(ctx context.Context, subscriptionID uint) (bool, error)
// GetSubscriptionIDsWithPendingPayments returns subscription IDs that have pending payments
// from the given list of subscription IDs
GetSubscriptionIDsWithPendingPayments(ctx context.Context, subscriptionIDs []uint) ([]uint, error)
GetExpiredPayments(ctx context.Context) ([]*Payment, error)
GetPendingUSDTPayments(ctx context.Context) ([]*Payment, error)
// GetConfirmedUSDTPaymentsNeedingActivation returns confirmed USDT payments

View File

@@ -22,12 +22,20 @@ type Repository interface {
// GetBySID retrieves a resource group by Stripe-style ID
GetBySID(ctx context.Context, sid string) (*ResourceGroup, error)
// GetBySIDs retrieves resource groups by their Stripe-style IDs
// Returns a map from SID to ResourceGroup for efficient lookup
GetBySIDs(ctx context.Context, sids []string) (map[string]*ResourceGroup, error)
// GetSIDsByIDs retrieves a map of resource group IDs to their SIDs
GetSIDsByIDs(ctx context.Context, ids []uint) (map[uint]string, error)
// GetByPlanID retrieves all resource groups for a plan
GetByPlanID(ctx context.Context, planID uint) ([]*ResourceGroup, error)
// GetByPlanIDs retrieves all resource groups for multiple plans
// Returns a map from planID to list of ResourceGroups
GetByPlanIDs(ctx context.Context, planIDs []uint) (map[uint][]*ResourceGroup, error)
// List retrieves resource groups with optional filters
List(ctx context.Context, filter ListFilter) ([]*ResourceGroup, int64, error)

View File

@@ -0,0 +1,28 @@
-- +goose Up
-- Migration: Convert forward_agents.group_id to group_ids JSON array
-- Description: Support forward agents belonging to multiple resource groups
-- Step 1: Add new group_ids JSON column
ALTER TABLE forward_agents ADD COLUMN group_ids JSON DEFAULT NULL;
-- Step 2: Migrate existing data from group_id to group_ids
UPDATE forward_agents SET group_ids = JSON_ARRAY(group_id) WHERE group_id IS NOT NULL;
-- Step 3: Drop the old group_id column and its index
DROP INDEX idx_forward_agent_group_id ON forward_agents;
ALTER TABLE forward_agents DROP COLUMN group_id;
-- +goose Down
-- Rollback: Convert group_ids JSON array back to single group_id
-- Step 1: Add back group_id column
ALTER TABLE forward_agents ADD COLUMN group_id BIGINT UNSIGNED NULL;
-- Step 2: Migrate data back (take first element from JSON array)
UPDATE forward_agents SET group_id = JSON_EXTRACT(group_ids, '$[0]') WHERE group_ids IS NOT NULL AND JSON_LENGTH(group_ids) > 0;
-- Step 3: Drop group_ids column
ALTER TABLE forward_agents DROP COLUMN group_ids;
-- Step 4: Recreate index
CREATE INDEX idx_forward_agent_group_id ON forward_agents (group_id);

View File

@@ -60,6 +60,14 @@ func (m *ForwardAgentMapperImpl) ToEntity(model *models.ForwardAgentModel) (*for
blockedProtocols = vo.NewBlockedProtocols(protocols)
}
// Parse group_ids from JSON
var groupIDs []uint
if len(model.GroupIDs) > 0 {
if err := json.Unmarshal(model.GroupIDs, &groupIDs); err != nil {
return nil, fmt.Errorf("failed to parse group_ids: %w", err)
}
}
entity, err := forward.ReconstructForwardAgent(
model.ID,
model.SID,
@@ -70,7 +78,7 @@ func (m *ForwardAgentMapperImpl) ToEntity(model *models.ForwardAgentModel) (*for
model.PublicAddress,
model.TunnelAddress,
model.Remark,
model.GroupID,
groupIDs,
model.AgentVersion,
model.Platform,
model.Arch,
@@ -116,6 +124,16 @@ func (m *ForwardAgentMapperImpl) ToModel(entity *forward.ForwardAgent) (*models.
}
}
// Serialize group_ids to JSON
var groupIDsJSON []byte
if len(entity.GroupIDs()) > 0 {
var err error
groupIDsJSON, err = json.Marshal(entity.GroupIDs())
if err != nil {
return nil, fmt.Errorf("failed to serialize group_ids: %w", err)
}
}
return &models.ForwardAgentModel{
ID: entity.ID(),
SID: entity.SID(),
@@ -126,7 +144,7 @@ func (m *ForwardAgentMapperImpl) ToModel(entity *forward.ForwardAgent) (*models.
TunnelAddress: entity.TunnelAddress(),
Status: string(entity.Status()),
Remark: entity.Remark(),
GroupID: entity.GroupID(),
GroupIDs: groupIDsJSON,
AgentVersion: entity.AgentVersion(),
Platform: entity.Platform(),
Arch: entity.Arch(),

View File

@@ -20,7 +20,7 @@ type ForwardAgentModel struct {
TunnelAddress string `gorm:"size:255"` // tunnel address for entry to connect to exit (nullable, overrides public_address)
Status string `gorm:"not null;default:enabled;size:20;index:idx_forward_agent_status"`
Remark string `gorm:"size:500"`
GroupID *uint `gorm:"index:idx_forward_agent_group_id"` // resource group ID
GroupIDs datatypes.JSON `gorm:"column:group_ids"` // resource group IDs (JSON array)
AgentVersion string `gorm:"size:50"` // agent software version (e.g., "1.2.3")
Platform string `gorm:"size:20"` // OS platform (linux, darwin, windows)
Arch string `gorm:"size:20"` // CPU architecture (amd64, arm64, arm, 386)

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"encoding/json"
"fmt"
"strings"
@@ -22,7 +23,6 @@ var allowedAgentOrderByFields = map[string]bool{
"sid": true,
"name": true,
"status": true,
"group_id": true,
"sort_order": true,
"last_seen_at": true,
"created_at": true,
@@ -225,7 +225,7 @@ func (r *ForwardAgentRepositoryImpl) Update(ctx context.Context, agent *forward.
"public_address": model.PublicAddress,
"tunnel_address": model.TunnelAddress,
"remark": model.Remark,
"group_id": model.GroupID,
"group_ids": model.GroupIDs,
"allowed_port_range": model.AllowedPortRange,
"blocked_protocols": model.BlockedProtocols,
"sort_order": model.SortOrder,
@@ -277,7 +277,9 @@ func (r *ForwardAgentRepositoryImpl) List(ctx context.Context, filter forward.Ag
query = query.Where("status = ?", filter.Status)
}
if len(filter.GroupIDs) > 0 {
query = query.Where("group_id IN ?", filter.GroupIDs)
// Use JSON_OVERLAPS to check if group_ids array contains any of the filter group IDs
groupIDsJSON, _ := json.Marshal(filter.GroupIDs)
query = query.Where("JSON_OVERLAPS(group_ids, ?)", string(groupIDsJSON))
}
// Count total records

View File

@@ -165,6 +165,26 @@ func (r *PaymentRepository) HasPendingPaymentBySubscriptionID(ctx context.Contex
return count > 0, nil
}
// GetSubscriptionIDsWithPendingPayments returns subscription IDs that have pending payments
// from the given list of subscription IDs
func (r *PaymentRepository) GetSubscriptionIDsWithPendingPayments(ctx context.Context, subscriptionIDs []uint) ([]uint, error) {
if len(subscriptionIDs) == 0 {
return []uint{}, nil
}
var results []uint
if err := db.GetTxFromContext(ctx, r.db).
Model(&models.PaymentModel{}).
Select("DISTINCT subscription_id").
Where("subscription_id IN ? AND payment_status = ?", subscriptionIDs, vo.PaymentStatusPending).
Pluck("subscription_id", &results).Error; err != nil {
return nil, fmt.Errorf("failed to get subscription IDs with pending payments: %w", err)
}
return results, nil
}
func (r *PaymentRepository) GetExpiredPayments(ctx context.Context) ([]*payment.Payment, error) {
var paymentModels []models.PaymentModel

View File

@@ -225,6 +225,60 @@ func (r *ResourceGroupRepositoryImpl) GetByPlanID(ctx context.Context, planID ui
return entities, nil
}
// GetBySIDs retrieves resource groups by their Stripe-style IDs.
// Returns a map from SID to ResourceGroup for efficient lookup.
func (r *ResourceGroupRepositoryImpl) GetBySIDs(ctx context.Context, sids []string) (map[string]*resource.ResourceGroup, error) {
if len(sids) == 0 {
return make(map[string]*resource.ResourceGroup), nil
}
var modelList []*models.ResourceGroupModel
if err := r.db.WithContext(ctx).Where("sid IN ?", sids).Find(&modelList).Error; err != nil {
r.logger.Errorw("failed to get resource groups by SIDs", "sids", sids, "error", err)
return nil, fmt.Errorf("failed to get resource groups: %w", err)
}
result := make(map[string]*resource.ResourceGroup, len(modelList))
for _, model := range modelList {
entity, err := r.mapper.ToEntity(model)
if err != nil {
r.logger.Errorw("failed to map resource group model to entity", "sid", model.SID, "error", err)
return nil, fmt.Errorf("failed to map resource group: %w", err)
}
result[model.SID] = entity
}
return result, nil
}
// GetByPlanIDs retrieves all resource groups for multiple plans.
// Returns a map from planID to list of ResourceGroups.
func (r *ResourceGroupRepositoryImpl) GetByPlanIDs(ctx context.Context, planIDs []uint) (map[uint][]*resource.ResourceGroup, error) {
if len(planIDs) == 0 {
return make(map[uint][]*resource.ResourceGroup), nil
}
var modelList []*models.ResourceGroupModel
if err := r.db.WithContext(ctx).Where("plan_id IN ?", planIDs).Find(&modelList).Error; err != nil {
r.logger.Errorw("failed to get resource groups by plan IDs", "plan_ids", planIDs, "error", err)
return nil, fmt.Errorf("failed to get resource groups: %w", err)
}
result := make(map[uint][]*resource.ResourceGroup)
for _, model := range modelList {
entity, err := r.mapper.ToEntity(model)
if err != nil {
r.logger.Errorw("failed to map resource group model to entity", "id", model.ID, "error", err)
return nil, fmt.Errorf("failed to map resource group: %w", err)
}
result[model.PlanID] = append(result[model.PlanID], entity)
}
return result, nil
}
// List retrieves resource groups with optional filters.
func (r *ResourceGroupRepositoryImpl) List(ctx context.Context, filter resource.ListFilter) ([]*resource.ResourceGroup, int64, error) {
query := r.db.WithContext(ctx).Model(&models.ResourceGroupModel{})

View File

@@ -75,7 +75,7 @@ type CreateForwardAgentRequest struct {
PublicAddress string `json:"public_address,omitempty" example:"203.0.113.1"`
TunnelAddress string `json:"tunnel_address,omitempty" example:"192.168.1.100"` // IP or hostname only (no port), configure if agent may serve as relay/exit in any rule
Remark string `json:"remark,omitempty" example:"Forward agent for production environment"`
GroupSID string `json:"group_sid,omitempty" example:"rg_xK9mP2vL3nQ"` // Resource group SID to associate with
GroupSIDs []string `json:"group_sids,omitempty" example:"[\"rg_xK9mP2vL3nQ\"]"` // Resource group SIDs to associate with
AllowedPortRange string `json:"allowed_port_range,omitempty" example:"80,443,8000-9000"`
BlockedProtocols []string `json:"blocked_protocols,omitempty" example:"socks5,http_connect"` // Protocols to block (e.g., socks5, http_connect, ssh)
SortOrder *int `json:"sort_order,omitempty" example:"100"` // Custom sort order for UI display (lower values appear first)
@@ -87,7 +87,7 @@ type UpdateForwardAgentRequest struct {
PublicAddress *string `json:"public_address,omitempty" example:"203.0.113.2"`
TunnelAddress *string `json:"tunnel_address,omitempty" example:"192.168.1.100"` // IP or hostname only (no port), configure if agent may serve as relay/exit in any rule
Remark *string `json:"remark,omitempty" example:"Updated remark"`
GroupSID *string `json:"group_sid,omitempty" example:"rg_xK9mP2vL3nQ"` // Resource group SID to associate with (use empty string to remove)
GroupSIDs []string `json:"group_sids,omitempty" example:"[\"rg_xK9mP2vL3nQ\"]"` // Resource group SIDs to associate with (empty array to remove all)
AllowedPortRange *string `json:"allowed_port_range,omitempty" example:"80,443,8000-9000"`
BlockedProtocols *[]string `json:"blocked_protocols,omitempty"` // Protocols to block (nil: no update, empty array: clear, non-empty: set new)
SortOrder *int `json:"sort_order,omitempty" example:"100"` // Custom sort order for UI display (lower values appear first)
@@ -113,7 +113,7 @@ func (h *Handler) CreateAgent(c *gin.Context) {
PublicAddress: req.PublicAddress,
TunnelAddress: req.TunnelAddress,
Remark: req.Remark,
GroupSID: req.GroupSID,
GroupSIDs: req.GroupSIDs,
AllowedPortRange: req.AllowedPortRange,
BlockedProtocols: req.BlockedProtocols,
SortOrder: req.SortOrder, // nil if not provided, allowing explicit 0 value
@@ -197,7 +197,7 @@ func (h *Handler) UpdateAgent(c *gin.Context) {
PublicAddress: req.PublicAddress,
TunnelAddress: req.TunnelAddress,
Remark: req.Remark,
GroupSID: req.GroupSID,
GroupSIDs: req.GroupSIDs,
AllowedPortRange: req.AllowedPortRange,
BlockedProtocols: req.BlockedProtocols,
SortOrder: req.SortOrder,

View File

@@ -209,11 +209,15 @@ func detectFormatFromUserAgent(userAgent string) string {
return "base64"
}
// V2Ray clients
if strings.Contains(ua, "v2ray") {
return "v2ray"
// V2RayN/V2RayNG clients - use base64 format (supports all protocol URIs)
// These are general-purpose clients that parse base64-encoded URI lists
// (vmess://, vless://, trojan://, ss://, etc.)
if strings.Contains(ua, "v2rayn") || strings.Contains(ua, "v2rayng") {
return "base64"
}
// Default to base64 format for unknown clients
// Note: "v2ray" format (JSON) is only for Shadowsocks-only clients,
// accessible via explicit /s/:token/v2ray endpoint
return "base64"
}

View File

@@ -56,14 +56,14 @@ func TestDetectFormatFromUserAgent(t *testing.T) {
expected: "base64",
},
{
name: "V2Ray client returns v2ray format",
name: "V2RayN client returns base64 format (supports all protocols)",
userAgent: "v2rayN/1.0.0",
expected: "v2ray",
expected: "base64",
},
{
name: "V2RayNG client returns v2ray format",
name: "V2RayNG client returns base64 format (supports all protocols)",
userAgent: "V2RayNG/1.0.0",
expected: "v2ray",
expected: "base64",
},
{
name: "Unknown browser returns base64 format",