Open-source X Recommendation Algorithm

This commit is contained in:
CI agent
2026-01-20 02:31:49 +00:00
commit aaa167b3de
79 changed files with 8816 additions and 0 deletions

26
thunder/deserializer.rs Normal file
View File

@@ -0,0 +1,26 @@
use crate::schema::{events::Event, tweet_events::TweetEvent};
use anyhow::{Context, Result};
use prost::Message;
use thrift::protocol::{TBinaryInputProtocol, TSerializable};
use xai_thunder_proto::InNetworkEvent;
/// Deserialize a Thrift binary message into TweetEvent
pub fn deserialize_tweet_event(payload: &[u8]) -> Result<TweetEvent> {
let mut cursor = std::io::Cursor::new(payload);
let mut protocol = TBinaryInputProtocol::new(&mut cursor, true);
TweetEvent::read_from_in_protocol(&mut protocol).context("Failed to deserialize TweetEvent")
}
/// Deserialize a Thrift binary message into Event
pub fn deserialize_event(payload: &[u8]) -> Result<Event> {
let mut cursor = std::io::Cursor::new(payload);
let mut protocol = TBinaryInputProtocol::new(&mut cursor, true);
Event::read_from_in_protocol(&mut protocol).context("Failed to deserialize Event")
}
/// Deserialize a proto binary message into InNetworkEvent
pub fn deserialize_tweet_event_v2(payload: &[u8]) -> Result<InNetworkEvent> {
InNetworkEvent::decode(payload).context("Failed to deserialize InNetworkEvent")
}

3
thunder/kafka/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod tweet_events_listener;
pub mod tweet_events_listener_v2;
pub mod utils;

View File

