diff --git a/internal/application/forward/dto/forwardagent.go b/internal/application/forward/dto/forwardagent.go index b96c2f6..1e29822 100644 --- a/internal/application/forward/dto/forwardagent.go +++ b/internal/application/forward/dto/forwardagent.go @@ -16,9 +16,9 @@ type ForwardAgentDTO struct { TunnelAddress string `json:"tunnel_address,omitempty"` // IP or hostname only (no port), configure if agent may serve as relay/exit in any rule Status string `json:"status"` Remark string `json:"remark"` - GroupSID *string `json:"group_id,omitempty"` // Resource group SID this agent belongs to - AgentVersion string `json:"agent_version"` // Agent software version (e.g., "1.2.3"), extracted from system_status for easy display - HasUpdate bool `json:"has_update"` // True if a newer version is available + GroupSIDs []string `json:"group_sids,omitempty"` // Resource group SIDs this agent belongs to + AgentVersion string `json:"agent_version"` // Agent software version (e.g., "1.2.3"), extracted from system_status for easy display + HasUpdate bool `json:"has_update"` // True if a newer version is available AllowedPortRange string `json:"allowed_port_range,omitempty"` BlockedProtocols []string `json:"blocked_protocols,omitempty"` // Protocols blocked by this agent SortOrder int `json:"sort_order"` // Custom sort order for UI display @@ -28,11 +28,12 @@ type ForwardAgentDTO struct { CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` SystemStatus *AgentStatusDTO `json:"system_status,omitempty"` + + internalGroupIDs []uint `json:"-"` // internal resource group IDs for lookup } // ToForwardAgentDTO converts a domain forward agent to DTO. -// groupInfo is optional, can be nil if group information is not available. -func ToForwardAgentDTO(agent *forward.ForwardAgent, groupInfo *GroupInfo) *ForwardAgentDTO { +func ToForwardAgentDTO(agent *forward.ForwardAgent) *ForwardAgentDTO { if agent == nil { return nil } @@ -58,25 +59,52 @@ func ToForwardAgentDTO(agent *forward.ForwardAgent, groupInfo *GroupInfo) *Forwa LastSeenAt: agent.LastSeenAt(), CreatedAt: agent.CreatedAt().Format("2006-01-02T15:04:05Z07:00"), UpdatedAt: agent.UpdatedAt().Format("2006-01-02T15:04:05Z07:00"), - } - - if groupInfo != nil { - dto.GroupSID = &groupInfo.SID + internalGroupIDs: agent.GroupIDs(), } return dto } // ToForwardAgentDTOs converts a slice of domain forward agents to DTOs. -// groupInfoMap is optional, can be nil if group information is not available. -func ToForwardAgentDTOs(agents []*forward.ForwardAgent, groupInfoMap GroupInfoMap) []*ForwardAgentDTO { +func ToForwardAgentDTOs(agents []*forward.ForwardAgent) []*ForwardAgentDTO { dtos := make([]*ForwardAgentDTO, len(agents)) for i, agent := range agents { - var groupInfo *GroupInfo - if groupInfoMap != nil && agent.GroupID() != nil { - groupInfo = groupInfoMap[*agent.GroupID()] - } - dtos[i] = ToForwardAgentDTO(agent, groupInfo) + dtos[i] = ToForwardAgentDTO(agent) } return dtos } + +// PopulateGroupSIDs fills in the group SIDs field using the SID map. +func (d *ForwardAgentDTO) PopulateGroupSIDs(groupMap GroupSIDMap) { + if len(d.internalGroupIDs) == 0 { + return + } + d.GroupSIDs = make([]string, 0, len(d.internalGroupIDs)) + for _, groupID := range d.internalGroupIDs { + if sid, ok := groupMap[groupID]; ok && sid != "" { + d.GroupSIDs = append(d.GroupSIDs, sid) + } + } +} + +// InternalGroupIDs returns the internal resource group IDs for repository lookups. +func (d *ForwardAgentDTO) InternalGroupIDs() []uint { + return d.internalGroupIDs +} + +// CollectAgentGroupIDs collects unique resource group IDs from agent DTOs for batch lookup. +func CollectAgentGroupIDs(dtos []*ForwardAgentDTO) []uint { + idSet := make(map[uint]struct{}) + for _, dto := range dtos { + for _, groupID := range dto.internalGroupIDs { + if groupID != 0 { + idSet[groupID] = struct{}{} + } + } + } + ids := make([]uint, 0, len(idSet)) + for id := range idSet { + ids = append(ids, id) + } + return ids +} diff --git a/internal/application/forward/dto/userforwardagent.go b/internal/application/forward/dto/userforwardagent.go index 6aed87f..10dab47 100644 --- a/internal/application/forward/dto/userforwardagent.go +++ b/internal/application/forward/dto/userforwardagent.go @@ -10,13 +10,20 @@ import ( // UserForwardAgentDTO represents a forward agent from user's perspective. // This DTO hides sensitive fields like token_hash and api_token. type UserForwardAgentDTO struct { - ID string `json:"id"` // Stripe-style prefixed ID (e.g., "fa_xK9mP2vL3nQ") - Name string `json:"name"` // Agent name - PublicAddress string `json:"public_address,omitempty"` // Public address for client connections - Status string `json:"status"` // enabled or disabled - GroupSID string `json:"group_id,omitempty"` // Resource group SID (e.g., "rg_xK9mP2vL3nQ") - GroupName string `json:"group_name,omitempty"` // Resource group name for display - CreatedAt time.Time `json:"created_at"` + ID string `json:"id"` // Stripe-style prefixed ID (e.g., "fa_xK9mP2vL3nQ") + Name string `json:"name"` // Agent name + PublicAddress string `json:"public_address,omitempty"` // Public address for client connections + Status string `json:"status"` // enabled or disabled + Groups []UserGroupInfoDTO `json:"groups,omitempty"` // Resource groups this agent belongs to + CreatedAt time.Time `json:"created_at"` + + internalGroupIDs []uint `json:"-"` // internal resource group IDs for lookup +} + +// UserGroupInfoDTO represents group info for user-facing DTOs. +type UserGroupInfoDTO struct { + SID string `json:"id"` // Resource group SID (e.g., "rg_xK9mP2vL3nQ") + Name string `json:"name"` // Resource group name for display } // GroupInfo holds resource group information for populating DTOs. @@ -29,36 +36,66 @@ type GroupInfo struct { type GroupInfoMap map[uint]*GroupInfo // ToUserForwardAgentDTO converts a domain forward agent to user-facing DTO. -func ToUserForwardAgentDTO(agent *forward.ForwardAgent, groupInfo *GroupInfo) *UserForwardAgentDTO { +func ToUserForwardAgentDTO(agent *forward.ForwardAgent) *UserForwardAgentDTO { if agent == nil { return nil } dto := &UserForwardAgentDTO{ - ID: agent.SID(), - Name: agent.Name(), - PublicAddress: agent.PublicAddress(), - Status: string(agent.Status()), - CreatedAt: agent.CreatedAt(), - } - - if groupInfo != nil { - dto.GroupSID = groupInfo.SID - dto.GroupName = groupInfo.Name + ID: agent.SID(), + Name: agent.Name(), + PublicAddress: agent.PublicAddress(), + Status: string(agent.Status()), + CreatedAt: agent.CreatedAt(), + internalGroupIDs: agent.GroupIDs(), } return dto } // ToUserForwardAgentDTOs converts a slice of domain forward agents to user-facing DTOs. -func ToUserForwardAgentDTOs(agents []*forward.ForwardAgent, groupInfoMap GroupInfoMap) []*UserForwardAgentDTO { +func ToUserForwardAgentDTOs(agents []*forward.ForwardAgent) []*UserForwardAgentDTO { dtos := make([]*UserForwardAgentDTO, len(agents)) for i, agent := range agents { - var groupInfo *GroupInfo - if agent.GroupID() != nil { - groupInfo = groupInfoMap[*agent.GroupID()] - } - dtos[i] = ToUserForwardAgentDTO(agent, groupInfo) + dtos[i] = ToUserForwardAgentDTO(agent) } return dtos } + +// PopulateGroups fills in the groups field using the group info map. +func (d *UserForwardAgentDTO) PopulateGroups(groupInfoMap GroupInfoMap) { + if len(d.internalGroupIDs) == 0 { + return + } + d.Groups = make([]UserGroupInfoDTO, 0, len(d.internalGroupIDs)) + for _, groupID := range d.internalGroupIDs { + if info, ok := groupInfoMap[groupID]; ok && info != nil { + d.Groups = append(d.Groups, UserGroupInfoDTO{ + SID: info.SID, + Name: info.Name, + }) + } + } +} + +// InternalGroupIDs returns the internal resource group IDs for repository lookups. +func (d *UserForwardAgentDTO) InternalGroupIDs() []uint { + return d.internalGroupIDs +} + +// CollectUserAgentGroupIDs collects unique resource group IDs from user agent DTOs for batch lookup. +func CollectUserAgentGroupIDs(dtos []*UserForwardAgentDTO) []uint { + idSet := make(map[uint]struct{}) + for _, dto := range dtos { + for _, groupID := range dto.internalGroupIDs { + if groupID != 0 { + idSet[groupID] = struct{}{} + } + } + } + ids := make([]uint, 0, len(idSet)) + for id := range idSet { + ids = append(ids, id) + } + return ids +} diff --git a/internal/application/forward/usecases/createforwardagent.go b/internal/application/forward/usecases/createforwardagent.go index 20e92ea..abce4a2 100644 --- a/internal/application/forward/usecases/createforwardagent.go +++ b/internal/application/forward/usecases/createforwardagent.go @@ -18,7 +18,7 @@ type CreateForwardAgentCommand struct { PublicAddress string TunnelAddress string Remark string - GroupSID string // Resource group SID to associate with (empty means no association) + GroupSIDs []string // Resource group SIDs to associate with (empty means no association) AllowedPortRange string // Port range string (e.g., "80,443,8000-9000"), empty means all ports allowed BlockedProtocols []string // Protocols to block (e.g., ["socks5", "http_connect"]), empty means no blocking SortOrder *int // Custom sort order for UI display (nil: use default 0, non-nil: set explicitly) @@ -124,18 +124,46 @@ func (uc *CreateForwardAgentUseCase) Execute(ctx context.Context, cmd CreateForw agent.UpdateSortOrder(*cmd.SortOrder) } - // Handle GroupSID (resolve SID to internal ID) - if cmd.GroupSID != "" { - group, err := uc.resourceGroupRepo.GetBySID(ctx, cmd.GroupSID) - if err != nil { - uc.logger.Errorw("failed to get resource group by SID", "group_sid", cmd.GroupSID, "error", err) - return nil, errors.NewNotFoundError("resource group", cmd.GroupSID) + // Handle GroupSIDs (resolve SIDs to internal IDs) + // Limit to 10 groups to prevent DoS attacks + const maxGroupSIDs = 10 + if len(cmd.GroupSIDs) > maxGroupSIDs { + return nil, errors.NewValidationError(fmt.Sprintf("too many group_sids, maximum allowed is %d", maxGroupSIDs)) + } + if len(cmd.GroupSIDs) > 0 { + // Deduplicate and filter empty SIDs + uniqueSIDs := make([]string, 0, len(cmd.GroupSIDs)) + seenSIDs := make(map[string]struct{}, len(cmd.GroupSIDs)) + for _, sid := range cmd.GroupSIDs { + if sid == "" { + continue + } + if _, exists := seenSIDs[sid]; exists { + continue + } + seenSIDs[sid] = struct{}{} + uniqueSIDs = append(uniqueSIDs, sid) } - if group == nil { - return nil, errors.NewNotFoundError("resource group", cmd.GroupSID) + + if len(uniqueSIDs) > 0 { + // Batch fetch all groups to avoid N+1 queries + groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, uniqueSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get resource groups", "error", err) + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + // Resolve SIDs to internal IDs + resolvedIDs := make([]uint, 0, len(uniqueSIDs)) + for _, sid := range uniqueSIDs { + group, ok := groupMap[sid] + if !ok || group == nil { + return nil, errors.NewNotFoundError("resource group", sid) + } + resolvedIDs = append(resolvedIDs, group.ID()) + } + agent.SetGroupIDs(resolvedIDs) } - groupID := group.ID() - agent.SetGroupID(&groupID) } // Persist diff --git a/internal/application/forward/usecases/createforwardrule.go b/internal/application/forward/usecases/createforwardrule.go index ffb7ee2..8e946ae 100644 --- a/internal/application/forward/usecases/createforwardrule.go +++ b/internal/application/forward/usecases/createforwardrule.go @@ -143,6 +143,33 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa cmd.ListenPort, agent.AllowedPortRange().String())) } + // Collect all agent SIDs that need to be fetched (to avoid N+1 queries) + allAgentSIDs := make([]string, 0) + if cmd.ExitAgentShortID != "" { + allAgentSIDs = append(allAgentSIDs, cmd.ExitAgentShortID) + } + for _, input := range cmd.ExitAgents { + allAgentSIDs = append(allAgentSIDs, input.AgentSID) + } + allAgentSIDs = append(allAgentSIDs, cmd.ChainAgentShortIDs...) + for shortID := range cmd.ChainPortConfig { + allAgentSIDs = append(allAgentSIDs, shortID) + } + + // Batch fetch all agents + var agentMap map[string]*forward.ForwardAgent + if len(allAgentSIDs) > 0 { + agents, err := uc.agentRepo.GetBySIDs(ctx, allAgentSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get agents", "error", err) + return nil, fmt.Errorf("failed to get agents: %w", err) + } + agentMap = make(map[string]*forward.ForwardAgent, len(agents)) + for _, a := range agents { + agentMap[a.SID()] = a + } + } + // Resolve exit agent configuration (single exitAgentID OR multiple exitAgents) var exitAgentID uint var exitAgents []vo.AgentWeight @@ -150,12 +177,8 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa return nil, errors.NewValidationError("exit_agent_id and exit_agents are mutually exclusive") } if cmd.ExitAgentShortID != "" { - exitAgent, err := uc.agentRepo.GetBySID(ctx, cmd.ExitAgentShortID) - if err != nil { - uc.logger.Errorw("failed to get exit agent", "exit_agent_short_id", cmd.ExitAgentShortID, "error", err) - return nil, fmt.Errorf("failed to validate exit agent: %w", err) - } - if exitAgent == nil { + exitAgent, ok := agentMap[cmd.ExitAgentShortID] + if !ok || exitAgent == nil { return nil, errors.NewNotFoundError("exit forward agent", cmd.ExitAgentShortID) } exitAgentID = exitAgent.ID() @@ -163,12 +186,8 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa // Resolve multiple exit agents with weights exitAgents = make([]vo.AgentWeight, 0, len(cmd.ExitAgents)) for _, input := range cmd.ExitAgents { - exitAgent, err := uc.agentRepo.GetBySID(ctx, input.AgentSID) - if err != nil { - uc.logger.Errorw("failed to get exit agent", "exit_agent_sid", input.AgentSID, "error", err) - return nil, fmt.Errorf("failed to validate exit agent: %w", err) - } - if exitAgent == nil { + exitAgent, ok := agentMap[input.AgentSID] + if !ok || exitAgent == nil { return nil, errors.NewNotFoundError("exit forward agent", input.AgentSID) } // Use provided weight or default @@ -193,12 +212,8 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa if len(cmd.ChainAgentShortIDs) > 0 { chainAgentIDs = make([]uint, len(cmd.ChainAgentShortIDs)) for i, shortID := range cmd.ChainAgentShortIDs { - chainAgent, err := uc.agentRepo.GetBySID(ctx, shortID) - if err != nil { - uc.logger.Errorw("failed to get chain agent", "chain_agent_short_id", shortID, "error", err) - return nil, fmt.Errorf("failed to validate chain agent: %w", err) - } - if chainAgent == nil { + chainAgent, ok := agentMap[shortID] + if !ok || chainAgent == nil { return nil, errors.NewNotFoundError("chain forward agent", shortID) } chainAgentIDs[i] = chainAgent.ID() @@ -212,12 +227,8 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa if len(cmd.ChainPortConfig) > 0 { chainPortConfig = make(map[uint]uint16, len(cmd.ChainPortConfig)) for shortID, port := range cmd.ChainPortConfig { - chainAgent, err := uc.agentRepo.GetBySID(ctx, shortID) - if err != nil { - uc.logger.Errorw("failed to get chain agent for port config", "chain_agent_short_id", shortID, "error", err) - return nil, fmt.Errorf("failed to validate chain agent in chain_port_config: %w", err) - } - if chainAgent == nil { + chainAgent, ok := agentMap[shortID] + if !ok || chainAgent == nil { return nil, errors.NewNotFoundError("chain forward agent in chain_port_config", shortID) } // Validate port against chain agent's allowed port range @@ -299,29 +310,47 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa // Resolve GroupSIDs to internal IDs and validate plan types (if provided) var groupIDs []uint if len(cmd.GroupSIDs) > 0 { - groupIDs = make([]uint, 0, len(cmd.GroupSIDs)) + // Validate SID formats first for _, groupSID := range cmd.GroupSIDs { - // Validate the SID format (rg_xxx) if err := id.ValidatePrefix(groupSID, id.PrefixResourceGroup); err != nil { return nil, errors.NewValidationError(fmt.Sprintf("invalid resource group ID format: %s", groupSID)) } + } - group, err := uc.resourceGroupRepo.GetBySID(ctx, groupSID) - if err != nil { - uc.logger.Errorw("failed to get resource group", "group_sid", groupSID, "error", err) - return nil, fmt.Errorf("failed to validate resource group: %w", err) - } - if group == nil { + // Batch fetch all groups to avoid N+1 queries + groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, cmd.GroupSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get resource groups", "error", err) + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + // Collect plan IDs for batch fetch + planIDs := make([]uint, 0, len(cmd.GroupSIDs)) + for _, groupSID := range cmd.GroupSIDs { + group, ok := groupMap[groupSID] + if !ok || group == nil { return nil, errors.NewNotFoundError("resource group", groupSID) } + planIDs = append(planIDs, group.PlanID()) + } - // Verify the plan type supports forward rules binding (node and hybrid only, not forward) - plan, err := uc.planRepo.GetByID(ctx, group.PlanID()) - if err != nil { - uc.logger.Errorw("failed to get plan for resource group", "plan_id", group.PlanID(), "error", err) - return nil, fmt.Errorf("failed to validate resource group plan: %w", err) - } - if plan == nil { + // Batch fetch all plans + plans, err := uc.planRepo.GetByIDs(ctx, planIDs) + if err != nil { + uc.logger.Errorw("failed to batch get plans", "error", err) + return nil, fmt.Errorf("failed to get plans: %w", err) + } + planMap := make(map[uint]*subscription.Plan, len(plans)) + for _, p := range plans { + planMap[p.ID()] = p + } + + // Validate each group and build groupIDs + groupIDs = make([]uint, 0, len(cmd.GroupSIDs)) + for _, groupSID := range cmd.GroupSIDs { + group := groupMap[groupSID] + plan, ok := planMap[group.PlanID()] + if !ok || plan == nil { return nil, fmt.Errorf("plan not found for resource group %s", groupSID) } if plan.PlanType().IsForward() { @@ -332,7 +361,6 @@ func (uc *CreateForwardRuleUseCase) Execute(ctx context.Context, cmd CreateForwa return nil, errors.NewValidationError( fmt.Sprintf("resource group %s belongs to a forward plan and cannot bind forward rules", groupSID)) } - groupIDs = append(groupIDs, group.ID()) } } @@ -658,28 +686,47 @@ func (uc *CreateForwardRuleUseCase) executeExternalRule(ctx context.Context, cmd // Resolve GroupSIDs to internal IDs and validate plan types (if provided) var groupIDs []uint if len(cmd.GroupSIDs) > 0 { - groupIDs = make([]uint, 0, len(cmd.GroupSIDs)) + // Validate SID formats first for _, groupSID := range cmd.GroupSIDs { if err := id.ValidatePrefix(groupSID, id.PrefixResourceGroup); err != nil { return nil, errors.NewValidationError(fmt.Sprintf("invalid resource group ID format: %s", groupSID)) } + } - group, err := uc.resourceGroupRepo.GetBySID(ctx, groupSID) - if err != nil { - uc.logger.Errorw("failed to get resource group", "group_sid", groupSID, "error", err) - return nil, fmt.Errorf("failed to validate resource group: %w", err) - } - if group == nil { + // Batch fetch all groups to avoid N+1 queries + groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, cmd.GroupSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get resource groups", "error", err) + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + // Collect plan IDs for batch fetch + planIDs := make([]uint, 0, len(cmd.GroupSIDs)) + for _, groupSID := range cmd.GroupSIDs { + group, ok := groupMap[groupSID] + if !ok || group == nil { return nil, errors.NewNotFoundError("resource group", groupSID) } + planIDs = append(planIDs, group.PlanID()) + } - // Verify the plan type supports forward rules binding - plan, err := uc.planRepo.GetByID(ctx, group.PlanID()) - if err != nil { - uc.logger.Errorw("failed to get plan for resource group", "plan_id", group.PlanID(), "error", err) - return nil, fmt.Errorf("failed to validate resource group plan: %w", err) - } - if plan == nil { + // Batch fetch all plans + plans, err := uc.planRepo.GetByIDs(ctx, planIDs) + if err != nil { + uc.logger.Errorw("failed to batch get plans", "error", err) + return nil, fmt.Errorf("failed to get plans: %w", err) + } + planMap := make(map[uint]*subscription.Plan, len(plans)) + for _, p := range plans { + planMap[p.ID()] = p + } + + // Validate each group and build groupIDs + groupIDs = make([]uint, 0, len(cmd.GroupSIDs)) + for _, groupSID := range cmd.GroupSIDs { + group := groupMap[groupSID] + plan, ok := planMap[group.PlanID()] + if !ok || plan == nil { return nil, fmt.Errorf("plan not found for resource group %s", groupSID) } if plan.PlanType().IsForward() { @@ -690,7 +737,6 @@ func (uc *CreateForwardRuleUseCase) executeExternalRule(ctx context.Context, cmd return nil, errors.NewValidationError( fmt.Sprintf("resource group %s belongs to a forward plan and cannot bind forward rules", groupSID)) } - groupIDs = append(groupIDs, group.ID()) } } diff --git a/internal/application/forward/usecases/listforwardagents.go b/internal/application/forward/usecases/listforwardagents.go index a6a2a4d..ecc1f05 100644 --- a/internal/application/forward/usecases/listforwardagents.go +++ b/internal/application/forward/usecases/listforwardagents.go @@ -97,40 +97,30 @@ func (uc *ListForwardAgentsUseCase) Execute(ctx context.Context, query ListForwa pages++ } - // Collect unique group IDs for batch query - groupIDSet := make(map[uint]struct{}) - for _, agent := range agents { - if agent.GroupID() != nil { - groupIDSet[*agent.GroupID()] = struct{}{} - } - } + // Convert to DTOs + dtos := dto.ToForwardAgentDTOs(agents) - // Query resource groups and build GroupInfoMap - var groupInfoMap dto.GroupInfoMap - if len(groupIDSet) > 0 && uc.resourceGroupRepo != nil { - groupIDs := make([]uint, 0, len(groupIDSet)) - for id := range groupIDSet { - groupIDs = append(groupIDs, id) - } + // Collect unique group IDs from DTOs for batch query + groupIDs := dto.CollectAgentGroupIDs(dtos) + // Query resource groups and populate GroupSIDs + if len(groupIDs) > 0 && uc.resourceGroupRepo != nil { groups, err := uc.resourceGroupRepo.GetByIDs(ctx, groupIDs) if err != nil { uc.logger.Warnw("failed to get resource groups, continuing without group info", "error", err, ) } else { - groupInfoMap = make(dto.GroupInfoMap, len(groups)) + groupSIDMap := make(dto.GroupSIDMap, len(groups)) for _, group := range groups { - groupInfoMap[group.ID()] = &dto.GroupInfo{ - SID: group.SID(), - Name: group.Name(), - } + groupSIDMap[group.ID()] = group.SID() + } + for _, d := range dtos { + d.PopulateGroupSIDs(groupSIDMap) } } } - dtos := dto.ToForwardAgentDTOs(agents, groupInfoMap) - // Collect agent IDs for batch status query and create ID mapping agentIDs := make([]uint, 0, len(agents)) idToIndexMap := make(map[uint]int, len(agents)) diff --git a/internal/application/forward/usecases/listuserforwardagents.go b/internal/application/forward/usecases/listuserforwardagents.go index 1220f08..dad2a5d 100644 --- a/internal/application/forward/usecases/listuserforwardagents.go +++ b/internal/application/forward/usecases/listuserforwardagents.go @@ -136,20 +136,21 @@ func (uc *ListUserForwardAgentsUseCase) Execute(ctx context.Context, query ListU }, nil } - // Step 4: Get active resource groups for these plans + // Step 4: Get active resource groups for these plans (batch query to avoid N+1) + groupsByPlan, err := uc.resourceGroupRepo.GetByPlanIDs(ctx, forwardPlanIDs) + if err != nil { + // Log warning but continue with empty result (consistent with original behavior) + uc.logger.Warnw("failed to batch get resource groups for plans", + "user_id", query.UserID, + "error", err, + ) + groupsByPlan = make(map[uint][]*resource.ResourceGroup) + } + groupIDs := make([]uint, 0) groupInfoMap := make(dto.GroupInfoMap) - for _, planID := range forwardPlanIDs { - groups, err := uc.resourceGroupRepo.GetByPlanID(ctx, planID) - if err != nil { - uc.logger.Warnw("failed to get resource groups for plan", - "plan_id", planID, - "error", err, - ) - continue - } - + for _, groups := range groupsByPlan { for _, group := range groups { if group.IsActive() { groupIDs = append(groupIDs, group.ID()) @@ -197,8 +198,11 @@ func (uc *ListUserForwardAgentsUseCase) Execute(ctx context.Context, query ListU pages++ } - // Step 6: Convert to user-facing DTOs - dtos := dto.ToUserForwardAgentDTOs(agents, groupInfoMap) + // Step 6: Convert to user-facing DTOs and populate groups + dtos := dto.ToUserForwardAgentDTOs(agents) + for _, d := range dtos { + d.PopulateGroups(groupInfoMap) + } uc.logger.Infow("user forward agents listed successfully", "user_id", query.UserID, diff --git a/internal/application/forward/usecases/updateforwardagent.go b/internal/application/forward/usecases/updateforwardagent.go index 69bd34d..eee7cc1 100644 --- a/internal/application/forward/usecases/updateforwardagent.go +++ b/internal/application/forward/usecases/updateforwardagent.go @@ -18,7 +18,7 @@ type UpdateForwardAgentCommand struct { PublicAddress *string TunnelAddress *string Remark *string - GroupSID *string // Resource group SID (empty string to remove association) + GroupSIDs []string // Resource group SIDs (empty slice to remove all associations) AllowedPortRange *string // nil: no update, empty string: clear (allow all), non-empty: set new range BlockedProtocols *[]string // nil: no update, empty slice: clear (allow all), non-empty: set new protocols SortOrder *int // nil: no update, non-nil: set new sort order @@ -98,23 +98,49 @@ func (uc *UpdateForwardAgentUseCase) Execute(ctx context.Context, cmd UpdateForw } } - // Handle GroupSID update (resolve SID to internal ID) - if cmd.GroupSID != nil { - if *cmd.GroupSID == "" { - // Empty string means remove the association - agent.SetGroupID(nil) + // Handle GroupSIDs update (resolve SIDs to internal IDs) + // Limit to 10 groups to prevent DoS attacks + const maxGroupSIDs = 10 + if len(cmd.GroupSIDs) > maxGroupSIDs { + return errors.NewValidationError(fmt.Sprintf("too many group_sids, maximum allowed is %d", maxGroupSIDs)) + } + // Note: We always update if GroupSIDs is provided (even empty slice means clear all) + if cmd.GroupSIDs != nil { + if len(cmd.GroupSIDs) == 0 { + // Empty slice means remove all associations + agent.SetGroupIDs(nil) } else { - // Resolve group SID to internal ID - group, err := uc.resourceGroupRepo.GetBySID(ctx, *cmd.GroupSID) + // Deduplicate and filter empty SIDs + uniqueSIDs := make([]string, 0, len(cmd.GroupSIDs)) + seenSIDs := make(map[string]struct{}, len(cmd.GroupSIDs)) + for _, sid := range cmd.GroupSIDs { + if sid == "" { + continue + } + if _, exists := seenSIDs[sid]; exists { + continue + } + seenSIDs[sid] = struct{}{} + uniqueSIDs = append(uniqueSIDs, sid) + } + + // Batch fetch all groups to avoid N+1 queries + groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, uniqueSIDs) if err != nil { - uc.logger.Errorw("failed to get resource group by SID", "group_sid", *cmd.GroupSID, "error", err) - return errors.NewNotFoundError("resource group", *cmd.GroupSID) + uc.logger.Errorw("failed to batch get resource groups", "error", err) + return fmt.Errorf("failed to get resource groups: %w", err) } - if group == nil { - return errors.NewNotFoundError("resource group", *cmd.GroupSID) + + // Resolve SIDs to internal IDs + resolvedIDs := make([]uint, 0, len(uniqueSIDs)) + for _, sid := range uniqueSIDs { + group, ok := groupMap[sid] + if !ok || group == nil { + return errors.NewNotFoundError("resource group", sid) + } + resolvedIDs = append(resolvedIDs, group.ID()) } - groupID := group.ID() - agent.SetGroupID(&groupID) + agent.SetGroupIDs(resolvedIDs) } } diff --git a/internal/application/forward/usecases/updateforwardrule.go b/internal/application/forward/usecases/updateforwardrule.go index 2c8a1d8..1dea588 100644 --- a/internal/application/forward/usecases/updateforwardrule.go +++ b/internal/application/forward/usecases/updateforwardrule.go @@ -96,6 +96,36 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa originalExitAgentIDs := rule.GetAllExitAgentIDs() // Includes both single and multiple exit agents originalChainAgentIDs := rule.ChainAgentIDs() + // Collect all agent SIDs that need to be fetched (to avoid N+1 queries) + allAgentSIDs := make([]string, 0) + if cmd.AgentShortID != nil { + allAgentSIDs = append(allAgentSIDs, *cmd.AgentShortID) + } + if cmd.ExitAgentShortID != nil { + allAgentSIDs = append(allAgentSIDs, *cmd.ExitAgentShortID) + } + for _, input := range cmd.ExitAgents { + allAgentSIDs = append(allAgentSIDs, input.AgentSID) + } + allAgentSIDs = append(allAgentSIDs, cmd.ChainAgentShortIDs...) + for shortID := range cmd.ChainPortConfig { + allAgentSIDs = append(allAgentSIDs, shortID) + } + + // Batch fetch all agents to avoid N+1 queries + var agentMap map[string]*forward.ForwardAgent + if len(allAgentSIDs) > 0 { + agents, err := uc.agentRepo.GetBySIDs(ctx, allAgentSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get agents", "error", err) + return fmt.Errorf("failed to get agents: %w", err) + } + agentMap = make(map[string]*forward.ForwardAgent, len(agents)) + for _, a := range agents { + agentMap[a.SID()] = a + } + } + // Update fields if cmd.Name != nil { if err := rule.UpdateName(*cmd.Name); err != nil { @@ -105,12 +135,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa // Update entry agent ID if cmd.AgentShortID != nil { - agent, err := uc.agentRepo.GetBySID(ctx, *cmd.AgentShortID) - if err != nil { - uc.logger.Errorw("failed to get agent", "agent_short_id", *cmd.AgentShortID, "error", err) - return fmt.Errorf("failed to validate agent: %w", err) - } - if agent == nil { + agent, ok := agentMap[*cmd.AgentShortID] + if !ok || agent == nil { return errors.NewNotFoundError("forward agent", *cmd.AgentShortID) } @@ -160,12 +186,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa // Update single exit agent ID (for entry type rules) if cmd.ExitAgentShortID != nil { - exitAgent, err := uc.agentRepo.GetBySID(ctx, *cmd.ExitAgentShortID) - if err != nil { - uc.logger.Errorw("failed to get exit agent", "exit_agent_short_id", *cmd.ExitAgentShortID, "error", err) - return fmt.Errorf("failed to validate exit agent: %w", err) - } - if exitAgent == nil { + exitAgent, ok := agentMap[*cmd.ExitAgentShortID] + if !ok || exitAgent == nil { return errors.NewNotFoundError("exit forward agent", *cmd.ExitAgentShortID) } @@ -185,12 +207,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa if len(cmd.ExitAgents) > 0 { exitAgents := make([]vo.AgentWeight, 0, len(cmd.ExitAgents)) for _, input := range cmd.ExitAgents { - exitAgent, err := uc.agentRepo.GetBySID(ctx, input.AgentSID) - if err != nil { - uc.logger.Errorw("failed to get exit agent", "exit_agent_sid", input.AgentSID, "error", err) - return fmt.Errorf("failed to validate exit agent: %w", err) - } - if exitAgent == nil { + exitAgent, ok := agentMap[input.AgentSID] + if !ok || exitAgent == nil { return errors.NewNotFoundError("exit forward agent", input.AgentSID) } @@ -226,12 +244,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa if cmd.ChainAgentShortIDs != nil { chainAgentIDs := make([]uint, len(cmd.ChainAgentShortIDs)) for i, shortID := range cmd.ChainAgentShortIDs { - chainAgent, err := uc.agentRepo.GetBySID(ctx, shortID) - if err != nil { - uc.logger.Errorw("failed to get chain agent", "chain_agent_short_id", shortID, "error", err) - return fmt.Errorf("failed to validate chain agent: %w", err) - } - if chainAgent == nil { + chainAgent, ok := agentMap[shortID] + if !ok || chainAgent == nil { return errors.NewNotFoundError("chain forward agent", shortID) } @@ -256,12 +270,8 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa oldChainPortConfig := rule.ChainPortConfig() chainPortConfig := make(map[uint]uint16, len(cmd.ChainPortConfig)) for shortID, port := range cmd.ChainPortConfig { - chainAgent, err := uc.agentRepo.GetBySID(ctx, shortID) - if err != nil { - uc.logger.Errorw("failed to get chain agent for port config", "chain_agent_short_id", shortID, "error", err) - return fmt.Errorf("failed to validate chain agent in chain_port_config: %w", err) - } - if chainAgent == nil { + chainAgent, ok := agentMap[shortID] + if !ok || chainAgent == nil { return errors.NewNotFoundError("chain forward agent in chain_port_config", shortID) } @@ -474,29 +484,47 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa if cmd.GroupSIDs != nil { var groupIDs []uint if len(*cmd.GroupSIDs) > 0 { - groupIDs = make([]uint, 0, len(*cmd.GroupSIDs)) + // Validate SID formats first for _, groupSID := range *cmd.GroupSIDs { - // Validate the SID format (rg_xxx) if err := id.ValidatePrefix(groupSID, id.PrefixResourceGroup); err != nil { return errors.NewValidationError(fmt.Sprintf("invalid resource group ID format: %s", groupSID)) } + } - group, err := uc.resourceGroupRepo.GetBySID(ctx, groupSID) - if err != nil { - uc.logger.Errorw("failed to get resource group", "group_sid", groupSID, "error", err) - return fmt.Errorf("failed to validate resource group: %w", err) - } - if group == nil { + // Batch fetch all groups to avoid N+1 queries + groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, *cmd.GroupSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get resource groups", "error", err) + return fmt.Errorf("failed to get resource groups: %w", err) + } + + // Collect plan IDs for batch fetch + planIDs := make([]uint, 0, len(*cmd.GroupSIDs)) + for _, groupSID := range *cmd.GroupSIDs { + group, ok := groupMap[groupSID] + if !ok || group == nil { return errors.NewNotFoundError("resource group", groupSID) } + planIDs = append(planIDs, group.PlanID()) + } - // Verify the plan type supports forward rules binding (node and hybrid only, not forward) - plan, err := uc.planRepo.GetByID(ctx, group.PlanID()) - if err != nil { - uc.logger.Errorw("failed to get plan for resource group", "plan_id", group.PlanID(), "error", err) - return fmt.Errorf("failed to validate resource group plan: %w", err) - } - if plan == nil { + // Batch fetch all plans to avoid N+1 queries + plans, err := uc.planRepo.GetByIDs(ctx, planIDs) + if err != nil { + uc.logger.Errorw("failed to batch get plans", "error", err) + return fmt.Errorf("failed to get plans: %w", err) + } + planMap := make(map[uint]*subscription.Plan, len(plans)) + for _, p := range plans { + planMap[p.ID()] = p + } + + // Validate each group and build groupIDs + groupIDs = make([]uint, 0, len(*cmd.GroupSIDs)) + for _, groupSID := range *cmd.GroupSIDs { + group := groupMap[groupSID] + plan, ok := planMap[group.PlanID()] + if !ok || plan == nil { return fmt.Errorf("plan not found for resource group %s", groupSID) } if plan.PlanType().IsForward() { @@ -507,7 +535,6 @@ func (uc *UpdateForwardRuleUseCase) Execute(ctx context.Context, cmd UpdateForwa return errors.NewValidationError( fmt.Sprintf("resource group %s belongs to a forward plan and cannot bind forward rules", groupSID)) } - groupIDs = append(groupIDs, group.ID()) } } @@ -632,14 +659,15 @@ func (uc *UpdateForwardRuleUseCase) getAccessibleGroupIDs(ctx context.Context, u return nil, nil } - // Step 4: Get active resource groups for these plans + // Step 4: Get active resource groups for these plans (batch query to avoid N+1) + groupsByPlan, err := uc.resourceGroupRepo.GetByPlanIDs(ctx, forwardPlanIDs) + if err != nil { + uc.logger.Warnw("failed to batch get resource groups for plans", "error", err) + return nil, nil + } + groupIDs := make([]uint, 0) - for _, planID := range forwardPlanIDs { - groups, err := uc.resourceGroupRepo.GetByPlanID(ctx, planID) - if err != nil { - uc.logger.Warnw("failed to get resource groups for plan", "plan_id", planID, "error", err) - continue - } + for _, groups := range groupsByPlan { for _, group := range groups { if group.IsActive() { groupIDs = append(groupIDs, group.ID()) @@ -652,6 +680,7 @@ func (uc *UpdateForwardRuleUseCase) getAccessibleGroupIDs(ctx context.Context, u // validateUserAgentAccess checks if the user has access to the specified agent. // Returns nil if access is allowed, or an error if access is denied. +// Access is granted if any of the agent's group IDs is in the user's accessible group IDs. func (uc *UpdateForwardRuleUseCase) validateUserAgentAccess(ctx context.Context, userID uint, agent *forward.ForwardAgent) error { // Get user's accessible group IDs accessibleGroupIDs, err := uc.getAccessibleGroupIDs(ctx, userID) @@ -659,9 +688,9 @@ func (uc *UpdateForwardRuleUseCase) validateUserAgentAccess(ctx context.Context, return fmt.Errorf("failed to get accessible groups: %w", err) } - // Check if agent's group ID is in the accessible list - agentGroupID := agent.GroupID() - if agentGroupID == nil { + // Check if any of agent's group IDs is in the accessible list + agentGroupIDs := agent.GroupIDs() + if len(agentGroupIDs) == 0 { // Agent has no group assigned, deny access for user endpoints uc.logger.Warnw("user attempted to access agent without group", "user_id", userID, @@ -669,11 +698,20 @@ func (uc *UpdateForwardRuleUseCase) validateUserAgentAccess(ctx context.Context, return errors.NewForbiddenError("agent is not accessible to user") } - if !containsUint(accessibleGroupIDs, *agentGroupID) { + // Check if there's any intersection between agent's groups and user's accessible groups + hasAccess := false + for _, agentGroupID := range agentGroupIDs { + if containsUint(accessibleGroupIDs, agentGroupID) { + hasAccess = true + break + } + } + + if !hasAccess { uc.logger.Warnw("user attempted to access unauthorized agent", "user_id", userID, "agent_sid", agent.SID(), - "agent_group_id", *agentGroupID, + "agent_group_ids", agentGroupIDs, "accessible_groups", accessibleGroupIDs) return errors.NewForbiddenError("user does not have access to this agent") } diff --git a/internal/application/node/usecases/createnode.go b/internal/application/node/usecases/createnode.go index 8c84b9f..4e518c4 100644 --- a/internal/application/node/usecases/createnode.go +++ b/internal/application/node/usecases/createnode.go @@ -378,30 +378,37 @@ func (uc *CreateNodeUseCase) Execute(ctx context.Context, cmd CreateNodeCommand) // Handle GroupSIDs (resolve SIDs to internal IDs) if len(cmd.GroupSIDs) > 0 { + // Deduplicate and filter empty SIDs + uniqueSIDs := make([]string, 0, len(cmd.GroupSIDs)) seenSIDs := make(map[string]struct{}, len(cmd.GroupSIDs)) - resolvedIDs := make([]uint, 0, len(cmd.GroupSIDs)) for _, sid := range cmd.GroupSIDs { - // Skip empty strings if sid == "" { continue } - // Skip duplicate SIDs if _, exists := seenSIDs[sid]; exists { continue } seenSIDs[sid] = struct{}{} - - group, err := uc.resourceGroupRepo.GetBySID(ctx, sid) - if err != nil { - uc.logger.Errorw("failed to get resource group by SID", "group_sid", sid, "error", err) - return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid)) - } - if group == nil { - return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid)) - } - resolvedIDs = append(resolvedIDs, group.ID()) + uniqueSIDs = append(uniqueSIDs, sid) } - if len(resolvedIDs) > 0 { + + if len(uniqueSIDs) > 0 { + // Batch fetch all groups to avoid N+1 queries + groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, uniqueSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get resource groups", "error", err) + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + // Resolve SIDs to internal IDs + resolvedIDs := make([]uint, 0, len(uniqueSIDs)) + for _, sid := range uniqueSIDs { + group, ok := groupMap[sid] + if !ok || group == nil { + return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid)) + } + resolvedIDs = append(resolvedIDs, group.ID()) + } nodeEntity.SetGroupIDs(resolvedIDs) } } diff --git a/internal/application/node/usecases/getusernodeusage.go b/internal/application/node/usecases/getusernodeusage.go index fed7b63..4cdd5a8 100644 --- a/internal/application/node/usecases/getusernodeusage.go +++ b/internal/application/node/usecases/getusernodeusage.go @@ -69,15 +69,26 @@ func (uc *GetUserNodeUsageUseCase) Execute(ctx context.Context, query GetUserNod maxNodeLimit := 0 hasUnlimitedNodes := false + // Batch fetch all plans to avoid N+1 queries + planIDs := make([]uint, 0, len(subscriptions)) + for _, sub := range subscriptions { + planIDs = append(planIDs, sub.PlanID()) + } + plans, err := uc.planRepo.GetByIDs(ctx, planIDs) + if err != nil { + uc.logger.Errorw("failed to batch fetch plans", "user_id", query.UserID, "error", err) + return nil, fmt.Errorf("failed to get plans: %w", err) + } + planMap := make(map[uint]*subscription.Plan, len(plans)) + for _, plan := range plans { + planMap[plan.ID()] = plan + } + // Find the highest limit among all active subscriptions for _, sub := range subscriptions { - plan, err := uc.planRepo.GetByID(ctx, sub.PlanID()) - if err != nil { - uc.logger.Warnw("failed to get plan for subscription", "subscription_id", sub.ID(), "plan_id", sub.PlanID(), "error", err) - continue - } - - if plan == nil { + plan, ok := planMap[sub.PlanID()] + if !ok { + uc.logger.Warnw("plan not found for subscription", "subscription_id", sub.ID(), "plan_id", sub.PlanID()) continue } diff --git a/internal/application/node/usecases/updatenode.go b/internal/application/node/usecases/updatenode.go index d3cf5ab..18aad33 100644 --- a/internal/application/node/usecases/updatenode.go +++ b/internal/application/node/usecases/updatenode.go @@ -186,33 +186,37 @@ func (uc *UpdateNodeUseCase) Execute(ctx context.Context, cmd UpdateNodeCommand) // Empty slice means remove all group associations existingNode.SetGroupIDs(nil) } else { - // Resolve each group SID to internal ID (with deduplication) + // Deduplicate and filter empty SIDs + uniqueSIDs := make([]string, 0, len(cmd.GroupSIDs)) seenSIDs := make(map[string]struct{}, len(cmd.GroupSIDs)) - resolvedIDs := make([]uint, 0, len(cmd.GroupSIDs)) for _, sid := range cmd.GroupSIDs { - // Skip empty strings if sid == "" { continue } - // Skip duplicate SIDs if _, exists := seenSIDs[sid]; exists { continue } seenSIDs[sid] = struct{}{} - - group, err := uc.resourceGroupRepo.GetBySID(ctx, sid) - if err != nil { - uc.logger.Errorw("failed to get resource group by SID", "group_sid", sid, "error", err) - return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid)) - } - if group == nil { - return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid)) - } - resolvedIDs = append(resolvedIDs, group.ID()) + uniqueSIDs = append(uniqueSIDs, sid) } - // Only update if we have valid SIDs after filtering - // (prevents accidental clear when all SIDs are empty strings) - if len(resolvedIDs) > 0 { + + if len(uniqueSIDs) > 0 { + // Batch fetch all groups to avoid N+1 queries + groupMap, err := uc.resourceGroupRepo.GetBySIDs(ctx, uniqueSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get resource groups", "error", err) + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + // Resolve SIDs to internal IDs + resolvedIDs := make([]uint, 0, len(uniqueSIDs)) + for _, sid := range uniqueSIDs { + group, ok := groupMap[sid] + if !ok || group == nil { + return nil, errors.NewNotFoundError(fmt.Sprintf("resource group not found: %s", sid)) + } + resolvedIDs = append(resolvedIDs, group.ID()) + } existingNode.SetGroupIDs(resolvedIDs) } } diff --git a/internal/application/payment/usecases/cancelunpaidsubscriptions.go b/internal/application/payment/usecases/cancelunpaidsubscriptions.go index 9f86fe8..94406bc 100644 --- a/internal/application/payment/usecases/cancelunpaidsubscriptions.go +++ b/internal/application/payment/usecases/cancelunpaidsubscriptions.go @@ -75,16 +75,25 @@ func (uc *CancelUnpaidSubscriptionsUseCase) Execute(ctx context.Context) (int, e now := biztime.NowUTC() cancelledCount := 0 + // Batch fetch subscription IDs with pending payments to avoid N+1 queries + subscriptionIDs := make([]uint, 0, len(subscriptions)) + for _, sub := range subscriptions { + subscriptionIDs = append(subscriptionIDs, sub.ID()) + } + subsWithPendingPayments, err := uc.paymentRepo.GetSubscriptionIDsWithPendingPayments(ctx, subscriptionIDs) + if err != nil { + uc.logger.Errorw("failed to batch check pending payments", "error", err) + return 0, fmt.Errorf("failed to check pending payments: %w", err) + } + pendingPaymentSet := make(map[uint]bool, len(subsWithPendingPayments)) + for _, subID := range subsWithPendingPayments { + pendingPaymentSet[subID] = true + } + for _, sub := range subscriptions { // Check if there are any pending payments for this subscription first // A new payment may have been created, skip cancellation - hasPending, err := uc.paymentRepo.HasPendingPaymentBySubscriptionID(ctx, sub.ID()) - if err != nil { - uc.logger.Errorw("failed to check pending payments", - "subscription_id", sub.ID(), - "error", err) - continue - } + hasPending := pendingPaymentSet[sub.ID()] if hasPending { // New pending payment exists, skip cancellation diff --git a/internal/application/payment/usecases/expirepayments.go b/internal/application/payment/usecases/expirepayments.go index b4e39a9..2d3272a 100644 --- a/internal/application/payment/usecases/expirepayments.go +++ b/internal/application/payment/usecases/expirepayments.go @@ -42,6 +42,18 @@ func (uc *ExpirePaymentsUseCase) Execute(ctx context.Context) (int, error) { uc.logger.Infow("processing expired payments", "count", len(expiredPayments)) + // Batch fetch all subscriptions to avoid N+1 queries + subscriptionIDs := make([]uint, 0, len(expiredPayments)) + for _, p := range expiredPayments { + subscriptionIDs = append(subscriptionIDs, p.SubscriptionID()) + } + subscriptionMap, err := uc.subscriptionRepo.GetByIDs(ctx, subscriptionIDs) + if err != nil { + uc.logger.Warnw("failed to batch fetch subscriptions", "error", err) + // Continue with empty map, will log warnings for individual lookups + subscriptionMap = make(map[uint]*subscription.Subscription) + } + expiredCount := 0 for _, p := range expiredPayments { if err := p.MarkAsExpired(); err != nil { @@ -61,11 +73,19 @@ func (uc *ExpirePaymentsUseCase) Execute(ctx context.Context) (int, error) { } // Record payment expiration time on subscription for auto-cancel grace period - if err := uc.markSubscriptionPaymentExpired(ctx, p.SubscriptionID()); err != nil { - uc.logger.Warnw("failed to mark subscription payment expired", - "error", err, + sub, ok := subscriptionMap[p.SubscriptionID()] + if !ok || sub == nil { + uc.logger.Warnw("subscription not found for payment", + "payment_id", p.ID(), "subscription_id", p.SubscriptionID()) - // Continue processing other payments even if this fails + } else { + // Record the payment expiration time for grace period calculation + sub.SetMetadata("payment_expired_at", biztime.FormatMetadataTime(biztime.NowUTC())) + if err := uc.subscriptionRepo.Update(ctx, sub); err != nil { + uc.logger.Warnw("failed to update subscription payment_expired_at", + "error", err, + "subscription_id", p.SubscriptionID()) + } } expiredCount++ @@ -81,24 +101,3 @@ func (uc *ExpirePaymentsUseCase) Execute(ctx context.Context) (int, error) { return expiredCount, nil } - -// markSubscriptionPaymentExpired records the payment expiration time on the subscription -// This is used by the auto-cancel logic to determine the grace period -func (uc *ExpirePaymentsUseCase) markSubscriptionPaymentExpired(ctx context.Context, subscriptionID uint) error { - sub, err := uc.subscriptionRepo.GetByID(ctx, subscriptionID) - if err != nil { - return fmt.Errorf("failed to get subscription: %w", err) - } - if sub == nil { - return fmt.Errorf("subscription not found: %d", subscriptionID) - } - - // Record the payment expiration time for grace period calculation - sub.SetMetadata("payment_expired_at", biztime.FormatMetadataTime(biztime.NowUTC())) - - if err := uc.subscriptionRepo.Update(ctx, sub); err != nil { - return fmt.Errorf("failed to update subscription: %w", err) - } - - return nil -} diff --git a/internal/application/resource/usecases/manageresourcegroupforwardagents.go b/internal/application/resource/usecases/manageresourcegroupforwardagents.go index 14e1cdd..b03f319 100644 --- a/internal/application/resource/usecases/manageresourcegroupforwardagents.go +++ b/internal/application/resource/usecases/manageresourcegroupforwardagents.go @@ -129,15 +129,23 @@ func (uc *ManageResourceGroupForwardAgentsUseCase) executeAddAgents(ctx context. } // Check if already in this group - if agent.GroupID() != nil && *agent.GroupID() == groupID { + currentGroupIDs := agent.GroupIDs() + alreadyInGroup := false + for _, gid := range currentGroupIDs { + if gid == groupID { + alreadyInGroup = true + break + } + } + if alreadyInGroup { // Already in this group, count as success result.Succeeded = append(result.Succeeded, agentSID) continue } - // Set group ID and update - gid := groupID - agent.SetGroupID(&gid) + // Add group ID to the list and update + newGroupIDs := append(currentGroupIDs, groupID) + agent.SetGroupIDs(newGroupIDs) if err := uc.agentRepo.Update(ctx, agent); err != nil { uc.logger.Errorw("failed to update forward agent", "error", err, "agent_sid", agentSID) result.Failed = append(result.Failed, dto.BatchOperationErr{ @@ -236,7 +244,15 @@ func (uc *ManageResourceGroupForwardAgentsUseCase) executeRemoveAgents(ctx conte } // Check if the agent belongs to this group - if agent.GroupID() == nil || *agent.GroupID() != groupID { + currentGroupIDs := agent.GroupIDs() + foundIndex := -1 + for i, gid := range currentGroupIDs { + if gid == groupID { + foundIndex = i + break + } + } + if foundIndex == -1 { result.Failed = append(result.Failed, dto.BatchOperationErr{ ID: agentSID, Reason: "forward agent does not belong to this group", @@ -244,8 +260,14 @@ func (uc *ManageResourceGroupForwardAgentsUseCase) executeRemoveAgents(ctx conte continue } - // Remove group ID - agent.SetGroupID(nil) + // Remove group ID from the list + newGroupIDs := make([]uint, 0, len(currentGroupIDs)-1) + for i, gid := range currentGroupIDs { + if i != foundIndex { + newGroupIDs = append(newGroupIDs, gid) + } + } + agent.SetGroupIDs(newGroupIDs) if err := uc.agentRepo.Update(ctx, agent); err != nil { uc.logger.Errorw("failed to update forward agent", "error", err, "agent_sid", agentSID) result.Failed = append(result.Failed, dto.BatchOperationErr{ diff --git a/internal/application/resource/usecases/manageresourcegroupnodes.go b/internal/application/resource/usecases/manageresourcegroupnodes.go index 6d1cd1b..879ab8e 100644 --- a/internal/application/resource/usecases/manageresourcegroupnodes.go +++ b/internal/application/resource/usecases/manageresourcegroupnodes.go @@ -85,18 +85,22 @@ func (uc *ManageResourceGroupNodesUseCase) executeAddNodes(ctx context.Context, Failed: make([]dto.BatchOperationErr, 0), } + // Batch fetch all nodes to avoid N+1 queries + nodes, err := uc.nodeRepo.GetBySIDs(ctx, nodeSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get nodes", "error", err) + return nil, fmt.Errorf("failed to get nodes: %w", err) + } + + // Build a map for quick lookup + nodeMap := make(map[string]*node.Node, len(nodes)) + for _, n := range nodes { + nodeMap[n.SID()] = n + } + for _, nodeSID := range nodeSIDs { - // Get node by SID - n, err := uc.nodeRepo.GetBySID(ctx, nodeSID) - if err != nil { - uc.logger.Warnw("failed to get node", "error", err, "node_sid", nodeSID) - result.Failed = append(result.Failed, dto.BatchOperationErr{ - ID: nodeSID, - Reason: "failed to get node", - }) - continue - } - if n == nil { + n, ok := nodeMap[nodeSID] + if !ok { result.Failed = append(result.Failed, dto.BatchOperationErr{ ID: nodeSID, Reason: "node not found", @@ -167,18 +171,22 @@ func (uc *ManageResourceGroupNodesUseCase) executeRemoveNodes(ctx context.Contex Failed: make([]dto.BatchOperationErr, 0), } + // Batch fetch all nodes to avoid N+1 queries + nodes, err := uc.nodeRepo.GetBySIDs(ctx, nodeSIDs) + if err != nil { + uc.logger.Errorw("failed to batch get nodes", "error", err) + return nil, fmt.Errorf("failed to get nodes: %w", err) + } + + // Build a map for quick lookup + nodeMap := make(map[string]*node.Node, len(nodes)) + for _, n := range nodes { + nodeMap[n.SID()] = n + } + for _, nodeSID := range nodeSIDs { - // Get node by SID - n, err := uc.nodeRepo.GetBySID(ctx, nodeSID) - if err != nil { - uc.logger.Warnw("failed to get node", "error", err, "node_sid", nodeSID) - result.Failed = append(result.Failed, dto.BatchOperationErr{ - ID: nodeSID, - Reason: "failed to get node", - }) - continue - } - if n == nil { + n, ok := nodeMap[nodeSID] + if !ok { result.Failed = append(result.Failed, dto.BatchOperationErr{ ID: nodeSID, Reason: "node not found", @@ -265,13 +273,23 @@ func (uc *ManageResourceGroupNodesUseCase) executeListNodes(ctx context.Context, for _, n := range nodes { groupIDSet.AddAll(n.GroupIDs()) } + + // Batch fetch group SIDs to avoid N+1 queries groupIDToSID := make(map[uint]string) groupIDToSID[groupID] = group.SID() // Current group is already loaded + otherGroupIDs := make([]uint, 0) for _, gid := range groupIDSet.ToSlice() { if gid != groupID { - g, err := uc.resourceGroupRepo.GetByID(ctx, gid) - if err == nil && g != nil { - groupIDToSID[gid] = g.SID() + otherGroupIDs = append(otherGroupIDs, gid) + } + } + if len(otherGroupIDs) > 0 { + sidMap, err := uc.resourceGroupRepo.GetSIDsByIDs(ctx, otherGroupIDs) + if err != nil { + uc.logger.Warnw("failed to batch get group SIDs", "error", err) + } else { + for gid, sid := range sidMap { + groupIDToSID[gid] = sid } } } diff --git a/internal/application/telegram/usecases/processreminder.go b/internal/application/telegram/usecases/processreminder.go index 9ad4cb2..65832e8 100644 --- a/internal/application/telegram/usecases/processreminder.go +++ b/internal/application/telegram/usecases/processreminder.go @@ -100,27 +100,46 @@ func (uc *ProcessReminderUseCase) processExpiringSubscriptions(ctx context.Conte return 0, 1 } + // Collect unique expiringDays values to avoid N+1 queries + expiringDaysSet := make(map[int]struct{}) + validBindings := make([]*telegram.TelegramBinding, 0, len(bindings)) for _, binding := range bindings { if !binding.CanNotifyExpiring() { continue } + expiringDaysSet[binding.ExpiringDays()] = struct{}{} + validBindings = append(validBindings, binding) + } - // Find expiring subscriptions for this user - subs, err := uc.subscriptionRepo.FindExpiringSubscriptions(ctx, binding.ExpiringDays()) + if len(validBindings) == 0 { + return 0, 0 + } + + // Batch fetch expiring subscriptions for each unique expiringDays value + // Key: expiringDays, Value: subscriptions grouped by userID + subscriptionsByDaysAndUser := make(map[int]map[uint][]*subscription.Subscription) + for days := range expiringDaysSet { + subs, err := uc.subscriptionRepo.FindExpiringSubscriptions(ctx, days) if err != nil { - uc.logger.Errorw("failed to find expiring subscriptions", "user_id", binding.UserID(), "error", err) + uc.logger.Errorw("failed to find expiring subscriptions", "days", days, "error", err) errors++ continue } - - // Filter to only this user's subscriptions - var userSubs []*subscription.Subscription + // Group subscriptions by userID + userSubsMap := make(map[uint][]*subscription.Subscription) for _, sub := range subs { - if sub.UserID() == binding.UserID() { - userSubs = append(userSubs, sub) - } + userSubsMap[sub.UserID()] = append(userSubsMap[sub.UserID()], sub) } + subscriptionsByDaysAndUser[days] = userSubsMap + } + for _, binding := range validBindings { + // Get pre-fetched subscriptions for this binding's expiringDays and userID + userSubsMap, ok := subscriptionsByDaysAndUser[binding.ExpiringDays()] + if !ok { + continue + } + userSubs := userSubsMap[binding.UserID()] if len(userSubs) == 0 { continue } @@ -185,9 +204,24 @@ func (uc *ProcessReminderUseCase) processTrafficUsage(ctx context.Context) (int, planSubscriptions[sub.PlanID()] = append(planSubscriptions[sub.PlanID()], sub) } + // Batch fetch all plans to avoid N+1 queries + planIDs := make([]uint, 0, len(planSubscriptions)) + for planID := range planSubscriptions { + planIDs = append(planIDs, planID) + } + plans, err := uc.planRepo.GetByIDs(ctx, planIDs) + if err != nil { + uc.logger.Warnw("failed to batch fetch plans", "error", err) + continue + } + planMap := make(map[uint]*subscription.Plan, len(plans)) + for _, plan := range plans { + planMap[plan.ID()] = plan + } + for planID, planSubs := range planSubscriptions { - plan, err := uc.planRepo.GetByID(ctx, planID) - if err != nil { + plan, ok := planMap[planID] + if !ok { continue } diff --git a/internal/domain/forward/forwardagent.go b/internal/domain/forward/forwardagent.go index f64e44b..7255440 100644 --- a/internal/domain/forward/forwardagent.go +++ b/internal/domain/forward/forwardagent.go @@ -45,7 +45,7 @@ type ForwardAgent struct { publicAddress string // optional public address for Entry to obtain Exit connection information tunnelAddress string // IP or hostname only (no port), configure if agent may serve as relay/exit in any rule remark string - groupID *uint // resource group ID + groupIDs []uint // resource group IDs agentVersion string // agent software version (e.g., "1.2.3") platform string // OS platform (linux, darwin, windows) arch string // CPU architecture (amd64, arm64, arm, 386) @@ -121,7 +121,7 @@ func ReconstructForwardAgent( publicAddress string, tunnelAddress string, remark string, - groupID *uint, + groupIDs []uint, agentVersion string, platform string, arch string, @@ -172,7 +172,7 @@ func ReconstructForwardAgent( publicAddress: publicAddress, tunnelAddress: tunnelAddress, remark: remark, - groupID: groupID, + groupIDs: groupIDs, agentVersion: agentVersion, platform: platform, arch: arch, @@ -236,14 +236,14 @@ func (a *ForwardAgent) GetEffectiveTunnelAddress() string { return a.publicAddress } -// GroupID returns the resource group ID -func (a *ForwardAgent) GroupID() *uint { - return a.groupID +// GroupIDs returns the resource group IDs +func (a *ForwardAgent) GroupIDs() []uint { + return a.groupIDs } -// SetGroupID sets the resource group ID -func (a *ForwardAgent) SetGroupID(groupID *uint) { - a.groupID = groupID +// SetGroupIDs sets the resource group IDs +func (a *ForwardAgent) SetGroupIDs(groupIDs []uint) { + a.groupIDs = groupIDs a.updatedAt = biztime.NowUTC() } diff --git a/internal/domain/payment/repository.go b/internal/domain/payment/repository.go index 52238ce..f8d66a7 100644 --- a/internal/domain/payment/repository.go +++ b/internal/domain/payment/repository.go @@ -12,6 +12,9 @@ type PaymentRepository interface { GetPendingBySubscriptionID(ctx context.Context, subscriptionID uint) (*Payment, error) // HasPendingPaymentBySubscriptionID checks if there are any pending payments for a subscription HasPendingPaymentBySubscriptionID(ctx context.Context, subscriptionID uint) (bool, error) + // GetSubscriptionIDsWithPendingPayments returns subscription IDs that have pending payments + // from the given list of subscription IDs + GetSubscriptionIDsWithPendingPayments(ctx context.Context, subscriptionIDs []uint) ([]uint, error) GetExpiredPayments(ctx context.Context) ([]*Payment, error) GetPendingUSDTPayments(ctx context.Context) ([]*Payment, error) // GetConfirmedUSDTPaymentsNeedingActivation returns confirmed USDT payments diff --git a/internal/domain/resource/repository.go b/internal/domain/resource/repository.go index d61ad84..0ec7308 100644 --- a/internal/domain/resource/repository.go +++ b/internal/domain/resource/repository.go @@ -22,12 +22,20 @@ type Repository interface { // GetBySID retrieves a resource group by Stripe-style ID GetBySID(ctx context.Context, sid string) (*ResourceGroup, error) + // GetBySIDs retrieves resource groups by their Stripe-style IDs + // Returns a map from SID to ResourceGroup for efficient lookup + GetBySIDs(ctx context.Context, sids []string) (map[string]*ResourceGroup, error) + // GetSIDsByIDs retrieves a map of resource group IDs to their SIDs GetSIDsByIDs(ctx context.Context, ids []uint) (map[uint]string, error) // GetByPlanID retrieves all resource groups for a plan GetByPlanID(ctx context.Context, planID uint) ([]*ResourceGroup, error) + // GetByPlanIDs retrieves all resource groups for multiple plans + // Returns a map from planID to list of ResourceGroups + GetByPlanIDs(ctx context.Context, planIDs []uint) (map[uint][]*ResourceGroup, error) + // List retrieves resource groups with optional filters List(ctx context.Context, filter ListFilter) ([]*ResourceGroup, int64, error) diff --git a/internal/infrastructure/migration/scripts/047_forward_agent_group_ids_to_array.sql b/internal/infrastructure/migration/scripts/047_forward_agent_group_ids_to_array.sql new file mode 100644 index 0000000..f6a2b07 --- /dev/null +++ b/internal/infrastructure/migration/scripts/047_forward_agent_group_ids_to_array.sql @@ -0,0 +1,28 @@ +-- +goose Up +-- Migration: Convert forward_agents.group_id to group_ids JSON array +-- Description: Support forward agents belonging to multiple resource groups + +-- Step 1: Add new group_ids JSON column +ALTER TABLE forward_agents ADD COLUMN group_ids JSON DEFAULT NULL; + +-- Step 2: Migrate existing data from group_id to group_ids +UPDATE forward_agents SET group_ids = JSON_ARRAY(group_id) WHERE group_id IS NOT NULL; + +-- Step 3: Drop the old group_id column and its index +DROP INDEX idx_forward_agent_group_id ON forward_agents; +ALTER TABLE forward_agents DROP COLUMN group_id; + +-- +goose Down +-- Rollback: Convert group_ids JSON array back to single group_id + +-- Step 1: Add back group_id column +ALTER TABLE forward_agents ADD COLUMN group_id BIGINT UNSIGNED NULL; + +-- Step 2: Migrate data back (take first element from JSON array) +UPDATE forward_agents SET group_id = JSON_EXTRACT(group_ids, '$[0]') WHERE group_ids IS NOT NULL AND JSON_LENGTH(group_ids) > 0; + +-- Step 3: Drop group_ids column +ALTER TABLE forward_agents DROP COLUMN group_ids; + +-- Step 4: Recreate index +CREATE INDEX idx_forward_agent_group_id ON forward_agents (group_id); diff --git a/internal/infrastructure/persistence/mappers/forwardagentmapper.go b/internal/infrastructure/persistence/mappers/forwardagentmapper.go index 332d307..ff6aaf6 100644 --- a/internal/infrastructure/persistence/mappers/forwardagentmapper.go +++ b/internal/infrastructure/persistence/mappers/forwardagentmapper.go @@ -60,6 +60,14 @@ func (m *ForwardAgentMapperImpl) ToEntity(model *models.ForwardAgentModel) (*for blockedProtocols = vo.NewBlockedProtocols(protocols) } + // Parse group_ids from JSON + var groupIDs []uint + if len(model.GroupIDs) > 0 { + if err := json.Unmarshal(model.GroupIDs, &groupIDs); err != nil { + return nil, fmt.Errorf("failed to parse group_ids: %w", err) + } + } + entity, err := forward.ReconstructForwardAgent( model.ID, model.SID, @@ -70,7 +78,7 @@ func (m *ForwardAgentMapperImpl) ToEntity(model *models.ForwardAgentModel) (*for model.PublicAddress, model.TunnelAddress, model.Remark, - model.GroupID, + groupIDs, model.AgentVersion, model.Platform, model.Arch, @@ -116,6 +124,16 @@ func (m *ForwardAgentMapperImpl) ToModel(entity *forward.ForwardAgent) (*models. } } + // Serialize group_ids to JSON + var groupIDsJSON []byte + if len(entity.GroupIDs()) > 0 { + var err error + groupIDsJSON, err = json.Marshal(entity.GroupIDs()) + if err != nil { + return nil, fmt.Errorf("failed to serialize group_ids: %w", err) + } + } + return &models.ForwardAgentModel{ ID: entity.ID(), SID: entity.SID(), @@ -126,7 +144,7 @@ func (m *ForwardAgentMapperImpl) ToModel(entity *forward.ForwardAgent) (*models. TunnelAddress: entity.TunnelAddress(), Status: string(entity.Status()), Remark: entity.Remark(), - GroupID: entity.GroupID(), + GroupIDs: groupIDsJSON, AgentVersion: entity.AgentVersion(), Platform: entity.Platform(), Arch: entity.Arch(), diff --git a/internal/infrastructure/persistence/models/forwardagentmodel.go b/internal/infrastructure/persistence/models/forwardagentmodel.go index 24fa74f..f71b913 100644 --- a/internal/infrastructure/persistence/models/forwardagentmodel.go +++ b/internal/infrastructure/persistence/models/forwardagentmodel.go @@ -20,7 +20,7 @@ type ForwardAgentModel struct { TunnelAddress string `gorm:"size:255"` // tunnel address for entry to connect to exit (nullable, overrides public_address) Status string `gorm:"not null;default:enabled;size:20;index:idx_forward_agent_status"` Remark string `gorm:"size:500"` - GroupID *uint `gorm:"index:idx_forward_agent_group_id"` // resource group ID + GroupIDs datatypes.JSON `gorm:"column:group_ids"` // resource group IDs (JSON array) AgentVersion string `gorm:"size:50"` // agent software version (e.g., "1.2.3") Platform string `gorm:"size:20"` // OS platform (linux, darwin, windows) Arch string `gorm:"size:20"` // CPU architecture (amd64, arm64, arm, 386) diff --git a/internal/infrastructure/repository/forwardagentrepository.go b/internal/infrastructure/repository/forwardagentrepository.go index 5200b35..a53c26c 100644 --- a/internal/infrastructure/repository/forwardagentrepository.go +++ b/internal/infrastructure/repository/forwardagentrepository.go @@ -2,6 +2,7 @@ package repository import ( "context" + "encoding/json" "fmt" "strings" @@ -22,7 +23,6 @@ var allowedAgentOrderByFields = map[string]bool{ "sid": true, "name": true, "status": true, - "group_id": true, "sort_order": true, "last_seen_at": true, "created_at": true, @@ -225,7 +225,7 @@ func (r *ForwardAgentRepositoryImpl) Update(ctx context.Context, agent *forward. "public_address": model.PublicAddress, "tunnel_address": model.TunnelAddress, "remark": model.Remark, - "group_id": model.GroupID, + "group_ids": model.GroupIDs, "allowed_port_range": model.AllowedPortRange, "blocked_protocols": model.BlockedProtocols, "sort_order": model.SortOrder, @@ -277,7 +277,9 @@ func (r *ForwardAgentRepositoryImpl) List(ctx context.Context, filter forward.Ag query = query.Where("status = ?", filter.Status) } if len(filter.GroupIDs) > 0 { - query = query.Where("group_id IN ?", filter.GroupIDs) + // Use JSON_OVERLAPS to check if group_ids array contains any of the filter group IDs + groupIDsJSON, _ := json.Marshal(filter.GroupIDs) + query = query.Where("JSON_OVERLAPS(group_ids, ?)", string(groupIDsJSON)) } // Count total records diff --git a/internal/infrastructure/repository/paymentrepository.go b/internal/infrastructure/repository/paymentrepository.go index cc23bd9..26a2bca 100644 --- a/internal/infrastructure/repository/paymentrepository.go +++ b/internal/infrastructure/repository/paymentrepository.go @@ -165,6 +165,26 @@ func (r *PaymentRepository) HasPendingPaymentBySubscriptionID(ctx context.Contex return count > 0, nil } +// GetSubscriptionIDsWithPendingPayments returns subscription IDs that have pending payments +// from the given list of subscription IDs +func (r *PaymentRepository) GetSubscriptionIDsWithPendingPayments(ctx context.Context, subscriptionIDs []uint) ([]uint, error) { + if len(subscriptionIDs) == 0 { + return []uint{}, nil + } + + var results []uint + + if err := db.GetTxFromContext(ctx, r.db). + Model(&models.PaymentModel{}). + Select("DISTINCT subscription_id"). + Where("subscription_id IN ? AND payment_status = ?", subscriptionIDs, vo.PaymentStatusPending). + Pluck("subscription_id", &results).Error; err != nil { + return nil, fmt.Errorf("failed to get subscription IDs with pending payments: %w", err) + } + + return results, nil +} + func (r *PaymentRepository) GetExpiredPayments(ctx context.Context) ([]*payment.Payment, error) { var paymentModels []models.PaymentModel diff --git a/internal/infrastructure/repository/resourcegrouprepository.go b/internal/infrastructure/repository/resourcegrouprepository.go index b28b1e2..2152cd3 100644 --- a/internal/infrastructure/repository/resourcegrouprepository.go +++ b/internal/infrastructure/repository/resourcegrouprepository.go @@ -225,6 +225,60 @@ func (r *ResourceGroupRepositoryImpl) GetByPlanID(ctx context.Context, planID ui return entities, nil } +// GetBySIDs retrieves resource groups by their Stripe-style IDs. +// Returns a map from SID to ResourceGroup for efficient lookup. +func (r *ResourceGroupRepositoryImpl) GetBySIDs(ctx context.Context, sids []string) (map[string]*resource.ResourceGroup, error) { + if len(sids) == 0 { + return make(map[string]*resource.ResourceGroup), nil + } + + var modelList []*models.ResourceGroupModel + + if err := r.db.WithContext(ctx).Where("sid IN ?", sids).Find(&modelList).Error; err != nil { + r.logger.Errorw("failed to get resource groups by SIDs", "sids", sids, "error", err) + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + result := make(map[string]*resource.ResourceGroup, len(modelList)) + for _, model := range modelList { + entity, err := r.mapper.ToEntity(model) + if err != nil { + r.logger.Errorw("failed to map resource group model to entity", "sid", model.SID, "error", err) + return nil, fmt.Errorf("failed to map resource group: %w", err) + } + result[model.SID] = entity + } + + return result, nil +} + +// GetByPlanIDs retrieves all resource groups for multiple plans. +// Returns a map from planID to list of ResourceGroups. +func (r *ResourceGroupRepositoryImpl) GetByPlanIDs(ctx context.Context, planIDs []uint) (map[uint][]*resource.ResourceGroup, error) { + if len(planIDs) == 0 { + return make(map[uint][]*resource.ResourceGroup), nil + } + + var modelList []*models.ResourceGroupModel + + if err := r.db.WithContext(ctx).Where("plan_id IN ?", planIDs).Find(&modelList).Error; err != nil { + r.logger.Errorw("failed to get resource groups by plan IDs", "plan_ids", planIDs, "error", err) + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + result := make(map[uint][]*resource.ResourceGroup) + for _, model := range modelList { + entity, err := r.mapper.ToEntity(model) + if err != nil { + r.logger.Errorw("failed to map resource group model to entity", "id", model.ID, "error", err) + return nil, fmt.Errorf("failed to map resource group: %w", err) + } + result[model.PlanID] = append(result[model.PlanID], entity) + } + + return result, nil +} + // List retrieves resource groups with optional filters. func (r *ResourceGroupRepositoryImpl) List(ctx context.Context, filter resource.ListFilter) ([]*resource.ResourceGroup, int64, error) { query := r.db.WithContext(ctx).Model(&models.ResourceGroupModel{}) diff --git a/internal/interfaces/http/handlers/forward/agent/crud/handler.go b/internal/interfaces/http/handlers/forward/agent/crud/handler.go index bd179d9..581e590 100644 --- a/internal/interfaces/http/handlers/forward/agent/crud/handler.go +++ b/internal/interfaces/http/handlers/forward/agent/crud/handler.go @@ -75,7 +75,7 @@ type CreateForwardAgentRequest struct { PublicAddress string `json:"public_address,omitempty" example:"203.0.113.1"` TunnelAddress string `json:"tunnel_address,omitempty" example:"192.168.1.100"` // IP or hostname only (no port), configure if agent may serve as relay/exit in any rule Remark string `json:"remark,omitempty" example:"Forward agent for production environment"` - GroupSID string `json:"group_sid,omitempty" example:"rg_xK9mP2vL3nQ"` // Resource group SID to associate with + GroupSIDs []string `json:"group_sids,omitempty" example:"[\"rg_xK9mP2vL3nQ\"]"` // Resource group SIDs to associate with AllowedPortRange string `json:"allowed_port_range,omitempty" example:"80,443,8000-9000"` BlockedProtocols []string `json:"blocked_protocols,omitempty" example:"socks5,http_connect"` // Protocols to block (e.g., socks5, http_connect, ssh) SortOrder *int `json:"sort_order,omitempty" example:"100"` // Custom sort order for UI display (lower values appear first) @@ -87,7 +87,7 @@ type UpdateForwardAgentRequest struct { PublicAddress *string `json:"public_address,omitempty" example:"203.0.113.2"` TunnelAddress *string `json:"tunnel_address,omitempty" example:"192.168.1.100"` // IP or hostname only (no port), configure if agent may serve as relay/exit in any rule Remark *string `json:"remark,omitempty" example:"Updated remark"` - GroupSID *string `json:"group_sid,omitempty" example:"rg_xK9mP2vL3nQ"` // Resource group SID to associate with (use empty string to remove) + GroupSIDs []string `json:"group_sids,omitempty" example:"[\"rg_xK9mP2vL3nQ\"]"` // Resource group SIDs to associate with (empty array to remove all) AllowedPortRange *string `json:"allowed_port_range,omitempty" example:"80,443,8000-9000"` BlockedProtocols *[]string `json:"blocked_protocols,omitempty"` // Protocols to block (nil: no update, empty array: clear, non-empty: set new) SortOrder *int `json:"sort_order,omitempty" example:"100"` // Custom sort order for UI display (lower values appear first) @@ -113,7 +113,7 @@ func (h *Handler) CreateAgent(c *gin.Context) { PublicAddress: req.PublicAddress, TunnelAddress: req.TunnelAddress, Remark: req.Remark, - GroupSID: req.GroupSID, + GroupSIDs: req.GroupSIDs, AllowedPortRange: req.AllowedPortRange, BlockedProtocols: req.BlockedProtocols, SortOrder: req.SortOrder, // nil if not provided, allowing explicit 0 value @@ -197,7 +197,7 @@ func (h *Handler) UpdateAgent(c *gin.Context) { PublicAddress: req.PublicAddress, TunnelAddress: req.TunnelAddress, Remark: req.Remark, - GroupSID: req.GroupSID, + GroupSIDs: req.GroupSIDs, AllowedPortRange: req.AllowedPortRange, BlockedProtocols: req.BlockedProtocols, SortOrder: req.SortOrder, diff --git a/internal/interfaces/http/handlers/node/subscriptionhandler.go b/internal/interfaces/http/handlers/node/subscriptionhandler.go index b49266a..75e700f 100644 --- a/internal/interfaces/http/handlers/node/subscriptionhandler.go +++ b/internal/interfaces/http/handlers/node/subscriptionhandler.go @@ -209,11 +209,15 @@ func detectFormatFromUserAgent(userAgent string) string { return "base64" } - // V2Ray clients - if strings.Contains(ua, "v2ray") { - return "v2ray" + // V2RayN/V2RayNG clients - use base64 format (supports all protocol URIs) + // These are general-purpose clients that parse base64-encoded URI lists + // (vmess://, vless://, trojan://, ss://, etc.) + if strings.Contains(ua, "v2rayn") || strings.Contains(ua, "v2rayng") { + return "base64" } // Default to base64 format for unknown clients + // Note: "v2ray" format (JSON) is only for Shadowsocks-only clients, + // accessible via explicit /s/:token/v2ray endpoint return "base64" } diff --git a/internal/interfaces/http/handlers/node/subscriptionhandler_test.go b/internal/interfaces/http/handlers/node/subscriptionhandler_test.go index 6530d9f..ca25e93 100644 --- a/internal/interfaces/http/handlers/node/subscriptionhandler_test.go +++ b/internal/interfaces/http/handlers/node/subscriptionhandler_test.go @@ -56,14 +56,14 @@ func TestDetectFormatFromUserAgent(t *testing.T) { expected: "base64", }, { - name: "V2Ray client returns v2ray format", + name: "V2RayN client returns base64 format (supports all protocols)", userAgent: "v2rayN/1.0.0", - expected: "v2ray", + expected: "base64", }, { - name: "V2RayNG client returns v2ray format", + name: "V2RayNG client returns base64 format (supports all protocols)", userAgent: "V2RayNG/1.0.0", - expected: "v2ray", + expected: "base64", }, { name: "Unknown browser returns base64 format",