mirror of
https://gitee.com/WuKongDev/WuKongIM.git
synced 2026-05-07 01:00:47 +08:00
fix: duplicate ids when concurrent update
This commit is contained in:
@@ -120,8 +120,8 @@ func (wk *wukongDB) AddOrUpdateConversationsBatchIfNotExist(conversations []Conv
|
||||
|
||||
func (wk *wukongDB) AddOrUpdateConversationsWithUser(uid string, conversations []Conversation) error {
|
||||
wk.metrics.AddOrUpdateConversationsAdd(1)
|
||||
// wk.dblock.conversationLock.lock(uid)
|
||||
// defer wk.dblock.conversationLock.unlock(uid)
|
||||
wk.dblock.conversationLock.lock(uid)
|
||||
defer wk.dblock.conversationLock.unlock(uid)
|
||||
if wk.opts.EnableCost {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -296,28 +296,22 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库获取
|
||||
ids, err := wk.getLastConversationIds(uid, updatedAt, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// 缓存未命中,使用全表扫描+过滤的方式直接从数据库获取
|
||||
db := wk.shardDB(uid)
|
||||
iter := db.NewIter(&pebble.IterOptions{
|
||||
LowerBound: key.NewConversationPrimaryKey(uid, 0),
|
||||
UpperBound: key.NewConversationPrimaryKey(uid, math.MaxUint64),
|
||||
})
|
||||
defer iter.Close()
|
||||
|
||||
// 使用批量查询优化,避免N+1查询问题
|
||||
conversations, err := wk.getConversationsBatch(uid, ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 过滤会话类型和排除的频道类型
|
||||
filteredConversations := make([]Conversation, 0, len(conversations))
|
||||
for _, conversation := range conversations {
|
||||
var allConversations []Conversation
|
||||
err := wk.iterateConversation(iter, func(conversation Conversation) bool {
|
||||
// 过滤会话类型
|
||||
if conversation.Type != tp {
|
||||
continue
|
||||
return true
|
||||
}
|
||||
|
||||
// 过滤排除的频道类型
|
||||
exclude := false
|
||||
if len(excludeChannelTypes) > 0 {
|
||||
for _, excludeChannelType := range excludeChannelTypes {
|
||||
@@ -328,19 +322,24 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
|
||||
}
|
||||
}
|
||||
if exclude {
|
||||
continue
|
||||
return true
|
||||
}
|
||||
|
||||
filteredConversations = append(filteredConversations, conversation)
|
||||
// 过滤更新时间(updatedAt=0表示获取所有会话)
|
||||
if updatedAt == 0 || (conversation.UpdatedAt != nil && uint64(conversation.UpdatedAt.UnixNano()) >= updatedAt) {
|
||||
allConversations = append(allConversations, conversation)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// conversations 根据id去重复
|
||||
filteredConversations = uniqueConversation(filteredConversations)
|
||||
|
||||
// 按照更新时间排序
|
||||
sort.Slice(filteredConversations, func(i, j int) bool {
|
||||
c1 := filteredConversations[i]
|
||||
c2 := filteredConversations[j]
|
||||
// 按照更新时间排序(最新的在前面)
|
||||
sort.Slice(allConversations, func(i, j int) bool {
|
||||
c1 := allConversations[i]
|
||||
c2 := allConversations[j]
|
||||
if c1.UpdatedAt == nil {
|
||||
return false
|
||||
}
|
||||
@@ -350,6 +349,14 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
|
||||
return c1.UpdatedAt.After(*c2.UpdatedAt)
|
||||
})
|
||||
|
||||
// 应用 limit 限制
|
||||
var filteredConversations []Conversation
|
||||
if limit > 0 && len(allConversations) > limit {
|
||||
filteredConversations = allConversations[:limit]
|
||||
} else {
|
||||
filteredConversations = allConversations
|
||||
}
|
||||
|
||||
// 将结果写入缓存
|
||||
wk.conversationCache.SetLastConversations(uid, tp, updatedAt, excludeChannelTypes, limit, filteredConversations)
|
||||
|
||||
@@ -415,7 +422,6 @@ func removeDupliConversationByChannel(conversations []Conversation) []Conversati
|
||||
}
|
||||
|
||||
func (wk *wukongDB) getLastConversationIds(uid string, updatedAt uint64, limit int) ([]uint64, error) {
|
||||
// 直接从数据库获取(不再单独缓存ID列表)
|
||||
db := wk.shardDB(uid)
|
||||
iter, err := db.NewIter(&pebble.IterOptions{
|
||||
LowerBound: key.NewConversationSecondIndexKey(uid, key.TableConversation.SecondIndex.UpdatedAt, updatedAt, 0),
|
||||
@@ -459,6 +465,11 @@ func (wk *wukongDB) getLastConversationIds(uid string, updatedAt uint64, limit i
|
||||
return uniqueIdsMap, nil
|
||||
}
|
||||
|
||||
// GetLastConversationIds 公开方法,用于测试 getLastConversationIds 的重复ID问题
|
||||
func (wk *wukongDB) GetLastConversationIds(uid string, updatedAt uint64, limit int) ([]uint64, error) {
|
||||
return wk.getLastConversationIds(uid, updatedAt, limit)
|
||||
}
|
||||
|
||||
// DeleteConversation 删除最近会话
|
||||
func (wk *wukongDB) DeleteConversation(uid string, channelId string, channelType uint8) error {
|
||||
|
||||
@@ -688,106 +699,6 @@ func (wk *wukongDB) getConversation(uid string, id uint64) (Conversation, error)
|
||||
return conversation, nil
|
||||
}
|
||||
|
||||
// getConversationsBatch 批量获取会话,避免N+1查询问题
|
||||
func (wk *wukongDB) getConversationsBatch(uid string, ids []uint64) ([]Conversation, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 先尝试从缓存获取部分数据
|
||||
conversations := make([]Conversation, 0, len(ids))
|
||||
|
||||
if len(ids) == 0 {
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// 从数据库获取缺失的会话
|
||||
var dbConversations []Conversation
|
||||
var err error
|
||||
|
||||
// 如果ID数量较少,使用优化的多范围查询
|
||||
if len(ids) <= 10 {
|
||||
dbConversations, err = wk.getConversationsBatchOptimized(uid, ids)
|
||||
} else {
|
||||
// ID数量较多时,使用全表扫描+过滤的方式
|
||||
dbConversations, err = wk.getConversationsBatchFiltered(uid, ids)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 不再单独缓存批量查询结果,由 GetLastConversations 统一缓存
|
||||
// 合并结果
|
||||
conversations = append(conversations, dbConversations...)
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// getConversationsBatchOptimized 对少量ID使用多个精确范围查询
|
||||
func (wk *wukongDB) getConversationsBatchOptimized(uid string, ids []uint64) ([]Conversation, error) {
|
||||
conversations := make([]Conversation, 0, len(ids))
|
||||
db := wk.shardDB(uid)
|
||||
|
||||
for _, id := range ids {
|
||||
iter := db.NewIter(&pebble.IterOptions{
|
||||
LowerBound: key.NewConversationColumnKey(uid, id, key.MinColumnKey),
|
||||
UpperBound: key.NewConversationColumnKey(uid, id, key.MaxColumnKey),
|
||||
})
|
||||
|
||||
var conversation = EmptyConversation
|
||||
err := wk.iterateConversation(iter, func(cn Conversation) bool {
|
||||
conversation = cn
|
||||
return false
|
||||
})
|
||||
iter.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conversation != EmptyConversation {
|
||||
conversations = append(conversations, conversation)
|
||||
}
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// getConversationsBatchFiltered 对大量ID使用全表扫描+过滤
|
||||
func (wk *wukongDB) getConversationsBatchFiltered(uid string, ids []uint64) ([]Conversation, error) {
|
||||
// 创建ID集合用于快速查找
|
||||
idSet := make(map[uint64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
idSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
// 使用单个迭代器查询所有会话数据
|
||||
db := wk.shardDB(uid)
|
||||
iter := db.NewIter(&pebble.IterOptions{
|
||||
LowerBound: key.NewConversationPrimaryKey(uid, 0),
|
||||
UpperBound: key.NewConversationPrimaryKey(uid, math.MaxUint64),
|
||||
})
|
||||
defer iter.Close()
|
||||
|
||||
conversations := make([]Conversation, 0, len(ids))
|
||||
err := wk.iterateConversation(iter, func(conversation Conversation) bool {
|
||||
// 只收集我们需要的会话ID
|
||||
if _, exists := idSet[conversation.Id]; exists {
|
||||
conversations = append(conversations, conversation)
|
||||
// 如果已经找到所有需要的会话,可以提前退出
|
||||
if len(conversations) == len(ids) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// func (wk *wukongDB) getConversationIdsByUid(uid string) ([]uint64, error) {
|
||||
// iter := wk.shardDB(uid).NewIter(&pebble.IterOptions{
|
||||
// LowerBound: key.NewConversationPrimaryKey(uid, 0),
|
||||
|
||||
339
pkg/wkdb/conversation_duplicate_ids_test.go
Normal file
339
pkg/wkdb/conversation_duplicate_ids_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package wkdb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wkdb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// 辅助方法:通过反射或者添加测试方法来访问私有的 getLastConversationIds
|
||||
// 这里我们通过 GetLastConversations 来间接测试,因为它内部调用了 getLastConversationIds
|
||||
|
||||
// TestDuplicateIdsInGetLastConversationIds 测试 getLastConversationIds 方法中出现重复 ID 的问题
|
||||
func TestDuplicateIdsInGetLastConversationIds(t *testing.T) {
|
||||
d := newTestDB(t)
|
||||
err := d.Open()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := d.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
uid := "test_duplicate_user1"
|
||||
channelId := "test_channel"
|
||||
channelType := uint8(1)
|
||||
now := time.Now()
|
||||
|
||||
// 场景1:测试同一个会话的多次快速更新
|
||||
t.Run("MultipleQuickUpdates", func(t *testing.T) {
|
||||
// 创建初始会话
|
||||
conversation := wkdb.Conversation{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: channelId + "_1",
|
||||
ChannelType: channelType,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 1,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 快速连续更新同一个会话多次
|
||||
for i := 0; i < 5; i++ {
|
||||
updatedTime := now.Add(time.Duration(i+1) * time.Millisecond)
|
||||
conversation.UnreadCount = uint32(i + 2)
|
||||
conversation.UpdatedAt = &updatedTime
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// 检查是否有重复的 ID
|
||||
ids, err := d.GetLastConversationIds(uid, 0, 10)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查重复
|
||||
idMap := make(map[uint64]int)
|
||||
for _, id := range ids {
|
||||
idMap[id]++
|
||||
}
|
||||
|
||||
for id, count := range idMap {
|
||||
if count > 1 {
|
||||
t.Logf("Found duplicate ID: %d, count: %d", id, count)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 场景2:测试批量更新中包含重复会话
|
||||
t.Run("BatchUpdateWithDuplicates", func(t *testing.T) {
|
||||
channelId2 := channelId + "_2"
|
||||
baseTime := now.Add(time.Hour)
|
||||
|
||||
// 创建包含重复会话的批量更新
|
||||
conversations := []wkdb.Conversation{
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: channelId2,
|
||||
ChannelType: channelType,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 1,
|
||||
CreatedAt: &baseTime,
|
||||
UpdatedAt: &baseTime,
|
||||
},
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: channelId2, // 相同的 channelId
|
||||
ChannelType: channelType, // 相同的 channelType
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 2, // 不同的 UnreadCount
|
||||
CreatedAt: &baseTime,
|
||||
UpdatedAt: &baseTime,
|
||||
},
|
||||
}
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, conversations)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查是否有重复的 ID
|
||||
ids, err := d.GetLastConversationIds(uid, 0, 10)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查重复
|
||||
idMap := make(map[uint64]int)
|
||||
for _, id := range ids {
|
||||
idMap[id]++
|
||||
}
|
||||
|
||||
for id, count := range idMap {
|
||||
if count > 1 {
|
||||
t.Logf("Found duplicate ID in batch update: %d, count: %d", id, count)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 场景3:测试相同时间戳的更新
|
||||
t.Run("SameTimestampUpdates", func(t *testing.T) {
|
||||
channelId3 := channelId + "_3"
|
||||
sameTime := now.Add(2 * time.Hour)
|
||||
|
||||
// 创建初始会话
|
||||
conversation := wkdb.Conversation{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: channelId3,
|
||||
ChannelType: channelType,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 1,
|
||||
CreatedAt: &sameTime,
|
||||
UpdatedAt: &sameTime,
|
||||
}
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 使用相同的时间戳更新多次
|
||||
for i := 0; i < 3; i++ {
|
||||
conversation.UnreadCount = uint32(i + 2)
|
||||
// 故意使用相同的 UpdatedAt 时间
|
||||
conversation.UpdatedAt = &sameTime
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// 检查是否有重复的 ID
|
||||
ids, err := d.GetLastConversationIds(uid, 0, 10)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查重复
|
||||
idMap := make(map[uint64]int)
|
||||
for _, id := range ids {
|
||||
idMap[id]++
|
||||
}
|
||||
|
||||
for id, count := range idMap {
|
||||
if count > 1 {
|
||||
t.Logf("Found duplicate ID with same timestamp: %d, count: %d", id, count)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 场景4:测试并发更新(模拟)
|
||||
t.Run("ConcurrentUpdates", func(t *testing.T) {
|
||||
|
||||
channelId4 := channelId + "_4"
|
||||
baseTime := now.Add(3 * time.Hour)
|
||||
|
||||
// 创建初始会话
|
||||
conversation := wkdb.Conversation{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: channelId4,
|
||||
ChannelType: channelType,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 1,
|
||||
CreatedAt: &baseTime,
|
||||
UpdatedAt: &baseTime,
|
||||
}
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 模拟并发更新:快速连续的更新操作
|
||||
done := make(chan bool, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
go func(index int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := 0; j < 2; j++ {
|
||||
updateTime := baseTime.Add(time.Duration(index*10+j) * time.Microsecond)
|
||||
conv := wkdb.Conversation{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: channelId4,
|
||||
ChannelType: channelType,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: uint32(index*10 + j + 1),
|
||||
CreatedAt: &baseTime,
|
||||
UpdatedAt: &updateTime,
|
||||
}
|
||||
|
||||
d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conv})
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有 goroutine 完成
|
||||
for i := 0; i < 3; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// 检查是否有重复的 ID
|
||||
ids, err := d.GetLastConversationIds(uid, 0, 10)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查重复
|
||||
idMap := make(map[uint64]int)
|
||||
for _, id := range ids {
|
||||
idMap[id]++
|
||||
}
|
||||
|
||||
for id, count := range idMap {
|
||||
if count > 1 {
|
||||
t.Logf("Found duplicate ID in concurrent updates: %d, count: %d", id, count)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 最终检查:获取所有会话 ID 并分析重复情况
|
||||
t.Run("FinalAnalysis", func(t *testing.T) {
|
||||
ids, err := d.GetLastConversationIds(uid, 0, 100)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Logf("Total IDs returned: %d", len(ids))
|
||||
|
||||
// 统计重复情况
|
||||
idMap := make(map[uint64]int)
|
||||
for _, id := range ids {
|
||||
idMap[id]++
|
||||
}
|
||||
|
||||
duplicateCount := 0
|
||||
for id, count := range idMap {
|
||||
if count > 1 {
|
||||
duplicateCount++
|
||||
t.Logf("Duplicate ID: %d appears %d times", id, count)
|
||||
}
|
||||
}
|
||||
|
||||
if duplicateCount > 0 {
|
||||
t.Logf("Found %d duplicate IDs out of %d unique IDs", duplicateCount, len(idMap))
|
||||
} else {
|
||||
t.Log("No duplicate IDs found")
|
||||
}
|
||||
|
||||
// 验证去重逻辑是否工作
|
||||
assert.Equal(t, len(idMap), len(ids), "去重逻辑应该确保没有重复ID")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetLastConversationIdsDeduplication 专门测试去重逻辑
|
||||
func TestGetLastConversationIdsDeduplication(t *testing.T) {
|
||||
d := newTestDB(t)
|
||||
err := d.Open()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := d.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
uid := "dedup_test_user"
|
||||
now := time.Now()
|
||||
|
||||
// 创建多个会话,然后多次更新以增加重复 ID 的可能性
|
||||
for i := 0; i < 5; i++ {
|
||||
channelId := "channel_" + string(rune('A'+i))
|
||||
|
||||
// 创建初始会话
|
||||
conversation := wkdb.Conversation{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: channelId,
|
||||
ChannelType: 1,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 1,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 多次更新每个会话
|
||||
for j := 0; j < 3; j++ {
|
||||
updateTime := now.Add(time.Duration(i*100+j) * time.Millisecond)
|
||||
conversation.UnreadCount = uint32(j + 2)
|
||||
conversation.UpdatedAt = &updateTime
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 ID 列表并检查去重效果
|
||||
ids, err := d.GetLastConversationIds(uid, 0, 20)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证去重逻辑
|
||||
idMap := make(map[uint64]int)
|
||||
for _, id := range ids {
|
||||
idMap[id]++
|
||||
}
|
||||
|
||||
// 检查是否有重复
|
||||
hasDeduplication := false
|
||||
for id, count := range idMap {
|
||||
if count > 1 {
|
||||
t.Errorf("ID %d appears %d times, deduplication failed", id, count)
|
||||
}
|
||||
if count == 1 && len(ids) > len(idMap) {
|
||||
hasDeduplication = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasDeduplication {
|
||||
t.Logf("Deduplication worked: %d raw IDs reduced to %d unique IDs", len(ids), len(idMap))
|
||||
}
|
||||
|
||||
t.Logf("Final result: %d unique conversation IDs", len(idMap))
|
||||
}
|
||||
349
pkg/wkdb/conversation_full_scan_test.go
Normal file
349
pkg/wkdb/conversation_full_scan_test.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package wkdb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/WuKongIM/WuKongIM/pkg/wkdb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestGetLastConversationsFullScan 测试全表扫描+过滤的 GetLastConversations 实现
|
||||
func TestGetLastConversationsFullScan(t *testing.T) {
|
||||
d := newTestDB(t)
|
||||
err := d.Open()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := d.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
uid := "full_scan_test_user"
|
||||
now := time.Now()
|
||||
|
||||
// 创建多个不同类型和时间的会话
|
||||
conversations := []wkdb.Conversation{
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "channel_1",
|
||||
ChannelType: 1,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 5,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
},
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "channel_2",
|
||||
ChannelType: 2,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 3,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
},
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "channel_3",
|
||||
ChannelType: 1,
|
||||
Type: wkdb.ConversationTypeCMD, // 不同类型
|
||||
UnreadCount: 2,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
},
|
||||
}
|
||||
|
||||
// 添加会话
|
||||
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 测试1:获取所有聊天类型的会话
|
||||
t.Run("GetAllChatConversations", func(t *testing.T) {
|
||||
// 先检查是否成功添加了会话
|
||||
allConversations, err := d.GetConversations(uid)
|
||||
assert.NoError(t, err)
|
||||
t.Logf("Total conversations added: %d", len(allConversations))
|
||||
for i, conv := range allConversations {
|
||||
t.Logf("Conversation %d: ID=%d, Type=%d, ChannelId=%s, ChannelType=%d",
|
||||
i, conv.Id, conv.Type, conv.ChannelId, conv.ChannelType)
|
||||
}
|
||||
|
||||
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
|
||||
assert.NoError(t, err)
|
||||
t.Logf("GetLastConversations returned: %d conversations", len(result))
|
||||
assert.Len(t, result, 2) // 应该有2个聊天类型的会话
|
||||
|
||||
// 验证结果都是聊天类型
|
||||
for _, conv := range result {
|
||||
assert.Equal(t, wkdb.ConversationTypeChat, conv.Type)
|
||||
}
|
||||
})
|
||||
|
||||
// 测试2:排除特定频道类型
|
||||
t.Run("ExcludeChannelTypes", func(t *testing.T) {
|
||||
excludeTypes := []uint8{2} // 排除频道类型2
|
||||
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, excludeTypes, 10)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1) // 应该只有1个会话(排除了频道类型2)
|
||||
|
||||
// 验证结果不包含被排除的频道类型
|
||||
for _, conv := range result {
|
||||
assert.NotEqual(t, uint8(2), conv.ChannelType)
|
||||
}
|
||||
})
|
||||
|
||||
// 测试3:限制数量
|
||||
t.Run("LimitResults", func(t *testing.T) {
|
||||
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1) // 应该只返回1个会话
|
||||
})
|
||||
|
||||
// 测试4:按更新时间过滤
|
||||
t.Run("FilterByUpdatedAt", func(t *testing.T) {
|
||||
// 更新其中一个会话的时间
|
||||
futureTime := now.Add(time.Hour)
|
||||
updatedConv := conversations[0]
|
||||
updatedConv.UpdatedAt = &futureTime
|
||||
updatedConv.UnreadCount = 10
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{updatedConv})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 使用未来时间作为过滤条件
|
||||
futureNano := uint64(futureTime.UnixNano())
|
||||
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, futureNano, nil, 10)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1) // 应该只有1个会话满足时间条件
|
||||
|
||||
// 验证返回的是更新后的会话
|
||||
assert.Equal(t, uint32(10), result[0].UnreadCount)
|
||||
})
|
||||
|
||||
// 测试5:时间排序
|
||||
t.Run("TimeOrdering", func(t *testing.T) {
|
||||
// 创建不同时间的会话
|
||||
time1 := now.Add(time.Minute)
|
||||
time2 := now.Add(2 * time.Minute)
|
||||
time3 := now.Add(3 * time.Minute)
|
||||
|
||||
newConversations := []wkdb.Conversation{
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "channel_time_1",
|
||||
ChannelType: 1,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 1,
|
||||
CreatedAt: &time1,
|
||||
UpdatedAt: &time1,
|
||||
},
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "channel_time_2",
|
||||
ChannelType: 1,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 2,
|
||||
CreatedAt: &time2,
|
||||
UpdatedAt: &time2,
|
||||
},
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "channel_time_3",
|
||||
ChannelType: 1,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 3,
|
||||
CreatedAt: &time3,
|
||||
UpdatedAt: &time3,
|
||||
},
|
||||
}
|
||||
|
||||
err := d.AddOrUpdateConversationsWithUser(uid, newConversations)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(result) >= 3)
|
||||
|
||||
// 验证按时间降序排列(最新的在前面)
|
||||
for i := 0; i < len(result)-1; i++ {
|
||||
if result[i].UpdatedAt != nil && result[i+1].UpdatedAt != nil {
|
||||
assert.True(t, result[i].UpdatedAt.After(*result[i+1].UpdatedAt) || result[i].UpdatedAt.Equal(*result[i+1].UpdatedAt))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetLastConversationsPerformance 测试全表扫描的性能
|
||||
func TestGetLastConversationsPerformance(t *testing.T) {
|
||||
d := newTestDB(t)
|
||||
err := d.Open()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := d.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
uid := "performance_test_user"
|
||||
now := time.Now()
|
||||
|
||||
// 创建大量会话数据
|
||||
conversations := make([]wkdb.Conversation, 0, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
updatedAt := now.Add(time.Duration(i) * time.Minute)
|
||||
conversations = append(conversations, wkdb.Conversation{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "channel_" + string(rune('A'+i%26)) + string(rune('0'+i/26)),
|
||||
ChannelType: uint8(1 + i%3), // 频道类型 1, 2, 3
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: uint32(i + 1),
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &updatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
// 批量添加会话
|
||||
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 测试性能
|
||||
start := time.Now()
|
||||
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 20)
|
||||
duration := time.Since(start)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 20)
|
||||
t.Logf("Full scan query took: %v for 100 conversations", duration)
|
||||
|
||||
// 验证结果正确性
|
||||
assert.True(t, len(result) <= 20)
|
||||
for _, conv := range result {
|
||||
assert.Equal(t, wkdb.ConversationTypeChat, conv.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetLastConversationsNoDuplicates 测试全表扫描不会产生重复结果
|
||||
func TestGetLastConversationsNoDuplicates(t *testing.T) {
|
||||
d := newTestDB(t)
|
||||
err := d.Open()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := d.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
uid := "no_duplicates_test_user"
|
||||
now := time.Now()
|
||||
|
||||
// 创建会话并多次更新
|
||||
conversation := wkdb.Conversation{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "test_channel",
|
||||
ChannelType: 1,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 1,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
|
||||
// 添加初始会话
|
||||
err = d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 多次更新同一个会话
|
||||
for i := 0; i < 5; i++ {
|
||||
updateTime := now.Add(time.Duration(i+1) * time.Minute)
|
||||
conversation.UnreadCount = uint32(i + 2)
|
||||
conversation.UpdatedAt = &updateTime
|
||||
|
||||
err = d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// 获取会话列表
|
||||
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证没有重复
|
||||
channelMap := make(map[string]int)
|
||||
for _, conv := range result {
|
||||
key := conv.ChannelId + ":" + string(rune(conv.ChannelType))
|
||||
channelMap[key]++
|
||||
}
|
||||
|
||||
for channel, count := range channelMap {
|
||||
assert.Equal(t, 1, count, "Channel %s should appear only once, but appeared %d times", channel, count)
|
||||
}
|
||||
|
||||
// 验证返回的是最新的数据
|
||||
if len(result) > 0 {
|
||||
assert.Equal(t, uint32(6), result[0].UnreadCount) // 最后一次更新的值
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetLastConversationsCache 测试缓存功能
|
||||
func TestGetLastConversationsCache(t *testing.T) {
|
||||
d := newTestDB(t)
|
||||
err := d.Open()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := d.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
uid := "cache_test_user"
|
||||
now := time.Now()
|
||||
|
||||
// 创建测试数据
|
||||
conversations := []wkdb.Conversation{
|
||||
{
|
||||
Id: d.NextPrimaryKey(),
|
||||
Uid: uid,
|
||||
ChannelId: "channel_1",
|
||||
ChannelType: 1,
|
||||
Type: wkdb.ConversationTypeChat,
|
||||
UnreadCount: 5,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
},
|
||||
}
|
||||
|
||||
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 第一次查询(缓存未命中)
|
||||
start1 := time.Now()
|
||||
result1, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
|
||||
duration1 := time.Since(start1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result1, 1)
|
||||
|
||||
// 第二次查询(缓存命中)
|
||||
start2 := time.Now()
|
||||
result2, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
|
||||
duration2 := time.Since(start2)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result2, 1)
|
||||
|
||||
// 验证结果一致
|
||||
assert.Equal(t, result1[0].Id, result2[0].Id)
|
||||
assert.Equal(t, result1[0].UnreadCount, result2[0].UnreadCount)
|
||||
|
||||
t.Logf("First query (cache miss): %v", duration1)
|
||||
t.Logf("Second query (cache hit): %v", duration2)
|
||||
|
||||
// 缓存命中应该更快
|
||||
if duration2 < duration1 {
|
||||
t.Logf("Cache hit is faster by: %v", duration1-duration2)
|
||||
}
|
||||
}
|
||||
@@ -291,7 +291,7 @@ func BenchmarkGetLastConversations(b *testing.B) {
|
||||
|
||||
// 创建大量会话数据
|
||||
conversations := make([]wkdb.Conversation, 0, 1000)
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := 0; i < 1000; i++ {
|
||||
updatedAt := now.Add(time.Duration(i) * time.Minute)
|
||||
conversations = append(conversations, wkdb.Conversation{
|
||||
Id: uint64(i + 1),
|
||||
|
||||
@@ -242,6 +242,9 @@ type ConversationDB interface {
|
||||
|
||||
// UpdateConversationDeletedAtMsgSeq 更新最近会话的已删除的消息序号位置
|
||||
UpdateConversationDeletedAtMsgSeq(uid string, channelId string, channelType uint8, deletedAtMsgSeq uint64) error
|
||||
|
||||
// GetLastConversationIds 获取最近会话ID列表(用于测试重复ID问题)
|
||||
GetLastConversationIds(uid string, updatedAt uint64, limit int) ([]uint64, error)
|
||||
}
|
||||
|
||||
type ChannelClusterConfigDB interface {
|
||||
|
||||
Reference in New Issue
Block a user