fix: duplicate ids when concurrent update

This commit is contained in:
tt
2025-06-23 15:51:04 +08:00
parent bb6b1762fe
commit c486fd1f75
5 changed files with 733 additions and 131 deletions

View File

@@ -120,8 +120,8 @@ func (wk *wukongDB) AddOrUpdateConversationsBatchIfNotExist(conversations []Conv
func (wk *wukongDB) AddOrUpdateConversationsWithUser(uid string, conversations []Conversation) error {
wk.metrics.AddOrUpdateConversationsAdd(1)
// wk.dblock.conversationLock.lock(uid)
// defer wk.dblock.conversationLock.unlock(uid)
wk.dblock.conversationLock.lock(uid)
defer wk.dblock.conversationLock.unlock(uid)
if wk.opts.EnableCost {
start := time.Now()
defer func() {
@@ -296,28 +296,22 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
return cached, nil
}
// 缓存未命中,从数据库获取
ids, err := wk.getLastConversationIds(uid, updatedAt, limit)
if err != nil {
return nil, err
}
if len(ids) == 0 {
return nil, nil
}
// 缓存未命中,使用全表扫描+过滤的方式直接从数据库获取
db := wk.shardDB(uid)
iter := db.NewIter(&pebble.IterOptions{
LowerBound: key.NewConversationPrimaryKey(uid, 0),
UpperBound: key.NewConversationPrimaryKey(uid, math.MaxUint64),
})
defer iter.Close()
// 使用批量查询优化避免N+1查询问题
conversations, err := wk.getConversationsBatch(uid, ids)
if err != nil {
return nil, err
}
// 过滤会话类型和排除的频道类型
filteredConversations := make([]Conversation, 0, len(conversations))
for _, conversation := range conversations {
var allConversations []Conversation
err := wk.iterateConversation(iter, func(conversation Conversation) bool {
// 过滤会话类型
if conversation.Type != tp {
continue
return true
}
// 过滤排除的频道类型
exclude := false
if len(excludeChannelTypes) > 0 {
for _, excludeChannelType := range excludeChannelTypes {
@@ -328,19 +322,24 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
}
}
if exclude {
continue
return true
}
filteredConversations = append(filteredConversations, conversation)
// 过滤更新时间updatedAt=0表示获取所有会话
if updatedAt == 0 || (conversation.UpdatedAt != nil && uint64(conversation.UpdatedAt.UnixNano()) >= updatedAt) {
allConversations = append(allConversations, conversation)
}
return true
})
if err != nil {
return nil, err
}
// conversations 根据id去重复
filteredConversations = uniqueConversation(filteredConversations)
// 按照更新时间排序
sort.Slice(filteredConversations, func(i, j int) bool {
c1 := filteredConversations[i]
c2 := filteredConversations[j]
// 按照更新时间排序(最新的在前面)
sort.Slice(allConversations, func(i, j int) bool {
c1 := allConversations[i]
c2 := allConversations[j]
if c1.UpdatedAt == nil {
return false
}
@@ -350,6 +349,14 @@ func (wk *wukongDB) GetLastConversations(uid string, tp ConversationType, update
return c1.UpdatedAt.After(*c2.UpdatedAt)
})
// 应用 limit 限制
var filteredConversations []Conversation
if limit > 0 && len(allConversations) > limit {
filteredConversations = allConversations[:limit]
} else {
filteredConversations = allConversations
}
// 将结果写入缓存
wk.conversationCache.SetLastConversations(uid, tp, updatedAt, excludeChannelTypes, limit, filteredConversations)
@@ -415,7 +422,6 @@ func removeDupliConversationByChannel(conversations []Conversation) []Conversati
}
func (wk *wukongDB) getLastConversationIds(uid string, updatedAt uint64, limit int) ([]uint64, error) {
// 直接从数据库获取不再单独缓存ID列表
db := wk.shardDB(uid)
iter, err := db.NewIter(&pebble.IterOptions{
LowerBound: key.NewConversationSecondIndexKey(uid, key.TableConversation.SecondIndex.UpdatedAt, updatedAt, 0),
@@ -459,6 +465,11 @@ func (wk *wukongDB) getLastConversationIds(uid string, updatedAt uint64, limit i
return uniqueIdsMap, nil
}
// GetLastConversationIds 公开方法,用于测试 getLastConversationIds 的重复ID问题
func (wk *wukongDB) GetLastConversationIds(uid string, updatedAt uint64, limit int) ([]uint64, error) {
return wk.getLastConversationIds(uid, updatedAt, limit)
}
// DeleteConversation 删除最近会话
func (wk *wukongDB) DeleteConversation(uid string, channelId string, channelType uint8) error {
@@ -688,106 +699,6 @@ func (wk *wukongDB) getConversation(uid string, id uint64) (Conversation, error)
return conversation, nil
}
// getConversationsBatch 批量获取会话避免N+1查询问题
func (wk *wukongDB) getConversationsBatch(uid string, ids []uint64) ([]Conversation, error) {
if len(ids) == 0 {
return nil, nil
}
// 先尝试从缓存获取部分数据
conversations := make([]Conversation, 0, len(ids))
if len(ids) == 0 {
return conversations, nil
}
// 从数据库获取缺失的会话
var dbConversations []Conversation
var err error
// 如果ID数量较少使用优化的多范围查询
if len(ids) <= 10 {
dbConversations, err = wk.getConversationsBatchOptimized(uid, ids)
} else {
// ID数量较多时使用全表扫描+过滤的方式
dbConversations, err = wk.getConversationsBatchFiltered(uid, ids)
}
if err != nil {
return nil, err
}
// 不再单独缓存批量查询结果,由 GetLastConversations 统一缓存
// 合并结果
conversations = append(conversations, dbConversations...)
return conversations, nil
}
// getConversationsBatchOptimized 对少量ID使用多个精确范围查询
func (wk *wukongDB) getConversationsBatchOptimized(uid string, ids []uint64) ([]Conversation, error) {
conversations := make([]Conversation, 0, len(ids))
db := wk.shardDB(uid)
for _, id := range ids {
iter := db.NewIter(&pebble.IterOptions{
LowerBound: key.NewConversationColumnKey(uid, id, key.MinColumnKey),
UpperBound: key.NewConversationColumnKey(uid, id, key.MaxColumnKey),
})
var conversation = EmptyConversation
err := wk.iterateConversation(iter, func(cn Conversation) bool {
conversation = cn
return false
})
iter.Close()
if err != nil {
return nil, err
}
if conversation != EmptyConversation {
conversations = append(conversations, conversation)
}
}
return conversations, nil
}
// getConversationsBatchFiltered 对大量ID使用全表扫描+过滤
func (wk *wukongDB) getConversationsBatchFiltered(uid string, ids []uint64) ([]Conversation, error) {
// 创建ID集合用于快速查找
idSet := make(map[uint64]struct{}, len(ids))
for _, id := range ids {
idSet[id] = struct{}{}
}
// 使用单个迭代器查询所有会话数据
db := wk.shardDB(uid)
iter := db.NewIter(&pebble.IterOptions{
LowerBound: key.NewConversationPrimaryKey(uid, 0),
UpperBound: key.NewConversationPrimaryKey(uid, math.MaxUint64),
})
defer iter.Close()
conversations := make([]Conversation, 0, len(ids))
err := wk.iterateConversation(iter, func(conversation Conversation) bool {
// 只收集我们需要的会话ID
if _, exists := idSet[conversation.Id]; exists {
conversations = append(conversations, conversation)
// 如果已经找到所有需要的会话,可以提前退出
if len(conversations) == len(ids) {
return false
}
}
return true
})
if err != nil {
return nil, err
}
return conversations, nil
}
// func (wk *wukongDB) getConversationIdsByUid(uid string) ([]uint64, error) {
// iter := wk.shardDB(uid).NewIter(&pebble.IterOptions{
// LowerBound: key.NewConversationPrimaryKey(uid, 0),

View File

@@ -0,0 +1,339 @@
package wkdb_test
import (
"testing"
"time"
"github.com/WuKongIM/WuKongIM/pkg/wkdb"
"github.com/stretchr/testify/assert"
)
// 辅助方法:通过反射或者添加测试方法来访问私有的 getLastConversationIds
// 这里我们通过 GetLastConversations 来间接测试,因为它内部调用了 getLastConversationIds
// TestDuplicateIdsInGetLastConversationIds 测试 getLastConversationIds 方法中出现重复 ID 的问题
func TestDuplicateIdsInGetLastConversationIds(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "test_duplicate_user1"
channelId := "test_channel"
channelType := uint8(1)
now := time.Now()
// 场景1测试同一个会话的多次快速更新
t.Run("MultipleQuickUpdates", func(t *testing.T) {
// 创建初始会话
conversation := wkdb.Conversation{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: channelId + "_1",
ChannelType: channelType,
Type: wkdb.ConversationTypeChat,
UnreadCount: 1,
CreatedAt: &now,
UpdatedAt: &now,
}
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
// 快速连续更新同一个会话多次
for i := 0; i < 5; i++ {
updatedTime := now.Add(time.Duration(i+1) * time.Millisecond)
conversation.UnreadCount = uint32(i + 2)
conversation.UpdatedAt = &updatedTime
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
}
// 检查是否有重复的 ID
ids, err := d.GetLastConversationIds(uid, 0, 10)
assert.NoError(t, err)
// 检查重复
idMap := make(map[uint64]int)
for _, id := range ids {
idMap[id]++
}
for id, count := range idMap {
if count > 1 {
t.Logf("Found duplicate ID: %d, count: %d", id, count)
}
}
})
// 场景2测试批量更新中包含重复会话
t.Run("BatchUpdateWithDuplicates", func(t *testing.T) {
channelId2 := channelId + "_2"
baseTime := now.Add(time.Hour)
// 创建包含重复会话的批量更新
conversations := []wkdb.Conversation{
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: channelId2,
ChannelType: channelType,
Type: wkdb.ConversationTypeChat,
UnreadCount: 1,
CreatedAt: &baseTime,
UpdatedAt: &baseTime,
},
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: channelId2, // 相同的 channelId
ChannelType: channelType, // 相同的 channelType
Type: wkdb.ConversationTypeChat,
UnreadCount: 2, // 不同的 UnreadCount
CreatedAt: &baseTime,
UpdatedAt: &baseTime,
},
}
err := d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 检查是否有重复的 ID
ids, err := d.GetLastConversationIds(uid, 0, 10)
assert.NoError(t, err)
// 检查重复
idMap := make(map[uint64]int)
for _, id := range ids {
idMap[id]++
}
for id, count := range idMap {
if count > 1 {
t.Logf("Found duplicate ID in batch update: %d, count: %d", id, count)
}
}
})
// 场景3测试相同时间戳的更新
t.Run("SameTimestampUpdates", func(t *testing.T) {
channelId3 := channelId + "_3"
sameTime := now.Add(2 * time.Hour)
// 创建初始会话
conversation := wkdb.Conversation{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: channelId3,
ChannelType: channelType,
Type: wkdb.ConversationTypeChat,
UnreadCount: 1,
CreatedAt: &sameTime,
UpdatedAt: &sameTime,
}
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
// 使用相同的时间戳更新多次
for i := 0; i < 3; i++ {
conversation.UnreadCount = uint32(i + 2)
// 故意使用相同的 UpdatedAt 时间
conversation.UpdatedAt = &sameTime
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
}
// 检查是否有重复的 ID
ids, err := d.GetLastConversationIds(uid, 0, 10)
assert.NoError(t, err)
// 检查重复
idMap := make(map[uint64]int)
for _, id := range ids {
idMap[id]++
}
for id, count := range idMap {
if count > 1 {
t.Logf("Found duplicate ID with same timestamp: %d, count: %d", id, count)
}
}
})
// 场景4测试并发更新模拟
t.Run("ConcurrentUpdates", func(t *testing.T) {
channelId4 := channelId + "_4"
baseTime := now.Add(3 * time.Hour)
// 创建初始会话
conversation := wkdb.Conversation{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: channelId4,
ChannelType: channelType,
Type: wkdb.ConversationTypeChat,
UnreadCount: 1,
CreatedAt: &baseTime,
UpdatedAt: &baseTime,
}
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
// 模拟并发更新:快速连续的更新操作
done := make(chan bool, 3)
for i := 0; i < 3; i++ {
go func(index int) {
defer func() { done <- true }()
for j := 0; j < 2; j++ {
updateTime := baseTime.Add(time.Duration(index*10+j) * time.Microsecond)
conv := wkdb.Conversation{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: channelId4,
ChannelType: channelType,
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(index*10 + j + 1),
CreatedAt: &baseTime,
UpdatedAt: &updateTime,
}
d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conv})
}
}(i)
}
// 等待所有 goroutine 完成
for i := 0; i < 3; i++ {
<-done
}
// 检查是否有重复的 ID
ids, err := d.GetLastConversationIds(uid, 0, 10)
assert.NoError(t, err)
// 检查重复
idMap := make(map[uint64]int)
for _, id := range ids {
idMap[id]++
}
for id, count := range idMap {
if count > 1 {
t.Logf("Found duplicate ID in concurrent updates: %d, count: %d", id, count)
}
}
})
// 最终检查:获取所有会话 ID 并分析重复情况
t.Run("FinalAnalysis", func(t *testing.T) {
ids, err := d.GetLastConversationIds(uid, 0, 100)
assert.NoError(t, err)
t.Logf("Total IDs returned: %d", len(ids))
// 统计重复情况
idMap := make(map[uint64]int)
for _, id := range ids {
idMap[id]++
}
duplicateCount := 0
for id, count := range idMap {
if count > 1 {
duplicateCount++
t.Logf("Duplicate ID: %d appears %d times", id, count)
}
}
if duplicateCount > 0 {
t.Logf("Found %d duplicate IDs out of %d unique IDs", duplicateCount, len(idMap))
} else {
t.Log("No duplicate IDs found")
}
// 验证去重逻辑是否工作
assert.Equal(t, len(idMap), len(ids), "去重逻辑应该确保没有重复ID")
})
}
// TestGetLastConversationIdsDeduplication 专门测试去重逻辑
func TestGetLastConversationIdsDeduplication(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "dedup_test_user"
now := time.Now()
// 创建多个会话,然后多次更新以增加重复 ID 的可能性
for i := 0; i < 5; i++ {
channelId := "channel_" + string(rune('A'+i))
// 创建初始会话
conversation := wkdb.Conversation{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: channelId,
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: 1,
CreatedAt: &now,
UpdatedAt: &now,
}
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
// 多次更新每个会话
for j := 0; j < 3; j++ {
updateTime := now.Add(time.Duration(i*100+j) * time.Millisecond)
conversation.UnreadCount = uint32(j + 2)
conversation.UpdatedAt = &updateTime
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
}
}
// 获取 ID 列表并检查去重效果
ids, err := d.GetLastConversationIds(uid, 0, 20)
assert.NoError(t, err)
// 验证去重逻辑
idMap := make(map[uint64]int)
for _, id := range ids {
idMap[id]++
}
// 检查是否有重复
hasDeduplication := false
for id, count := range idMap {
if count > 1 {
t.Errorf("ID %d appears %d times, deduplication failed", id, count)
}
if count == 1 && len(ids) > len(idMap) {
hasDeduplication = true
}
}
if hasDeduplication {
t.Logf("Deduplication worked: %d raw IDs reduced to %d unique IDs", len(ids), len(idMap))
}
t.Logf("Final result: %d unique conversation IDs", len(idMap))
}

View File

@@ -0,0 +1,349 @@
package wkdb_test
import (
"testing"
"time"
"github.com/WuKongIM/WuKongIM/pkg/wkdb"
"github.com/stretchr/testify/assert"
)
// TestGetLastConversationsFullScan 测试全表扫描+过滤的 GetLastConversations 实现
func TestGetLastConversationsFullScan(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "full_scan_test_user"
now := time.Now()
// 创建多个不同类型和时间的会话
conversations := []wkdb.Conversation{
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "channel_1",
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: 5,
CreatedAt: &now,
UpdatedAt: &now,
},
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "channel_2",
ChannelType: 2,
Type: wkdb.ConversationTypeChat,
UnreadCount: 3,
CreatedAt: &now,
UpdatedAt: &now,
},
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "channel_3",
ChannelType: 1,
Type: wkdb.ConversationTypeCMD, // 不同类型
UnreadCount: 2,
CreatedAt: &now,
UpdatedAt: &now,
},
}
// 添加会话
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 测试1获取所有聊天类型的会话
t.Run("GetAllChatConversations", func(t *testing.T) {
// 先检查是否成功添加了会话
allConversations, err := d.GetConversations(uid)
assert.NoError(t, err)
t.Logf("Total conversations added: %d", len(allConversations))
for i, conv := range allConversations {
t.Logf("Conversation %d: ID=%d, Type=%d, ChannelId=%s, ChannelType=%d",
i, conv.Id, conv.Type, conv.ChannelId, conv.ChannelType)
}
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
assert.NoError(t, err)
t.Logf("GetLastConversations returned: %d conversations", len(result))
assert.Len(t, result, 2) // 应该有2个聊天类型的会话
// 验证结果都是聊天类型
for _, conv := range result {
assert.Equal(t, wkdb.ConversationTypeChat, conv.Type)
}
})
// 测试2排除特定频道类型
t.Run("ExcludeChannelTypes", func(t *testing.T) {
excludeTypes := []uint8{2} // 排除频道类型2
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, excludeTypes, 10)
assert.NoError(t, err)
assert.Len(t, result, 1) // 应该只有1个会话排除了频道类型2
// 验证结果不包含被排除的频道类型
for _, conv := range result {
assert.NotEqual(t, uint8(2), conv.ChannelType)
}
})
// 测试3限制数量
t.Run("LimitResults", func(t *testing.T) {
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 1)
assert.NoError(t, err)
assert.Len(t, result, 1) // 应该只返回1个会话
})
// 测试4按更新时间过滤
t.Run("FilterByUpdatedAt", func(t *testing.T) {
// 更新其中一个会话的时间
futureTime := now.Add(time.Hour)
updatedConv := conversations[0]
updatedConv.UpdatedAt = &futureTime
updatedConv.UnreadCount = 10
err := d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{updatedConv})
assert.NoError(t, err)
// 使用未来时间作为过滤条件
futureNano := uint64(futureTime.UnixNano())
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, futureNano, nil, 10)
assert.NoError(t, err)
assert.Len(t, result, 1) // 应该只有1个会话满足时间条件
// 验证返回的是更新后的会话
assert.Equal(t, uint32(10), result[0].UnreadCount)
})
// 测试5时间排序
t.Run("TimeOrdering", func(t *testing.T) {
// 创建不同时间的会话
time1 := now.Add(time.Minute)
time2 := now.Add(2 * time.Minute)
time3 := now.Add(3 * time.Minute)
newConversations := []wkdb.Conversation{
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "channel_time_1",
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: 1,
CreatedAt: &time1,
UpdatedAt: &time1,
},
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "channel_time_2",
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: 2,
CreatedAt: &time2,
UpdatedAt: &time2,
},
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "channel_time_3",
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: 3,
CreatedAt: &time3,
UpdatedAt: &time3,
},
}
err := d.AddOrUpdateConversationsWithUser(uid, newConversations)
assert.NoError(t, err)
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
assert.NoError(t, err)
assert.True(t, len(result) >= 3)
// 验证按时间降序排列(最新的在前面)
for i := 0; i < len(result)-1; i++ {
if result[i].UpdatedAt != nil && result[i+1].UpdatedAt != nil {
assert.True(t, result[i].UpdatedAt.After(*result[i+1].UpdatedAt) || result[i].UpdatedAt.Equal(*result[i+1].UpdatedAt))
}
}
})
}
// TestGetLastConversationsPerformance 测试全表扫描的性能
func TestGetLastConversationsPerformance(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "performance_test_user"
now := time.Now()
// 创建大量会话数据
conversations := make([]wkdb.Conversation, 0, 100)
for i := 0; i < 100; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "channel_" + string(rune('A'+i%26)) + string(rune('0'+i/26)),
ChannelType: uint8(1 + i%3), // 频道类型 1, 2, 3
Type: wkdb.ConversationTypeChat,
UnreadCount: uint32(i + 1),
CreatedAt: &now,
UpdatedAt: &updatedAt,
})
}
// 批量添加会话
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 测试性能
start := time.Now()
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 20)
duration := time.Since(start)
assert.NoError(t, err)
assert.Len(t, result, 20)
t.Logf("Full scan query took: %v for 100 conversations", duration)
// 验证结果正确性
assert.True(t, len(result) <= 20)
for _, conv := range result {
assert.Equal(t, wkdb.ConversationTypeChat, conv.Type)
}
}
// TestGetLastConversationsNoDuplicates 测试全表扫描不会产生重复结果
func TestGetLastConversationsNoDuplicates(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "no_duplicates_test_user"
now := time.Now()
// 创建会话并多次更新
conversation := wkdb.Conversation{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "test_channel",
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: 1,
CreatedAt: &now,
UpdatedAt: &now,
}
// 添加初始会话
err = d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
// 多次更新同一个会话
for i := 0; i < 5; i++ {
updateTime := now.Add(time.Duration(i+1) * time.Minute)
conversation.UnreadCount = uint32(i + 2)
conversation.UpdatedAt = &updateTime
err = d.AddOrUpdateConversationsWithUser(uid, []wkdb.Conversation{conversation})
assert.NoError(t, err)
}
// 获取会话列表
result, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
assert.NoError(t, err)
// 验证没有重复
channelMap := make(map[string]int)
for _, conv := range result {
key := conv.ChannelId + ":" + string(rune(conv.ChannelType))
channelMap[key]++
}
for channel, count := range channelMap {
assert.Equal(t, 1, count, "Channel %s should appear only once, but appeared %d times", channel, count)
}
// 验证返回的是最新的数据
if len(result) > 0 {
assert.Equal(t, uint32(6), result[0].UnreadCount) // 最后一次更新的值
}
}
// TestGetLastConversationsCache 测试缓存功能
func TestGetLastConversationsCache(t *testing.T) {
d := newTestDB(t)
err := d.Open()
assert.NoError(t, err)
defer func() {
err := d.Close()
assert.NoError(t, err)
}()
uid := "cache_test_user"
now := time.Now()
// 创建测试数据
conversations := []wkdb.Conversation{
{
Id: d.NextPrimaryKey(),
Uid: uid,
ChannelId: "channel_1",
ChannelType: 1,
Type: wkdb.ConversationTypeChat,
UnreadCount: 5,
CreatedAt: &now,
UpdatedAt: &now,
},
}
err = d.AddOrUpdateConversationsWithUser(uid, conversations)
assert.NoError(t, err)
// 第一次查询(缓存未命中)
start1 := time.Now()
result1, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
duration1 := time.Since(start1)
assert.NoError(t, err)
assert.Len(t, result1, 1)
// 第二次查询(缓存命中)
start2 := time.Now()
result2, err := d.GetLastConversations(uid, wkdb.ConversationTypeChat, 0, nil, 10)
duration2 := time.Since(start2)
assert.NoError(t, err)
assert.Len(t, result2, 1)
// 验证结果一致
assert.Equal(t, result1[0].Id, result2[0].Id)
assert.Equal(t, result1[0].UnreadCount, result2[0].UnreadCount)
t.Logf("First query (cache miss): %v", duration1)
t.Logf("Second query (cache hit): %v", duration2)
// 缓存命中应该更快
if duration2 < duration1 {
t.Logf("Cache hit is faster by: %v", duration1-duration2)
}
}

View File

@@ -291,7 +291,7 @@ func BenchmarkGetLastConversations(b *testing.B) {
// 创建大量会话数据
conversations := make([]wkdb.Conversation, 0, 1000)
for i := 0; i < 100; i++ {
for i := 0; i < 1000; i++ {
updatedAt := now.Add(time.Duration(i) * time.Minute)
conversations = append(conversations, wkdb.Conversation{
Id: uint64(i + 1),

View File

@@ -242,6 +242,9 @@ type ConversationDB interface {
// UpdateConversationDeletedAtMsgSeq 更新最近会话的已删除的消息序号位置
UpdateConversationDeletedAtMsgSeq(uid string, channelId string, channelType uint8, deletedAtMsgSeq uint64) error
// GetLastConversationIds 获取最近会话ID列表用于测试重复ID问题
GetLastConversationIds(uid string, updatedAt uint64, limit int) ([]uint64, error)
}
type ChannelClusterConfigDB interface {