@@ -0,0 +1,390 @@
use anyhow::{Context, Result};
use log::{error, info, warn};
use prost::Message;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::RwLock;
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
use xai_kafka::{KafkaProducer, KafkaProducerConfig};
use xai_thunder_proto::{
InNetworkEvent, LightPost, TweetCreateEvent, TweetDeleteEvent, in_network_event,
};
use crate::{
args::Args,
crate::config::MIN_VIDEO_DURATION_MS,
deserializer::deserialize_tweet_event,
kafka::utils::{create_kafka_consumer, deserialize_kafka_messages},
metrics,
schema::{tweet::Tweet, tweet_events::TweetEventData},
};
/// Counter for logging batch processing every Nth time
static BATCH_LOG_COUNTER: AtomicUsize = AtomicUsize::new(0);
/// Monitor Kafka partition lag and update metrics
async fn monitor_partition_lag(
consumer: Arc<RwLock<KafkaConsumer>>,
topic: String,
interval_secs: u64,
) {
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
loop {
interval.tick().await;
let consumer = consumer.read().await;
match consumer.get_partition_lags().await {
Ok(lag_info) => {
for partition_lag in lag_info {
let partition_str = partition_lag.partition_id.to_string();
metrics::KAFKA_PARTITION_LAG
.with_label_values(&[&topic, &partition_str])
.set(partition_lag.lag as f64);
}
}
Err(e) => {
warn!("Failed to get partition lag info: {}", e);
}
}
}
}
fn is_eligible_video(tweet: &Tweet) -> bool {
let Some(media) = tweet.media.as_ref() else {
return false;
};
let [first_media] = media.as_slice() else {
return false;
};
let Some(crate::schema::tweet_media::MediaInfo::VideoInfo(video_info)) =
first_media.media_info.as_ref()
else {
return false;
};
video_info
.duration_millis
.map(|d| d >= MIN_VIDEO_DURATION_MS)
.unwrap_or(false)
}
/// Start the partition lag monitoring task in the background
pub fn start_partition_lag_monitor(
consumer: Arc<RwLock<KafkaConsumer>>,
topic: String,
interval_secs: u64,
) {
tokio::spawn(async move {
info!(
"Starting partition lag monitoring task for topic '{}' (interval: {}s)",
topic, interval_secs
);
monitor_partition_lag(consumer, topic, interval_secs).await;
});
}
/// Start the tweet event processing loop in the background with configurable number of threads
pub async fn start_tweet_event_processing(
base_config: KafkaConsumerConfig,
producer_config: KafkaProducerConfig,
args: &Args,
) {
let num_partitions = args.tweet_events_num_partitions as usize;
let kafka_num_threads = args.kafka_num_threads;
// Use all available partitions
let partitions_to_use: Vec<i32> = (0..num_partitions as i32).collect();
let partitions_per_thread = num_partitions.div_ceil(kafka_num_threads);
info!(
"Starting {} message processing threads for {} partitions ({} partitions per thread)",
kafka_num_threads, num_partitions, partitions_per_thread
);
let producer = if !args.is_serving {
info!("Kafka producer enabled, starting producer...");
let producer = Arc::new(RwLock::new(KafkaProducer::new(producer_config)));
if let Err(e) = producer.write().await.start().await {
panic!("Failed to start Kafka producer: {:#}", e);
}
Some(producer)
} else {
info!("Kafka producer disabled, skipping producer initialization");
None
};
spawn_processing_threads(base_config, partitions_to_use, producer, args);
}
/// Spawn multiple processing threads, each handling a subset of partitions
fn spawn_processing_threads(
base_config: KafkaConsumerConfig,
partitions_to_use: Vec<i32>,
producer: Option<Arc<RwLock<KafkaProducer>>>,
args: &Args,
) {
let total_partitions = partitions_to_use.len();
let partitions_per_thread = total_partitions.div_ceil(args.kafka_num_threads);
for thread_id in 0..args.kafka_num_threads {
let start_idx = thread_id * partitions_per_thread;
let end_idx = ((thread_id + 1) * partitions_per_thread).min(total_partitions);
if start_idx >= total_partitions {
break;
}
let thread_partitions = partitions_to_use[start_idx..end_idx].to_vec();
let mut thread_config = base_config.clone();
thread_config.partitions = Some(thread_partitions.clone());
let producer_clone = producer.as_ref().map(Arc::clone);
let topic = thread_config.base_config.topic.clone();
let lag_monitor_interval_secs = args.lag_monitor_interval_secs;
let batch_size = args.kafka_batch_size;
let post_retention_sec = args.post_retention_seconds;
tokio::spawn(async move {
info!(
"Starting message processing thread {} for partitions {:?}",
thread_id, thread_partitions
);
match create_kafka_consumer(thread_config).await {
Ok(consumer) => {
// Start partition lag monitoring for this thread's partitions
start_partition_lag_monitor(
Arc::clone(&consumer),
topic,
lag_monitor_interval_secs,
);
if let Err(e) = process_tweet_events(
consumer,
batch_size,
producer_clone,
post_retention_sec as i64,
)
.await
{
panic!(
"Tweet events processing thread {} exited unexpectedly: {:#}. This is a critical failure - the feeder cannot function without tweet event processing.",
thread_id, e
);
}
}
Err(e) => {
panic!(
"Failed to create consumer for thread {}: {:#}",
thread_id, e
);
}
}
});
}
}
/// Process a batch of messages: deserialize, extract posts, and store them
async fn process_message_batch(
messages: Vec<KafkaMessage>,
batch_num: usize,
producer: Option<Arc<RwLock<KafkaProducer>>>,
post_retention_sec: i64,
) -> Result<()> {
let results = deserialize_kafka_messages(messages, deserialize_tweet_event)?;
let mut create_tweets = Vec::new();
let mut delete_tweets = Vec::new();
let mut first_post_id = 0;
let mut first_user_id = 0;
let len_posts = results.len();
let now_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
for tweet_event in results {
let data = tweet_event.data.unwrap();
match data {
TweetEventData::TweetCreateEvent(create_event) => {
first_post_id = create_event.tweet.as_ref().unwrap().id.unwrap();
first_user_id = create_event.user.as_ref().unwrap().id.unwrap();
let tweet = create_event.tweet.as_ref().unwrap();
let core_data = tweet.core_data.as_ref().unwrap();
if let Some(nullcast) = core_data.nullcast
&& nullcast
{
continue;
}
create_tweets.push(LightPost {
post_id: tweet.id.unwrap(),
author_id: create_event.user.as_ref().unwrap().id.unwrap(),
created_at: core_data.created_at_secs.unwrap(),
in_reply_to_post_id: core_data
.reply
.as_ref()
.and_then(|r| r.in_reply_to_status_id),
in_reply_to_user_id: core_data
.reply
.as_ref()
.and_then(|r| r.in_reply_to_user_id),
is_retweet: core_data.share.is_some(),
is_reply: core_data.reply.is_some(),
source_post_id: core_data.share.as_ref().and_then(|s| s.source_status_id),
source_user_id: core_data.share.as_ref().and_then(|s| s.source_user_id),
has_video: is_eligible_video(tweet),
conversation_id: core_data.conversation_id,
});
}
TweetEventData::TweetDeleteEvent(delete_event) => {
let created_at_secs = delete_event
.tweet
.as_ref()
.unwrap()
.core_data
.as_ref()
.unwrap()
.created_at_secs
.unwrap();
if now_secs - created_at_secs > post_retention_sec {
continue;
}
delete_tweets.push(delete_event.tweet.as_ref().unwrap().id.unwrap());
}
TweetEventData::QuotedTweetDeleteEvent(delete_event) => {
delete_tweets.push(delete_event.quoting_tweet_id.unwrap());
}
_ => {
log::info!("Other non post creation/deletion event")
}
}
}
// Send each LightPost as an InNetworkEvent to the producer in separate tasks (only if producer is enabled)
if let Some(ref producer) = producer {
let mut send_tasks = Vec::with_capacity(create_tweets.len());
for light_post in &create_tweets {
let event = InNetworkEvent {
event_variant: Some(in_network_event::EventVariant::TweetCreateEvent(
TweetCreateEvent {
post_id: light_post.post_id,
author_id: light_post.author_id,
created_at: light_post.created_at,
in_reply_to_post_id: light_post.in_reply_to_post_id,
in_reply_to_user_id: light_post.in_reply_to_user_id,
is_retweet: light_post.is_retweet,
is_reply: light_post.is_reply,
source_post_id: light_post.source_post_id,
source_user_id: light_post.source_user_id,
has_video: light_post.has_video,
conversation_id: light_post.conversation_id,
},
)),
};
let payload = event.encode_to_vec();
let producer_clone = Arc::clone(producer);
send_tasks.push(tokio::spawn(async move {
let producer_lock = producer_clone.read().await;
if let Err(e) = producer_lock.send(&payload).await {
warn!("Failed to send InNetworkEvent to producer: {:#}", e);
}
}));
}
for post_id in &delete_tweets {
let event = InNetworkEvent {
event_variant: Some(in_network_event::EventVariant::TweetDeleteEvent(
TweetDeleteEvent {
post_id: *post_id,
deleted_at: now_secs,
},
)),
};
let payload = event.encode_to_vec();
let producer_clone = Arc::clone(producer);
send_tasks.push(tokio::spawn(async move {
let producer_lock = producer_clone.read().await;
if let Err(e) = producer_lock.send(&payload).await {
warn!("Failed to send InNetworkEvent to producer: {:#}", e);
}
}));
}
// Wait for all send tasks to complete
for task in send_tasks {
if let Err(e) = task.await {
error!("Error writing to kafka {}", e);
}
}
}
// Log every 100th batch
let batch_count = BATCH_LOG_COUNTER.fetch_add(1, Ordering::Relaxed);
if batch_count.is_multiple_of(1000) {
info!(
"Batch processing milestone: processed {} batches total, latest batch {} had {} posts (first: post_id={}, user_id={})",
batch_count + 1,
batch_num,
len_posts,
first_post_id,
first_user_id
);
}
Ok(())
}
/// Main message processing loop that polls Kafka, batches messages, and stores posts
async fn process_tweet_events(
consumer: Arc<RwLock<KafkaConsumer>>,
batch_size: usize,
producer: Option<Arc<RwLock<KafkaProducer>>>,
post_retention_sec: i64,
) -> Result<()> {
let mut message_buffer = Vec::new();
let mut batch_num = 0;
loop {
let poll_result = {
let mut consumer_lock = consumer.write().await;
consumer_lock.poll(100).await
};
match poll_result {
Ok(messages) => {
message_buffer.extend(messages);
// Process batch when we have enough messages
if message_buffer.len() >= batch_size {
batch_num += 1;
let messages = std::mem::take(&mut message_buffer);
let producer_clone = producer.clone();
// Spawn batch processing in a blocking task
process_message_batch(messages, batch_num, producer_clone, post_retention_sec)
.await
.context("Error processing tweet event batch")?;
consumer.write().await.commit_offsets()?;
}
}
Err(e) => {
warn!("Error polling messages: {:#}", e);
metrics::KAFKA_POLL_ERRORS.inc();
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}

View File

@@ -0,0 +1,249 @@
use anyhow::Result;
use log::{info, warn};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, Semaphore};
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
use xai_thunder_proto::{LightPost, TweetDeleteEvent, in_network_event};
use crate::{
args::Args,
deserializer::deserialize_tweet_event_v2,
kafka::utils::{create_kafka_consumer, deserialize_kafka_messages},
metrics,
posts::post_store::PostStore,
};
/// Counter for logging deserialization every Nth time
static DESER_LOG_COUNTER: AtomicUsize = AtomicUsize::new(0);
/// Start the tweet event processing loop in the background with configurable number of threads
pub async fn start_tweet_event_processing_v2(
base_config: KafkaConsumerConfig,
post_store: Arc<PostStore>,
args: &Args,
tx: tokio::sync::mpsc::Sender<i64>,
) {
let num_partitions = args.kafka_tweet_events_v2_num_partitions;
let kafka_num_threads = args.kafka_num_threads;
// Use all available partitions
let partitions_to_use: Vec<i32> = (0..num_partitions as i32).collect();
let partitions_per_thread = num_partitions.div_ceil(kafka_num_threads);
info!(
"Starting {} message processing threads for {} partitions ({} partitions per thread)",
kafka_num_threads, num_partitions, partitions_per_thread
);
spawn_processing_threads_v2(base_config, partitions_to_use, post_store, args, tx);
}
/// Spawn multiple processing threads, each handling a subset of partitions
fn spawn_processing_threads_v2(
base_config: KafkaConsumerConfig,
partitions_to_use: Vec<i32>,
post_store: Arc<PostStore>,
args: &Args,
tx: tokio::sync::mpsc::Sender<i64>,
) {
let total_partitions = partitions_to_use.len();
let partitions_per_thread = total_partitions.div_ceil(args.kafka_num_threads);
// Create shared semaphore to prevent too many tweet_events partition updates at the same time
let semaphore = Arc::new(Semaphore::new(3));
for thread_id in 0..args.kafka_num_threads {
let start_idx = thread_id * partitions_per_thread;
let end_idx = ((thread_id + 1) * partitions_per_thread).min(total_partitions);
if start_idx >= total_partitions {
break;
}
let thread_partitions = partitions_to_use[start_idx..end_idx].to_vec();
let mut thread_config = base_config.clone();
thread_config.partitions = Some(thread_partitions.clone());
let post_store_clone = Arc::clone(&post_store);
let topic = thread_config.base_config.topic.clone();
let lag_monitor_interval_secs = args.lag_monitor_interval_secs;
let batch_size = args.kafka_batch_size;
let tx_clone = tx.clone();
let semaphore_clone = Arc::clone(&semaphore);
tokio::spawn(async move {
info!(
"Starting message processing thread {} for partitions {:?}",
thread_id, thread_partitions
);
match create_kafka_consumer(thread_config).await {
Ok(consumer) => {
// Start partition lag monitoring for this thread's partitions
crate::kafka::tweet_events_listener::start_partition_lag_monitor(
Arc::clone(&consumer),
topic,
lag_monitor_interval_secs,
);
if let Err(e) = process_tweet_events_v2(
consumer,
post_store_clone,
batch_size,
tx_clone,
semaphore_clone,
)
.await
{
panic!(
"Tweet events processing thread {} exited unexpectedly: {:#}. This is a critical failure - the feeder cannot function without tweet event processing.",
thread_id, e
);
}
}
Err(e) => {
panic!(
"Failed to create consumer for thread {}: {:#}",
thread_id, e
);
}
}
});
}
}
/// Process a single batch of messages: deserialize, extract posts, and store them
fn deserialize_batch(
messages: Vec<KafkaMessage>,
) -> Result<(Vec<LightPost>, Vec<TweetDeleteEvent>)> {
let start_time = Instant::now();
let num_messages = messages.len();
let results = deserialize_kafka_messages(messages, deserialize_tweet_event_v2)?;
let deser_elapsed = start_time.elapsed();
if DESER_LOG_COUNTER
.fetch_add(1, Ordering::Relaxed)
.is_multiple_of(1000)
{
info!(
"Deserialized {} messages in {:?} ({:.2} msgs/sec)",
num_messages,
deser_elapsed,
num_messages as f64 / deser_elapsed.as_secs_f64()
);
}
let mut create_tweets = Vec::with_capacity(results.len());
let mut delete_tweets = Vec::with_capacity(10);
for tweet_event in results {
match tweet_event.event_variant.unwrap() {
in_network_event::EventVariant::TweetCreateEvent(create_event) => {
create_tweets.push(LightPost {
post_id: create_event.post_id,
author_id: create_event.author_id,
created_at: create_event.created_at,
in_reply_to_post_id: create_event.in_reply_to_post_id,
in_reply_to_user_id: create_event.in_reply_to_user_id,
is_retweet: create_event.is_retweet,
is_reply: create_event.is_reply
|| create_event.in_reply_to_post_id.is_some()
|| create_event.in_reply_to_user_id.is_some(),
source_post_id: create_event.source_post_id,
source_user_id: create_event.source_user_id,
has_video: create_event.has_video,
conversation_id: create_event.conversation_id,
});
}
in_network_event::EventVariant::TweetDeleteEvent(delete_event) => {
delete_tweets.push(delete_event);
}
}
}
Ok((create_tweets, delete_tweets))
}
/// Main message processing loop that polls Kafka, batches messages, and stores posts
async fn process_tweet_events_v2(
consumer: Arc<RwLock<KafkaConsumer>>,
post_store: Arc<PostStore>,
batch_size: usize,
tx: tokio::sync::mpsc::Sender<i64>,
semaphore: Arc<Semaphore>,
) -> Result<()> {
let mut message_buffer = Vec::new();
let mut batch_count = 0_usize;
let mut init_data_downloaded = false;
loop {
let poll_result = {
let mut consumer_lock = consumer.write().await;
consumer_lock.poll(batch_size).await
};
match poll_result {
Ok(messages) => {
let catchup_sender = if !init_data_downloaded {
let consumer_lock = consumer.read().await;
if let Ok(lags) = consumer_lock.get_partition_lags().await {
let total_lag: i64 = lags.iter().map(|l| l.lag).sum();
if total_lag < (lags.len() * batch_size) as i64 {
init_data_downloaded = true;
Some((tx.clone(), total_lag))
} else {
None
}
} else {
None
}
} else {
None
};
message_buffer.extend(messages);
// Process batch when we have enough messages
if message_buffer.len() >= batch_size {
batch_count += 1;
let messages = std::mem::take(&mut message_buffer);
let post_store_clone = Arc::clone(&post_store);
// Acquire semaphore permit if init data is downloaded to allow enough CPU for serving requests
let permit = if init_data_downloaded {
Some(semaphore.clone().acquire_owned().await.unwrap())
} else {
None
};
// Send batch to blocking thread pool for processing
let _ = tokio::task::spawn_blocking(move || {
let _permit = permit; // Hold permit until task completes
match deserialize_batch(messages) {
Err(e) => warn!("Error processing batch {}: {:#}", batch_count, e),
Ok((light_posts, delete_posts)) => {
post_store_clone.insert_posts(light_posts);
post_store_clone.mark_as_deleted(delete_posts);
}
};
})
.await;
if let Some((sender, lag)) = catchup_sender {
info!("Completed kafka init for a single thread");
if let Err(e) = sender.send(lag).await {
log::error!("error sending {}", e);
}
}
}
}
Err(e) => {
warn!("Error polling messages: {:#}", e);
metrics::KAFKA_POLL_ERRORS.inc();
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}

48
thunder/kafka/utils.rs Normal file
View File

@@ -0,0 +1,48 @@
use anyhow::{Context, Result};
use std::sync::Arc;
use tokio::sync::RwLock;
use xai_kafka::{KafkaMessage, config::KafkaConsumerConfig, consumer::KafkaConsumer};
use crate::metrics;
/// Create and start a Kafka consumer with the given configuration
pub async fn create_kafka_consumer(
config: KafkaConsumerConfig,
) -> Result<Arc<RwLock<KafkaConsumer>>> {
let mut consumer = KafkaConsumer::new(config);
consumer
.start()
.await
.context("Failed to start Kafka consumer")?;
Ok(Arc::new(RwLock::new(consumer)))
}
/// Process a batch of Kafka messages and deserialize them using the provided deserializer function
pub fn deserialize_kafka_messages<T, F>(
messages: Vec<KafkaMessage>,
deserializer: F,
) -> Result<Vec<T>>
where
F: Fn(&[u8]) -> Result<T>,
{
let _timer = metrics::Timer::new(metrics::BATCH_PROCESSING_TIME.clone());
let mut kafka_data = Vec::with_capacity(messages.len());
for msg in messages.iter() {
if let Some(payload) = &msg.payload {
match deserializer(payload) {
Ok(deserialized_msg) => {
kafka_data.push(deserialized_msg);
}
Err(e) => {
log::error!("Failed to parse Kafka message: {}", e);
metrics::KAFKA_MESSAGES_FAILED_PARSE.inc();
}
}
}
}
Ok(kafka_data)
}

115
thunder/kafka_utils.rs Normal file
View File

@@ -0,0 +1,115 @@
use anyhow::{Context, Result};
use std::sync::Arc;
use xai_kafka::KafkaProducerConfig;
use xai_kafka::config::{KafkaConfig, KafkaConsumerConfig, SslConfig};
use xai_wily::WilyConfig;
use crate::{
args,
kafka::{
tweet_events_listener::start_tweet_event_processing,
tweet_events_listener_v2::start_tweet_event_processing_v2,
},
};
const TWEET_EVENT_TOPIC: &str = "";
const TWEET_EVENT_DEST: &str = "";
const IN_NETWORK_EVENTS_DEST: &str = "";
const IN_NETWORK_EVENTS_TOPIC: &str = "";
pub async fn start_kafka(
args: &args::Args,
post_store: Arc<crate::posts::post_store::PostStore>,
user: &str,
tx: tokio::sync::mpsc::Sender<i64>,
) -> Result<()> {
let sasl_password = std::env::var("")
.ok()
.or(args.sasl_password.clone())?;
let producer_sasl_password = std::env::var("")
.ok()
.or(args.producer_sasl_password.clone());
if args.is_serving {
let unique_id = uuid::Uuid::new_v4().to_string();
let v2_tweet_events_consumer_config = KafkaConsumerConfig {
base_config: KafkaConfig {
dest: args.in_network_events_consumer_dest.clone(),
topic: IN_NETWORK_EVENTS_TOPIC.to_string(),
wily_config: Some(WilyConfig::default()),
ssl: Some(SslConfig {
security_protocol: args.security_protocol.clone(),
sasl_mechanism: Some(args.producer_sasl_mechanism.clone()),
sasl_username: Some(args.producer_sasl_username.clone()),
sasl_password: producer_sasl_password.clone(),
}),
..Default::default()
},
group_id: format!("{}-{}", args.kafka_group_id, unique_id),
auto_offset_reset: args.auto_offset_reset.clone(),
fetch_timeout_ms: args.fetch_timeout_ms,
max_partition_fetch_bytes: Some(1024 * 1024 * 100),
skip_to_latest: args.skip_to_latest,
..Default::default()
};
// Start Kafka background tasks
start_tweet_event_processing_v2(
v2_tweet_events_consumer_config,
Arc::clone(&post_store),
args,
tx,
)
.await;
}
// Only start Kafka processing and background tasks if not in serving mode
if !args.is_serving {
// Create Kafka consumer config
let tweet_events_consumer_config = KafkaConsumerConfig {
base_config: KafkaConfig {
dest: TWEET_EVENT_DEST.to_string(),
topic: TWEET_EVENT_TOPIC.to_string(),
wily_config: Some(WilyConfig::default()),
ssl: Some(SslConfig {
security_protocol: args.security_protocol.clone(),
sasl_mechanism: Some(args.sasl_mechanism.clone()),
sasl_username: Some(args.sasl_username.clone()),
sasl_password: Some(sasl_password.clone()),
}),
..Default::default()
},
group_id: format!("{}-{}", args.kafka_group_id, user),
auto_offset_reset: args.auto_offset_reset.clone(),
enable_auto_commit: false,
fetch_timeout_ms: args.fetch_timeout_ms,
max_partition_fetch_bytes: Some(1024 * 1024 * 10),
partitions: None,
skip_to_latest: args.skip_to_latest,
..Default::default()
};
let producer_config = KafkaProducerConfig {
base_config: KafkaConfig {
dest: IN_NETWORK_EVENTS_DEST.to_string(),
topic: IN_NETWORK_EVENTS_TOPIC.to_string(),
wily_config: Some(WilyConfig::default()),
ssl: Some(SslConfig {
security_protocol: args.security_protocol.clone(),
sasl_mechanism: Some(args.producer_sasl_mechanism.clone()),
sasl_username: Some(args.producer_sasl_username.clone()),
sasl_password: producer_sasl_password.clone(),
}),
..Default::default()
},
..Default::default()
};
start_tweet_event_processing(tweet_events_consumer_config, producer_config, args).await;
}
Ok(())
}

11
thunder/lib.rs Normal file
View File

@@ -0,0 +1,11 @@
pub mod args;
pub mod config;
pub mod deserializer;
pub mod kafka;
pub mod kafka_utils;
pub mod metrics;
pub mod o2;
pub mod posts;
pub mod schema;
pub mod strato_client;
pub mod thunder_service;

100
thunder/main.rs Normal file
View File

@@ -0,0 +1,100 @@
use anyhow::{Context, Result};
use axum::Router;
use clap::Parser;
use log::info;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tonic::service::Routes;
use xai_http_server::{CancellationToken, GrpcConfig, HttpServer};
use thunder::{
args, kafka_utils, posts::post_store::PostStore, strato_client::StratoClient,
thunder_service::ThunderServiceImpl,
};
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
let args = args::Args::parse();
// Initialize PostStore
let post_store = Arc::new(PostStore::new(
args.post_retention_seconds,
args.request_timeout_ms,
));
info!(
"Initialized PostStore for in-memory post storage (retention: {} seconds / {:.1} days, request_timeout: {}ms)",
args.post_retention_seconds,
args.post_retention_seconds as f64 / 86400.0,
args.request_timeout_ms
);
// Initialize StratoClient for fetching following lists
let strato_client = Arc::new(StratoClient::new());
info!("Initialized StratoClient");
// Create ThunderService with the PostStore, StratoClient, and concurrency limit
let thunder_service = ThunderServiceImpl::new(
Arc::clone(&post_store),
Arc::clone(&strato_client),
args.max_concurrent_requests,
);
info!(
"Initialized with max_concurrent_requests={}",
args.max_concurrent_requests
);
let routes = Routes::new(thunder_service.server());
// Set up gRPC config
let grpc_config = GrpcConfig::new(args.grpc_port, routes);
// Create HTTP server with gRPC support
let mut http_server = HttpServer::new(
args.http_port,
Router::new(),
Some(grpc_config),
CancellationToken::new(),
Duration::from_secs(10),
)
.await
.context("Failed to create HTTP server")?;
if args.enable_profiling {
xai_profiling::spawn_server(3000, CancellationToken::new()).await;
}
// Create channel for post events
let (tx, mut rx) = tokio::sync::mpsc::channel::<i64>(args.kafka_num_threads);
kafka_utils::start_kafka(&args, post_store.clone(), "", tx).await?;
if args.is_serving {
// Wait for Kafka catchup signal
let start = Instant::now();
for _ in 0..args.kafka_num_threads {
rx.recv().await;
}
info!("Kafka init took {:?}", start.elapsed());
post_store.finalize_init().await?;
// Start stats logger
Arc::clone(&post_store).start_stats_logger();
info!("Started PostStore stats logger",);
// Start auto-trim task to remove posts older than retention period
Arc::clone(&post_store).start_auto_trim(2); // Run every 2 minutes
info!(
"Started PostStore auto-trim task (interval: 2 minutes, retention: {:.1} days)",
args.post_retention_seconds as f64 / 86400.0
);
}
http_server.set_readiness(true);
info!("HTTP/gRPC server is ready");
// Wait for termination signal
http_server.wait_for_termination().await;
info!("Server terminated");
Ok(())
}

1
thunder/posts/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod post_store;

526
thunder/posts/post_store.rs Normal file
View File

@@ -0,0 +1,526 @@
use anyhow::Result;
use dashmap::DashMap;
use log::info;
use std::collections::{HashSet, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use xai_thunder_proto::{LightPost, TweetDeleteEvent};
use crate::config::{
DELETE_EVENT_KEY, MAX_ORIGINAL_POSTS_PER_AUTHOR, MAX_REPLY_POSTS_PER_AUTHOR,
MAX_TINY_POSTS_PER_USER_SCAN, MAX_VIDEO_POSTS_PER_AUTHOR,
};
use crate::metrics::{
POST_STORE_DELETED_POSTS, POST_STORE_DELETED_POSTS_FILTERED, POST_STORE_ENTITY_COUNT,
POST_STORE_POSTS_RETURNED, POST_STORE_POSTS_RETURNED_RATIO, POST_STORE_REQUEST_TIMEOUTS,
POST_STORE_REQUESTS, POST_STORE_TOTAL_POSTS, POST_STORE_USER_COUNT,
};
/// Minimal post reference stored in user timelines (only ID and timestamp)
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TinyPost {
pub post_id: i64,
pub created_at: i64,
}
impl TinyPost {
/// Create a new TinyPost from a post ID and creation timestamp
pub fn new(post_id: i64, created_at: i64) -> Self {
TinyPost {
post_id,
created_at,
}
}
}
/// A thread-safe store for posts grouped by user ID
/// Note: LightPost is now defined in the protobuf schema (in-network.proto)
#[derive(Clone)]
pub struct PostStore {
/// Full post data indexed by post_id
posts: Arc<DashMap<i64, LightPost>>,
/// Maps user_id to a deque of TinyPost references for original posts (non-reply, non-retweet)
original_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
/// Maps user_id to a deque of TinyPost references for replies and retweets
secondary_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
/// Maps user_id to a deque of TinyPost references for video posts
video_posts_by_user: Arc<DashMap<i64, VecDeque<TinyPost>>>,
deleted_posts: Arc<DashMap<i64, bool>>,
/// Retention period for posts in seconds
retention_seconds: u64,
/// Request timeout for get_posts_by_users iteration (0 = no timeout)
request_timeout: Duration,
}
impl PostStore {
/// Creates a new empty PostStore with the specified retention period and request timeout
pub fn new(retention_seconds: u64, request_timeout_ms: u64) -> Self {
PostStore {
posts: Arc::new(DashMap::new()),
original_posts_by_user: Arc::new(DashMap::new()),
secondary_posts_by_user: Arc::new(DashMap::new()),
video_posts_by_user: Arc::new(DashMap::new()),
deleted_posts: Arc::new(DashMap::new()),
retention_seconds,
request_timeout: Duration::from_millis(request_timeout_ms),
}
}
pub fn mark_as_deleted(&self, posts: Vec<TweetDeleteEvent>) {
for post in posts.into_iter() {
self.posts.remove(&post.post_id);
self.deleted_posts.insert(post.post_id, true);
let mut user_posts_entry = self
.original_posts_by_user
.entry(DELETE_EVENT_KEY)
.or_default();
user_posts_entry.push_back(TinyPost {
post_id: post.post_id,
created_at: post.deleted_at,
});
}
}
/// Inserts posts into the post store
pub fn insert_posts(&self, mut posts: Vec<LightPost>) {
// Filter to keep only posts created in the last retention_seconds and not from the future
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
posts.retain(|p| {
p.created_at < current_time
&& current_time - p.created_at <= (self.retention_seconds as i64)
});
// Sort remaining posts by created_at timestamp
posts.sort_unstable_by_key(|p| p.created_at);
Self::insert_posts_internal(self, posts);
}
pub async fn finalize_init(&self) -> Result<()> {
self.sort_all_user_posts().await;
self.trim_old_posts().await;
// This is needed because order of create_event/delete_event can be be lost in the feeder
for entry in self.deleted_posts.iter() {
self.posts.remove(entry.key());
}
Ok(())
}
fn insert_posts_internal(&self, posts: Vec<LightPost>) {
for post in posts {
let post_id = post.post_id;
let author_id = post.author_id;
let created_at = post.created_at;
let is_original = !post.is_reply && !post.is_retweet;
if self.deleted_posts.contains_key(&post_id) {
continue;
}
// Store the full post data
let old = self.posts.insert(post_id, post);
if old.is_some() {
// if already stored - don't add it again
continue;
}
// Create a TinyPost reference for the timeline
let tiny_post = TinyPost::new(post_id, created_at);
// Use entry API to get mutable access to the appropriate user's posts timeline
if is_original {
let mut user_posts_entry =
self.original_posts_by_user.entry(author_id).or_default();
user_posts_entry.push_back(tiny_post.clone());
} else {
let mut user_posts_entry =
self.secondary_posts_by_user.entry(author_id).or_default();
user_posts_entry.push_back(tiny_post.clone());
}
let mut video_eligible = post.has_video;
// If this is a retweet and the retweeted post has video, mark has_video as true
if !video_eligible
&& post.is_retweet
&& let Some(source_post_id) = post.source_post_id
&& let Some(source_post) = self.posts.get(&source_post_id)
{
video_eligible = !source_post.is_reply && source_post.has_video;
}
if post.is_reply {
video_eligible = false;
}
// Also add to video posts timeline if post has video
if video_eligible {
let mut user_posts_entry = self.video_posts_by_user.entry(author_id).or_default();
user_posts_entry.push_back(tiny_post);
}
}
}
/// Retrieves video posts from multiple users
pub fn get_videos_by_users(
&self,
user_ids: &[i64],
exclude_tweet_ids: &HashSet<i64>,
start_time: Instant,
request_user_id: i64,
) -> Vec<LightPost> {
let video_posts = self.get_posts_from_map(
&self.video_posts_by_user,
user_ids,
MAX_VIDEO_POSTS_PER_AUTHOR,
exclude_tweet_ids,
&HashSet::new(),
start_time,
request_user_id,
);
POST_STORE_POSTS_RETURNED.observe(video_posts.len() as f64);
video_posts
}
/// Retrieves all posts from multiple users
pub fn get_all_posts_by_users(
&self,
user_ids: &[i64],
exclude_tweet_ids: &HashSet<i64>,
start_time: Instant,
request_user_id: i64,
) -> Vec<LightPost> {
let following_users_set: HashSet<i64> = user_ids.iter().copied().collect();
let mut all_posts = self.get_posts_from_map(
&self.original_posts_by_user,
user_ids,
MAX_ORIGINAL_POSTS_PER_AUTHOR,
exclude_tweet_ids,
&HashSet::new(),
start_time,
request_user_id,
);
let secondary_posts = self.get_posts_from_map(
&self.secondary_posts_by_user,
user_ids,
MAX_REPLY_POSTS_PER_AUTHOR,
exclude_tweet_ids,
&following_users_set,
start_time,
request_user_id,
);
all_posts.extend(secondary_posts);
POST_STORE_POSTS_RETURNED.observe(all_posts.len() as f64);
all_posts
}
#[allow(clippy::too_many_arguments)]
pub fn get_posts_from_map(
&self,
posts_map: &Arc<DashMap<i64, VecDeque<TinyPost>>>,
user_ids: &[i64],
max_per_user: usize,
exclude_tweet_ids: &HashSet<i64>,
following_users: &HashSet<i64>,
start_time: Instant,
request_user_id: i64,
) -> Vec<LightPost> {
POST_STORE_REQUESTS.inc();
let mut light_posts = Vec::new();
let mut total_eligible: usize = 0;
for (i, user_id) in user_ids.iter().enumerate() {
if !self.request_timeout.is_zero() && start_time.elapsed() >= self.request_timeout {
log::error!(
"Timed out fetching posts for user={}; Processed: {}/{}. Stage: {}",
request_user_id,
i,
user_ids.len(),
if following_users.is_empty() {
"original"
} else {
"secondary"
}
);
POST_STORE_REQUEST_TIMEOUTS.inc();
break;
}
if let Some(user_posts_ref) = posts_map.get(user_id) {
let user_posts = user_posts_ref.value();
total_eligible += user_posts.len();
// Start from newest posts (reverse iterator)
// Take a capped number to prevent from going all the way back to when user is inactive
let tiny_posts_iter = user_posts
.iter()
.rev()
.filter(|post| !exclude_tweet_ids.contains(&post.post_id))
.take(MAX_TINY_POSTS_PER_USER_SCAN);
// Perform light doc lookup to get full LightPost data. This will also filter deleted posts
// Note: We copy the value immediately to release the read lock and avoid potential
// deadlock when acquiring nested read locks while a writer is waiting.
let light_post_iter_1 = tiny_posts_iter
.filter_map(|tiny_post| self.posts.get(&tiny_post.post_id).map(|r| *r.value()));
let light_post_iter = light_post_iter_1.filter(|post| {
if self.deleted_posts.get(&post.post_id).is_some() {
POST_STORE_DELETED_POSTS_FILTERED.inc();
false
} else {
true
}
});
let light_post_iter = light_post_iter.filter(|post| {
!(post.is_retweet && post.source_user_id == Some(request_user_id))
});
let filtered_post_iter = light_post_iter.filter(|post| {
if following_users.is_empty() {
return true;
}
post.in_reply_to_post_id.is_none_or(|reply_to_post_id| {
if let Some(replied_to_post) = self.posts.get(&reply_to_post_id) {
if !replied_to_post.is_retweet && !replied_to_post.is_reply {
return true;
}
return post.conversation_id.is_some_and(|convo_id| {
let reply_to_reply_to_original =
replied_to_post.in_reply_to_post_id == Some(convo_id);
let reply_to_followed_user = post
.in_reply_to_user_id
.map(|uid| following_users.contains(&uid))
.unwrap_or(false);
reply_to_reply_to_original && reply_to_followed_user
});
}
false
})
});
light_posts.extend(filtered_post_iter.take(max_per_user));
}
}
// Track ratio of returned posts to eligible posts
if total_eligible > 0 {
let ratio = light_posts.len() as f64 / total_eligible as f64;
POST_STORE_POSTS_RETURNED_RATIO.observe(ratio);
}
light_posts
}
/// Start a background task that periodically logs PostStore statistics
pub fn start_stats_logger(self: Arc<Self>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
let user_count = self.original_posts_by_user.len();
let total_posts = self.posts.len();
let deleted_posts = self.deleted_posts.len();
// Sum up all VecDeque sizes for each map
let original_posts_count: usize = self
.original_posts_by_user
.iter()
.map(|entry| entry.value().len())
.sum();
let secondary_posts_count: usize = self
.secondary_posts_by_user
.iter()
.map(|entry| entry.value().len())
.sum();
let video_posts_count: usize = self
.video_posts_by_user
.iter()
.map(|entry| entry.value().len())
.sum();
// Update Prometheus gauges
POST_STORE_USER_COUNT.set(user_count as f64);
POST_STORE_TOTAL_POSTS.set(total_posts as f64);
POST_STORE_DELETED_POSTS.set(deleted_posts as f64);
// Update entity count gauge with labels
POST_STORE_ENTITY_COUNT
.with_label_values(&["users"])
.set(user_count as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["posts"])
.set(total_posts as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["original_posts"])
.set(original_posts_count as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["secondary_posts"])
.set(secondary_posts_count as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["video_posts"])
.set(video_posts_count as f64);
POST_STORE_ENTITY_COUNT
.with_label_values(&["deleted_posts"])
.set(deleted_posts as f64);
info!(
"PostStore Stats: {} users, {} total posts, {} deleted posts",
user_count, total_posts, deleted_posts
);
}
});
}
/// Start a background task that periodically trims old posts
pub fn start_auto_trim(self: Arc<Self>, interval_minutes: u64) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(interval_minutes * 60));
loop {
interval.tick().await;
let trimmed = self.trim_old_posts().await;
if trimmed > 0 {
info!("Auto-trim: removed {} old posts", trimmed);
}
}
});
}
/// Manually trim posts older than retention period from all users
/// Returns the number of posts trimmed
pub async fn trim_old_posts(&self) -> usize {
let posts_map = Arc::clone(&self.posts);
let original_posts_by_user = Arc::clone(&self.original_posts_by_user);
let secondary_posts_by_user = Arc::clone(&self.secondary_posts_by_user);
let video_posts_by_user = Arc::clone(&self.video_posts_by_user);
let deleted_posts = Arc::clone(&self.deleted_posts);
let retention_seconds = self.retention_seconds;
tokio::task::spawn_blocking(move || {
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let mut total_trimmed = 0;
// Helper closure to trim posts from a given map
let trim_map = |posts_by_user: &DashMap<i64, VecDeque<TinyPost>>,
posts_map: &DashMap<i64, LightPost>,
deleted_posts: &DashMap<i64, bool>|
-> usize {
let mut trimmed = 0;
let mut users_to_remove = Vec::new();
for mut entry in posts_by_user.iter_mut() {
let user_id = *entry.key();
let user_posts = entry.value_mut();
while let Some(oldest_post) = user_posts.front() {
if current_time - (oldest_post.created_at as u64) > retention_seconds {
let trimmed_post = user_posts.pop_front().unwrap();
posts_map.remove(&trimmed_post.post_id);
if user_id == DELETE_EVENT_KEY {
deleted_posts.remove(&trimmed_post.post_id);
}
trimmed += 1;
} else {
break;
}
}
if user_posts.capacity() > user_posts.len() * 2 {
let new_cap = user_posts.len() as f32 * 1.5_f32;
user_posts.shrink_to(new_cap as usize);
}
if user_posts.is_empty() {
users_to_remove.push(user_id);
}
}
for user_id in users_to_remove {
posts_by_user.remove_if(&user_id, |_, posts| posts.is_empty());
}
trimmed
};
total_trimmed += trim_map(&original_posts_by_user, &posts_map, &deleted_posts);
total_trimmed += trim_map(&secondary_posts_by_user, &posts_map, &deleted_posts);
trim_map(&video_posts_by_user, &posts_map, &deleted_posts);
total_trimmed
})
.await
.expect("spawn_blocking failed")
}
/// Sorts all user post lists by creation time (newest first)
pub async fn sort_all_user_posts(&self) {
let original_posts_by_user = Arc::clone(&self.original_posts_by_user);
let secondary_posts_by_user = Arc::clone(&self.secondary_posts_by_user);
let video_posts_by_user = Arc::clone(&self.video_posts_by_user);
tokio::task::spawn_blocking(move || {
// Sort original posts
for mut entry in original_posts_by_user.iter_mut() {
let user_posts = entry.value_mut();
user_posts
.make_contiguous()
.sort_unstable_by_key(|a| a.created_at);
}
// Sort secondary posts
for mut entry in secondary_posts_by_user.iter_mut() {
let user_posts = entry.value_mut();
user_posts
.make_contiguous()
.sort_unstable_by_key(|a| a.created_at);
}
// Sort video posts
for mut entry in video_posts_by_user.iter_mut() {
let user_posts = entry.value_mut();
user_posts
.make_contiguous()
.sort_unstable_by_key(|a| a.created_at);
}
})
.await
.expect("spawn_blocking failed");
}
/// Clears all posts from the store
pub fn clear(&self) {
self.posts.clear();
self.original_posts_by_user.clear();
self.secondary_posts_by_user.clear();
self.video_posts_by_user.clear();
info!("PostStore cleared");
}
}
impl Default for PostStore {
fn default() -> Self {
// Default to 2 days retention, no timeout
Self::new(2 * 24 * 60 * 60, 0)
}
}

339
thunder/thunder_service.rs Normal file
View File

@@ -0,0 +1,339 @@
use lazy_static::lazy_static;
use log::{debug, info, warn};
use std::cmp::Reverse;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::Semaphore;
use tonic::{Request, Response, Status};
use xai_thunder_proto::{
GetInNetworkPostsRequest, GetInNetworkPostsResponse, LightPost,
in_network_posts_service_server::{InNetworkPostsService, InNetworkPostsServiceServer},
};
use crate::config::{
MAX_INPUT_LIST_SIZE, MAX_POSTS_TO_RETURN, MAX_VIDEOS_TO_RETURN,
};
use crate::metrics::{
GET_IN_NETWORK_POSTS_COUNT, GET_IN_NETWORK_POSTS_DURATION,
GET_IN_NETWORK_POSTS_DURATION_WITHOUT_STRATO, GET_IN_NETWORK_POSTS_EXCLUDED_SIZE,
GET_IN_NETWORK_POSTS_FOLLOWING_SIZE, GET_IN_NETWORK_POSTS_FOUND_FRESHNESS_SECONDS,
GET_IN_NETWORK_POSTS_FOUND_POSTS_PER_AUTHOR, GET_IN_NETWORK_POSTS_FOUND_REPLY_RATIO,
GET_IN_NETWORK_POSTS_FOUND_TIME_RANGE_SECONDS, GET_IN_NETWORK_POSTS_FOUND_UNIQUE_AUTHORS,
GET_IN_NETWORK_POSTS_MAX_RESULTS, IN_FLIGHT_REQUESTS, REJECTED_REQUESTS, Timer,
};
use crate::posts::post_store::PostStore;
use crate::strato_client::StratoClient;
pub struct ThunderServiceImpl {
/// PostStore for retrieving posts by user ID
post_store: Arc<PostStore>,
/// StratoClient for fetching following lists when not provided
strato_client: Arc<StratoClient>,
/// Semaphore to limit concurrent requests and prevent overload
request_semaphore: Arc<Semaphore>,
}
impl ThunderServiceImpl {
pub fn new(
post_store: Arc<PostStore>,
strato_client: Arc<StratoClient>,
max_concurrent_requests: usize,
) -> Self {
info!(
"Initializing ThunderService with max_concurrent_requests={}",
max_concurrent_requests
);
Self {
post_store,
strato_client,
request_semaphore: Arc::new(Semaphore::new(max_concurrent_requests)),
}
}
/// Create a gRPC server for this service
pub fn server(self) -> InNetworkPostsServiceServer<Self> {
InNetworkPostsServiceServer::new(self)
.accept_compressed(tonic::codec::CompressionEncoding::Zstd)
.send_compressed(tonic::codec::CompressionEncoding::Zstd)
}
/// Analyze found posts, calculate statistics, and report metrics
/// The `stage` parameter is used as a label to differentiate between stages (e.g., "post_store", "scored")
fn analyze_and_report_post_statistics(posts: &[LightPost], stage: &str) {
if posts.is_empty() {
debug!("[{}] No posts found for analysis", stage);
return;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
// Time since most recent post
let time_since_most_recent = posts
.iter()
.map(|post| post.created_at)
.max()
.map(|most_recent| now - most_recent);
// Time since oldest post
let time_since_oldest = posts
.iter()
.map(|post| post.created_at)
.min()
.map(|oldest| now - oldest);
// Count replies vs original posts
let reply_count = posts.iter().filter(|post| post.is_reply).count();
let original_count = posts.len() - reply_count;
// Unique authors
let unique_authors: HashSet<_> = posts.iter().map(|post| post.author_id).collect();
let unique_author_count = unique_authors.len();
// Report metrics with stage label
if let Some(freshness) = time_since_most_recent {
GET_IN_NETWORK_POSTS_FOUND_FRESHNESS_SECONDS
.with_label_values(&[stage])
.observe(freshness as f64);
}
if let (Some(oldest), Some(newest)) = (time_since_oldest, time_since_most_recent) {
let time_range = oldest - newest;
GET_IN_NETWORK_POSTS_FOUND_TIME_RANGE_SECONDS
.with_label_values(&[stage])
.observe(time_range as f64);
}
let reply_ratio = reply_count as f64 / posts.len() as f64;
GET_IN_NETWORK_POSTS_FOUND_REPLY_RATIO
.with_label_values(&[stage])
.observe(reply_ratio);
GET_IN_NETWORK_POSTS_FOUND_UNIQUE_AUTHORS
.with_label_values(&[stage])
.observe(unique_author_count as f64);
if unique_author_count > 0 {
let posts_per_author = posts.len() as f64 / unique_author_count as f64;
GET_IN_NETWORK_POSTS_FOUND_POSTS_PER_AUTHOR
.with_label_values(&[stage])
.observe(posts_per_author);
}
// Log statistics with stage label
debug!(
"[{}] Post statistics: total={}, original={}, replies={}, unique_authors={}, posts_per_author={:.2}, reply_ratio={:.2}, time_since_most_recent={:?}s, time_range={:?}s",
stage,
posts.len(),
original_count,
reply_count,
unique_author_count,
if unique_author_count > 0 {
posts.len() as f64 / unique_author_count as f64
} else {
0.0
},
reply_ratio,
time_since_most_recent,
if let (Some(o), Some(n)) = (time_since_oldest, time_since_most_recent) {
Some(o - n)
} else {
None
}
);
}
}
#[tonic::async_trait]
impl InNetworkPostsService for ThunderServiceImpl {
/// Get posts from users in the network
async fn get_in_network_posts(
&self,
request: Request<GetInNetworkPostsRequest>,
) -> Result<Response<GetInNetworkPostsResponse>, Status> {
// Try to acquire semaphore permit without blocking
// If we're at capacity, reject immediately with RESOURCE_EXHAUSTED
let _permit = match self.request_semaphore.try_acquire() {
Ok(permit) => {
IN_FLIGHT_REQUESTS.inc();
permit
}
Err(_) => {
REJECTED_REQUESTS.inc();
return Err(Status::resource_exhausted(
"Server at capacity, please retry",
));
}
};
// Use a guard to decrement in_flight_requests when the request completes
struct InFlightGuard;
impl Drop for InFlightGuard {
fn drop(&mut self) {
IN_FLIGHT_REQUESTS.dec();
}
}
let _in_flight_guard = InFlightGuard;
// Start timer for total latency
let _total_timer = Timer::new(GET_IN_NETWORK_POSTS_DURATION.clone());
let req = request.into_inner();
if req.debug {
info!(
"Received GetInNetworkPosts request: user_id={}, following_count={}, exclude_tweet_ids={}",
req.user_id,
req.following_user_ids.len(),
req.exclude_tweet_ids.len(),
);
}
// If following_user_id list is empty, fetch it from Strato
let following_user_ids = if req.following_user_ids.is_empty() && req.debug {
info!(
"Following list is empty, fetching from Strato for user {}",
req.user_id
);
match self
.strato_client
.fetch_following_list(req.user_id as i64, MAX_INPUT_LIST_SIZE as i32)
.await
{
Ok(following_list) => {
info!(
"Fetched {} following users from Strato for user {}",
following_list.len(),
req.user_id
);
following_list.into_iter().map(|id| id as u64).collect()
}
Err(e) => {
warn!(
"Failed to fetch following list from Strato for user {}: {}",
req.user_id, e
);
return Err(Status::internal(format!(
"Failed to fetch following list: {}",
e
)));
}
}
} else {
req.following_user_ids
};
// Record metrics for request parameters
GET_IN_NETWORK_POSTS_FOLLOWING_SIZE.observe(following_user_ids.len() as f64);
GET_IN_NETWORK_POSTS_EXCLUDED_SIZE.observe(req.exclude_tweet_ids.len() as f64);
// Start timer for latency without strato call
let _processing_timer = Timer::new(GET_IN_NETWORK_POSTS_DURATION_WITHOUT_STRATO.clone());
// Default max_results if not specified
let max_results = if req.max_results > 0 {
req.max_results as usize
} else if req.is_video_request {
MAX_VIDEOS_TO_RETURN
} else {
MAX_POSTS_TO_RETURN
};
GET_IN_NETWORK_POSTS_MAX_RESULTS.observe(max_results as f64);
// Limit following_user_ids and exclude_tweet_ids to first K entries
let following_count = following_user_ids.len();
if following_count > MAX_INPUT_LIST_SIZE {
warn!(
"Limiting following_user_ids from {} to {} entries for user {}",
following_count, MAX_INPUT_LIST_SIZE, req.user_id
);
}
let following_user_ids: Vec<u64> = following_user_ids
.into_iter()
.take(MAX_INPUT_LIST_SIZE)
.collect();
let exclude_count = req.exclude_tweet_ids.len();
if exclude_count > MAX_INPUT_LIST_SIZE {
warn!(
"Limiting exclude_tweet_ids from {} to {} entries for user {}",
exclude_count, MAX_INPUT_LIST_SIZE, req.user_id
);
}
let exclude_tweet_ids: Vec<u64> = req
.exclude_tweet_ids
.into_iter()
.take(MAX_INPUT_LIST_SIZE)
.collect();
// Clone Arc references needed inside spawn_blocking
let post_store = Arc::clone(&self.post_store);
let request_user_id = req.user_id as i64;
// Use spawn_blocking to avoid blocking tokio's async runtime
let proto_posts = tokio::task::spawn_blocking(move || {
// Create exclude tweet IDs set for efficient filtering of previously seen posts
let exclude_tweet_ids: HashSet<i64> =
exclude_tweet_ids.iter().map(|&id| id as i64).collect();
let start_time = Instant::now();
// Fetch all posts (original + secondary) for the followed users
let all_posts: Vec<LightPost> = if req.is_video_request {
post_store.get_videos_by_users(
&following_user_ids,
&exclude_tweet_ids,
start_time,
request_user_id,
)
} else {
post_store.get_all_posts_by_users(
&following_user_ids,
&exclude_tweet_ids,
start_time,
request_user_id,
)
};
// Analyze posts and report statistics after querying post_store
ThunderServiceImpl::analyze_and_report_post_statistics(&all_posts, "retrieved");
let scored_posts = score_recent(all_posts, max_results);
// Analyze posts and report statistics after scoring
ThunderServiceImpl::analyze_and_report_post_statistics(&scored_posts, "scored");
scored_posts
})
.await
.map_err(|e| Status::internal(format!("Failed to process posts: {}", e)))?;
if req.debug {
info!(
"Returning {} posts for user {}",
proto_posts.len(),
req.user_id
);
}
// Record the number of posts returned
GET_IN_NETWORK_POSTS_COUNT.observe(proto_posts.len() as f64);
let response = GetInNetworkPostsResponse { posts: proto_posts };
Ok(Response::new(response))
}
}
/// Score posts by recency (created_at timestamp, newer posts first)
fn score_recent(mut light_posts: Vec<LightPost>, max_results: usize) -> Vec<LightPost> {
light_posts.sort_unstable_by_key(|post| Reverse(post.created_at));
// Limit to max results
light_posts.into_iter().take(max_results).collect()
}