Open-source X Recommendation Algorithm

This commit is contained in:
CI agent
2026-05-15 07:28:04 +00:00
parent aaa167b3de
commit e414c171ed
187 changed files with 18260 additions and 923 deletions

2
.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
*.npz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text

View File

@@ -6,6 +6,7 @@ This repository contains the core recommendation system powering the "For You" f
## Table of Contents
- [Updates — May 15th, 2026](#updates--may-15th-2026)
- [Overview](#overview)
- [System Architecture](#system-architecture)
- [Components](#components)
@@ -22,6 +23,26 @@ This repository contains the core recommendation system powering the "For You" f
---
## Updates — May 15th, 2026
This release updates the For You algorithm code, including a runnable end-to-end inference pipeline alongside new components for content understanding, ads, and candidate sourcing.
1. **End-to-end inference pipeline:** A new [`phoenix/run_pipeline.py`](phoenix/run_pipeline.py) replaces the separate `run_ranker.py` and `run_retrieval.py` scripts with a single entry point that runs **retrieval → ranking** from exported checkpoints, mirroring how the two stages are composed in production.
2. **Pre-trained model artifacts:** A pre-trained mini Phoenix model (256-dim embeddings, 4 attention heads, 2 transformer layers) is now packaged as a ~3 GB archive distributed via Git LFS, enabling out-of-the-box inference without training your own model first.
3. **Grox content-understanding pipeline:** A new [`grox/`](grox/) service is included, providing classifiers, embedders, and a task-execution engine for content understanding workloads such as spam detection, post-category classification, and PTOS policy enforcement.
4. **Ads blending system:** Includes a new [`home-mixer/ads/`](home-mixer/ads/) module that handles ad injection and positioning within the feed, including brand-safety tracking that respects sensitive content boundaries.
5. **Query hydrators:** Home mixer now hydrates user context including followed topics, starter packs, impression bloom filters, IP, mutual follow graphs, and served history.
6. **Candidate hydrators:** Additional hydrators for engagement counts, brand safety signals, language codes, media detection, quote post expansion, mutual follow scores, and more.
7. **Candidate sources:** Adds sources for ads, who to follow, Phoenix MoE, Phoenix topics, prompts, and updates Thunder/Phoenix ones.
---
## Overview
The For You feed algorithm retrieves, ranks, and filters posts from two sources:

View File

@@ -2,23 +2,39 @@ use crate::filter::Filter;
use crate::hydrator::Hydrator;
use crate::query_hydrator::QueryHydrator;
use crate::scorer::Scorer;
use crate::selector::SelectResult;
use crate::selector::Selector;
use crate::side_effect::{SideEffect, SideEffectInput};
use crate::source::Source;
use crate::util;
use futures::future::join_all;
use log::{error, info, warn};
use std::any::type_name_of_val;
use std::sync::Arc;
use std::time::Instant;
use tonic::async_trait;
use tracing::{Span, field::Empty, info};
use xai_stats_receiver::{HistogramBuckets, global_stats_receiver};
const FINAL_RESULT_SIZE_SCOPE: [(&str, &str); 1] = [("requests", "result_size")];
const FINAL_RESULT_EMPTY_SCOPE: [(&str, &str); 1] = [("requests", "result_empty")];
#[derive(Copy, Clone, Debug)]
pub enum PipelineStage {
QueryHydrator,
DependentQueryHydrator,
Source,
Hydrator,
PostSelectionHydrator,
Filter,
PostSelectionFilter,
Scorer,
Selector,
SideEffect,
}
pub struct PipelineComponents {
pub stage: PipelineStage,
pub components: Vec<String>,
}
pub struct PipelineResult<Q, C> {
@@ -28,18 +44,36 @@ pub struct PipelineResult<Q, C> {
pub query: Arc<Q>,
}
/// Provides a stable request identifier for logging/tracing.
pub trait HasRequestId {
fn request_id(&self) -> &str;
impl<Q: Default, C> PipelineResult<Q, C> {
/// Create an empty result with a default query. Useful for short-circuiting
/// requests (e.g. test users) without running the pipeline.
pub fn empty() -> Self {
Self {
retrieved_candidates: vec![],
filtered_candidates: vec![],
selected_candidates: vec![],
query: Arc::new(Q::default()),
}
}
}
pub trait PipelineQuery: Clone + Send + Sync + 'static {
fn params(&self) -> &xai_feature_switches::Params;
fn decider(&self) -> Option<&xai_decider::Decider>;
}
pub trait PipelineCandidate: Clone + Send + Sync + 'static {}
impl<T> PipelineCandidate for T where T: Clone + Send + Sync + 'static {}
#[async_trait]
pub trait CandidatePipeline<Q, C>: Send + Sync
where
Q: HasRequestId + Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
Q: PipelineQuery,
C: PipelineCandidate,
{
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<Q>>];
fn dependent_query_hydrators(&self) -> &[Box<dyn QueryHydrator<Q>>] {
&[]
}
fn sources(&self) -> &[Box<dyn Source<Q, C>>];
fn hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
fn filters(&self) -> &[Box<dyn Filter<Q, C>>];
@@ -49,37 +83,48 @@ where
fn post_selection_filters(&self) -> &[Box<dyn Filter<Q, C>>];
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<Q, C>>>>;
fn result_size(&self) -> usize;
fn finalize(&self, _query: &Q, _candidates: &mut Vec<C>) {}
#[xai_stats_macro::receive_stats(latency=Bucket500To2500)]
async fn execute(&self, query: Q) -> PipelineResult<Q, C> {
let hydrated_query = self.hydrate_query(query).await;
let hydrated_query = self.hydrate_dependent_query(hydrated_query).await;
let candidates = self.fetch_candidates(&hydrated_query).await;
let hydrated_candidates = self.hydrate(&hydrated_query, candidates).await;
let (kept_candidates, mut filtered_candidates) = self
.filter(&hydrated_query, hydrated_candidates.clone())
.await;
let (kept_candidates, mut filtered_candidates) =
self.filter(&hydrated_query, hydrated_candidates.clone());
let scored_candidates = self.score(&hydrated_query, kept_candidates).await;
let selected_candidates = self.select(&hydrated_query, scored_candidates);
let SelectResult {
selected: selected_candidates,
non_selected: mut non_selected_candidates,
} = self.select(&hydrated_query, scored_candidates);
let post_selection_hydrated_candidates = self
.hydrate_post_selection(&hydrated_query, selected_candidates)
.await;
let (mut final_candidates, post_selection_filtered_candidates) = self
.filter_post_selection(&hydrated_query, post_selection_hydrated_candidates)
.await;
let (mut final_candidates, post_selection_filtered_candidates) =
self.filter_post_selection(&hydrated_query, post_selection_hydrated_candidates);
filtered_candidates.extend(post_selection_filtered_candidates);
final_candidates.truncate(self.result_size());
let truncated_candidates =
final_candidates.split_off(self.result_size().min(final_candidates.len()));
non_selected_candidates.extend(truncated_candidates);
self.finalize(&hydrated_query, &mut final_candidates);
self.stat_result_size(&final_candidates);
let arc_hydrated_query = Arc::new(hydrated_query);
let input = Arc::new(SideEffectInput {
query: arc_hydrated_query.clone(),
selected_candidates: final_candidates.clone(),
non_selected_candidates, // candidates are moved so we don't need to clone them
});
self.run_side_effects(input);
@@ -91,78 +136,158 @@ where
}
}
/// Return all configured components grouped by stage.
fn components(&self) -> Vec<PipelineComponents> {
fn stage<T: ?Sized>(
stage: PipelineStage,
items: &[Box<T>],
name: impl Fn(&T) -> &str,
) -> PipelineComponents {
PipelineComponents {
stage,
components: items
.iter()
.map(|item| name(item.as_ref()).to_string())
.collect(),
}
}
vec![
stage(PipelineStage::QueryHydrator, self.query_hydrators(), |h| {
h.name()
}),
stage(
PipelineStage::DependentQueryHydrator,
self.dependent_query_hydrators(),
|h| h.name(),
),
stage(PipelineStage::Source, self.sources(), |s| s.name()),
stage(PipelineStage::Hydrator, self.hydrators(), |h| h.name()),
stage(PipelineStage::Filter, self.filters(), |f| f.name()),
stage(PipelineStage::Scorer, self.scorers(), |s| s.name()),
PipelineComponents {
stage: PipelineStage::Selector,
components: vec![self.selector().name().to_string()],
},
stage(
PipelineStage::PostSelectionHydrator,
self.post_selection_hydrators(),
|h| h.name(),
),
stage(
PipelineStage::PostSelectionFilter,
self.post_selection_filters(),
|f| f.name(),
),
stage(
PipelineStage::SideEffect,
self.side_effects().as_ref(),
|s| s.name(),
),
]
}
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
// -------------------------- Pipeline Execution --------------------------
/// Run all query hydrators in parallel and merge results into the query.
#[tracing::instrument(skip_all, name = "query_hydrators", fields(
total_count = Empty,
enabled_count = Empty,
disabled = Empty,
))]
async fn hydrate_query(&self, query: Q) -> Q {
let request_id = query.request_id().to_string();
let hydrators: Vec<_> = self
.query_hydrators()
.iter()
.filter(|h| h.enable(&query))
.collect();
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(&query));
let start = Instant::now();
let all = self.query_hydrators();
Self::record_enabled_components(all.iter(), |h| h.enable(&query), |h| h.name());
let hydrators: Vec<_> = all.iter().filter(|h| h.enable(&query)).collect();
let hydrate_futures = hydrators.iter().map(|h| h.run(&query));
let results = join_all(hydrate_futures).await;
let mut hydrated_query = query;
for (hydrator, result) in hydrators.iter().zip(results) {
match result {
Ok(hydrated) => {
hydrator.update(&mut hydrated_query, hydrated);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::QueryHydrator,
hydrator.name(),
err
);
}
if let Ok(hydrated) = result {
hydrator.update(&mut hydrated_query, hydrated);
}
}
self.log_stage(start);
hydrated_query
}
/// Run dependent query hydrators in parallel and merge results into the query.
///
/// This stage runs **after** [`hydrate_query`], so the incoming query
/// already has all initial features populated.
#[tracing::instrument(skip_all, name = "dependent_query_hydrators", fields(
total_count = Empty,
enabled_count = Empty,
disabled = Empty,
))]
async fn hydrate_dependent_query(&self, query: Q) -> Q {
let all = self.dependent_query_hydrators();
if all.is_empty() {
return query;
}
let start = Instant::now();
Self::record_enabled_components(all.iter(), |h| h.enable(&query), |h| h.name());
let hydrators: Vec<_> = all.iter().filter(|h| h.enable(&query)).collect();
let hydrate_futures = hydrators.iter().map(|h| h.run(&query));
let results = join_all(hydrate_futures).await;
let mut hydrated_query = query;
for (hydrator, result) in hydrators.iter().zip(results) {
if let Ok(hydrated) = result {
hydrator.update(&mut hydrated_query, hydrated);
}
}
self.log_stage(start);
hydrated_query
}
/// Run all candidate sources in parallel and collect results.
#[tracing::instrument(skip_all, name = "sources", fields(
total_count = Empty,
enabled_count = Empty,
disabled = Empty,
candidate_count = Empty,
))]
async fn fetch_candidates(&self, query: &Q) -> Vec<C> {
let request_id = query.request_id().to_string();
let sources: Vec<_> = self.sources().iter().filter(|s| s.enable(query)).collect();
let source_futures = sources.iter().map(|s| s.get_candidates(query));
let start = Instant::now();
let all = self.sources();
Self::record_enabled_components(all.iter(), |s| s.enable(query), |s| s.name());
let sources: Vec<_> = all.iter().filter(|s| s.enable(query)).collect();
let source_futures = sources.iter().map(|s| s.run(query));
let results = join_all(source_futures).await;
let mut collected = Vec::new();
for (source, result) in sources.iter().zip(results) {
match result {
Ok(mut candidates) => {
info!(
"request_id={} stage={:?} component={} fetched {} candidates",
request_id,
PipelineStage::Source,
source.name(),
candidates.len()
);
collected.append(&mut candidates);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::Source,
source.name(),
err
);
}
}
for mut candidates in results.into_iter().flatten() {
collected.append(&mut candidates);
}
Span::current().record("candidate_count", collected.len());
self.log_stage_size(start, collected.len());
collected
}
/// Run all candidate hydrators in parallel and merge results into candidates.
#[tracing::instrument(skip_all, name = "hydrators", fields(
total_count = Empty,
enabled_count = Empty,
disabled = Empty,
))]
async fn hydrate(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
self.run_hydrators(query, candidates, self.hydrators(), PipelineStage::Hydrator)
.await
}
/// Run post-selection candidate hydrators in parallel and merge results into candidates.
#[tracing::instrument(skip_all, name = "post_selection_hydrators", fields(
total_count = Empty,
enabled_count = Empty,
disabled = Empty,
))]
async fn hydrate_post_selection(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
self.run_hydrators(
query,
@@ -179,139 +304,114 @@ where
query: &Q,
mut candidates: Vec<C>,
hydrators: &[Box<dyn Hydrator<Q, C>>],
stage: PipelineStage,
_stage: PipelineStage,
) -> Vec<C> {
let request_id = query.request_id().to_string();
let start = Instant::now();
Self::record_enabled_components(hydrators.iter(), |h| h.enable(query), |h| h.name());
let hydrators: Vec<_> = hydrators.iter().filter(|h| h.enable(query)).collect();
let expected_len = candidates.len();
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(query, &candidates));
let hydrate_futures = hydrators.iter().map(|h| h.run(query, &candidates));
let results = join_all(hydrate_futures).await;
for (hydrator, result) in hydrators.iter().zip(results) {
match result {
Ok(hydrated) => {
if hydrated.len() == expected_len {
hydrator.update_all(&mut candidates, hydrated);
} else {
warn!(
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
request_id,
stage,
hydrator.name(),
expected_len,
hydrated.len()
);
}
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
stage,
hydrator.name(),
err
);
}
}
hydrator.update_all(&mut candidates, result);
}
self.log_stage_size(start, candidates.len());
candidates
}
/// Run all filters sequentially. Each filter partitions candidates into kept and removed.
async fn filter(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
#[tracing::instrument(skip_all, name = "filters", fields(
total_count = Empty,
enabled_count = Empty,
disabled = Empty,
input_count = candidates.len(),
kept_count = Empty,
removed_count = Empty,
filter_rate = Empty,
))]
fn filter(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
self.run_filters(query, candidates, self.filters(), PipelineStage::Filter)
.await
}
/// Run post-scoring filters sequentially on already-scored candidates.
async fn filter_post_selection(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
#[tracing::instrument(skip_all, name = "post_selection_filters", fields(
total_count = Empty,
enabled_count = Empty,
disabled = Empty,
input_count = candidates.len(),
kept_count = Empty,
removed_count = Empty,
filter_rate = Empty,
))]
fn filter_post_selection(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
self.run_filters(
query,
candidates,
self.post_selection_filters(),
PipelineStage::PostSelectionFilter,
)
.await
}
// Shared helper to run filters sequentially from a provided filter list.
async fn run_filters(
fn run_filters(
&self,
query: &Q,
mut candidates: Vec<C>,
filters: &[Box<dyn Filter<Q, C>>],
stage: PipelineStage,
_stage: PipelineStage,
) -> (Vec<C>, Vec<C>) {
let request_id = query.request_id().to_string();
Self::record_enabled_components(filters.iter(), |f| f.enable(query), |f| f.name());
let mut all_removed = Vec::new();
let mut removed_per_filter: Vec<(String, usize)> = Vec::new();
for filter in filters.iter().filter(|f| f.enable(query)) {
let backup = candidates.clone();
match filter.filter(query, candidates).await {
Ok(result) => {
candidates = result.kept;
all_removed.extend(result.removed);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
stage,
filter.name(),
err
);
candidates = backup;
}
let result = filter.run(query, candidates);
if !result.removed.is_empty() {
removed_per_filter.push((filter.name().to_string(), result.removed.len()));
}
candidates = result.kept;
all_removed.extend(result.removed);
}
info!(
"request_id={} stage={:?} kept {}, removed {}",
request_id,
stage,
candidates.len(),
all_removed.len()
);
let kept = candidates.len();
let removed = all_removed.len();
let total = kept + removed;
let rate = if total > 0 {
removed as f64 / total as f64
} else {
0.0
};
Span::current().record("kept_count", kept);
Span::current().record("removed_count", removed);
Span::current().record("filter_rate", format!("{:.3}", rate).as_str());
self.log_filters(kept, removed, &removed_per_filter);
(candidates, all_removed)
}
/// Run all scorers sequentially and apply their results to candidates.
#[tracing::instrument(skip_all, name = "scorers", fields(
total_count = Empty,
enabled_count = Empty,
disabled = Empty,
))]
async fn score(&self, query: &Q, mut candidates: Vec<C>) -> Vec<C> {
let request_id = query.request_id().to_string();
let expected_len = candidates.len();
for scorer in self.scorers().iter().filter(|s| s.enable(query)) {
match scorer.score(query, &candidates).await {
Ok(scored) => {
if scored.len() == expected_len {
scorer.update_all(&mut candidates, scored);
} else {
warn!(
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
request_id,
PipelineStage::Scorer,
scorer.name(),
expected_len,
scored.len()
);
}
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::Scorer,
scorer.name(),
err
);
}
}
let start = Instant::now();
let all = self.scorers();
Self::record_enabled_components(all.iter(), |s| s.enable(query), |s| s.name());
for scorer in all.iter().filter(|s| s.enable(query)) {
let scored = scorer.run(query, &candidates).await;
scorer.update_all(&mut candidates, scored);
}
self.log_stage_size(start, candidates.len());
candidates
}
/// Select (sort/truncate) candidates using the configured selector
fn select(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
fn select(&self, query: &Q, candidates: Vec<C>) -> SelectResult<C> {
if self.selector().enable(query) {
self.selector().select(query, candidates)
self.selector().run(query, candidates)
} else {
candidates
SelectResult {
selected: candidates,
non_selected: vec![],
}
}
}
@@ -326,4 +426,68 @@ where
let _ = join_all(futures).await;
});
}
// -------------------------- Helpers --------------------------
/// Iterate components, applying `is_enabled` to each, and record
/// `total_count`, `enabled_count`, and (if any are disabled) the
/// comma-joined names of disabled components on the current tracing span.
fn record_enabled_components<'a, T: 'a>(
items: impl Iterator<Item = &'a T>,
is_enabled: impl Fn(&T) -> bool,
get_name: impl Fn(&T) -> &str,
) {
let mut total = 0usize;
let mut disabled: Vec<&str> = Vec::new();
for item in items {
total += 1;
if !is_enabled(item) {
disabled.push(get_name(item));
}
}
let span = Span::current();
span.record("total_count", total);
span.record("enabled_count", total - disabled.len());
if !disabled.is_empty() {
span.record("disabled", disabled.join(",").as_str());
}
}
// -------------------------- Logging and Stats --------------------------
fn log_stage(&self, start: Instant) {
info!("latency_ms={}", start.elapsed().as_millis());
}
fn log_stage_size(&self, start: Instant, size: usize) {
info!("latency_ms={} size={}", start.elapsed().as_millis(), size);
}
fn log_filters(&self, kept: usize, removed: usize, removed_per_filter: &[(String, usize)]) {
let removed_summary = removed_per_filter
.iter()
.map(|(name, removed)| format!("{}={}", name, removed))
.collect::<Vec<_>>()
.join(",");
info!(
"kept {}, removed {} removed_per_filter [{}]",
kept, removed, removed_summary,
);
}
fn stat_result_size(&self, final_candidates: &[C]) {
if let Some(receiver) = global_stats_receiver() {
let response_size = final_candidates.len();
let metric_name = format!("{}.execute", self.name());
receiver.observe(
metric_name.as_str(),
&FINAL_RESULT_SIZE_SCOPE,
response_size as f64,
HistogramBuckets::Bucket0To50,
);
if response_size == 0 {
receiver.incr(metric_name.as_str(), &FINAL_RESULT_EMPTY_SCOPE, 1u64);
}
}
}
}

View File

@@ -1,7 +1,11 @@
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
use crate::util;
use std::any::{Any, type_name_of_val};
use tracing::{Span, field::Empty};
use xai_stats_receiver::global_stats_receiver;
const KEPT_SCOPE: [(&str, &str); 1] = [("requests", "kept")];
const REMOVED_SCOPE: [(&str, &str); 1] = [("requests", "removed")];
pub struct FilterResult<C> {
pub kept: Vec<C>,
@@ -9,24 +13,58 @@ pub struct FilterResult<C> {
}
/// Filters run sequentially and partition candidates into kept and removed sets
#[async_trait]
pub trait Filter<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
Q: PipelineQuery,
C: PipelineCandidate,
{
/// Decide if this filter should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
#[xai_stats_macro::receive_stats(latency=Bucket0To50)]
#[tracing::instrument(skip_all, name = "filter", fields(
name = self.name(),
input_count = candidates.len(),
kept_count = Empty,
removed_count = Empty,
filter_rate = Empty,
))]
fn run(&self, query: &Q, candidates: Vec<C>) -> FilterResult<C> {
let result = self.filter(query, candidates);
let total = result.kept.len() + result.removed.len();
let rate = if total > 0 {
result.removed.len() as f64 / total as f64
} else {
0.0
};
let span = Span::current();
span.record("kept_count", result.kept.len());
span.record("removed_count", result.removed.len());
span.record("filter_rate", format!("{:.3}", rate).as_str());
self.stat(&result);
result
}
/// Filter candidates by evaluating each against some criteria.
/// Returns a FilterResult containing kept candidates (which continue to the next stage)
/// and removed candidates (which are excluded from further processing).
async fn filter(&self, query: &Q, candidates: Vec<C>) -> Result<FilterResult<C>, String>;
fn filter(&self, query: &Q, candidates: Vec<C>) -> FilterResult<C>;
/// Returns a stable name for logging/metrics.
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
fn stat(&self, result: &FilterResult<C>) {
if let Some(receiver) = global_stats_receiver() {
let metric_name = format!("{}.run", self.name());
receiver.incr(metric_name.as_str(), &KEPT_SCOPE, result.kept.len() as u64);
receiver.incr(
metric_name.as_str(),
&REMOVED_SCOPE,
result.removed.len() as u64,
);
}
}
}

View File

@@ -1,13 +1,17 @@
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
use crate::util;
use std::any::{Any, type_name_of_val};
use std::hash::Hash;
use tonic::async_trait;
use tracing::warn;
use xai_stats_receiver::global_stats_receiver;
// Hydrators run in parallel and update candidate fields
#[async_trait]
pub trait Hydrator<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
Q: PipelineQuery,
C: PipelineCandidate,
{
/// Decide if this hydrator should run for the given query
fn enable(&self, _query: &Q) -> bool {
@@ -19,17 +23,41 @@ where
///
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
/// Dropping candidates in a hydrator is not allowed - use a filter stage instead.
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>>;
#[xai_stats_macro::receive_stats(latency=Bucket50To500, size=Bucket500To2500)]
#[tracing::instrument(skip_all, name = "hydrator", fields(name = self.name()))]
async fn run(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>> {
let hydrated = self.hydrate(query, candidates).await;
let expected_len = candidates.len();
if hydrated.len() == expected_len {
hydrated
} else {
let message = format!(
"Hydrator length_mismatch expected={} got={}",
expected_len,
hydrated.len()
);
warn!(
"Skipped: length_mismatch expected={} got={}",
expected_len,
hydrated.len()
);
vec![Err(message); expected_len]
}
}
/// Update a single candidate with the hydrated fields.
/// Only the fields this hydrator is responsible for should be copied.
fn update(&self, candidate: &mut C, hydrated: C);
/// Update all candidates with the hydrated fields from `hydrated`.
/// Update all successfully hydrated candidates with the fields from `hydrated`.
/// Default implementation iterates and calls `update` for each pair.
fn update_all(&self, candidates: &mut [C], hydrated: Vec<C>) {
for (c, h) in candidates.iter_mut().zip(hydrated) {
self.update(c, h);
fn update_all(&self, candidates: &mut [C], hydrated: Vec<Result<C, String>>) {
for (candidate, hydrated) in candidates.iter_mut().zip(hydrated) {
if let Ok(hydrated) = hydrated {
self.update(candidate, hydrated);
}
}
}
@@ -37,3 +65,125 @@ where
util::short_type_name(type_name_of_val(self))
}
}
const CACHE_HIT_SCOPE: [(&str, &str); 1] = [("requests", "cache_hit")];
const CACHE_MISS_SCOPE: [(&str, &str); 1] = [("requests", "cache_miss")];
#[async_trait]
pub trait CacheStore<K, V>: Send + Sync {
async fn get(&self, key: &K) -> Option<V>;
async fn insert(&self, key: K, value: V);
}
#[async_trait]
pub trait CachedHydrator<Q, C>: Any + Send + Sync
where
Q: PipelineQuery,
C: PipelineCandidate,
{
type CacheKey: Eq + Hash + Send + Sync + 'static;
type CacheValue: Clone + Send + Sync + 'static;
fn enable(&self, _query: &Q) -> bool {
true
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue>;
fn cache_key(&self, candidate: &C) -> Self::CacheKey;
fn cache_value(&self, hydrated: &C) -> Self::CacheValue;
fn hydrate_from_cache(&self, value: Self::CacheValue) -> C;
async fn hydrate_from_client(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>>;
fn update(&self, candidate: &mut C, hydrated: C);
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
fn stat_cache(&self, cache_hits: usize, cache_misses: usize) {
if let Some(receiver) = global_stats_receiver() {
let metric_name = format!("{}.cache", self.name());
if cache_hits > 0 {
receiver.incr(metric_name.as_str(), &CACHE_HIT_SCOPE, cache_hits as u64);
}
if cache_misses > 0 {
receiver.incr(metric_name.as_str(), &CACHE_MISS_SCOPE, cache_misses as u64);
}
}
}
}
#[async_trait]
impl<Q, C, T> Hydrator<Q, C> for T
where
Q: PipelineQuery,
C: PipelineCandidate,
T: CachedHydrator<Q, C> + ?Sized,
{
fn enable(&self, query: &Q) -> bool {
CachedHydrator::enable(self, query)
}
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>> {
let mut results = vec![None; candidates.len()];
let mut missing_candidates = Vec::new();
let mut missing_keys = Vec::new();
let mut missing_indices = Vec::new();
let mut cache_hits = 0usize;
let mut cache_misses = 0usize;
for (index, candidate) in candidates.iter().enumerate() {
let key = self.cache_key(candidate);
match self.cache_store().get(&key).await {
Some(value) => {
results[index] = Some(Ok(self.hydrate_from_cache(value)));
cache_hits += 1;
}
None => {
missing_candidates.push(candidate.clone());
missing_keys.push(key);
missing_indices.push(index);
cache_misses += 1;
}
}
}
self.stat_cache(cache_hits, cache_misses);
if !missing_candidates.is_empty() {
let hydrated_missing = self.hydrate_from_client(query, &missing_candidates).await;
if hydrated_missing.len() != missing_candidates.len() {
let message = format!(
"CachedHydrator length_mismatch expected={} got={}",
missing_candidates.len(),
hydrated_missing.len()
);
return vec![Err(message); candidates.len()];
}
for ((index, key), hydrated) in missing_indices
.into_iter()
.zip(missing_keys.into_iter())
.zip(hydrated_missing.into_iter())
{
if let Ok(ref hydrated_candidate) = hydrated {
let value = self.cache_value(hydrated_candidate);
self.cache_store().insert(key, value).await;
}
results[index] = Some(hydrated);
}
}
results
.into_iter()
.map(|result| {
result.unwrap_or_else(|| Err("Missing hydration result for candidate".to_string()))
})
.collect()
}
fn update(&self, candidate: &mut C, hydrated: C) {
CachedHydrator::update(self, candidate, hydrated);
}
}

View File

@@ -1,18 +1,32 @@
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::candidate_pipeline::PipelineQuery;
use crate::util;
use tracing::error;
#[async_trait]
pub trait QueryHydrator<Q>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
Q: PipelineQuery,
{
/// Decide if this query hydrator should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
#[xai_stats_macro::receive_stats]
#[tracing::instrument(skip_all, name = "query_hydrator", fields(name = self.name()))]
async fn run(&self, query: &Q) -> Result<Q, String> {
match self.hydrate(query).await {
Ok(hydrated) => Ok(hydrated),
Err(err) => {
error!("Failed: {}", err);
Err(err)
}
}
}
/// Hydrate the query by performing async operations.
/// Returns a new query with this hydrator's fields populated.
async fn hydrate(&self, query: &Q) -> Result<Q, String>;

View File

@@ -1,35 +1,61 @@
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
use crate::util;
use std::any::type_name_of_val;
use tonic::async_trait;
use tracing::warn;
/// Scorers update candidate fields (like a score field) and run sequentially
#[async_trait]
pub trait Scorer<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
Q: PipelineQuery,
C: PipelineCandidate,
{
/// Decide if this scorer should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
#[xai_stats_macro::receive_stats]
#[tracing::instrument(skip_all, name = "scorer", fields(name = self.name()))]
async fn run(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>> {
let scored = self.score(query, candidates).await;
let expected_len = candidates.len();
if scored.len() == expected_len {
scored
} else {
let message = format!(
"Scorer length_mismatch expected={} got={}",
expected_len,
scored.len()
);
warn!(
"Skipped: length_mismatch expected={} got={}",
expected_len,
scored.len()
);
vec![Err(message); expected_len]
}
}
/// Score candidates by performing async operations.
/// Returns candidates with this scorer's fields populated.
///
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
/// Dropping candidates in a scorer is not allowed - use a filter stage instead.
async fn score(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
/// Dropping candidates in a hydrator is not allowed - use a filter stage instead.
async fn score(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>>;
/// Update a single candidate with the scored fields.
/// Only the fields this scorer is responsible for should be copied.
fn update(&self, candidate: &mut C, scored: C);
/// Update all candidates with the scored fields from `scored`.
/// Update all successfully scored candidates with the fields from `scored`.
/// Default implementation iterates and calls `update` for each pair.
fn update_all(&self, candidates: &mut [C], scored: Vec<C>) {
for (c, s) in candidates.iter_mut().zip(scored) {
self.update(c, s);
fn update_all(&self, candidates: &mut [C], scored: Vec<Result<C, String>>) {
for (candidate, scored) in candidates.iter_mut().zip(scored) {
if let Ok(scored) = scored {
self.update(candidate, scored);
}
}
}

View File

@@ -1,25 +1,65 @@
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
use crate::util;
use std::any::type_name_of_val;
use tracing::{Span, field::Empty};
pub struct SelectResult<C> {
pub selected: Vec<C>,
pub non_selected: Vec<C>,
}
impl<C> SelectResult<C> {
pub fn len(&self) -> usize {
self.selected.len()
}
pub fn is_empty(&self) -> bool {
self.selected.is_empty() && self.non_selected.is_empty()
}
}
pub trait Selector<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
Q: PipelineQuery,
C: PipelineCandidate,
{
/// Default selection: sort and truncate based on provided configs
fn select(&self, _query: &Q, candidates: Vec<C>) -> Vec<C> {
let mut sorted = self.sort(candidates);
if let Some(limit) = self.size() {
sorted.truncate(limit);
}
sorted
}
/// Decide if this selector should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
#[xai_stats_macro::receive_stats(latency=Bucket0To50, size=Bucket0To50)]
#[tracing::instrument(skip_all, name = "selector", fields(
name = self.name(),
input_count = candidates.len(),
selected_count = Empty,
non_selected_count = Empty,
))]
fn run(&self, query: &Q, candidates: Vec<C>) -> SelectResult<C> {
let result = self.select(query, candidates);
let span = Span::current();
span.record("selected_count", result.selected.len());
span.record("non_selected_count", result.non_selected.len());
result
}
// Returns (selected, non_selected).
fn select(&self, _query: &Q, candidates: Vec<C>) -> SelectResult<C> {
let mut sorted = self.sort(candidates);
if let Some(limit) = self.size() {
let non_selected = sorted.split_off(limit.min(sorted.len()));
SelectResult {
selected: sorted,
non_selected,
}
} else {
SelectResult {
selected: sorted,
non_selected: vec![],
}
}
}
/// Extract the score from a candidate to use for sorting.
fn score(&self, candidate: &C) -> f64;

View File

@@ -1,27 +1,35 @@
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
use crate::util;
use std::any::type_name_of_val;
use std::sync::Arc;
use tonic::async_trait;
// A side-effect is an action run that doesn't affect the pipeline result from being returned
#[derive(Clone)]
pub struct SideEffectInput<Q, C> {
pub query: Arc<Q>,
pub selected_candidates: Vec<C>,
pub non_selected_candidates: Vec<C>,
}
#[async_trait]
pub trait SideEffect<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
Q: PipelineQuery,
C: PipelineCandidate,
{
/// Decide if this side-effect should be run
fn enable(&self, _query: Arc<Q>) -> bool {
true
}
async fn run(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String>;
#[xai_stats_macro::receive_stats]
async fn run(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String> {
self.side_effect(input).await
}
async fn side_effect(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String>;
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))

View File

@@ -1,20 +1,37 @@
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
use crate::util;
use tracing::{error, info};
#[async_trait]
pub trait Source<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
Q: PipelineQuery,
C: PipelineCandidate,
{
/// Decide if this source should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
async fn get_candidates(&self, query: &Q) -> Result<Vec<C>, String>;
#[xai_stats_macro::receive_stats(size=Bucket500To1000)]
#[tracing::instrument(skip_all, name = "source", fields(name = self.name()))]
async fn run(&self, query: &Q) -> Result<Vec<C>, String> {
match self.source(query).await {
Ok(candidates) => {
info!("Fetched {} candidates", candidates.len());
Ok(candidates)
}
Err(err) => {
error!("Failed: {}", err);
Err(err)
}
}
}
async fn source(&self, query: &Q) -> Result<Vec<C>, String>;
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))

View File

@@ -0,0 +1,3 @@
pub fn short_type_name(full: &'static str) -> &'static str {
full.rsplit("::").next().unwrap_or(full)
}

0
grox/__init__.py Normal file
View File

View File

@@ -0,0 +1,160 @@
import re
import uuid
import logging
from grox.data_loaders.media_loader import MediaLoader
from grox.data_loaders.strato_loader import TweetStratoLoader
from grox.lm.post import PostRenderer
from grox.lm.user import UserRenderer
from grox.lm.convo import Role, Message, Conversation
from grox.config.config import ModelName, grox_config
from grok_sampler.config import GrokModelConfig
from grox.prompts.template import BangerMiniVlmScreenScore
from grok_sampler.vision_sampler import VisionSampler
from grox.data_loaders.data_types import (
Post,
ContentCategoryType,
ContentCategoryResult,
TweetBoolMetadata,
ContentCategoryScore,
)
from grox.classifiers.content.classifier import ContentClassifier
from monitor.metrics import Metrics
from pydantic import BaseModel
from strato_http.queries.grok_topics import StratoGrokTopics
logger = logging.getLogger(__name__)
class BangerInitialScreenResult(BaseModel):
quality_score: float
description: str
tags: list[str]
taxonomy_categories: list[dict] | None = None
tweet_bool_metadata: TweetBoolMetadata | None = None
is_image_editable_by_grok: bool | None = None
slop_score: int | None = None
has_minor_score: float | None = None
class BangerInitialScreenClassifier(ContentClassifier):
result_pattern = re.compile(r"(.*)<json>(.*)</json>", re.DOTALL)
def __init__(self):
vlm_config = grox_config.get_model(ModelName.VLM_PRIMARY)
vlm_config.temperature = 0.000001
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
super().__init__(
categories=[
ContentCategoryType.BANGER_INITIAL_SCREEN,
ContentCategoryType.GROK_RANKER,
],
llm=vlm,
)
self._topics = None
@staticmethod
def build_convo(post: Post, topics: list | None = None) -> Conversation:
convo = Conversation(conversation_id=uuid.uuid4().hex)
convo.messages.append(
Message(
role=Role.SYSTEM,
content=[BangerMiniVlmScreenScore().render(params={"topics": topics})],
)
)
user_msg = Message(role=Role.USER, content=[])
user_msg.content.extend(UserRenderer.render(post.user))
user_msg.content.extend(PostRenderer.render(post))
user_msg.content.append(
f"\n\nAnalyze the post {post.id} and provide the requested JSON object for the post."
)
convo.messages.append(user_msg)
convo.messages.append(Message(role=Role.ASSISTANT, content=[""], separator=""))
return convo
async def classify(
self, post: Post, topics: list | None = None
) -> list[ContentCategoryResult]:
self._topics = topics
return await super().classify(post)
async def _classify_for_eval(self, post: Post) -> str:
self._topics = None
convo = await self._to_convo(post)
logger.info(f"Banger initial screen conversation for post {post.id}")
result = await self.llm.sample(
convo.interleave(), conversation_id=convo.conversation_id
)
logger.info(f"Banger initial screen result for post {post.id}: {result}")
return result
async def _to_convo(self, post: Post) -> Conversation:
return self.build_convo(post, topics=self._topics)
async def _sample(self, convo: Conversation) -> str:
return await self.llm.sample(
convo.interleave(), conversation_id=convo.conversation_id
)
async def _parse(self, post: Post, output: str) -> list[ContentCategoryResult]:
match = self.result_pattern.search(output)
if match:
reasoning = match.group(1).strip()
logger.info(
f"Banger initial screen result reasoning for post {post.id}: {reasoning}"
)
result = BangerInitialScreenResult.model_validate_json(
match.group(2).strip()
)
score = result.quality_score
Metrics.histogram(
"banger_initial_screen_score",
explicit_bucket_boundaries_advisory=[
0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1,
],
).record(score)
banger_initial_positive = score >= 0.4
taxonomy_categories = []
if result.taxonomy_categories:
for tc in result.taxonomy_categories:
taxonomy_categories.append(
ContentCategoryScore(
id=tc["id"],
name=tc["name"],
score=tc["score"],
category_id=None,
)
)
return [
ContentCategoryResult(
category=cat,
positive=banger_initial_positive,
score=score,
reason=reasoning,
summary=result.description,
tags=result.tags,
taxonomy_categories=taxonomy_categories,
tweet_bool_metadata=result.tweet_bool_metadata,
is_image_editable_by_grok=result.is_image_editable_by_grok,
slop_score=result.slop_score,
has_minor_score=result.has_minor_score,
)
for cat in self.categories
]
else:
raise ValueError(f"Invalid output: {output}")

View File

@@ -0,0 +1,98 @@
import logging
import time
import traceback
from abc import ABC, abstractmethod
from grok_sampler.llm import LiteLLM
from grox.data_loaders.data_types import (
ContentCategoryResult,
ContentCategoryType,
Post,
)
from grox.lm.convo import Conversation
from monitor.metrics import Metrics
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class ContentClassifier(ABC):
def __init__(self, categories: list[ContentCategoryType], llm: LiteLLM):
self.categories = categories
self.llm = llm
@property
def model_name(self) -> str:
return self.llm.model_config.model_name
async def classify(self, post: Post) -> list[ContentCategoryResult]:
logger.info(
f"[{self.__class__.__name__}] started processing content classify request: {post.id}"
)
for category in self.categories:
Metrics.counter(f"content.classification.request.count").add(
1, attributes={"category": category.value.lower()}
)
for category in self.categories:
Metrics.counter("content.classification.intake.count").add(
1, attributes={"category": category.value.lower()}
)
start = time.perf_counter()
try:
res = await self._classify(post)
except Exception:
for category in self.categories:
Metrics.counter(f"content.classification.error.count").add(
1, attributes={"category": category.value.lower()}
)
logger.error(
f"[{self.__class__.__name__}] error processing content classify request: {post.id} {traceback.format_exc()}"
)
raise
for category in self.categories:
Metrics.counter(f"content.classification.success.count").add(
1, attributes={"category": category.value.lower()}
)
end = time.perf_counter()
logger.info(
f"[{self.__class__.__name__}] finished processing content classify request: {post.id} in {end - start:.2f} seconds"
)
Metrics.histogram(f"content.classification.latency").record(
end - start, attributes={"class": self.__class__.__name__}
)
self._post_process_for_logging(res, start, end)
return res
def _post_process_for_logging(
self, res: list[BaseModel], start_time, end_time
) -> None:
for category in self.categories:
Metrics.histogram(f"content.classification.latency").record(
end_time - start_time, attributes={"category": category.value.lower()}
)
for result in res:
Metrics.counter(f"content.classification.result.count").add(
1,
attributes={
"category": result.category.value.lower(),
"positive": str(result.positive),
},
)
@abstractmethod
async def _to_convo(self, post: Post) -> Conversation:
pass
@abstractmethod
async def _sample(self, convo: Conversation) -> str:
pass
@abstractmethod
async def _parse(self, post: Post, output: str) -> list[ContentCategoryResult]:
pass
async def _classify(self, post: Post) -> list[ContentCategoryResult]:
convo = await self._to_convo(post)
output = await self._sample(convo)
return await self._parse(post, output)

View File

@@ -0,0 +1,92 @@
import re
import uuid
import logging
import asyncio
from grox.data_loaders.media_loader import MediaLoader
from grox.data_loaders.strato_loader import TweetStratoLoader
from grox.lm.post import PostRenderer
from grox.lm.user import UserRenderer
from grox.lm.convo import Role, Message, Conversation
from grox.config.config import ModelName, grox_config
from grok_sampler.config import GrokModelConfig
from grox.prompts.template import PostSafetyDeluxe
from grok_sampler.vision_sampler import VisionSampler
from grox.data_loaders.data_types import (
Post,
ContentCategoryType,
ContentCategoryResult,
TweetBoolMetadata,
)
from grox.classifiers.content.classifier import ContentClassifier
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class PostSafetyScreenResult(BaseModel):
tweet_bool_metadata: TweetBoolMetadata
class PostSafetyDeluxeClassifier(ContentClassifier):
result_pattern = re.compile(r"(.*)<json>(.*)</json>", re.DOTALL)
def __init__(self):
vlm_config = grox_config.get_model(ModelName.VLM_PRIMARY_CRITICAL)
vlm_config.temperature = 0.000001
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
super().__init__(
categories=[
ContentCategoryType.POST_SAFETY_SCREEN,
],
llm=vlm,
)
@staticmethod
def build_convo(post: Post) -> Conversation:
convo = Conversation(conversation_id=uuid.uuid4().hex)
convo.messages.append(
Message(role=Role.SYSTEM, content=[PostSafetyDeluxe().render()])
)
user_msg = Message(role=Role.USER, content=[])
user_msg.content.extend(UserRenderer.render(post.user))
user_msg.content.extend(PostRenderer.render(post))
user_msg.content.append(
f"\n\nAnalyze the post {post.id} and provide the requested JSON object for the post."
)
convo.messages.append(user_msg)
convo.messages.append(Message(role=Role.ASSISTANT, content=[]))
return convo
async def _to_convo(self, post: Post) -> Conversation:
return self.build_convo(post)
async def _sample(self, convo: Conversation) -> str:
return await self.llm.sample(
convo.interleave(), conversation_id=convo.conversation_id
)
async def _parse(self, post: Post, output: str) -> list[ContentCategoryResult]:
match = self.result_pattern.search(output)
if match:
reasoning = match.group(1).strip()
logger.info(
f"Post Safety Screen reasoning for post {post.id} : {reasoning}"
)
result = PostSafetyScreenResult.model_validate_json(match.group(2).strip())
logger.info(
f"Post Safety Screen result for post {post.id} : {result}"
)
return [
ContentCategoryResult(
category=ContentCategoryType.POST_SAFETY_SCREEN,
positive=False,
score=0.0,
tweet_bool_metadata=result.tweet_bool_metadata,
)
]
else:
raise ValueError(f"Invalid output: {output}")

View File

@@ -0,0 +1,158 @@
import uuid
import json
import logging
import re
import json_repair
from grox.data_loaders.media_loader import MediaLoader
from grox.lm.convo import Role, Message, Conversation
from grox.lm.thread import ThreadRenderer
from grox.config.config import ModelName, grox_config
from grok_sampler.config import GrokModelConfig
from grox.prompts.template import ReplyScoringSystem
from grok_sampler.vision_sampler import VisionSampler
from grox.data_loaders.data_types import (
Post,
ReplyScoreResult,
)
from monitor.metrics import Metrics
from pydantic import ValidationError
from grox.data_loaders.strato_loader import TweetStratoLoader
logger = logging.getLogger(__name__)
class ReplyScorer:
model_name = "GROK"
def __init__(self):
vlm_config = grox_config.get_model(ModelName.VLM_MINI_CRITICAL)
vlm_config.temperature = 0.000001
self.vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
vlm_fallback_config = grox_config.get_model(
ModelName.VLM_PRIMARY_CRITICAL
)
vlm_fallback_config.temperature = 0.000001
self.vlm_fallback = VisionSampler(
GrokModelConfig(**vlm_fallback_config.model_dump())
)
async def score(self, post: Post) -> list[ReplyScoreResult]:
convo = await self._to_convo(post)
result = await self._sample(convo, post)
parsed = await self._parse(result)
Metrics.histogram(
"ranked_replies_scores",
explicit_bucket_boundaries_advisory=[0.0, 1.0, 2.0, 3.0],
).record(parsed[0].score)
return parsed
async def _to_convo(self, post: Post, non_reasoning: bool = False) -> Conversation:
convo = Conversation(conversation_id=uuid.uuid4().hex)
system_prompt = ReplyScoringSystem().render(
params={"large_account_follower_threshold": ""}
)
if non_reasoning:
system_prompt = "" + system_prompt
convo.messages.append(Message(role=Role.SYSTEM, content=[system_prompt]))
convo.messages.append(
ThreadRenderer.render(
post, role=Role.HUMAN, separator="\n\n", include_signals=True
)
)
if non_reasoning:
convo.messages.append(
Message(role=Role.ASSISTANT, content=[""], separator="")
)
else:
convo.messages.append(Message(role=Role.ASSISTANT, content=[]))
return convo
async def _sample(self, convo: Conversation, post: Post) -> str:
output = await self.vlm.sample(
convo.interleave(),
conversation_id=convo.conversation_id,
json_schema=json.dumps(ReplyScoreResult.model_json_schema()),
)
match = re.search(r"\{.*\}", output, re.DOTALL)
if not (match and "score" in match.group(0)):
fallback_convo = await self._to_convo(post, non_reasoning=True)
output = await self.vlm_fallback.sample(
fallback_convo.interleave(),
conversation_id=fallback_convo.conversation_id,
json_schema=json.dumps(ReplyScoreResult.model_json_schema()),
)
return output
async def _clean_output(self, output: str) -> str:
if output.endswith("<|eos|>"):
output = output.removesuffix("<|eos|>")
output = output.strip()
if output.startswith("```json"):
output = output[7:]
elif output.startswith("```"):
output = output[3:]
if output.endswith("```"):
output = output[:-3]
output = output.strip()
return output
async def _parse(self, output: str) -> list[ReplyScoreResult]:
score = None
reason = ""
match = re.search(r"\{.*\}", output, re.DOTALL)
if match and "score" in match.group(0):
raw_result = match.group(0).strip()
else:
raw_result = output
cleaned_result = await self._clean_output(raw_result)
try:
result = ReplyScoreResult.model_validate_json(cleaned_result)
score = result.score
reason = result.reason
except (ValidationError, ValueError):
try:
repaired = json_repair.repair_json(cleaned_result, return_objects=True)
if isinstance(repaired, dict) and "score" in repaired:
result = ReplyScoreResult.model_validate(repaired)
score = result.score
reason = result.reason
Metrics.counter("task.reply_ranker.json_repaired.count").add(1)
except Exception:
pass
if score is None:
score_match = re.search(r'"score":\s*([\d.]+)', cleaned_result)
if score_match:
try:
score = float(score_match.group(1).strip())
except ValueError:
score = None
Metrics.counter("task.reply_ranker.invalid.count").add(
1,
attributes={
"filter": "reply_ranking",
"reason": "invalid_score_format",
},
)
if not reason:
reason_match = re.search(
r'"reason":\s*"((?:[^"\\]|\\.)*)"', cleaned_result, re.DOTALL
)
if reason_match:
reason = reason_match.group(1)
if not score and score != 0:
logger.error(f"Invalid output format: {output}")
Metrics.counter("task.reply_ranker.invalid.count").add(
1, attributes={"filter": "reply_ranking", "reason": "invalid_format"}
)
raise ValueError(f"Invalid output: {output}")
return [ReplyScoreResult(score=score, reason=reason)]

View File

@@ -0,0 +1,288 @@
import logging
import re
import traceback
import uuid
from typing import List
from grox.data_loaders.media_loader import MediaLoader
from grox.data_loaders.strato_loader import TweetStratoLoader
from grox.lm.convo import Role, Message, Conversation
from grox.lm.post import PostRenderer
from grox.lm.user import UserRenderer
from grox.config.config import ModelName, grox_config
from grok_sampler.config import GrokModelConfig, EapiModelConfig
from grox.prompts.template import (
SafetyPtos,
ViolentMediaPolicy,
AdultContentPolicy,
SpamPolicy,
IllegalAndRegulatedBehaviorsPolicy,
HateOrAbusePolicy,
ViolentSpeechPolicy,
SuicideOrSelfHarmPolicy,
)
from grok_sampler.vision_sampler import VisionSampler
from grok_sampler.eapi_sampler import EapiSampler
from grox.data_loaders.data_types import (
Post,
SafetyPostAnnotations,
ContentCategoryResult,
ContentCategoryType,
SafetyPtosViolatedPolicy,
SafetyPolicy,
SafetyPolicyCategory,
)
from grox.classifiers.content.classifier import ContentClassifier
logger = logging.getLogger(__name__)
_THINKING_RESTRICTION_LINES = {
"",
"",
}
def _strip_thinking_restrictions(text: str) -> str:
lines = text.splitlines(keepends=True)
return "".join(
line for line in lines if line.strip() not in _THINKING_RESTRICTION_LINES
).lstrip("\n")
def _render_safety_ptos_for_reasoning() -> str:
return _strip_thinking_restrictions(SafetyPtos().render())
class SafetyPtosCategoryClassifier(ContentClassifier):
result_pattern = re.compile(r"(.*)<json>(.*)</json>", re.DOTALL)
def __init__(
self,
model_name: ModelName = ModelName.VLM_SAFETY,
deluxe: bool = False,
):
self.deluxe = deluxe
vlm_config = grox_config.get_model(model_name)
vlm_config.temperature = 0.000001
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
super().__init__(categories=[ContentCategoryType.SAFETY_PTOS], llm=vlm)
def build_convo(self, post: Post) -> Conversation:
convo = Conversation(conversation_id=uuid.uuid4().hex)
if self.deluxe:
convo.messages.append(
Message(role=Role.SYSTEM, content=[_render_safety_ptos_for_reasoning()])
)
else:
convo.messages.append(
Message(role=Role.SYSTEM, content=[SafetyPtos().render()])
)
user_msg = Message(role=Role.USER, content=[])
user_msg.content.extend(UserRenderer.render(post.user))
user_msg.content.extend(PostRenderer.render(post, include_reply_to=True))
user_msg.content.append(
f"\n\nAnalyze the post {post.id} and provide the requested JSON object for the post."
)
convo.messages.append(user_msg)
if self.deluxe:
convo.messages.append(Message(role=Role.ASSISTANT, content=[]))
else:
convo.messages.append(
Message(role=Role.ASSISTANT, content=[""], separator="")
)
return convo
async def classify_post(self, post: Post) -> SafetyPostAnnotations:
convo = await self._to_convo(post)
result = await self._sample(convo, post)
mode = "deluxe" if self.deluxe else "standard"
logger.info(
f"safety ptos category classifier ({mode}) result for {post.id}: {result}"
)
match = self.result_pattern.search(result)
if match:
json_str = match.group(2).strip()
return SafetyPostAnnotations.model_validate_json(json_str)
else:
raise ValueError(
f"Invalid output for safety ptos category classifier ({mode}): {result}"
)
async def _to_convo(self, post: Post) -> Conversation:
return self.build_convo(post)
async def _sample(self, convo: Conversation, post: Post = None) -> str:
return await self.llm.sample(
convo.interleave(), conversation_id=convo.conversation_id
)
async def _parse(self, post: Post, output: str) -> List[ContentCategoryResult]:
match = self.result_pattern.search(output)
if match:
return [
ContentCategoryResult(
category=ContentCategoryType.SAFETY_PTOS, positive=True, score=0.0
)
]
else:
mode = "deluxe" if self.deluxe else "standard"
raise ValueError(
f"Invalid parsing for safety ptos category classifier ({mode}): {output}"
)
class SafetyPtosPolicyClassifier(ContentClassifier):
result_pattern = re.compile(r"(.*)<json>(.*)</json>", re.DOTALL)
def __init__(self, deluxe: bool = False):
self.deluxe = deluxe
vlm_config = grox_config.get_model(ModelName.VLM_PRIMARY_CRITICAL)
vlm_config.temperature = 0.000001
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
super().__init__(categories=[ContentCategoryType.SAFETY_PTOS], llm=vlm)
if deluxe:
eapi_config_reasoning = grox_config.get_eapi_model(
ModelName.EAPI_REASONING_INTERNAL
)
self.eapi_reasoning = EapiSampler(
EapiModelConfig(**eapi_config_reasoning.model_dump())
)
eapi_config_reasoning_x_algo = grox_config.get_eapi_model(
ModelName.EAPI_REASONING
)
self.eapi_reasoning_x_algo = EapiSampler(
EapiModelConfig(**eapi_config_reasoning_x_algo.model_dump())
)
@staticmethod
def _get_policy_prompt(violation: SafetyPtosViolatedPolicy) -> str:
if violation.category == SafetyPolicyCategory.ViolentMedia:
return ViolentMediaPolicy().render()
elif violation.category == SafetyPolicyCategory.AdultContent:
return AdultContentPolicy().render()
elif violation.category == SafetyPolicyCategory.Spam:
return SpamPolicy().render()
elif violation.category == SafetyPolicyCategory.IllegalAndRegulatedBehaviors:
return IllegalAndRegulatedBehaviorsPolicy().render()
elif violation.category == SafetyPolicyCategory.HateOrAbuse:
return HateOrAbusePolicy().render()
elif violation.category == SafetyPolicyCategory.ViolentSpeech:
return ViolentSpeechPolicy().render()
elif violation.category == SafetyPolicyCategory.SuicideOrSelfHarm:
return SuicideOrSelfHarmPolicy().render()
else:
raise ValueError(
f"No policy prompt available for category: {violation.category.value}"
)
def build_convo(
self, post: Post, violation: SafetyPtosViolatedPolicy
) -> Conversation:
content = self._get_policy_prompt(violation)
if self.deluxe:
content = _strip_thinking_restrictions(content)
convo = Conversation(conversation_id=uuid.uuid4().hex)
convo.messages.append(Message(role=Role.SYSTEM, content=[content]))
user_msg = Message(role=Role.USER, content=[])
user_msg.content.extend(UserRenderer.render(post.user))
user_msg.content.extend(PostRenderer.render(post, include_reply_to=True))
user_msg.content.append(
f"\n\nAnalyze the post {post.id} for the specific safety policy violation category: {violation.category.value}"
)
user_msg.content.append(
f"\n\nProvide the requested JSON object for the specific safety policy type."
)
convo.messages.append(user_msg)
if self.deluxe:
convo.messages.append(Message(role=Role.ASSISTANT, content=[]))
else:
convo.messages.append(
Message(role=Role.ASSISTANT, content=[""], separator="")
)
return convo
SUPPORTED_POLICY_CATEGORIES = {
SafetyPolicyCategory.ViolentMedia,
SafetyPolicyCategory.AdultContent,
SafetyPolicyCategory.Spam,
SafetyPolicyCategory.IllegalAndRegulatedBehaviors,
SafetyPolicyCategory.HateOrAbuse,
SafetyPolicyCategory.ViolentSpeech,
SafetyPolicyCategory.SuicideOrSelfHarm,
}
DELUXE_4_2_CATEGORIES = {
SafetyPolicyCategory.AdultContent,
SafetyPolicyCategory.ViolentMedia,
}
async def classify_policy_for_violation(
self, post: Post, violation: SafetyPtosViolatedPolicy
) -> SafetyPolicy | None:
if violation.category not in self.SUPPORTED_POLICY_CATEGORIES:
return None
convo = await self._to_convo(post, violation)
if (
and self.deluxe
and violation.category in self.DELUXE_4_2_CATEGORIES
):
mode = "deluxe-4.2"
result = await self._sample_4_2(convo, post)
else:
mode = "deluxe" if self.deluxe else "standard"
result = await self._sample(convo, post)
logger.info(
f"safety ptos policy classifier ({mode}) result for post {post.id}, violation {violation.category}: {result}"
)
match = self.result_pattern.search(result)
if match:
json_str = match.group(2).strip()
return SafetyPolicy.model_validate_json(json_str)
else:
raise ValueError(
f"Invalid output for safety ptos policy ({mode}): {result}"
)
async def _to_convo(
self, post: Post, violation: SafetyPtosViolatedPolicy
) -> Conversation:
return self.build_convo(post, violation)
async def _sample_4_2(self, convo: Conversation, post: Post) -> str:
try:
return await self.eapi_reasoning_x_algo.sample(
convo.interleaveToEapi(), conversation_id=convo.conversation_id
)
except Exception:
logger.error(
f"Failed to call 4.2 reasoning, error: {traceback.format_exc()}"
)
return await self.llm.sample(
convo.interleave(), conversation_id=convo.conversation_id
)
async def _sample(self, convo: Conversation, post: Post = None) -> str:
return await self.llm.sample(
convo.interleave(), conversation_id=convo.conversation_id
)
async def _parse(self, post: Post, output: str) -> List[ContentCategoryResult]:
return []

View File

@@ -0,0 +1,104 @@
import logging
import re
import uuid
from pydantic import ValidationError
from grox.lm.convo import Role, Message, Conversation
from grox.lm.thread import ThreadRenderer
from grox.config.config import ModelName, grox_config
from grok_sampler.config import GrokModelConfig
from grox.prompts.template import SpamSystemLowFollower
from grok_sampler.vision_sampler import VisionSampler
from grox.data_loaders.data_types import (
Post,
ContentCategoryType,
ContentCategoryResult,
SpamSampleResult,
)
from grox.classifiers.content.classifier import ContentClassifier
from grox.data_loaders.strato_loader import TweetStratoLoader
from grox.data_loaders.media_loader import MediaLoader
logger = logging.getLogger(__name__)
class SpamEapiLowFollowerClassifier(ContentClassifier):
def __init__(self, model_name: ModelName = ModelName.VLM_PRIMARY):
vlm_config = grox_config.get_model(model_name)
vlm_config.temperature = 0.000001
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
super().__init__(categories=[ContentCategoryType.SPAM_COMMENT], llm=vlm)
@property
def model_name(self) -> str:
return "grox"
async def _classify(self, post: Post) -> list[ContentCategoryResult]:
convo = await self._to_convo(post)
result = await self._sample(convo)
parsed = await self._parse(post, result)
filtered_parsed = [
res for res in parsed if res.category == ContentCategoryType.SPAM_COMMENT
]
assert len(filtered_parsed) == 1
return filtered_parsed
async def _to_convo(self, post: Post) -> Conversation:
convo = Conversation(conversation_id=uuid.uuid4().hex)
convo.messages.append(
Message(role=Role.SYSTEM, content=[SpamSystemLowFollower().render()])
)
convo.messages.append(
ThreadRenderer.render(post, role=Role.HUMAN, separator="\n\n")
)
return convo
async def _sample(self, convo: Conversation) -> str:
return await self.llm.sample(
convo.interleave(), conversation_id=convo.conversation_id
)
async def _clean_output(self, output: str) -> str:
if output.endswith("<|eos|>"):
output = output.removesuffix("<|eos|>")
output = output.strip()
if output.startswith("```json"):
output = output[7:]
elif output.startswith("```"):
output = output[3:]
if output.endswith("```"):
output = output[:-3]
output = output.strip()
return output
async def _parse(self, post: Post, output: str) -> list[ContentCategoryResult]:
decision = None
summary = ""
cleaned_result = await self._clean_output(output)
try:
result = SpamSampleResult.model_validate_json(cleaned_result)
decision = result.decision
summary = result.reason
except ValidationError:
match = re.search(r'"decision":\s*"(.*?)"', cleaned_result)
if match:
decision = match.group(1).strip()
if not decision:
raise ValueError(f"Invalid output format: {output}")
is_spam = decision == "spam"
score = 1.0 if is_spam else 0.0
if is_spam:
logger.info(f"Spam found for low follower user: {post.id}")
return [
ContentCategoryResult(
category=ContentCategoryType.SPAM_COMMENT,
positive=is_spam,
score=score,
summary=summary,
)
]

View File

@@ -0,0 +1,394 @@
import asyncio
import base64
import logging
import os
import subprocess
import tempfile
import time
import traceback
from multiprocessing import Event, Process, Queue
from multiprocessing.synchronize import Event as MultiprocessingEvent
from queue import Empty
import aiohttp
from cachetools import TTLCache
from pydantic import BaseModel
from grox.config.config import grox_config
from grox.schedules.init import init_proc
from monitor.logging import Logging
from monitor.metrics import Metrics
logger = logging.getLogger(__name__)
class _ASRRequest(BaseModel):
post_id: str
video_url: str
max_audio_duration_s: float
class _ASRResult(BaseModel):
post_id: str
transcript: str | None = None
error: str | None = None
def _extract_wav_from_url(
video_url: str, max_duration_s: float | None = None
) -> bytes | None:
with tempfile.TemporaryDirectory() as tmpdir:
wav_path = os.path.join(tmpdir, "audio.wav")
cmd = [
"ffmpeg",
"-y",
"-timeout",
"60000000",
"-rw_timeout",
"60000000",
"-reconnect",
"1",
"-reconnect_streamed",
"1",
"-reconnect_delay_max",
"5",
"-i",
video_url,
"-vn",
"-acodec",
"pcm_s16le",
"-ar",
"16000",
"-ac",
"1",
]
if max_duration_s is not None and max_duration_s > 0:
cmd += ["-t", str(max_duration_s)]
cmd.append(wav_path)
result = subprocess.run(cmd, capture_output=True, timeout=180)
if result.returncode != 0:
if not os.path.exists(wav_path):
return None
raise subprocess.CalledProcessError(
result.returncode, cmd, result.stdout, result.stderr
)
with open(wav_path, "rb") as f:
return f.read()
def _clean_asr(raw: str) -> str:
if "<asr_text>" in raw:
raw = raw.split("<asr_text>", 1)[1]
if "</asr_text>" in raw:
raw = raw.split("</asr_text>", 1)[0]
return raw.strip()
class _ASRWorker:
def __init__(
self, task_queue: Queue, resp_queue: Queue, shutdown_event: MultiprocessingEvent
):
self._task_queue: Queue[tuple[_ASRRequest, dict[str, str]]] = task_queue
self._resp_queue: Queue[_ASRResult] = resp_queue
self._shutdown_event: MultiprocessingEvent = shutdown_event
async def _transcribe(self, request: _ASRRequest) -> str | None:
asr_config = grox_config.asr
t_start = time.monotonic()
loop = asyncio.get_event_loop()
wav_bytes = await loop.run_in_executor(
None, _extract_wav_from_url, request.video_url, request.max_audio_duration_s
)
t_extract = time.monotonic() - t_start
Metrics.histogram("asr_proc.extract_duration_s").record(t_extract)
if wav_bytes is None:
return None
Metrics.histogram("asr_proc.audio_bytes").record(len(wav_bytes))
logger.debug(
f"Extracted audio in {t_extract:.2f}s, size={len(wav_bytes)} bytes"
)
t_start = time.monotonic()
b64_audio = base64.b64encode(wav_bytes).decode()
body = {
"model": "default",
"messages": [
{
"role": "user",
"content": [
{
"type": "audio_url",
"audio_url": {"url": f"data:audio/wav;base64,{b64_audio}"},
},
{"type": "text", "text": "Transcribe this audio."},
],
}
],
"temperature": asr_config.temperature,
"max_tokens": asr_config.max_tokens,
}
async with self._session.post(
f"{asr_config.endpoint}/v1/chat/completions",
json=body,
timeout=aiohttp.ClientTimeout(total=asr_config.timeout),
) as resp:
if resp.status == 200:
data = await resp.json()
raw_transcript = data["choices"][0]["message"]["content"].strip()
transcript = _clean_asr(raw_transcript)
t_transcribe = time.monotonic() - t_start
Metrics.histogram("asr_proc.transcribe_duration_s").record(t_transcribe)
Metrics.histogram("asr_proc.transcript_chars").record(len(transcript))
if "usage" in data:
usage = data["usage"]
if "prompt_tokens" in usage:
Metrics.histogram("asr_proc.prompt_tokens").record(
usage["prompt_tokens"]
)
if "completion_tokens" in usage:
Metrics.histogram("asr_proc.completion_tokens").record(
usage["completion_tokens"]
)
if "total_tokens" in usage:
Metrics.histogram("asr_proc.total_tokens").record(
usage["total_tokens"]
)
return transcript
else:
error_text = await resp.text()
raise Exception(
f"ASR request failed with status {resp.status}: {error_text}"
)
async def _process(self, request: _ASRRequest, ctx: dict[str, str]) -> None:
attributes = {"pid": str(os.getpid())}
with Metrics.tracer("asr_proc").start_as_current_span("asr.process"):
Logging.set_context(**ctx)
start = time.perf_counter()
try:
Metrics.counter("asr_proc.total.count").add(1, attributes=attributes)
transcript = await self._transcribe(request)
if transcript is None:
logger.debug(
f"Video has no audio stream for post {request.post_id}, skipping ASR"
)
Metrics.counter("asr_proc.skip.count").add(
1, attributes={**attributes, "reason": "no_audio_stream"}
)
self._resp_queue.put(
_ASRResult(post_id=request.post_id, error="no_audio_stream")
)
else:
Metrics.counter("asr_proc.success.count").add(
1, attributes=attributes
)
self._resp_queue.put(
_ASRResult(post_id=request.post_id, transcript=transcript)
)
except subprocess.TimeoutExpired:
logger.warning(
f"FFmpeg timeout extracting audio for post {request.post_id}"
)
Metrics.counter("asr_proc.error.count").add(
1, attributes={**attributes, "reason": "ffmpeg_timeout"}
)
self._resp_queue.put(
_ASRResult(post_id=request.post_id, error="ffmpeg_timeout")
)
except subprocess.CalledProcessError as e:
error_msg = e.stderr.decode() if e.stderr else str(e)
logger.warning(f"FFmpeg error for post {request.post_id}: {error_msg}")
Metrics.counter("asr_proc.error.count").add(
1, attributes={**attributes, "reason": "ffmpeg_error"}
)
self._resp_queue.put(
_ASRResult(
post_id=request.post_id, error=f"ffmpeg_error: {error_msg}"
)
)
except asyncio.TimeoutError:
logger.warning(f"ASR request timed out for post {request.post_id}")
Metrics.counter("asr_proc.error.count").add(
1, attributes={**attributes, "reason": "asr_timeout"}
)
self._resp_queue.put(
_ASRResult(post_id=request.post_id, error="asr_timeout")
)
except Exception as e:
logger.error(
f"ASR processing failed for post {request.post_id}: {traceback.format_exc()}"
)
Metrics.counter("asr_proc.error.count").add(
1, attributes={**attributes, "reason": "unknown"}
)
self._resp_queue.put(_ASRResult(post_id=request.post_id, error=str(e)))
finally:
end = time.perf_counter()
Metrics.histogram("asr_proc.duration").record(end - start)
async def _init_run(self) -> None:
await init_proc("asr_proc")
self._session = aiohttp.ClientSession()
async def _run(self) -> None:
logger.info("starting ASR worker process loop")
pending: set[asyncio.Task] = set()
while not self._is_shutdown() or not self._task_queue.empty():
try:
request, ctx = self._task_queue.get(block=False)
except Empty:
await asyncio.sleep(0.01)
continue
try:
task = asyncio.create_task(self._process(request, ctx))
pending.add(task)
task.add_done_callback(pending.discard)
except Exception:
logger.error(
f"error processing ASR request {request.post_id}: {traceback.format_exc()}"
)
if pending:
logger.info(f"ASR worker draining {len(pending)} in-flight tasks")
await asyncio.gather(*pending, return_exceptions=True)
logger.warning("ASR worker process loop done")
def run(self) -> None:
async def wrapper():
await self._init_run()
try:
await self._run()
finally:
await self._session.close()
asyncio.run(wrapper())
def _start_loop(self) -> Process:
process = Process(target=self.run)
process.start()
return process
def start(self) -> list[Process]:
return [self._start_loop() for _ in range(grox_config.asr.max_workers)]
def _is_shutdown(self) -> bool:
try:
return self._shutdown_event.is_set()
except BrokenPipeError:
logger.error("Broken pipe error, assuming shutdown")
return True
except Exception:
logger.error(
f"Error checking shutdown event, assuming shutdown: {traceback.format_exc()}"
)
return True
class ASRProcessor:
_task_queue: Queue = Queue()
_resp_queue: Queue = Queue()
_shutdown_event = Event()
_inflights: dict[str, asyncio.Future[str | None]] = {}
_workers: list[Process] = []
_initialized = False
_result_task: asyncio.Task | None = None
_cache: TTLCache = TTLCache(maxsize=1_000, ttl=300)
@classmethod
async def process(
cls, post_id: str, video_url: str, max_audio_duration_s: float | None = None
) -> str | None:
if not cls._initialized:
raise RuntimeError("ASR processor not initialized")
cached = cls._cache.get(post_id)
if cached is not None:
Metrics.counter("asr_proc.cache_hit.count").add(1)
return cached
if max_audio_duration_s is None:
max_audio_duration_s = grox_config.asr.max_audio_duration_s
future = cls._submit(post_id, video_url, max_audio_duration_s)
transcript = await future
if transcript is not None:
cls._cache[post_id] = transcript
return transcript
@classmethod
def _submit(
cls, post_id: str, video_url: str, max_audio_duration_s: float
) -> asyncio.Future[str | None]:
if post_id in cls._inflights:
return cls._inflights[post_id]
request = _ASRRequest(
post_id=post_id,
video_url=video_url,
max_audio_duration_s=max_audio_duration_s,
)
cls._task_queue.put((request, Logging.get_context()))
future: asyncio.Future[str | None] = asyncio.get_running_loop().create_future()
cls._inflights[post_id] = future
return future
@classmethod
async def _result_loop(cls) -> None:
logger.info("ASR processor result loop started")
while not cls._shutdown_event.is_set() or cls._inflights:
try:
result: _ASRResult = cls._resp_queue.get(block=False)
future = cls._inflights.pop(result.post_id, None)
if not future:
logger.warning(f"no future found for post {result.post_id}")
continue
if result.error:
if result.error == "no_audio_stream":
logger.debug(
f"ASR skipped for post {result.post_id}: no audio stream"
)
else:
logger.warning(
f"ASR failed for post {result.post_id}: {result.error}"
)
future.set_result(None)
else:
future.set_result(result.transcript)
except Empty:
await asyncio.sleep(0.01)
except Exception:
logger.error(f"Error processing ASR result: {traceback.format_exc()}")
logger.warning("ASR processor result loop done")
@classmethod
def start(cls) -> None:
if cls._initialized:
logger.warning("ASR processor already initialized")
return
logger.info(
f"starting ASR processor with {grox_config.asr.max_workers} workers"
)
cls._workers = _ASRWorker(
cls._task_queue, cls._resp_queue, cls._shutdown_event
).start()
cls._result_task = asyncio.create_task(cls._result_loop())
cls._initialized = True
@classmethod
async def stop(cls, timeout: float = 5) -> None:
logger.warning("stopping ASR processor")
cls._shutdown_event.set()
for worker in cls._workers:
if worker.is_alive():
worker.join(timeout)
if cls._result_task and not cls._result_task.done():
cls._result_task.cancel()
logger.warning("ASR processor stopped")

View File

@@ -0,0 +1,232 @@
import struct
import time
import uuid
import asyncio
import logging
import traceback
from abc import abstractmethod
from typing import override
from collections.abc import AsyncGenerator
from concurrent.futures import ThreadPoolExecutor
from aiokafka import TopicPartition
from kafka_cli.config import KafkaMessage
from grox.config.config import KafkaTopicName, grox_config
from kafka_cli.consumer import KafkaConsumer
from grox.data_loaders.data_types import Post, User, GroxContentAnalysis
from grox.data_loaders.message_queue_loader import (
MessageQueueLoader,
MessageQueuePayload,
)
from monitor.metrics import Metrics
from thrifts.serdes import SerDesError
logger = logging.getLogger(__name__)
MAX_WORKING_THREADS = 12
class _Payload(MessageQueuePayload):
tp: TopicPartition
offset: int
class KafkaLoader(MessageQueueLoader):
def __init__(self, topic_name: KafkaTopicName):
super().__init__()
self._initialized = False
self._shutdown_event = asyncio.Event()
self.topic_name = topic_name
self.loader_config = grox_config.get_kafka_loader_topic(topic_name)
self.consumer_config = grox_config.get_kafka_consumer_topic(topic_name)
self.consumer = KafkaConsumer(self.consumer_config)
self.loaded_messages: dict[str, tuple[TopicPartition, int]] = {}
self.queue: asyncio.Queue[MessageQueuePayload] = asyncio.Queue()
self._prefetcher_task: asyncio.Task | None = None
def _is_shutdown(self) -> bool:
try:
return self._shutdown_event.is_set()
except Exception:
logger.error(
f"Error checking if KafkaLoader is shutdown: {traceback.format_exc()}"
)
return True
async def start(self):
logger.info(f"Initializing KafkaLoader, topic: {self.topic_name}")
self._initialized = True
await self.consumer.start()
self._prefetcher_task = asyncio.create_task(self._prefetcher())
self._initialized = True
logger.info(f"KafkaLoader initialized, topic: {self.topic_name}")
async def stop(self):
logger.warning(f"Stopping KafkaLoader, topic: {self.topic_name}")
self._shutdown_event.set()
try:
if self._prefetcher_task:
await asyncio.wait_for(self._prefetcher_task, 5)
except asyncio.TimeoutError:
logger.warning(
f"Waiting prefetcher to stop timed out, topic: {self.topic_name}"
)
await self.consumer.stop()
logger.warning(f"KafkaLoader stopped, topic: {self.topic_name}")
async def poll(self) -> AsyncGenerator[MessageQueuePayload | None, None]:
while not self._shutdown_event.is_set() or not self.queue.empty():
try:
yield self.queue.get_nowait()
except asyncio.QueueEmpty:
logger.debug(
f"Queue is empty, waiting for prefetcher to fill, topic: {self.topic_name}"
)
yield None
except Exception:
logger.error(
f"Error polling from kafka, topic: {self.topic_name}, error: {traceback.format_exc()}"
)
yield None
async def ack(self, mid: str, success: bool = True):
pass
@abstractmethod
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
pass
def _process_messages(self, messages: list[KafkaMessage]) -> list[_Payload]:
group_size = max(1, len(messages) // MAX_WORKING_THREADS)
message_groups = [
messages[i : i + group_size] for i in range(0, len(messages), group_size)
]
with ThreadPoolExecutor(max_workers=MAX_WORKING_THREADS) as executor:
payloads = []
for result in executor.map(self._messages_to_payloads, message_groups):
payloads.extend(result)
return payloads
async def _prefetcher(self) -> None:
logger.info(f"Starting prefetcher, topic: {self.topic_name}")
prefetching_threshold = self.loader_config.prefetching_threshold
prefetching_batch_size = self.loader_config.prefetching_batch_size
while not self._is_shutdown():
if self.queue.qsize() < prefetching_threshold:
logger.debug(
f"Inventory low at {self.queue.qsize()}, prefetching {prefetching_batch_size} messages, topic: {self.topic_name}"
)
try:
messages = await self.consumer.poll(prefetching_batch_size)
try:
payloads = self._process_messages(messages)
except SerDesError:
logger.error(
f"Error processing messages, error: {traceback.format_exc()}"
)
raise
await asyncio.gather(
*[
self.queue.put(
MessageQueuePayload(
mid=payload.mid,
user=payload.user,
post=payload.post,
user_context=payload.user_context,
grox_content_analysis=payload.grox_content_analysis,
deadline_ts_secs=payload.deadline_ts_secs,
)
)
for payload in payloads
]
)
logger.debug(
f"Prefetched {prefetching_batch_size} messages, inventory now at {self.queue.qsize()}, topic: {self.topic_name}"
)
except Exception:
logger.error(
f"Error prefetching messages, error: {traceback.format_exc()}"
)
await asyncio.sleep(0.1)
else:
await asyncio.sleep(0.1)
logger.warning("Prefetcher stopped")
class KafkaPostLoader(KafkaLoader):
def __init__(self, topic_name: KafkaTopicName):
super().__init__(topic_name)
@override
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
return [
_Payload(
mid=uuid.uuid4().hex,
post=Post.from_thrift_content_understanding_metadata(message.value),
tp=message.tp,
offset=message.offset,
deadline_ts_secs=int(time.time())
+ self.loader_config.task_deadline_secs,
)
for message in messages
]
class KafkaPostEmbeddingRequestLoader(KafkaLoader):
def __init__(self, topic_name: KafkaTopicName):
super().__init__(topic_name)
@override
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
return [
_Payload(
mid=uuid.uuid4().hex,
post=Post.from_thrift_post_embedding_request(message.value),
tp=message.tp,
offset=message.offset,
deadline_ts_secs=int(time.time())
+ self.loader_config.task_deadline_secs,
)
for message in messages
]
class KafkaGroxContentAnalysisLoader(KafkaLoader):
def __init__(self, topic_name: KafkaTopicName):
super().__init__(topic_name)
@override
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
return [
_Payload(
mid=uuid.uuid4().hex,
grox_content_analysis=GroxContentAnalysis.from_thrift_content_understanding_metadata(
message.value
),
tp=message.tp,
offset=message.offset,
deadline_ts_secs=int(time.time())
+ self.loader_config.task_deadline_secs,
)
for message in messages
]
class KafkaTweetEmbeddingLoader(KafkaLoader):
def __init__(self, topic_name: KafkaTopicName):
super().__init__(topic_name)
@override
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
return [
_Payload(
mid=uuid.uuid4().hex,
post=Post.from_thrift_tweet_embedding(message.value),
tp=message.tp,
offset=message.offset,
deadline_ts_secs=int(time.time())
+ self.loader_config.task_deadline_secs,
)
for message in messages
]

View File

@@ -0,0 +1,39 @@
import logging
from abc import ABC, abstractmethod
from grox.data_loaders.data_types import Post, User, UserContext, GroxContentAnalysis
from collections.abc import AsyncGenerator
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class MessageQueuePayload(BaseModel):
mid: str
post: Post | None = None
user: User | None = None
user_context: UserContext | None = None
grox_content_analysis: GroxContentAnalysis | None = None
deadline_ts_secs: int
class MessageQueueLoader(ABC):
def __init__(self):
pass
@abstractmethod
async def start(self):
pass
@abstractmethod
async def stop(self):
pass
@abstractmethod
def poll(self) -> AsyncGenerator[MessageQueuePayload | None, None]:
pass
@abstractmethod
async def ack(self, mid: str, success: bool = True):
pass

View File

@@ -0,0 +1,154 @@
import asyncio
import logging
from grox.data_loaders.data_types import Post, User
from strato_http.queries.data_types import (
ReplyRankingScore,
ReplyRankingScoreKafka,
)
from strato_http.queries.content_understanding_author_metadata import (
StratoContentUnderstandingAuthorMetadata,
)
from strato_http.queries.content_understanding_post_quote_metadata import (
StratoContentUnderstandingPostQuoteMetadata,
)
from strato_http.queries.content_understanding_metadata_v2 import (
StratoContentUnderstandingMetadataV2,
)
from strato_http.queries.reply_ranking_score import StratoReplyRankingScore
from strato_http.queries.reply_spam_annotation import StratoReplySpamAnnotation
from strato_http.queries.reply_ranking_score_kafka_v2 import (
StratoReplyRankingScoreV2Kafka,
)
from strato_http.queries.safety_label import StratoSafetyLabel
from strato_http.queries.user_recent_posts import StratoUserRecentPosts
from grox.data_loaders.mappers.post_mapper import PostMapper
logger = logging.getLogger(__name__)
class TweetStratoLoader:
content_understanding_metadata_strato = StratoContentUnderstandingMetadataV2()
content_understanding_post_quote_metadata_strato = (
StratoContentUnderstandingPostQuoteMetadata()
)
@classmethod
async def load_post(
cls, tweet_id: str, include_ancestors: bool = True
) -> Post | None:
if include_ancestors:
content_understanding_metadata = (
await cls.content_understanding_metadata_strato.fetch(int(tweet_id))
)
if content_understanding_metadata:
post = PostMapper.from_strato_content_understanding_metadata(
content_understanding_metadata
)
return post
else:
post_with_quote_metadata = (
await cls.content_understanding_post_quote_metadata_strato.fetch(
int(tweet_id)
)
)
if post_with_quote_metadata:
post = PostMapper.from_strato_post_with_quote_metadata(
post_with_quote_metadata
)
return post
return None
class UserStratoLoader:
strato = StratoContentUnderstandingAuthorMetadata()
@classmethod
async def load_user(cls, user_id: int) -> User | None:
strato_user = await cls.strato.fetch(user_id)
if not strato_user:
logger.warning(f"failed to hydrate user with {user_id=}, not found")
return None
return PostMapper._from_strato_user_metadata_to_user(strato_user)
class ReplyRankingScoreStratoLoader:
strato = StratoReplyRankingScore()
reply_ranking_v2_kafka_strato = StratoReplyRankingScoreV2Kafka()
@classmethod
async def save_reply_ranking_score(
cls, post_id: str, reply_ranking_score: ReplyRankingScore
):
await cls.strato.put(int(post_id), reply_ranking_score)
@classmethod
async def save_reply_ranking_kafka_v2(
cls, post_id: str, reply_ranking_score_kafka: ReplyRankingScoreKafka
):
await cls.reply_ranking_v2_kafka_strato.insert(
int(post_id), reply_ranking_score_kafka
)
class ReplySpamStratoLoader:
strato = StratoReplySpamAnnotation()
@classmethod
async def save_spam_reply_annotation(
cls, post_id: str, score: float, positive: bool, reason: str
):
await cls.strato.put(int(post_id), score, positive, reason)
class UserRecentPostsLoader:
recent_posts_strato = StratoUserRecentPosts()
post_hydrator = StratoContentUnderstandingPostQuoteMetadata()
safety_label = StratoSafetyLabel()
@classmethod
async def load(cls, user_id: int, limit: int = 10) -> list[Post]:
res = await cls.recent_posts_strato.fetch(
user_id, limit=limit, max_per_type=limit
)
if not res or "v" not in res:
logger.warning(f"No recent posts found for {user_id=}")
return []
post_ids: list[int] = []
for _post_type, posts in res["v"]:
for post in posts:
if _post_type == "TypeRetweet":
if "inReactionToPostId" in post:
post_ids.append(post["inReactionToPostId"])
else:
if "postId" in post:
post_ids.append(post["postId"])
if not post_ids:
return []
tasks = [cls.post_hydrator.fetch(post_id) for post_id in post_ids]
results = await asyncio.gather(*tasks, return_exceptions=True)
hydrated: list[Post] = []
for post_id, result in zip(post_ids, results):
if isinstance(result, Exception):
logger.warning(
f"Failed to hydrate recent post {post_id} for {user_id=}: {result}"
)
continue
if result is None:
continue
try:
hydrated.append(PostMapper.from_strato_post_with_quote_metadata(result))
except Exception:
logger.warning(
f"Failed to map recent post {post_id} for {user_id=}", exc_info=True
)
for post in hydrated:
post.safety_labels = await cls.safety_label.scan(post.id)
return hydrated

370
grox/dispatcher.py Normal file
View File

@@ -0,0 +1,370 @@
import asyncio
import logging
import traceback
import multiprocessing
from queue import Empty, Queue
from threading import Event
from multiprocessing import Process
from tenacity import retry, wait_incrementing, stop_after_attempt
from monitor.metrics import Metrics
from grox.config.config import TaskGeneratorType, grox_config
from grox.schedules.init import init_proc
from grox.schedules.types import TaskResult, TaskPayload
from grox.schedules.context import ScheduleContext
from grox.generators.task_generator import TaskGenerator, PriorityTaskGenerator
from grox.generators.stream_generator import (
PostStreamTaskGenerator,
PostStreamRecoveryTaskGenerator,
PostStreamTestTaskGenerator,
PostStreamDelayedTaskGenerator,
PostSafetyStreamTaskGenerator,
ReplyRankingRecoveryTaskGenerator,
PostEmbeddingRequestWithSummaryStreamTaskGenerator,
PostEmbeddingRequestWithSummaryRecoveryStreamTaskGenerator,
MinTractionPostStreamForGroxTaskGenerator,
MinTractionPostStreamForGroxPtosTaskGenerator,
PostEmbeddingV5StreamTaskGenerator,
PostEmbeddingV5ForReplyStreamTaskGenerator,
MinTractionPostStreamForGroxMultiModalTaskGenerator,
PostEmbeddingRequestWithSummaryForReplyRecoveryStreamTaskGenerator,
SafetyPtosRecoveryStreamTaskGenerator,
SafetyPtosDeluxeStreamTaskGenerator,
)
logger = logging.getLogger(__name__)
class Dispatcher:
def __init__(self, context: ScheduleContext):
self.config = grox_config.dispatcher
self.context = context
self._task_queue: Queue[TaskPayload] = self.context["task_queue"]
self._resp_queue: Queue[TaskResult] = self.context["resp_queue"]
self._shutdown_event: Event = self.context["shutdown_event"]
self._queue_connection_shutdown_event: Event = self.context[
"queue_connection_shutdown_event"
]
self._process = None
def _is_shutdown(self) -> bool:
try:
return self._shutdown_event.is_set()
except BrokenPipeError:
logger.error("Broken pipe error, assuming shutdown")
return True
except Exception:
logger.error(
f"Error checking shutdown event, assuming shutdown: {traceback.format_exc()}"
)
return True
def _is_queue_connection_shutdown(self) -> bool:
try:
return self._queue_connection_shutdown_event.is_set()
except BrokenPipeError:
logger.error("Broken pipe error, assuming queue connection shutdown")
return True
except Exception:
logger.error(
f"Error checking shutdown event, assuming queue connection shutdown: {traceback.format_exc()}"
)
return True
@retry(
stop=stop_after_attempt(3), wait=wait_incrementing(start=1, increment=3, max=9)
)
async def _init_run(self):
await init_proc("dispatcher")
self._in_flights: set[str] = set()
self._task_generator = self._get_task_generators()
await self._task_generator.start()
def _get_task_generators(self) -> TaskGenerator:
generators: list[tuple[TaskGenerator, int]] = []
for task_generator_config in self.config.task_generators:
match task_generator_config.type:
case TaskGeneratorType.POST_STREAM:
generators.append(
(
PostStreamTaskGenerator(task_generator_config.max_qps),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_STREAM_RECOVERY:
generators.append(
(
PostStreamRecoveryTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_STREAM_TEST:
generators.append(
(
PostStreamTestTaskGenerator(task_generator_config.max_qps),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_STREAM_DELAYED:
generators.append(
(
PostStreamDelayedTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_SAFETY_STREAM:
generators.append(
(
PostSafetyStreamTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX:
generators.append(
(
MinTractionPostStreamForGroxTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX_PTOS:
generators.append(
(
MinTractionPostStreamForGroxPtosTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX_MULTI_MODAL:
generators.append(
(
MinTractionPostStreamForGroxMultiModalTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY_FOR_REPLY_RECOVERY:
generators.append(
(
PostEmbeddingRequestWithSummaryForReplyRecoveryStreamTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.SAFETY_PTOS_RECOVERY:
generators.append(
(
SafetyPtosRecoveryStreamTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.SAFETY_PTOS_DELUXE:
generators.append(
(
SafetyPtosDeluxeStreamTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_EMBEDDING_V5_STREAM:
generators.append(
(
PostEmbeddingV5StreamTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_EMBEDDING_V5_FOR_REPLY_STREAM:
generators.append(
(
PostEmbeddingV5ForReplyStreamTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.REPLY_RANKING_RECOVERY:
generators.append(
(
ReplyRankingRecoveryTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY:
generators.append(
(
PostEmbeddingRequestWithSummaryStreamTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY_RECOVERY:
generators.append(
(
PostEmbeddingRequestWithSummaryRecoveryStreamTaskGenerator(
task_generator_config.max_qps
),
task_generator_config.weight,
)
)
case _:
raise ValueError(
f"Invalid task generator type: {task_generator_config.type}"
)
return PriorityTaskGenerator(generators)
async def _submit_task(self, task_payload: TaskPayload) -> None:
inflight_gauge = Metrics.gauge("dispatcher.inflight.count")
self._in_flights.add(task_payload.payload_id)
inflight_gauge.set(len(self._in_flights))
Metrics.counter("dispatcher.task.sent.count").add(
1,
attributes={
"task_type": task_payload.task_type.value
if task_payload.task_type
else "none"
},
)
self._task_queue.put(task_payload)
logger.debug(
f"Submitted task: {task_payload.payload_id}, queue size: {self._task_queue.qsize()}"
)
async def _fill_loop(self):
logger.info("Starting fill loop")
while not self._is_shutdown():
try:
async for task_payload in self._task_generator.poll():
if task_payload is None:
await asyncio.sleep(0.01)
continue
while len(self._in_flights) >= self.config.max_in_flight:
await asyncio.sleep(0.01)
continue
await self._submit_task(task_payload)
except Exception:
logger.error(
f"Error polling from task queues: {traceback.format_exc()}"
)
async def _poll_result(self) -> TaskResult | None:
try:
res = self._resp_queue.get(block=False)
logger.debug(f"Dispatcher received result: {res.task.payload_id}")
Metrics.counter("dispatcher.result.received.count").add(1)
return res
except Empty:
return None
except BrokenPipeError:
logger.error("Broken pipe error, shutting down")
return None
except Exception:
logger.error(f"failed to poll result: {traceback.format_exc()}")
return None
async def _result_loop(self) -> None:
logger.info("Starting result loop")
max_attempts = self.config.max_attempts
inflight_gauge = Metrics.gauge("dispatcher.inflight.count")
while not self._is_shutdown() or self._in_flights:
try:
result = await self._poll_result()
if result is None:
await asyncio.sleep(0.01)
continue
task = result.task
if result.success:
Metrics.counter("dispatcher.result.success.count").add(1)
if task.payload_id in self._in_flights:
self._in_flights.remove(task.payload_id)
inflight_gauge.set(len(self._in_flights))
await self._task_generator.ack(result)
else:
if task.attempt < max_attempts:
Metrics.counter("dispatcher.result.failed.count").add(1)
logger.warning(
f"Task {task.payload_id} failed, retrying... (attempt {task.attempt})"
)
task.attempt += 1
await self._submit_task(task)
else:
if task.payload_id in self._in_flights:
self._in_flights.remove(task.payload_id)
inflight_gauge.set(len(self._in_flights))
logger.error(
f"Task {task.payload_id} failed after {max_attempts} attempts, error is {result.error}"
)
origin = self._task_generator.identify_task_origin(result)
Metrics.counter("dispatcher.result.failed.final.count").add(
1,
attributes={
"origin": origin.value if origin else "unknown"
},
)
if origin is None:
logger.warning(
f"No origin found for task {task.payload_id}, skipping ack"
)
else:
await self._task_generator.ack(result)
except Empty:
await asyncio.sleep(0.1)
except BrokenPipeError:
logger.error("Broken pipe error, shutting down")
break
logger.warning("Result loop finished")
async def _wait_for_queue_connection_shutdown(self):
while not self._is_queue_connection_shutdown():
await asyncio.sleep(1)
logger.warning("Shutdowning task generators")
await self._task_generator.stop()
logger.warning("Task generators stopped")
async def _run(self, started_event: Event):
await self._init_run()
started_event.set()
await asyncio.gather(
self._fill_loop(),
self._result_loop(),
self._wait_for_queue_connection_shutdown(),
)
def run(self, started_event: Event):
asyncio.run(self._run(started_event))
async def start(self):
logger.info("Starting Grox dispatcher...")
started_event = multiprocessing.Event()
self._process = Process(
target=self.run, args=(started_event,), name="grox-dispatcher"
)
self._process.start()
started_event.wait()
logger.info("Grox dispatcher started")
async def stop(self):
logger.warning("Stopping Grox dispatcher...")
if self._process and self._process.is_alive():
self._process.join(self.config.graceful_shutdown_timeout)
else:
logger.warning("Dispatcher process is not alive, skipping join")
logger.warning("Dispatcher stopped")

View File

@@ -0,0 +1,287 @@
import os
import time
import logging
import json
import numpy as np
from grox.lm.post import LitePostRenderer, MMEmbedPostRenderer, EvalPostRenderer
from grox.lm.convo import Image as ConvoImage, Video as ConvoVideo, Content
from embed.embed_cli import XaiEmbeddingClient
from monitor.metrics import Metrics
from grox.config.config import ModelName, grox_config
from grox.data_loaders.data_types import Post
from grox.data_loaders.media_loader import MediaLoader
from grox.data_loaders.strato_loader import TweetStratoLoader
from grox.data_loaders.media_description_loader import MediaDescriptionLoader
from strato_http.queries.post_multimodal_embedding_mh_searchai import (
StratoContentUnderstandingUnifiedPostAnnotations,
)
from grox.config.config import EmbeddingModelConfig
logger = logging.getLogger(__name__)
class MultimodalPostEmbedderV2:
def __init__(
self,
model: str = "qwen3",
use_grok_summary: bool = False,
renderer_version="mmembed_summary",
use_media_descriptions: bool = False,
use_post_context_summary: bool = False,
use_grok_summary_path: str = "",
custom_endpoint: str = "",
instruction: str = "",
):
embed_config = grox_config.get_embedding_model(ModelName.EMBED_PRIMARY)
video_embed_config = grox_config.get_embedding_model(ModelName.EMBED_PRIMARY_VIDEO)
embed_config.text_max_len = 8192
video_embed_config.text_max_len = 8192
qwen_3_embed_06b_config = grox_config.get_embedding_model(
ModelName.EMBED_SMALL
)
qwen_3_embed_8b_config = grox_config.get_embedding_model(
ModelName.EMBED_LARGE
)
recsys_v4_embed_config = grox_config.get_embedding_model(
ModelName.RECSYS_EMBED_V4
)
qwen_3_embed_06b_config.text_max_len = 4096
qwen_3_embed_8b_config.text_max_len = 4096
recsys_v4_embed_config.text_max_len = 4096
if custom_endpoint:
custom_embed_config = EmbeddingModelConfig(
model_name="custom", endpoint=custom_endpoint, text_max_len=4096
)
self._custom_embed_client = XaiEmbeddingClient(config=custom_embed_config)
self.use_custom_embed = True if custom_endpoint else False
self.renderer_version = renderer_version
self._client = XaiEmbeddingClient(config=embed_config)
self._video_client = XaiEmbeddingClient(config=video_embed_config)
self._qwen_3_embed_06b_client = XaiEmbeddingClient(
config=qwen_3_embed_06b_config
)
self._qwen_3_embed_8b_client = XaiEmbeddingClient(config=qwen_3_embed_8b_config)
self._recsys_v4_embed_client = XaiEmbeddingClient(config=recsys_v4_embed_config)
self.model = model
self.use_grok_summary = use_grok_summary
self.use_media_descriptions = use_media_descriptions
self.use_grok_summary_versioned = False
self.instruction = instruction
self.use_post_context_summary = use_post_context_summary
if use_grok_summary_path:
assert os.path.exists(use_grok_summary_path), (
f"Grok summary path {use_grok_summary_path} does not exist"
)
assert use_grok_summary_path.endswith(".jsonl"), (
f"Grok summary path {use_grok_summary_path} is not a jsonl file"
)
self.grok_summary_versioned: dict[str, str] = {}
self.use_grok_summary_versioned = True
with open(use_grok_summary_path, "r") as f:
for line in f:
json_line = json.loads(line)
self.grok_summary_versioned[str(json_line["post_id"]).strip()] = (
json_line["summary"]
)
def _get_client(
self, num_text: int, num_image: int, num_video: int
) -> XaiEmbeddingClient:
if self.use_custom_embed:
return self._custom_embed_client
if self.model == "qwen3":
return self._qwen_3_embed_06b_client
if self.model == "qwen3_8b":
return self._qwen_3_embed_8b_client
if self.model == "v4":
return self._recsys_v4_embed_client
if num_video > 0:
logger.info(
f"Using video client for post with {num_text} text, {num_image} images, and {num_video} videos"
)
return self._video_client
return self._client
def document_original(
self, content: list[Content]
) -> tuple[list[tuple[str, str | bytes]], int, int, int]:
def get_convo_video_instruction(video: ConvoVideo) -> str:
res = [f"The video lasts for {video.total_duration:.2f} seconds."]
bucket_times = [i * video.duration for i in range(len(video.frames))]
res.append(
f"The following frames are sampled at every {video.duration:.2f} second interval."
)
for i, frame in enumerate(video.frames):
subtitle = (
video.subtitles[i]
if video.subtitles and i < len(video.subtitles)
else None
)
subtitle_str = (
f"with subtitle: {subtitle}" if subtitle else "(no subtitles)"
)
res.append(f"At {bucket_times[i]:.2f} seconds, {subtitle_str}.")
res.append("The frames are listed below:")
return " ".join(res)
document = []
num_text = 0
num_image = 0
num_video = 0
new_text_part = ""
for c in content:
if isinstance(c, ConvoImage):
document.append(("text", f"Image: \n"))
document.append(("image", c.content))
num_image += 1
elif isinstance(c, ConvoVideo):
if c.combined_video_bytes:
new_text_part += get_convo_video_instruction(c)
document.append(("video", c.combined_video_bytes))
num_video += 1
elif isinstance(c, str):
new_text_part += c
num_text += 1
new_text_part = new_text_part.strip()
document.append(
("text", "")
)
document.append(("text", new_text_part))
return document, num_text, num_image, num_video
def document_v1(
self, content: list[Content]
) -> tuple[list[tuple[str, str | bytes]], int, int, int]:
def video_frames(
video: ConvoVideo, index: int
) -> list[tuple[str, str | bytes]]:
res: list[tuple[str, str | bytes]] = []
for i, frame in enumerate(video.frames):
res.append(("image", frame))
if (
video.subtitles
and i < len(video.subtitles)
and video.subtitles[i] is not None
and video.subtitles[i].strip() != ""
):
res.append(("text", "subtitle: " + video.subtitles[i] + " "))
return res
document = []
num_text = 0
num_image = 0
num_video = 0
for c in content:
if isinstance(c, str):
document.append(("text", c.strip()))
num_text += 1
else:
if isinstance(c, ConvoImage):
document.append(("image", c.content))
num_image += 1
elif isinstance(c, ConvoVideo):
document.extend(video_frames(c, num_video))
num_video += 1
return document, num_text, num_image, num_video
async def _create_embeddings_for_post(
self,
content: list[Content],
is_query: bool = False,
document_version: str = "v1",
) -> tuple[list[tuple[str, str | bytes]], np.ndarray]:
if document_version == "default":
document_fn = self.document_original
elif document_version == "v1":
document_fn = self.document_v1
else:
raise ValueError(f"document_version not found: {document_version}")
document, num_text, num_image, num_video = document_fn(content)
logger.info(
f"creating embeddings for post with {num_text} text, {num_image} images, and {num_video} videos"
)
client = self._get_client(num_text, num_image, num_video)
return document, await client.create_embeddings_async(
[document], is_query=is_query
)
async def hydrate_grok_post_summary(self, post: Post):
query = StratoContentUnderstandingUnifiedPostAnnotations()
res = await query.fetch(int(post.id))
if res:
description = res["annotations"]["description"]
post.summary = description
def get_detailed_instruct(self, instruction: str) -> str:
return f"Instruct: {instruction}\nQuery: Please embed the following post:"
def _get_document_fn(self, document_version: str):
if document_version == "default":
return self.document_original
if document_version == "v1":
return self.document_v1
raise ValueError(f"document_version not found: {document_version}")
async def embed_texts_batch(
self, texts: list[str], is_query: bool = True, document_version: str = "v1"
) -> list[list[float]]:
if not texts:
return []
document_fn = self._get_document_fn(document_version)
documents: list[list[tuple[str, str | bytes]]] = []
for text in texts:
document, _, _, _ = document_fn([text])
documents.append(document)
client = self._get_client(num_text=1, num_image=0, num_video=0)
embeddings = await client.create_embeddings_async(documents, is_query=is_query)
return [embedding.flatten().tolist() for embedding in embeddings]
async def embed(
self, post: Post, is_query: bool = False, document_version: str = "v1"
) -> tuple[list[tuple[str, str | bytes]], list[float]]:
if self.instruction:
content: list[Content] = [self.get_detailed_instruct(self.instruction)]
if self.renderer_version == "lite":
content = LitePostRenderer.render_for_embedding(post)
elif self.renderer_version == "eval":
content = EvalPostRenderer.render_for_embedding(post)
elif self.renderer_version == "mmembed_summary":
content = await MMEmbedPostRenderer.render_for_embedding(
post, use_grok_summary=self.use_grok_summary
)
if self.use_grok_summary_versioned:
if str(post.id) in self.grok_summary_versioned:
content.append(
f"\nPost summary and description: {self.grok_summary_versioned[str(post.id)]}"
)
if self.use_grok_summary and not self.use_grok_summary_versioned:
await self.hydrate_grok_post_summary(post)
content.append(f"\nPost summary and description: {post.summary}")
if self.use_post_context_summary:
content.append(f"\nPost summary and description: {post.summary}")
if self.use_media_descriptions:
await MediaDescriptionLoader.hydrate_media_descriptions(post)
content.append(
f"\nThe post has these associated media descriptions: \n{post.media_descriptions}"
)
start_time = time.perf_counter_ns()
document, embedding = await self._create_embeddings_for_post(
content, is_query, document_version
)
duration_ms = (time.perf_counter_ns() - start_time) / 1_000_000
logger.info(f"Embedding finished in {duration_ms:.2f} ms")
Metrics.histogram("post_embedding_duration_ms").record(duration_ms)
return document, embedding.flatten().tolist()

View File

@@ -0,0 +1,120 @@
import logging
import time
import numpy as np
from embed.embed_http import ChatTemplate, XaiEmbeddingClientHttp
from embed.embed_http import EmbeddingModelConfig as HttpModelConfig
from grox.config.config import ModelName, grox_config
from grox.data_loaders.data_types import Post, Video
from grox.lm.post_v5 import V5EmbedPostRenderer
from monitor.metrics import Metrics
logger = logging.getLogger(__name__)
DEFAULT_SYSTEM_PROMPT = ""
TRUNCATE_DIM = 1024
class MultimodalPostEmbedderV5:
@staticmethod
def has_video(post: Post) -> bool:
if post.media:
for m in post.media:
if isinstance(m, Video):
return True
return False
def __init__(
self,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
max_images: int | None = None,
):
self.truncate_dim = TRUNCATE_DIM
self.max_images = max_images
embed_config = grox_config.get_embedding_model(ModelName.RECSYS_EMBED_V5)
http_config = HttpModelConfig(
model_name=embed_config.model_name,
endpoint=embed_config.endpoint,
text_max_len=4096,
timeout_seconds=60.0,
)
chat_template = ChatTemplate(system_prompt=system_prompt)
self._client = XaiEmbeddingClientHttp(
config=http_config, chat_template=chat_template
)
def _maybe_truncate(self, embedding: np.ndarray) -> list[float]:
if self.truncate_dim > 0 and len(embedding) > self.truncate_dim:
emb = embedding[: self.truncate_dim]
norm = np.linalg.norm(emb)
if norm > 0:
emb = emb / norm
return emb.tolist()
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding.tolist()
async def embed(
self,
post: Post,
transcript: str | None = None,
is_query: bool = False,
**kwargs,
) -> tuple[list[tuple[str, str | bytes]], list[float]]:
total_start = time.perf_counter()
render_start = time.perf_counter()
text_with_pads, images = V5EmbedPostRenderer.render_for_embedding(
post, max_images=self.max_images
)
render_duration_ms = (time.perf_counter() - render_start) * 1000
Metrics.histogram("post_embedding_v5.render_duration_ms").record(
render_duration_ms
)
if transcript:
text_with_pads += f"\nTranscript: {transcript}"
document: list[tuple[str, str | bytes]] = [("text", text_with_pads)]
for img in images:
document.append(("image", img))
if not text_with_pads and not images:
logger.warning(f"Post {post.id} has no text or media content")
return document, []
encode_start = time.perf_counter()
embedding = await self._client.encode_with_embedded_pads_async(
text_with_pads, images if images else None
)
encode_duration_ms = (time.perf_counter() - encode_start) * 1000
Metrics.histogram("post_embedding_v5.encode_duration_ms").record(
encode_duration_ms
)
truncate_start = time.perf_counter()
result = self._maybe_truncate(embedding)
truncate_duration_ms = (time.perf_counter() - truncate_start) * 1000
Metrics.histogram("post_embedding_v5.truncate_duration_ms").record(
truncate_duration_ms
)
total_duration_ms = (time.perf_counter() - total_start) * 1000
Metrics.histogram("post_embedding_v5.total_duration_ms").record(
total_duration_ms
)
Metrics.counter("post_embedding_v5.image_count").add(len(images))
total_image_bytes = sum(len(img) for img in images) if images else 0
Metrics.histogram("post_embedding_v5.image_payload_bytes").record(
total_image_bytes
)
logger.info(
f"Embedding V5 post={post.id}: total={total_duration_ms:.1f}ms "
f"(render={render_duration_ms:.1f}ms, encode={encode_duration_ms:.1f}ms, truncate={truncate_duration_ms:.1f}ms), "
f"images={len(images)}, image_bytes={total_image_bytes:,}, text_len={len(text_with_pads)}, has_transcript={transcript is not None}"
)
return document, result

137
grox/engine.py Normal file
View File

@@ -0,0 +1,137 @@
import os
import time
import asyncio
import logging
import traceback
import multiprocessing
from queue import Empty, Queue
from threading import Event
from multiprocessing import Process
from monitor.logging import Logging
from monitor.metrics import Metrics
from grox.config.config import grox_config
from grox.schedules.init import init_proc
from grox.schedules.types import TaskResult, TaskPayload
from grox.plans.plan_master import PlanMaster
from grox.schedules.context import ScheduleContext
from grox.data_loaders.media_processor import MediaProcessor
from grox.data_loaders.asr_processor import ASRProcessor
logger = logging.getLogger(__name__)
class Engine:
def __init__(self, context: ScheduleContext):
self.config = grox_config.engine
self.context = context
self._task_queue: Queue[TaskPayload] = self.context["task_queue"]
self._resp_queue: Queue[TaskResult] = self.context["resp_queue"]
self._shutdown_event: Event = self.context["shutdown_event"]
self._process = None
def _is_shutdown(self) -> bool:
try:
return self._shutdown_event.is_set()
except BrokenPipeError:
logger.error("Broken pipe error, assuming shutdown")
return True
except Exception:
logger.error(
f"Error checking shutdown event, assuming shutdown: {traceback.format_exc()}"
)
return True
async def _init_run(self):
await init_proc("engine")
MediaProcessor.start()
ASRProcessor.start()
async def _process_task(self, task: TaskPayload) -> TaskResult:
logger.debug(f"engine started processing task")
start = time.perf_counter()
res = await PlanMaster.exec(task)
end = time.perf_counter()
logger.debug(f"engine finished processing task in {end - start:.2f} seconds")
Metrics.histogram("engine.task.processing_time").record(end - start)
return res
async def _run_task(self, task: TaskPayload):
start = time.perf_counter()
with Metrics.tracer("engine").start_as_current_span("task.root"):
Logging.set_context(task=task.payload_id)
if task.post:
Logging.set_context(post=task.post.id)
if task.user:
Logging.set_context(user=task.user.id)
if task.user_context:
Logging.set_context(user=task.user_context.user.id)
try:
res = await self._process_task(task)
self._resp_queue.put(res)
Metrics.counter("engine.task.success.count").add(1)
except Exception as e:
logger.error(f"failed to process task, error: {traceback.format_exc()}")
self._resp_queue.put(
TaskResult(
task=task,
success=False,
error=str(e),
task_finished_at=start,
task_started_at=time.perf_counter(),
)
)
Metrics.counter("engine.task.failed.count").add(1)
async def _poll_task(self) -> TaskPayload | None:
logger.debug(f"engine polling task, queue size: {self._task_queue.qsize()}")
try:
task = self._task_queue.get(block=False)
logger.debug(f"engine received task: {task.payload_id}")
Metrics.counter("engine.task.received.count").add(1)
return task
except Empty:
logger.debug("engine polling task returned None")
return None
except BrokenPipeError:
logger.error("Broken pipe error, shutting down")
return None
except Exception:
logger.error(f"failed to poll task: {traceback.format_exc()}")
return None
async def _run(self, started_event: Event):
await self._init_run()
started_event.set()
while not self._is_shutdown() or not self._task_queue.empty():
task = await self._poll_task()
if task is None:
await asyncio.sleep(0.1)
continue
asyncio.create_task(self._run_task(task))
logger.warning("engine stopped")
def run(self, started_event: Event):
asyncio.run(self._run(started_event))
os._exit(0)
async def start(self):
logger.info("Starting Grox engine...")
started_event = multiprocessing.Event()
self._process = Process(
target=self.run, args=(started_event,), name="grox-engine"
)
self._process.start()
started_event.wait()
logger.info("Grox engine started")
async def stop(self):
logger.warning("Stopping Grox engine...")
if self._process and self._process.is_alive():
self._process.join(self.config.graceful_shutdown_timeout)
else:
logger.warning("Engine process is not alive, skipping join")
await MediaProcessor.stop()
await ASRProcessor.stop()
logger.warning("Engine stopped")

View File

@@ -0,0 +1,222 @@
import abc
import logging
from collections.abc import AsyncGenerator
from grox.config.config import KafkaTopicName, TaskGeneratorType
from grox.schedules.types import TaskResult, TaskPayload, TaskEligibility
from grox.data_loaders.kafka_loader import (
KafkaAdPostLoader,
KafkaPostLoader,
KafkaPostEmbeddingRequestLoader,
)
from grox.generators.task_generator import TaskGenerator
from grox.data_loaders.message_queue_loader import MessageQueueLoader
logger = logging.getLogger(__name__)
class StreamTaskGenerator(TaskGenerator, metaclass=abc.ABCMeta):
ELIGIBILITIES_TO_INJECT: set[TaskEligibility]
def __init__(self, max_qps: int | None):
super().__init__(max_qps)
self._loader = self._get_loader()
@abc.abstractmethod
def _get_loader(self) -> MessageQueueLoader:
pass
async def start(self) -> None:
logger.info("Starting StreamTaskGenerator")
await self._loader.start()
async def stop(self) -> None:
logger.info("Stopping StreamTaskGenerator")
await super().stop()
await self._loader.stop()
async def _poll(self) -> AsyncGenerator[TaskPayload | None, None]:
async for payload in self._loader.poll():
if not payload:
yield None
continue
yield TaskPayload(
payload_id=payload.mid,
post=payload.post,
user=payload.user,
user_context=payload.user_context,
deadline_ts_secs=payload.deadline_ts_secs,
task_type=self.TASK_GENERATOR_TYPE,
eligibilities=self.ELIGIBILITIES_TO_INJECT.copy(),
grox_content_analysis=payload.grox_content_analysis,
)
async def ack(self, result: TaskResult):
await self._loader.ack(result.task.payload_id, result.success)
class PostStreamTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_STREAM
ELIGIBILITIES_TO_INJECT = {
TaskEligibility.SPAM_COMMENT,
TaskEligibility.REPLY_RANKING,
}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS
)
class MinTractionPostStreamForGroxTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX
ELIGIBILITIES_TO_INJECT = {TaskEligibility.BANGER_INITIAL_SCREEN}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_MIN_TRACTION_FOR_GROX
)
class MinTractionPostStreamForGroxPtosTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX_PTOS
ELIGIBILITIES_TO_INJECT = {TaskEligibility.SAFETY_PTOS}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_MIN_TRACTION_FOR_GROX_PTOS
)
class MinTractionPostStreamForGroxMultiModalTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = (
TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX_MULTI_MODAL
)
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_EMBEDDING_WITH_SUMMARY_FOR_REPLY}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_MIN_TRACTION_FOR_GROX_MULTI_MODAL
)
class PostEmbeddingRequestWithSummaryForReplyRecoveryStreamTaskGenerator(
StreamTaskGenerator
):
TASK_GENERATOR_TYPE = (
TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY_FOR_REPLY_RECOVERY
)
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_EMBEDDING_WITH_SUMMARY_FOR_REPLY}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_REQUESTS_WITH_SUMMARY_FOR_REPLY_RECOVERY
)
class PostStreamRecoveryTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_STREAM_RECOVERY
ELIGIBILITIES_TO_INJECT = {TaskEligibility.BANGER_INITIAL_SCREEN}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_RECOVERY
)
class SafetyPtosRecoveryStreamTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.SAFETY_PTOS_RECOVERY
ELIGIBILITIES_TO_INJECT = {TaskEligibility.SAFETY_PTOS}
def _get_loader(self):
return KafkaPostLoader(KafkaTopicName.SAFETY_PTOS_RECOVERY)
class SafetyPtosDeluxeStreamTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.SAFETY_PTOS_DELUXE
ELIGIBILITIES_TO_INJECT = {TaskEligibility.SAFETY_PTOS}
def _get_loader(self):
return KafkaPostLoader(KafkaTopicName.SAFETY_PTOS_DELUXE)
class PostStreamTestTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_STREAM_TEST
ELIGIBILITIES_TO_INJECT = {TaskEligibility.BANGER_INITIAL_SCREEN}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_TEST
)
class PostSafetyStreamTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_SAFETY_STREAM
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_SAFETY}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_POPULAR
)
class PostStreamDelayedTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_STREAM_DELAYED
ELIGIBILITIES_TO_INJECT = {}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_DELAYED
)
class ReplyRankingRecoveryTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.REPLY_RANKING_RECOVERY
ELIGIBILITIES_TO_INJECT = {TaskEligibility.REPLY_RANKING}
def _get_loader(self):
return KafkaPostLoader(KafkaTopicName.REPLY_RANKING_RECOVERY)
class PostEmbeddingRequestWithSummaryStreamTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_EMBEDDING_WITH_SUMMARY}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_REQUESTS_WITH_SUMMARY
)
class PostEmbeddingRequestWithSummaryRecoveryStreamTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = (
TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY_RECOVERY
)
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_EMBEDDING_WITH_SUMMARY}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_REQUESTS_WITH_SUMMARY_RECOVERY
)
class PostEmbeddingV5StreamTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_EMBEDDING_V5_STREAM
ELIGIBILITIES_TO_INJECT = {TaskEligibility.MM_EMB_V5}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_REQUESTS_WITH_SUMMARY
)
class PostEmbeddingV5ForReplyStreamTaskGenerator(StreamTaskGenerator):
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_EMBEDDING_V5_FOR_REPLY_STREAM
ELIGIBILITIES_TO_INJECT = {TaskEligibility.MM_EMB_V5_FOR_REPLY}
def _get_loader(self):
return KafkaPostLoader(
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_MIN_TRACTION_FOR_GROX_MULTI_MODAL
)

View File

@@ -0,0 +1,153 @@
import asyncio
import logging
import random
import traceback
from abc import ABC, abstractmethod
from grox.config.config import TaskGeneratorType
from grox.schedules.types import TaskResult, TaskPayload
from limits import RateLimitItemPerSecond, storage, strategies
from typing import AsyncGenerator
logger = logging.getLogger(__name__)
limiter = strategies.FixedWindowRateLimiter(storage.MemoryStorage())
class TaskGenerator(ABC):
TASK_GENERATOR_TYPE: TaskGeneratorType | None = None
def __init__(self, max_qps: int | None):
self._shutdown_event = asyncio.Event()
self._limiter_key = self.__class__.__name__
self._limit = RateLimitItemPerSecond(max_qps, 1) if max_qps else None
def is_shutdown(self) -> bool:
try:
return self._shutdown_event.is_set()
except Exception:
logger.error(
f"Error checking if task generator is shutdown: {traceback.format_exc()}"
)
return True
async def start(self) -> None:
pass
async def stop(self) -> None:
logger.info(f"Stopping task generator {self.__class__.__name__}")
self._shutdown_event.set()
async def poll(self) -> AsyncGenerator[TaskPayload | None, None]:
async for payload in self._poll():
if not payload:
yield None
continue
if self._limit:
while not limiter.test(self._limit, self._limiter_key):
yield None
await asyncio.sleep(0.01)
limiter.hit(self._limit, self._limiter_key)
yield payload
@abstractmethod
def _poll(self) -> AsyncGenerator[TaskPayload | None, None]:
pass
async def ack(self, result: TaskResult):
pass
def identify_task_origin(self, result: TaskResult) -> TaskGeneratorType | None:
return self.TASK_GENERATOR_TYPE
class PriorityTaskGenerator(TaskGenerator):
def __init__(self, generators: list[tuple[TaskGenerator, int]]):
if not generators:
raise ValueError("No generators provided")
if any(weight <= 0 for _, weight in generators):
raise ValueError("All weights must be positive")
super().__init__(None)
self._generators: dict[str, TaskGenerator] = {}
self._weights: dict[str, int] = {}
for i, (gen, weight) in enumerate(generators):
label = f"GEN_{i}"
self._generators[label] = gen
self._weights[label] = weight
self._result_cache: dict[str, str] = {}
logger.info(
f"Initialized priority task generator with {list(zip(self._generators.keys(), [gen.__class__.__name__ for gen in self._generators.values()], self._weights.values(), strict=True))}"
)
async def start(self) -> None:
logger.info(f"Starting priority task generators")
await asyncio.gather(*[gen.start() for gen in self._generators.values()])
self._streams = {label: gen.poll() for label, gen in self._generators.items()}
logger.info(f"Priority task generators started")
async def stop(self) -> None:
logger.warning(f"Stopping priority task generators")
await asyncio.gather(*[gen.stop() for gen in self._generators.values()])
await super().stop()
logger.warning(f"Priority task generators stopped")
async def _poll(self) -> AsyncGenerator[TaskPayload | None, None]:
if not self._streams:
raise RuntimeError("Task generators not started")
while self._weights:
_weights = self._weights.copy()
polled = False
while _weights:
labels = list(_weights.keys())
weights = list(_weights.values())
labels = random.choices(labels, weights, k=1)
label = labels[0]
stream = self._streams[label]
try:
payload = await anext(stream)
if payload:
self._result_cache[payload.payload_id] = label
yield payload
polled = True
break
else:
del _weights[label]
except StopAsyncIteration:
logger.warning(
f"Task generator {label} exhausted, removing from pool"
)
if label in _weights:
del _weights[label]
if label in self._weights:
del self._weights[label]
continue
except Exception:
logger.error(
f"Error polling task generator {label}: {traceback.format_exc()}"
)
if label in _weights:
del _weights[label]
continue
if not polled:
yield None
async def ack(self, result: TaskResult):
logger.debug(f"Acknowledging task {result.task.payload_id}")
label = self._result_cache.pop(result.task.payload_id, None)
if not label:
logger.warning(
f"No label found for task {result.task.payload_id}, skipping ack"
)
return
gen = self._generators[label]
await gen.ack(result)
def identify_task_origin(self, result: TaskResult) -> TaskGeneratorType | None:
logger.debug(f"Identifying task origin for {result.task.payload_id}")
label = self._result_cache.get(result.task.payload_id)
if not label:
logger.warning(
f"No label found for task {result.task.payload_id}, cannot identify origin"
)
return None
gen = self._generators[label]
return gen.identify_task_origin(result)

46
grox/lib/stream.py Normal file
View File

@@ -0,0 +1,46 @@
from typing import AsyncIterator, AsyncGenerator, TypeVar
from asyncio import Queue, create_task
from enum import Enum
import logging
T = TypeVar("T")
class StreamStatus(Enum):
STOP = "Stop"
logger = logging.getLogger(__name__)
async def parallel_merge(*streams: AsyncIterator[T]) -> AsyncGenerator[T, None]:
if not streams:
return
queue: Queue[T | StreamStatus | Exception] = Queue()
async def enqueue(ait: AsyncIterator[T]):
try:
async for item in ait:
await queue.put(item)
except GeneratorExit:
pass
except Exception as e:
await queue.put(e)
finally:
await queue.put(StreamStatus.STOP)
_enq_tasks = [create_task(enqueue(s)) for s in streams]
nstreams_done = 0
while True:
item = await queue.get()
if item == StreamStatus.STOP:
nstreams_done += 1
elif isinstance(item, Exception):
raise item
else:
yield item
queue.task_done()
if nstreams_done == len(streams):
break

6
grox/lib/utils.py Normal file
View File

@@ -0,0 +1,6 @@
def camel_to_snake(s: str) -> str:
return "".join(["_" + c.lower() if c.isupper() else c for c in s]).lstrip("_")
def snake_to_camel(s: str) -> str:
return "".join(word.capitalize() for word in s.split("_"))

54
grox/main.py Normal file
View File

@@ -0,0 +1,54 @@
import signal
import asyncio
import logging
from grox.engine import Engine
from grox.service import GrpcServer
from grox.dispatcher import Dispatcher
from grox.config.config import grox_config
from grox.schedules.init import init_proc
from grox.schedules.context import (
cleanup,
new_context,
shutdown_context,
queue_connection_shutdown_context,
)
logger = logging.getLogger(__name__)
shutdown = asyncio.Event()
async def serve():
await init_proc("main")
logger.info(f"Starting grox server...")
context = new_context()
engine = Engine(context)
dispatcher = Dispatcher(context)
grpc_server = GrpcServer(context)
await engine.start()
await dispatcher.start()
await grpc_server.start()
logger.info("Grox server started")
event_loop = asyncio.get_running_loop()
event_loop.add_signal_handler(signal.SIGINT, lambda: shutdown.set())
event_loop.add_signal_handler(signal.SIGTERM, lambda: shutdown.set())
await shutdown.wait()
logger.warning("Grox server shutting down...")
queue_connection_shutdown_context(context)
await asyncio.sleep(300)
shutdown_context(context)
await asyncio.gather(
grpc_server.stop(),
dispatcher.stop(),
engine.stop(),
)
cleanup()
logger.warning("Grox server stopped")
if __name__ == "__main__":
asyncio.run(serve())

106
grox/plans/plan.py Normal file
View File

@@ -0,0 +1,106 @@
import time
import asyncio
import logging
import traceback
from abc import ABC
from functools import cache
from grox.lib.utils import camel_to_snake
from grox.tasks.task import Task, TaskResultCategory
from monitor.metrics import Metrics
from grox.schedules.types import TaskResult, TaskContext, TaskPayload, TaskEligibility
logger = logging.getLogger(__name__)
class Plan(ABC):
TASKS: dict[str, type[Task]] = {}
TASK_DEPENDENCIES: dict[str, set[str]] = {}
REQUIRED_ELIGIBILITY: TaskEligibility
def __init__(self):
self.deps = set([d for deps in self.TASK_DEPENDENCIES.values() for d in deps])
if any(t not in self.TASKS for t in self.deps) or any(
t not in self.TASKS for t in self.TASK_DEPENDENCIES.keys()
):
raise ValueError("Not every task in TASK_DEPENDENCIES is defined in TASKS")
async def execute(self, task: TaskPayload) -> TaskResult | None:
if not self._eligible(task):
return None
Metrics.counter("plan.execute.count").add(
1, attributes={"plan_name": self.get_name()}
)
logger.debug(f"Creating execution plan for graph: {self.TASK_DEPENDENCIES}")
loop = asyncio.get_running_loop()
dependencies = {task: loop.create_future() for task in self.deps}
start = time.perf_counter()
ctx = TaskContext(task)
try:
await asyncio.gather(
*[self._execute_task(t, ctx, dependencies) for t in self.TASKS.keys()]
)
Metrics.counter("plan.execute.success.count").add(
1, attributes={"plan_name": self.get_name()}
)
except Exception as e:
logger.error(f"Error executing plan: {traceback.format_exc()}")
ctx.errors.append(e)
Metrics.counter("plan.execute.failed.count").add(
1, attributes={"plan_name": self.get_name()}
)
finally:
duration = time.perf_counter() - start
Metrics.histogram("plan.execute.duration").record(
duration, attributes={"plan_name": self.get_name()}
)
for fut in dependencies.values():
try:
if not fut.done():
fut.cancel()
except Exception:
logger.error(
f"Error canceling dependency future: {traceback.format_exc()}"
)
dependencies.clear()
return TaskResult(
task=task,
content_categories=[c.model_copy() for c in ctx.content_categories],
task_started_at=ctx.start_time,
task_finished_at=time.perf_counter(),
multimodal_post_embedding=ctx.multimodal_post_embedding,
reason=ctx.reason,
success=len(ctx.errors) == 0,
error="\n".join([str(e) for e in ctx.errors]),
)
def _eligible(self, ctx: TaskPayload) -> bool:
return self.REQUIRED_ELIGIBILITY in ctx.eligibilities
async def _execute_task(
self, task_name: str, ctx: TaskContext, dependencies: dict[str, asyncio.Future]
):
logger.debug(f"Waiting for task to become ready: {task_name}")
task = self.TASKS[task_name]
deps = self.TASK_DEPENDENCIES.get(task_name, set())
dep_futures = [dependencies[d] for d in deps]
dep_results = await asyncio.gather(*dep_futures)
task_future = dependencies.get(task_name, None)
if any(r == TaskResultCategory.SKIPPED for r in dep_results):
if task_future is not None:
task_future.set_result(TaskResultCategory.SKIPPED)
return
logger.debug(f"Started executing task: {task_name}")
try:
res = await task.exec(ctx)
except Exception as e:
if task_future is not None:
task_future.set_exception(e)
raise e
if task_future is not None:
task_future.set_result(res)
logger.debug(f"Finished executing task: {task_name}")
@cache
def get_name(self) -> str:
return camel_to_snake(self.__class__.__name__)

View File

@@ -0,0 +1,37 @@
from grox.plans.plan import Plan
from grox.tasks.task_pub import (
TaskPublishKafka,
TaskPublishUnifiedPostAnnotationsManhattan,
)
from grox.schedules.types import TaskEligibility
from grox.tasks.task_media import TaskMediaHydrationBanger
from grox.tasks.task_filters import TaskInitialBangerFilter
from grox.tasks.task_banger_screen import TaskBangerScreen
from grox.tasks.task_rate_limit import TaskRateLimitBangerAnnotationWithPost
from grox.tasks.task_grok_upa_action_with_labels import TaskGrokUpaActionWithLabels
class PlanInitialBanger(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.BANGER_INITIAL_SCREEN
TASKS = {
"task_initial_banger_filter": TaskInitialBangerFilter,
"task_banger_annotation_rate_limit": TaskRateLimitBangerAnnotationWithPost,
"task_media_hydration": TaskMediaHydrationBanger,
"task_banger_screen_initial": TaskBangerScreen,
"task_grok_upa_action_with_labels": TaskGrokUpaActionWithLabels,
"task_publish_unified_post_annotations_manhattan": TaskPublishUnifiedPostAnnotationsManhattan,
"task_publish_kafka": TaskPublishKafka,
}
TASK_DEPENDENCIES = {
"task_initial_banger_filter": set(),
"task_banger_annotation_rate_limit": {"task_initial_banger_filter"},
"task_media_hydration": {"task_banger_annotation_rate_limit"},
"task_banger_screen_initial": {"task_media_hydration"},
"task_grok_upa_action_with_labels": {"task_banger_screen_initial"},
"task_publish_unified_post_annotations_manhattan": {
"task_banger_screen_initial"
},
"task_publish_kafka": {"task_publish_unified_post_annotations_manhattan"},
}

62
grox/plans/plan_master.py Normal file
View File

@@ -0,0 +1,62 @@
import asyncio
from grox.plans.plan import Plan
from grox.schedules.types import TaskResult, TaskPayload
from grox.plans.plan_spam_comment import PlanSpamComment
from grox.plans.plan_initial_banger import PlanInitialBanger
from grox.plans.plan_post_embedding_with_summary import PlanPostEmbeddingWithSummary
from grox.plans.plan_post_embedding_v5 import PlanPostEmbeddingV5
from grox.plans.plan_post_embedding_v5_for_reply import PlanPostEmbeddingV5ForReply
from grox.plans.plan_post_embedding_with_summary_for_reply import (
PlanPostEmbeddingWithSummaryForReply,
)
from grox.plans.plan_post_safety import PlanPostSafety
from grox.plans.plan_reply_ranking import PlanReplyRanking
from grox.plans.plan_safety_ptos import PlanSafetyPtos
class PlanMaster:
ALL_PLANS: list[Plan] = [
PlanInitialBanger(),
PlanPostSafety(),
PlanSpamComment(),
PlanPostEmbeddingWithSummary(),
PlanPostEmbeddingWithSummaryForReply(),
PlanPostEmbeddingV5(),
PlanPostEmbeddingV5ForReply(),
PlanReplyRanking(),
PlanSafetyPtos(),
]
@classmethod
async def exec(cls, task: TaskPayload) -> TaskResult:
results = await asyncio.gather(*[p.execute(task) for p in cls.ALL_PLANS])
result = cls.merge_results(task, [r for r in results if r is not None])
return result
@classmethod
def merge_results(cls, task: TaskPayload, results: list[TaskResult]) -> TaskResult:
multimodal_post_embedding = [
r.multimodal_post_embedding
for r in results
if r.multimodal_post_embedding is not None
]
if multimodal_post_embedding:
multimodal_post_embedding = multimodal_post_embedding[0]
else:
multimodal_post_embedding = None
return TaskResult(
task=task,
content_categories=[
c.model_copy() for r in results for c in r.content_categories
],
task_started_at=min(r.task_started_at for r in results),
task_finished_at=max(r.task_finished_at for r in results),
multimodal_post_embedding=multimodal_post_embedding,
reason="\n".join([r.reason for r in results if r.reason]),
success=all(r.success for r in results),
error="\n".join(
[r.error or "unknown error" for r in results if not r.success]
),
)

View File

@@ -0,0 +1,29 @@
from grox.plans.plan import Plan
from grox.schedules.types import TaskEligibility
from grox.tasks.task_media import TaskMediaHydration
from grox.tasks.task_multimodal_post_embedding import TaskMultimodalPostEmbeddingV5
from grox.tasks.task_write_mm_embedding_sink import (
TaskWriteMMEmbeddingSinkV5SkipKafkaForReplies,
)
from grox.tasks.task_rate_limit import TaskRateLimitEmbeddingV5
from grox.tasks.task_asr import TaskASRTranscription
class PlanPostEmbeddingV5(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.MM_EMB_V5
TASKS = {
"task_post_embedding_rate_limit": TaskRateLimitEmbeddingV5,
"task_media_hydration": TaskMediaHydration,
"task_asr_transcription": TaskASRTranscription,
"task_multimodal_post_embedding_v5": TaskMultimodalPostEmbeddingV5,
"task_write_post_embedding_sink_v5": TaskWriteMMEmbeddingSinkV5SkipKafkaForReplies,
}
TASK_DEPENDENCIES = {
"task_post_embedding_rate_limit": set(),
"task_media_hydration": {"task_post_embedding_rate_limit"},
"task_asr_transcription": {"task_media_hydration"},
"task_multimodal_post_embedding_v5": {"task_asr_transcription"},
"task_write_post_embedding_sink_v5": {"task_multimodal_post_embedding_v5"},
}

View File

@@ -0,0 +1,32 @@
from grox.plans.plan import Plan
from grox.schedules.types import TaskEligibility
from grox.tasks.task_filters import TaskPostEmbeddingWithSummaryForReplyFilter
from grox.tasks.task_media import TaskMediaHydration
from grox.tasks.task_multimodal_post_embedding import TaskMultimodalPostEmbeddingV5
from grox.tasks.task_write_mm_embedding_sink import TaskWriteMMEmbeddingSinkV5
from grox.tasks.task_rate_limit import TaskRateLimitEmbeddingV5ForReply
from grox.tasks.task_asr import TaskASRTranscription
class PlanPostEmbeddingV5ForReply(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.MM_EMB_V5_FOR_REPLY
TASKS = {
"task_post_embedding_rate_limit_v5_for_reply": TaskRateLimitEmbeddingV5ForReply,
"task_post_embedding_filter_for_reply": TaskPostEmbeddingWithSummaryForReplyFilter,
"task_media_hydration": TaskMediaHydration,
"task_asr_transcription": TaskASRTranscription,
"task_multimodal_post_embedding_v5": TaskMultimodalPostEmbeddingV5,
"task_write_post_embedding_sink_v5": TaskWriteMMEmbeddingSinkV5,
}
TASK_DEPENDENCIES = {
"task_post_embedding_rate_limit_v5_for_reply": set(),
"task_post_embedding_filter_for_reply": {
"task_post_embedding_rate_limit_v5_for_reply"
},
"task_media_hydration": {"task_post_embedding_filter_for_reply"},
"task_asr_transcription": {"task_media_hydration"},
"task_multimodal_post_embedding_v5": {"task_asr_transcription"},
"task_write_post_embedding_sink_v5": {"task_multimodal_post_embedding_v5"},
}

View File

@@ -0,0 +1,38 @@
from grox.plans.plan import Plan
from grox.schedules.types import TaskEligibility
from grox.tasks.task_filters import TaskPostEmbeddingWithSummaryFilter
from grox.tasks.task_media import TaskMediaHydration
from grox.tasks.task_multimodal_post_embedding import (
TaskMultimodalPostEmbeddingWithSummary,
)
from grox.tasks.task_write_mm_embedding_sink import TaskWriteMMEmbeddingSinkV3
from grox.tasks.task_rate_limit import TaskRateLimitEmbeddingWithPostSummary
from grox.tasks.task_summarizer_for_post_embedding import TaskPostEmbeddingSummarizer
class PlanPostEmbeddingWithSummary(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.POST_EMBEDDING_WITH_SUMMARY
TASKS = {
"task_post_embedding_rate_limit_summary": TaskRateLimitEmbeddingWithPostSummary,
"task_post_embedding_with_summary_filter": TaskPostEmbeddingWithSummaryFilter,
"task_media_hydration": TaskMediaHydration,
"task_post_embedding_summarizer": TaskPostEmbeddingSummarizer,
"task_multimodal_post_embedding_with_summary": TaskMultimodalPostEmbeddingWithSummary,
"task_write_post_embedding_sink_v3": TaskWriteMMEmbeddingSinkV3,
}
TASK_DEPENDENCIES = {
"task_post_embedding_rate_limit_summary": set(),
"task_post_embedding_with_summary_filter": {
"task_post_embedding_rate_limit_summary"
},
"task_media_hydration": {"task_post_embedding_with_summary_filter"},
"task_post_embedding_summarizer": {"task_media_hydration"},
"task_multimodal_post_embedding_with_summary": {
"task_post_embedding_summarizer"
},
"task_write_post_embedding_sink_v3": {
"task_multimodal_post_embedding_with_summary"
},
}

View File

@@ -0,0 +1,38 @@
from grox.plans.plan import Plan
from grox.schedules.types import TaskEligibility
from grox.tasks.task_filters import TaskPostEmbeddingWithSummaryForReplyFilter
from grox.tasks.task_media import TaskMediaHydration
from grox.tasks.task_multimodal_post_embedding import (
TaskMultimodalPostEmbeddingWithSummary,
)
from grox.tasks.task_write_mm_embedding_sink import TaskWriteMMEmbeddingSinkV3
from grox.tasks.task_rate_limit import TaskRateLimitEmbeddingWithPostSummaryForReply
from grox.tasks.task_summarizer_for_post_embedding import TaskPostEmbeddingSummarizer
class PlanPostEmbeddingWithSummaryForReply(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.POST_EMBEDDING_WITH_SUMMARY_FOR_REPLY
TASKS = {
"task_post_embedding_rate_limit_summary_for_reply": TaskRateLimitEmbeddingWithPostSummaryForReply,
"task_post_embedding_with_summary_filter_for_reply": TaskPostEmbeddingWithSummaryForReplyFilter,
"task_media_hydration": TaskMediaHydration,
"task_post_embedding_summarizer": TaskPostEmbeddingSummarizer,
"task_multimodal_post_embedding_with_summary": TaskMultimodalPostEmbeddingWithSummary,
"task_write_post_embedding_sink_v3": TaskWriteMMEmbeddingSinkV3,
}
TASK_DEPENDENCIES = {
"task_post_embedding_rate_limit_summary_for_reply": set(),
"task_post_embedding_with_summary_filter_for_reply": {
"task_post_embedding_rate_limit_summary_for_reply"
},
"task_media_hydration": {"task_post_embedding_with_summary_filter_for_reply"},
"task_post_embedding_summarizer": {"task_media_hydration"},
"task_multimodal_post_embedding_with_summary": {
"task_post_embedding_summarizer"
},
"task_write_post_embedding_sink_v3": {
"task_multimodal_post_embedding_with_summary"
},
}

View File

@@ -0,0 +1,32 @@
from grox.plans.plan import Plan
from grox.tasks.task_pub import TaskUpsertTweetBoolMetadataToUnifiedPostAnnotation
from grox.schedules.types import TaskEligibility
from grox.tasks.task_filters import TaskPostSafetyDeluxeFilter
from grox.tasks.task_media import TaskMediaHydrationBanger
from grox.tasks.task_post_safety_screen_deluxe import TaskPostSafetyScreenDeluxe
from grox.tasks.task_rate_limit import TaskRateLimitPostSafetyAnnotationWithPost
from grox.tasks.task_grok_upa_action_with_labels import TaskGrokUpaActionWithLabels
class PlanPostSafety(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.POST_SAFETY
TASKS = {
"task_post_safety_deluxe_filter": TaskPostSafetyDeluxeFilter,
"task_post_safety_annotation_rate_limit": TaskRateLimitPostSafetyAnnotationWithPost,
"task_media_hydration": TaskMediaHydrationBanger,
"task_post_safety_screen_deluxe": TaskPostSafetyScreenDeluxe,
"task_grok_upa_action_with_labels": TaskGrokUpaActionWithLabels,
"task_upsert_tweet_bool_metadata_to_unified_post_annotations_manhattan": TaskUpsertTweetBoolMetadataToUnifiedPostAnnotation,
}
TASK_DEPENDENCIES = {
"task_post_safety_deluxe_filter": set(),
"task_post_safety_annotation_rate_limit": {"task_post_safety_deluxe_filter"},
"task_media_hydration": {"task_post_safety_annotation_rate_limit"},
"task_post_safety_screen_deluxe": {"task_media_hydration"},
"task_grok_upa_action_with_labels": {"task_post_safety_screen_deluxe"},
"task_upsert_tweet_bool_metadata_to_unified_post_annotations_manhattan": {
"task_post_safety_screen_deluxe"
},
}

View File

@@ -0,0 +1,27 @@
from grox.plans.plan import Plan
from grox.tasks.task_pub import TaskWriteReplyRankingManhattan
from grox.schedules.types import TaskEligibility
from grox.tasks.task_media import TaskMediaHydration
from grox.tasks.task_filters import TaskReplyRankingFilter
from grox.tasks.task_rank_replies import TaskRankReplies
from grox.tasks.task_rate_limit import TaskRateLimitReplyRankingAnnotationWithPost
class PlanReplyRanking(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.REPLY_RANKING
TASKS = {
"task_reply_ranking_filter": TaskReplyRankingFilter,
"task_reply_ranking_annotation_rate_limit": TaskRateLimitReplyRankingAnnotationWithPost,
"task_media_hydration": TaskMediaHydration,
"task_rank_replies": TaskRankReplies,
"task_write_reply_ranking_manhattan": TaskWriteReplyRankingManhattan,
}
TASK_DEPENDENCIES = {
"task_reply_ranking_filter": set(),
"task_reply_ranking_annotation_rate_limit": {"task_reply_ranking_filter"},
"task_media_hydration": {"task_reply_ranking_annotation_rate_limit"},
"task_rank_replies": {"task_media_hydration"},
"task_write_reply_ranking_manhattan": {"task_rank_replies"},
}

View File

@@ -0,0 +1,32 @@
from grox.plans.plan import Plan
from grox.schedules.types import TaskEligibility
from grox.tasks.task_media import TaskMediaHydration
from grox.tasks.task_filters import TaskSafetyPtosFilter
from grox.tasks.task_safety_ptos_category import TaskSafetyPtosCategoryDetection
from grox.tasks.task_safety_ptos_policy import TaskSafetyPtosPolicyDetection
from grox.tasks.task_rate_limit import TaskRateLimitSafetyPtosAnnotationWithPost
from grox.tasks.task_write_safety_post_annotations_result_sink import (
TaskWriteSafetyPostAnnotationsResultSink,
)
class PlanSafetyPtos(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.SAFETY_PTOS
TASKS = {
"task_safety_ptos_filter": TaskSafetyPtosFilter,
"task_safety_ptos_annotation_rate_limit": TaskRateLimitSafetyPtosAnnotationWithPost,
"task_media_hydration": TaskMediaHydration,
"task_safety_ptos_category_detection": TaskSafetyPtosCategoryDetection,
"task_safety_ptos_policy_detection": TaskSafetyPtosPolicyDetection,
"task_write_safety_post_annotations_result_sink": TaskWriteSafetyPostAnnotationsResultSink,
}
TASK_DEPENDENCIES = {
"task_safety_ptos_filter": {},
"task_safety_ptos_annotation_rate_limit": {"task_safety_ptos_filter"},
"task_media_hydration": {"task_safety_ptos_annotation_rate_limit"},
"task_safety_ptos_category_detection": {"task_media_hydration"},
"task_safety_ptos_policy_detection": {"task_safety_ptos_category_detection"},
"task_write_safety_post_annotations_result_sink": {"task_safety_ptos_policy_detection"},
}

View File

@@ -0,0 +1,29 @@
from grox.plans.plan import Plan
from grox.tasks.task_pub import TaskPublishKafka, TaskWriteReplySpamManhattan
from grox.schedules.types import TaskEligibility
from grox.tasks.task_media import TaskMediaHydration
from grox.tasks.task_filters import TaskSpamFilter
from grox.tasks.task_spam_detection import TaskSpamDetection
from grox.tasks.task_rate_limit import TaskRateLimitReplySpamAnnotationWithPost
class PlanSpamComment(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.SPAM_COMMENT
TASKS = {
"task_spam_filter": TaskSpamFilter,
"task_reply_spam_annotation_rate_limit": TaskRateLimitReplySpamAnnotationWithPost,
"task_media_hydration": TaskMediaHydration,
"task_spam_detection": TaskSpamDetection,
"task_publish_reply_spam_mh": TaskWriteReplySpamManhattan,
"task_publish_kafka": TaskPublishKafka,
}
TASK_DEPENDENCIES = {
"task_spam_filter": set(),
"task_reply_spam_annotation_rate_limit": {"task_spam_filter"},
"task_media_hydration": {"task_reply_spam_annotation_rate_limit"},
"task_spam_detection": {"task_media_hydration"},
"task_publish_reply_spam_mh": {"task_spam_detection"},
"task_publish_kafka": {"task_spam_detection"},
}

47
grox/schedules/context.py Normal file
View File

@@ -0,0 +1,47 @@
import signal
import logging
from typing import Any
from multiprocessing import Manager
from multiprocessing.managers import DictProxy, SyncManager
type ScheduleContext = DictProxy[str, Any]
logger = logging.getLogger(__name__)
_manager: SyncManager | None = None
def get_manager() -> SyncManager:
global _manager
if _manager is None:
_manager = Manager()
return _manager
def new_context() -> ScheduleContext:
manager = get_manager()
return manager.dict(
task_queue=manager.Queue(),
resp_queue=manager.Queue(),
live_task_queue=manager.Queue(),
live_resp_queue=manager.Queue(),
shutdown_event=manager.Event(),
queue_connection_shutdown_event=manager.Event(),
)
def prevent_default() -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
def shutdown_context(context: ScheduleContext) -> None:
context["shutdown_event"].set()
def queue_connection_shutdown_context(context: ScheduleContext) -> None:
context["queue_connection_shutdown_event"].set()
def cleanup() -> None:
if _manager is not None:
logger.warning("Shutting down context manager")
_manager.shutdown()

35
grox/schedules/init.py Normal file
View File

@@ -0,0 +1,35 @@
import asyncio
import gc
import logging
import random
import time
import setproctitle
from grox.config.config import grox_config
from grox.schedules.context import prevent_default
from monitor.logging import Logging
from monitor.metrics import Metrics
logger = logging.getLogger(__name__)
async def init_proc(proc_name: str):
prevent_default()
Logging.config(grox_config.logging)
Metrics.init(proc_name, grox_config.metrics)
setproctitle.setproctitle(proc_name)
logger.info(f"Changed process title to {proc_name}")
asyncio.create_task(periodic_gc(proc_name))
async def periodic_gc(proc_name: str):
while True:
seconds = grox_config.periodic_gc.interval + random.randint(
0, int(grox_config.periodic_gc.jitter)
)
await asyncio.sleep(seconds)
logger.info(f"Running periodic GC for {proc_name}")
start = time.perf_counter()
gc.collect()
end = time.perf_counter()
logger.info(f"GC for {proc_name} took {end - start:.2f} seconds")

73
grox/schedules/types.py Normal file
View File

@@ -0,0 +1,73 @@
import time
from enum import Enum
from typing import Any
from pydantic import Field, BaseModel
from grox.config.config import TaskGeneratorType
from grox.data_loaders.data_types import (
Post,
User,
UserContext,
ContentCategoryResult,
GroxContentAnalysis,
ReplyScoreResult,
SafetyPostAnnotations,
)
from grox.classifiers.content.classifier_data_collection import (
ClassifierDataCollectionResult,
)
class TaskEligibility(str, Enum):
SPAM_COMMENT = "spam_comment"
BANGER_INITIAL_SCREEN = "banger_initial_screen"
POST_EMBEDDING_WITH_SUMMARY = "post_embedding_with_summary"
POST_EMBEDDING_WITH_SUMMARY_FOR_REPLY = "post_embedding_with_summary_for_reply"
MM_EMB_V4 = "mm_emb_v4"
MM_EMB_V5 = "mm_emb_v5"
MM_EMB_V5_FOR_REPLY = "mm_emb_v5_for_reply"
REPLY_RANKING = "reply_ranking"
SAFETY_PTOS = "safety_ptos"
POST_SAFETY = "post_safety"
class TaskPayload(BaseModel):
payload_id: str
post: Post | None = None
user: User | None = None
user_context: UserContext | None = None
attempt: int = 0
eligibilities: set[TaskEligibility] = Field(default_factory=set)
deadline_ts_secs: int | None = None
task_type: TaskGeneratorType | None = None
grox_content_analysis: GroxContentAnalysis | None = None
class TaskResult(BaseModel):
task: TaskPayload
task_started_at: float
task_finished_at: float = Field(default_factory=time.perf_counter)
content_categories: list[ContentCategoryResult] = Field(default_factory=list)
multimodal_post_embedding: list[float] | None = None
reason: str = ""
success: bool = Field(default=True)
error: str | None = Field(default=None)
class TaskContext:
def __init__(self, task: TaskPayload):
self.payload = task
self.eligibilities: set[TaskEligibility] = task.eligibilities.copy()
self.content_categories: list[ContentCategoryResult] = []
self.summary: str = ""
self.multimodal_post_embedding: list[float] | None = None
self.multimodal_post_embedding_dict: dict[str, list[float]] = {}
self.reply_ranking_results: list[ReplyScoreResult] = []
self.reason: str = ""
self.available_topics: list | None = None
self.start_time = time.perf_counter()
self.errors: list[Exception] = []
self.safety_annotations: SafetyPostAnnotations | None = None
class TaskError(Exception):
pass

View File

@@ -0,0 +1,44 @@
import logging
import time
import traceback
from abc import ABC, abstractmethod
from grok_sampler.eapi_sampler import EapiSampler
from grox.config.config import EapiModelConfig
from monitor.metrics import Metrics
logger = logging.getLogger(__name__)
from typing import TypeVar, Generic, Any
T = TypeVar("T")
class EapiSummarizer(ABC, Generic[T]):
def __init__(self, eapi_config: EapiModelConfig, eapi_sampler: EapiSampler):
self.eapi_config = eapi_config
self.eapi_sampler = eapi_sampler
async def summarize(self, input: T) -> Any:
logger.info(f"[{self.__class__.__name__}] started processing summarize request")
Metrics.counter("summarize.request.count").add(1)
start = time.perf_counter()
try:
res = await self._summarize(input)
except Exception:
Metrics.counter("summarize.error.count").add(1)
logger.error(
f"[{self.__class__.__name__}] error processing summarize request: {traceback.format_exc()}"
)
raise
Metrics.counter("summarize.success.count").add(1)
end = time.perf_counter()
logger.info(
f"[{self.__class__.__name__}] finished processing summarize request in {end - start:.2f} seconds"
)
Metrics.histogram("summarize.latency.seconds").record(end - start)
return res
@abstractmethod
async def _summarize(self, input: T) -> Any:
pass

View File

@@ -0,0 +1,55 @@
import uuid
import logging
from grox.lm.post import PostRenderer
from grox.lm.user import UserRenderer
from grox.config.config import grox_config, ModelName
from grox.summarizer.summarizer import Summarizer
from grox.data_loaders.data_types import Post
from grox.lm.convo import Conversation, Message, Role
from grok_sampler.config import GrokModelConfig
from grok_sampler.vision_sampler import VisionSampler
import os
logger = logging.getLogger(__name__)
class PostEmbeddingSummarizer(Summarizer):
def __init__(self, prompt_file: str):
vlm_config = grox_config.get_model(ModelName.VLM_MINI_CRITICAL)
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
self.prompt_file: str = prompt_file
if not os.path.exists(self.prompt_file):
raise FileNotFoundError(f"Prompt file {self.prompt_file} not found")
super().__init__(vlm_config, vlm)
async def _summarize(self, post: Post) -> str:
convo = await self._render_vlm_conversation(post)
result = await self.vlm.sample(
convo.interleave(), conversation_id=convo.conversation_id
)
result_section = result.split("<description>")[1].split("</description>")[0]
return result_section
async def _render_vlm_conversation(
self, post: Post, disable_thinking: bool = True
) -> Conversation:
convo = Conversation(conversation_id=uuid.uuid4().hex)
prompt = ""
with open(self.prompt_file, "r") as f:
prompt = f.read()
convo.messages.append(Message(role=Role.SYSTEM, content=[prompt]))
convo.messages.append(await self._build_task_message(post))
if disable_thinking:
convo.messages.append(
Message(role=Role.ASSISTANT, content=[""], separator="")
)
else:
convo.messages.append(Message(role=Role.ASSISTANT))
return convo
async def _build_task_message(self, post: Post) -> Message:
msg: Message = Message(role=Role.USER, content=[])
msg.content.extend(UserRenderer.render(post.user))
msg.content.extend(PostRenderer.render(post))
return msg

View File

@@ -0,0 +1,44 @@
import logging
import time
import traceback
from abc import ABC, abstractmethod
from grok_sampler.vision_sampler import VisionSampler
from grox.config.config import ModelConfig
from monitor.metrics import Metrics
logger = logging.getLogger(__name__)
from typing import TypeVar, Generic, Any
T = TypeVar("T")
class Summarizer(ABC, Generic[T]):
def __init__(self, model_config: ModelConfig, vlm: VisionSampler):
self.model_config = model_config
self.vlm = vlm
async def summarize(self, input: T) -> Any:
logger.info(f"[{self.__class__.__name__}] started processing summarize request")
Metrics.counter("summarize.request.count").add(1)
start = time.perf_counter()
try:
res = await self._summarize(input)
except Exception:
Metrics.counter("summarize.error.count").add(1)
logger.error(
f"[{self.__class__.__name__}] error processing summarize request: {traceback.format_exc()}"
)
raise
Metrics.counter("summarize.success.count").add(1)
end = time.perf_counter()
logger.info(
f"[{self.__class__.__name__}] finished processing summarize request in {end - start:.2f} seconds"
)
Metrics.histogram("summarize.latency.seconds").record(end - start)
return res
@abstractmethod
async def _summarize(self, input: T) -> Any:
pass

View File

@@ -0,0 +1,56 @@
from abc import ABC
from grox.config.env import is_dev, is_prod, is_local, is_mm_emb_prod, is_ptos_prod
from grox.schedules.types import TaskContext
class DisableTaskRule(ABC):
DISABLE_REASON: str = ""
@classmethod
def should_disable(cls, ctx: TaskContext) -> bool:
return False
@classmethod
def disable_reason(cls) -> str | None:
return cls.DISABLE_REASON
class DisableTaskForLocal(DisableTaskRule):
DISABLE_REASON = "Task is disabled for local mode"
@classmethod
def should_disable(cls, ctx: TaskContext) -> bool:
return is_local
class DisableTaskForDev(DisableTaskRule):
DISABLE_REASON = "Task is disabled for dev mode"
@classmethod
def should_disable(cls, ctx: TaskContext) -> bool:
return is_dev
class DisableTaskForNonProd(DisableTaskRule):
DISABLE_REASON = "Task is disabled for non-prod mode"
@classmethod
def should_disable(cls, ctx: TaskContext) -> bool:
return not is_prod
class DisableTaskForNonMmEmbProd(DisableTaskRule):
DISABLE_REASON = "Task is disabled for non-mm-emb-prod mode"
@classmethod
def should_disable(cls, ctx: TaskContext) -> bool:
return not is_mm_emb_prod
class DisableTaskForNonPtosProd(DisableTaskRule):
DISABLE_REASON = "Task is disabled for non-ptos-prod mode"
@classmethod
def should_disable(cls, ctx: TaskContext) -> bool:
return not is_ptos_prod

150
grox/tasks/task.py Normal file
View File

@@ -0,0 +1,150 @@
import logging
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from functools import cache
from tenacity import retry, wait_fixed, stop_after_attempt
from grox.lib.utils import camel_to_snake
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.tasks.disable_rules import DisableTaskRule
from grox.data_loaders.data_types import GroxContentAnalysis, Post, User, UserContext
logger = logging.getLogger(__name__)
class TaskResultCategory(str, Enum):
SUCCESS = "success"
SKIPPED = "skipped"
class TaskStopExecution(Exception):
pass
class Task(ABC):
DISABLE_RULES: list[type[DisableTaskRule]] = []
@classmethod
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
metrics_attributes = {"task_name": cls.get_name()}
Metrics.counter("task.exec.count").add(1, attributes=metrics_attributes)
logger.debug(f"[{cls.get_name()}] starting task")
if cls.should_skip(ctx):
Metrics.counter("task.exec.skipped.count").add(
1, attributes=metrics_attributes
)
logger.info(f"[{cls.get_name()}] skipping task")
return TaskResultCategory.SKIPPED
Metrics.counter("task.exec.intaken.count").add(1, attributes=metrics_attributes)
try:
await cls._exec(ctx)
except TaskStopExecution:
Metrics.counter("task.exec.skipped.count").add(
1, attributes=metrics_attributes
)
logger.info(f"[{cls.get_name()}] skipping task")
return TaskResultCategory.SKIPPED
except Exception:
Metrics.counter("task.exec.failed.count").add(
1, attributes=metrics_attributes
)
logger.error(
f"[{cls.get_name()}] failed to execute task with error: {traceback.format_exc()}"
)
raise
Metrics.counter("task.exec.success.count").add(1, attributes=metrics_attributes)
return TaskResultCategory.SUCCESS
@classmethod
@abstractmethod
async def _exec(cls, ctx: TaskContext) -> None:
raise NotImplementedError()
@classmethod
def should_skip(cls, ctx: TaskContext) -> bool:
if cls.should_disable(ctx):
return True
return cls._should_skip(ctx)
@classmethod
def _should_skip(cls, ctx: TaskContext) -> bool:
return False
@classmethod
def should_disable(cls, ctx: TaskContext) -> bool:
for rule in cls.DISABLE_RULES:
if rule.should_disable(ctx):
logger.debug(
f"[{cls.get_name()}] skipping task because {rule.disable_reason()}"
)
return True
return False
@classmethod
@cache
def get_name(cls) -> str:
return camel_to_snake(cls.__name__)
class TaskWithUser(Task):
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
user = ctx.payload.user
if not user:
raise TaskStopExecution("No user for task")
await cls._exec_with_user(ctx, user)
@classmethod
@abstractmethod
async def _exec_with_user(cls, ctx: TaskContext, user: User) -> None:
raise NotImplementedError()
class TaskWithUserContext(Task):
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
user_context = ctx.payload.user_context
if not user_context:
raise TaskStopExecution("No user context for task")
await cls._exec_with_user_context(ctx, user_context)
@classmethod
@abstractmethod
async def _exec_with_user_context(
cls, ctx: TaskContext, user_context: UserContext
) -> None:
raise NotImplementedError()
class TaskWithPost(Task):
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
post = ctx.payload.post
if not post:
raise TaskStopExecution("No post for task")
await cls._exec_with_post(ctx, post)
@classmethod
@abstractmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
raise NotImplementedError()
class TaskWithContentAnalysis(Task):
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
content_analysis: GroxContentAnalysis | None = ctx.payload.grox_content_analysis
if not content_analysis:
raise TaskStopExecution("No content analysis for task")
await cls._exec_with_content_analysis(ctx, content_analysis)
@classmethod
@abstractmethod
async def _exec_with_content_analysis(
cls, ctx: TaskContext, content_analysis: GroxContentAnalysis
) -> None:
raise NotImplementedError()

74
grox/tasks/task_asr.py Normal file
View File

@@ -0,0 +1,74 @@
import logging
from grox.data_loaders.asr_processor import ASRProcessor
from grox.data_loaders.data_types import Post, Video
from grox.schedules.types import TaskContext
from grox.tasks.task import TaskWithPost
from monitor.metrics import Metrics
logger = logging.getLogger(__name__)
def _get_video_url(video: Video) -> tuple[str | None, bool]:
if video.animatedGifInfo:
v = video.animatedGifInfo.get_best_variant()
if v and v.url:
return v.url, True
if video.videoInfo:
v = video.videoInfo.get_best_variant()
if v and v.url:
return v.url, False
if video.url:
return video.url, False
return None, False
class TaskASRTranscription(TaskWithPost):
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
Metrics.counter("task.asr_transcription.total.count").add(1)
if not post.media:
logger.debug(f"Post {post.id} has no media, skipping ASR")
Metrics.counter("task.asr_transcription.skipped.count").add(
1, attributes={"reason": "no_media"}
)
return
videos = [m for m in post.media if isinstance(m, Video)]
if not videos:
logger.debug(f"Post {post.id} has no video, skipping ASR")
Metrics.counter("task.asr_transcription.skipped.count").add(
1, attributes={"reason": "no_video"}
)
return
Metrics.counter("task.asr_transcription.has_video.count").add(1)
for video in videos:
video_url, is_animated_gif = _get_video_url(video)
if not video_url:
continue
if is_animated_gif:
logger.debug(
f"Post {post.id} video {video.id} is an animated GIF, skipping ASR (no audio)"
)
Metrics.counter("task.asr_transcription.skipped.count").add(
1, attributes={"reason": "animated_gif"}
)
continue
transcript = await ASRProcessor.process(post.id, video_url)
if transcript:
if video.convo_video is not None:
video.convo_video.asr_transcript = transcript
Metrics.counter("task.asr_transcription.success.count").add(1)
logger.info(
f"ASR completed for post {post.id} video {video.id}, transcript_len={len(transcript)}"
)
else:
Metrics.counter("task.asr_transcription.failed.count").add(
1, attributes={"reason": "processor_error"}
)

View File

@@ -0,0 +1,79 @@
import logging
import time
from typing import override
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post, ContentCategoryType
from grox.classifiers.content.banger_initial_screen import BangerInitialScreenClassifier
from strato_http.queries.grok_topics import StratoGrokTopics
logger = logging.getLogger(__name__)
LOG_EVERY_N = 10000
CACHE_TTL_SECONDS = 3600
class TaskBangerScreen(TaskWithPost):
classifier = BangerInitialScreenClassifier()
_cached_topics = None
_cache_timestamp = None
@classmethod
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)
@override
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
Metrics.counter("task.banger_initial_screen.total.count").add(1)
if cls._cached_topics is None or (
cls._cache_timestamp is not None
and time.time() - cls._cache_timestamp > CACHE_TTL_SECONDS
):
logger.info("Fetching grok topics for cache")
Metrics.counter("task.banger_initial_screen.grok_cache.new_load.count").add(
1
)
query = StratoGrokTopics()
fetched_topics = await query.fetch()
if fetched_topics:
cls._cached_topics = fetched_topics
cls._cache_timestamp = time.time()
logger.info(f"Cached {len(cls._cached_topics)} categories with topics")
Metrics.counter(
"task.banger_initial_screen.grok_cache.new_load.success.count"
).add(1)
else:
logger.warning("Failed to fetch grok topics")
Metrics.counter(
"task.banger_initial_screen.grok_cache.new_load.failure.count"
).add(1)
else:
Metrics.counter(
"task.banger_initial_screen.grok_cache.cache_hit.count"
).add(1)
if cls._cached_topics and len(cls._cached_topics) > 0:
Metrics.counter("task.banger_initial_screen.with_cached_topics.count").add(
1
)
else:
Metrics.counter(
"task.banger_initial_screen.without_cached_topics.count"
).add(1)
res = await cls.classifier.classify(post, topics=cls._cached_topics)
ctx.content_categories.extend(res)
ctx.available_topics = cls._cached_topics
passed = any(
r.positive
for r in res
if r.category == ContentCategoryType.BANGER_INITIAL_SCREEN
)
if passed:
Metrics.counter("task.banger_initial_screen.passed.count").add(1)
else:
Metrics.counter("task.banger_initial_screen.failed.count").add(1)

View File

@@ -0,0 +1,79 @@
import logging
import os
import time
import traceback
from functools import cache
from thrifts.gen.twitter.strato.columns.content_understanding.content_understanding.ttypes import (
SimpleTweetEmbedding,
)
from thrifts.serdes import Serializer
from grox.tasks.task import TaskWithPost
from grox.tasks.disable_rules import DisableTaskForNonMmEmbProd
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post
from grox.config.config import grox_config, KafkaTopicName
from kafka_cli.producer import ScramKafkaProducer
from monitor.metrics import Metrics
logger = logging.getLogger(__name__)
class TaskPublishEmbeddingKafka(TaskWithPost):
DISABLE_RULES = [DisableTaskForNonMmEmbProd]
KAFKA_TOPIC_NAME: KafkaTopicName
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
embedding = ctx.multimodal_post_embedding
if embedding is None:
Metrics.counter("task.publish_embedding_kafka.skipped.count").add(
1, attributes={"reason": "no_embedding"}
)
logger.info(f"No embedding available for post {post.id}, skipping")
return
Metrics.counter("task.publish_embedding_kafka.intaken.count").add(1)
try:
await cls._publish_to_kafka(post, embedding)
Metrics.counter("task.publish_embedding_kafka.success.count").add(1)
if post.created_at:
latency = time.time() - post.created_at.timestamp()
Metrics.histogram("task.publish_embedding_kafka.e2e_latency").record(
latency
)
except Exception:
Metrics.counter("task.publish_embedding_kafka.failed.count").add(1)
logger.error(
f"Failed to publish embedding to Kafka: {traceback.format_exc()}"
)
raise
@classmethod
async def _publish_to_kafka(cls, post: Post, embedding: list[float]) -> None:
tweet_embedding = SimpleTweetEmbedding(
tweetId=int(post.id),
embedding1=embedding,
)
serialized_bytes = Serializer.serialize(tweet_embedding)
await cls._get_kafka_producer().send(id=post.id, value=serialized_bytes)
logger.info(
f"Published embedding for post {post.id} to {cls.KAFKA_TOPIC_NAME.value}"
)
@classmethod
@cache
def _get_kafka_producer(cls) -> ScramKafkaProducer:
producer_config = grox_config.get_kafka_producer_topic(cls.KAFKA_TOPIC_NAME)
logger.info(
f"Creating embedding kafka producer with config: {producer_config.model_dump()}"
)
return ScramKafkaProducer(producer_config)
class TaskPublishEmbeddingV4Kafka(TaskPublishEmbeddingKafka):
KAFKA_TOPIC_NAME = KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_V4
class TaskPublishEmbeddingV5Kafka(TaskPublishEmbeddingKafka):
KAFKA_TOPIC_NAME = KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_V5

370
grox/tasks/task_filters.py Normal file
View File

@@ -0,0 +1,370 @@
import logging
from abc import abstractmethod
from typing import override
from datetime import datetime, timezone
from grox.config.config import grox_config, TaskGeneratorType
from grox.tasks.task import Task, TaskStopExecution
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post, User
logger = logging.getLogger(__name__)
class TaskFilter(Task):
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
if not await cls._eligible(ctx):
raise TaskStopExecution()
@classmethod
@abstractmethod
async def _eligible(cls, ctx: TaskContext) -> bool:
raise NotImplementedError()
class TaskFilterWithUser(TaskFilter):
@override
@classmethod
async def _eligible(cls, ctx: TaskContext) -> bool:
if not ctx.payload.user:
return False
return await cls._eligible_with_user(ctx.payload.user, ctx)
@classmethod
@abstractmethod
async def _eligible_with_user(cls, user: User, ctx: TaskContext) -> bool:
raise NotImplementedError()
class TaskFilterWithPost(TaskFilter):
@override
@classmethod
async def _eligible(cls, ctx: TaskContext) -> bool:
if not ctx.payload.post:
return False
return await cls._eligible_with_post(ctx.payload.post, ctx)
@classmethod
@abstractmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
raise NotImplementedError()
class TaskSpamFilter(TaskFilterWithPost):
FOLLOWER_COUNT_THRESHOLD_FOR_SPAM_DETECTION = ""
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
if not post.ancestors:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "spam_detection", "reason": "not_reply"}
)
return False
if not post.user:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "spam_detection", "reason": "no_user"}
)
return False
if post.user.id == 0:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={"filter": "spam_detection", "reason": "is_system_account"},
)
return False
if post.user.id == 0:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={"filter": "spam_detection", "reason": "is_system_account"},
)
return False
if not post.ancestors[-1].user:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "spam_detection",
"reason": "previous_post_no_user",
},
)
return False
if post.user.id == post.ancestors[-1].user.id:
logger.info(
f"Skipping reply spam since the replier is same as reply target post {post.id}"
)
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "spam_detection", "reason": "same_user_reply"}
)
return False
if not post.ancestors[0].user:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={"filter": "spam_detection", "reason": "root_post_no_user"},
)
return False
if post.user.id == post.ancestors[0].user.id:
logger.info(
f"Skipping reply spam since the replier is same as reply root post {post.id}"
)
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "spam_detection",
"reason": "same_user_reply_as_root",
},
)
return False
in_reply_user_follower_count = post.ancestors[-1].user.follower_count or 0
root_user_follower_count = post.ancestors[0].user.follower_count or 0
if (
in_reply_user_follower_count
> cls.FOLLOWER_COUNT_THRESHOLD_FOR_SPAM_DETECTION
or root_user_follower_count
> cls.FOLLOWER_COUNT_THRESHOLD_FOR_SPAM_DETECTION
):
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "spam_detection",
"reason": "reply_ranking_target",
},
)
return False
return True
class TaskReplyRankingFilter(TaskFilterWithPost):
FOLLOWER_COUNT_THRESHOLD_FOR_REPLY_RANKING = ""
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
if not post.ancestors:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "reply_ranking", "reason": "not_reply"}
)
return False
if not post.user:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "reply_ranking", "reason": "no_user"}
)
return False
if not post.ancestors[-1].user:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "reply_ranking",
"reason": "previous_post_no_user",
},
)
return False
if not post.ancestors[0].user:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "reply_ranking", "reason": "root_post_no_user"}
)
return False
in_reply_user_follower_count = post.ancestors[-1].user.follower_count or 0
root_user_follower_count = post.ancestors[0].user.follower_count or 0
if (
in_reply_user_follower_count
<= cls.FOLLOWER_COUNT_THRESHOLD_FOR_REPLY_RANKING
and root_user_follower_count
<= cls.FOLLOWER_COUNT_THRESHOLD_FOR_REPLY_RANKING
):
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "reply_ranking", "reason": "low_blast_radius"}
)
return False
if post.user.id == post.ancestors[-1].user.id:
logger.info(
f"Skipping reply ranking since the replier is same as reply target post {post.id}"
)
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "reply_ranking", "reason": "same_user_reply"}
)
return False
if post.user.id == post.ancestors[0].user.id:
logger.info(
f"Skipping reply ranking since the replier is same as reply root post {post.id}"
)
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "reply_ranking",
"reason": "same_user_reply_as_root",
},
)
return False
Metrics.counter("task.reply_ranking.eligible.count").add(1)
return True
class TaskPostEmbeddingWithSummaryFilter(TaskFilterWithPost):
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
if post.ancestors:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "post_embedding_with_summary",
"reason": "is_reply",
},
)
return False
if not post.user:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "post_embedding_with_summary",
"reason": "no_user",
},
)
return False
if post.user.is_protected:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "post_embedding_with_summary",
"reason": "private_account",
},
)
return False
Metrics.counter("task.post_embedding_with_summary.eligible.count").add(1)
return True
class TaskPostEmbeddingWithSummaryForReplyFilter(TaskFilterWithPost):
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
if not post.ancestors:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "post_embedding_with_summary_for_reply",
"reason": "is_original",
},
)
return False
if not post.user:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "post_embedding_with_summary_for_reply",
"reason": "no_user",
},
)
return False
in_reply_user_protected = post.ancestors[-1].user.is_protected
root_user_protected = post.ancestors[0].user.is_protected
if in_reply_user_protected or root_user_protected or post.user.is_protected:
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={
"filter": "post_embedding_with_summary_for_reply",
"reason": "private_account",
},
)
return False
Metrics.counter(
"task.post_embedding_with_summary_for_reply.eligible.count"
).add(1)
return True
class TaskSafetyPtosFilter(TaskFilterWithPost):
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
is_deluxe = ctx.payload.task_type == TaskGeneratorType.SAFETY_PTOS_DELUXE
filter_name = (
"safety_ptos_deluxe_detection" if is_deluxe else "safety_ptos_detection"
)
Metrics.counter(f"task.{filter_name}.request.count").add(1)
if not post.user:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": filter_name, "reason": "no_user"}
)
return False
Metrics.counter(f"task.{filter_name}.eligible.count").add(1)
return True
class TaskPostSafetyDeluxeFilter(TaskFilterWithPost):
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
if not post.user:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "post_safety_deluxe", "reason": "no_user"}
)
return False
if post.ancestors:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "post_safety_deluxe", "reason": "reply"}
)
logger.info(f"Skipping post {post.id} because it is a reply")
return False
filter_reason = cls._get_hardcoded_filter_reason(post)
if filter_reason:
logger.info(
f"Skipping upa deluxe {post.id} because it is hit by hardcoded filters, reason: {filter_reason}"
)
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "post_safety_deluxe", "reason": filter_reason}
)
return False
Metrics.counter("task.post_safety_deluxe.eligible.count").add(1)
return True
@classmethod
def _get_hardcoded_filter_reason(cls, post: Post) -> str | None:
if not post.user:
return None
if post.user.is_protected:
return "private_account"
return None
class TaskInitialBangerFilter(TaskFilterWithPost):
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
Metrics.counter("task.initial_banger_filter.count").add(1)
if not post.user:
return False
if post.ancestors:
Metrics.counter("task.filter.skipped.count").add(
1, attributes={"filter": "content_understanding", "reason": "reply"}
)
logger.info(f"Skipping post {post.id} because it is a reply")
return False
filter_reason = cls._get_hardcoded_filter_reason(post)
if filter_reason:
logger.info(
f"Skipping post {post.id} because it is hit by hardcoded filters, reason: {filter_reason}"
)
Metrics.counter("task.filter.skipped.count").add(
1,
attributes={"filter": "content_understanding", "reason": filter_reason},
)
return False
logger.info(f"Post {post.id} is eligible for initial banger")
Metrics.counter("task.initial_banger_filter.eligible.count").add(1)
return True
@classmethod
def _get_hardcoded_filter_reason(cls, post: Post) -> str | None:
if not post.user:
return None
if post.user.is_protected:
return "private_account"
return None

View File

@@ -0,0 +1,57 @@
import logging
from grox.tasks.task import Task
from grox.tasks.disable_rules import DisableTaskForNonProd
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from strato_http.queries.grok_upa_action_with_labels import (
StratoGrokUpaActionWithLabels,
)
logger = logging.getLogger(__name__)
class TaskGrokUpaActionWithLabels(Task):
DISABLE_RULES = [DisableTaskForNonProd]
_strato_grok_upa_action_with_labels = StratoGrokUpaActionWithLabels()
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
Metrics.counter("task.grok_upa_action_with_labels.count").add(1)
post = ctx.payload.post
if not post:
return
results = ctx.content_categories
if not results:
return
grok_response = next((r for r in results if r.tweet_bool_metadata), None)
if not grok_response or not grok_response.tweet_bool_metadata:
return
tweet_id = int(post.id)
tweet_bool_metadata = grok_response.tweet_bool_metadata.model_dump()
action_result = await cls._strato_grok_upa_action_with_labels.execute(
tweet_id, tweet_bool_metadata
)
if action_result and len(action_result.applied_labels) > 0:
logger.info(
f"grokUpaActionWithLabels applied labels: debugString='{action_result.debug_string}', appliedLabels={action_result.applied_labels} for post {tweet_id}"
)
Metrics.counter("task.grok_upa_action_with_labels.applied.count").add(1)
for label in action_result.applied_labels:
Metrics.counter(
"task.grok_upa_action_with_labels.applied_label.count"
).add(1, attributes={"label": label})
elif action_result:
logger.info(
f"grokUpaActionWithLabels no labels applied: debugString='{action_result.debug_string}' for post {tweet_id}"
)
Metrics.counter("task.grok_upa_action_with_labels.empty.count").add(1)
else:
logger.info(f"grokUpaActionWithLabels failed for post {tweet_id}")
Metrics.counter("task.grok_upa_action_with_labels.failed.count").add(1)

View File

@@ -0,0 +1,35 @@
import time
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post
from grox.data_loaders.strato_loader import TweetStratoLoader
from grox.schedules.types import TaskPayload
from grox.tasks.task import TaskStopExecution
from monitor.metrics import Metrics
from tenacity import retry, wait_chain, wait_fixed, stop_after_attempt
class TaskLoadPostWithNotFoundRetry(TaskWithPost):
@classmethod
@retry(stop=stop_after_attempt(3), wait=wait_chain(wait_fixed(1), wait_fixed(1)))
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
start_time = time.perf_counter_ns()
loaded_post = await TweetStratoLoader.load_post(post.id)
if loaded_post is None:
task_type = (
ctx.payload.task_type.value
if ctx.payload and ctx.payload.task_type
else "None"
)
if "recovery" in task_type:
raise TaskStopExecution(f"Post not found: {post.id}")
else:
raise ValueError(f"Post not found: {post.id}")
ctx.payload.post = loaded_post
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
Metrics.histogram("task.embedding_load_post.duration_ms").record(duration_ms)

View File

@@ -0,0 +1,27 @@
from grox.tasks.task import TaskWithPost
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post
from grox.data_loaders.strato_loader import TweetStratoLoader
from grox.tasks.task import TaskStopExecution
from strato_http.queries.post_multimodal_embedding_mh_searchai import (
StratoPostMultimodalEmbeddingGrokSummaryMh,
)
class TaskLoadPostWithSummary(TaskWithPost):
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
loaded_post = await TweetStratoLoader.load_post(post.id)
if loaded_post is None:
raise TaskStopExecution(f"Post not found: {post.id}")
stratoPostMultimodalEmbeddingGrokSummaryMh = (
StratoPostMultimodalEmbeddingGrokSummaryMh()
)
summary = await stratoPostMultimodalEmbeddingGrokSummaryMh.fetch(
int(post.id), "v3"
)
if summary is None:
raise TaskStopExecution(f"Summary not found: {post.id}")
loaded_post.summary = summary
ctx.payload.post = loaded_post

View File

@@ -0,0 +1,23 @@
import logging
from grox.data_loaders.data_types import Post
from grox.data_loaders.strato_loader import UserRecentPostsLoader
from grox.schedules.types import TaskContext
from grox.tasks.task import TaskWithPost, TaskStopExecution
logger = logging.getLogger(__name__)
class TaskLoadUserRecentPosts(TaskWithPost):
RECENT_POSTS_LIMIT = ""
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
if not post.user or not post.user.id:
raise TaskStopExecution("Post has no author user to load recent posts for")
recent_posts = await UserRecentPostsLoader.load(
post.user.id, limit=cls.RECENT_POSTS_LIMIT
)
post.user.recent_posts = recent_posts
logger.info(f"Loaded {len(recent_posts)} recent posts for user {post.user.id}")

21
grox/tasks/task_media.py Normal file
View File

@@ -0,0 +1,21 @@
from grox.tasks.task import TaskWithPost
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post
from grox.data_loaders.media_processor import MediaProcessor
from monitor.metrics import Metrics
class TaskMediaHydration(TaskWithPost):
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
Metrics.counter("task.media_hydration.total.count").add(1)
await MediaProcessor.process(post, video_duration_limit_minutes=360)
Metrics.counter("task.media_hydration.passed.count").add(1)
class TaskMediaHydrationBanger(TaskWithPost):
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
Metrics.counter("task.media_hydration_banger.total.count").add(1)
await MediaProcessor.process(post, video_duration_limit_minutes=360)
Metrics.counter("task.media_hydration_banger.passed.count").add(1)

View File

@@ -0,0 +1,86 @@
import logging
from typing import override
from grox.tasks.task import TaskWithPost
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext, TaskEligibility
from grox.data_loaders.data_types import Post, Video
from grox.embedder.multimodal_post_embedder_v2 import MultimodalPostEmbedderV2
from grox.embedder.multimodal_post_embedder_v5 import MultimodalPostEmbedderV5
logger = logging.getLogger(__name__)
class TaskMultimodalPostEmbeddingWithSummary(TaskWithPost):
embedder = MultimodalPostEmbedderV2(
model="qwen3", renderer_version="lite", use_post_context_summary=True
)
@override
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
_, embedding = await cls.embedder.embed(post)
ctx.multimodal_post_embedding_dict["v3"] = embedding
logger.info(
f"TaskMultimodalPostEmbeddingWithSummary Embedding Added, length: {len(embedding)}"
)
Metrics.counter("task.multimodal_post_embedding_with_summary.count").add(1)
class TaskMultimodalPostEmbeddingRecsysV4(TaskWithPost):
embedder = MultimodalPostEmbedderV2(
model="v4",
renderer_version="lite",
use_post_context_summary=True,
)
@override
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
try:
_, embedding = await cls.embedder.embed(post)
except Exception as e:
Metrics.counter("task.multimodal_post_embedding_recsys_v4.error").add(1)
logger.warning(
f"TaskMultimodalPostEmbeddingRecsysV4 failed for post {post.id}: {e}"
)
return
ctx.multimodal_post_embedding_dict["v4"] = embedding
logger.info(
f"TaskMultimodalPostEmbeddingRecsysV4 Embedding Added, length: {len(embedding)}"
)
Metrics.counter("task.multimodal_post_embedding_recsys_v4.count").add(1)
class TaskMultimodalPostEmbeddingV5(TaskWithPost):
embedder = MultimodalPostEmbedderV5()
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
try:
transcripts = []
if post.media:
for m in post.media:
if (
isinstance(m, Video)
and m.convo_video
and m.convo_video.asr_transcript
):
transcripts.append(m.convo_video.asr_transcript)
transcript = "\n".join(transcripts) if transcripts else None
_, embedding = await cls.embedder.embed(post, transcript=transcript)
except Exception as e:
Metrics.counter("task.multimodal_post_embedding_v5.error").add(1)
logger.warning(
f"TaskMultimodalPostEmbeddingV5 failed for post {post.id}: {e}"
)
raise
ctx.multimodal_post_embedding_dict["v5_1"] = embedding
logger.info(
f"TaskMultimodalPostEmbeddingV5 Embedding Added, length: {len(embedding)}, has_transcript={transcript is not None}"
)
Metrics.counter("task.multimodal_post_embedding_v5.count").add(1)
if transcript:
Metrics.counter(
"task.multimodal_post_embedding_v5.with_transcript.count"
).add(1)

View File

@@ -0,0 +1,25 @@
import logging
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post
from grox.classifiers.content.post_safety_screen_deluxe import (
PostSafetyDeluxeClassifier,
)
logger = logging.getLogger(__name__)
class TaskPostSafetyScreenDeluxe(TaskWithPost):
classifier = PostSafetyDeluxeClassifier()
@classmethod
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
Metrics.counter("task.post_safety_screen_deluxe.total.count").add(1)
res = await cls.classifier.classify(post)
ctx.content_categories.extend(res)

555
grox/tasks/task_pub.py Normal file
View File

@@ -0,0 +1,555 @@
import time
import logging
import traceback
from functools import cache
import thrifts.gen.twitter.strato.columns.content_understanding.content_understanding.ttypes as t
from thrifts.serdes import Serializer
from grox.tasks.task import Task
from monitor.metrics import Metrics
from grox.config.config import KafkaTopicName, grox_config
from kafka_cli.producer import KafkaProducer
from grox.schedules.types import ReplyScoreResult, TaskContext
from grox.tasks.disable_rules import (
DisableTaskForDev,
DisableTaskForLocal,
DisableTaskForNonProd,
)
from strato_http.queries.unified_post_annotations import (
StratoUnifiedPostAnnotations,
StratoUpsertTweetBoolMetadataToUnifiedPostAnnotations,
)
from strato_http.queries.grok_reply_spam_action_with_labels import (
StratoGrokReplySpamActionWithLabels,
)
from grox.data_loaders.data_types import (
Post,
ContentCategoryType,
ContentCategoryResult,
ContentCategoryScore,
Image,
Video,
)
from strato_http.queries.data_types import (
EntityWithMetadata,
FoundMetadata,
UnifiedPostAnnotations,
ReplyRankingScore,
ReplyRankingScoreKafka,
QualifiedId,
)
from grox.data_loaders.strato_loader import (
ReplyRankingScoreStratoLoader,
ReplySpamStratoLoader,
)
logger = logging.getLogger(__name__)
class TaskPublishKafka(Task):
DISABLE_RULES = [DisableTaskForLocal, DisableTaskForDev]
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
post = ctx.payload.post
results = ctx.content_categories
if not post:
return
if not results:
Metrics.counter("task.publish_kafka.skipped.count").add(
1, attributes={"reason": "no_results"}
)
return
Metrics.counter("task.publish_kafka.intaken.count").add(1)
try:
await cls._publish_to_kafka(
post.id,
post.user.id if post.user else None,
results,
summary=ctx.summary,
embedding=ctx.multimodal_post_embedding,
)
for result in results:
Metrics.counter("task.publish_kafka.success.count").add(
1, attributes={"category": result.category.value}
)
if post.created_at:
latency = time.time() - post.created_at.timestamp()
for res in results:
Metrics.histogram("task.classification_e2e_latency").record(
latency, attributes={"category": res.category.value}
)
else:
Metrics.counter("task.publish_kafka.post_no_created_at.count").add(1)
except Exception:
Metrics.counter("task.publish_kafka.failed.count").add(1)
logger.error(
f"Failed to publish classification record: {traceback.format_exc()}"
)
raise
@classmethod
async def _publish_to_kafka(
cls,
post_id: str,
user_id: int | None,
results: list[ContentCategoryResult],
summary: str,
embedding: list[float] | None,
):
category_results = [
t.CategoryResult(
category=r.category.name,
positive=r.positive,
score=r.score,
summary=r.summary,
taxonomyCategories=[
t.TaxonomyCategoryScore(id=tc.id, name=tc.name, score=tc.score)
for tc in r.taxonomy_categories
]
if r.taxonomy_categories
else None,
keywords=None,
)
for r in results
]
grox_content_analysis = t.GroxContentAnalysis(
postId=int(post_id),
userId=user_id,
categoryResults=category_results,
summary=summary,
createdAt=int(time.time()),
)
serialized_bytes = Serializer.serialize(grox_content_analysis)
await cls._get_kafka_producer().send(id=post_id, value=serialized_bytes)
@classmethod
@cache
def _get_kafka_producer(cls):
producer_config = grox_config.get_kafka_producer_topic(
KafkaTopicName.GROX_CONTENT_ANALYSIS
)
logger.info(
f"Creating kafka producer with config: {producer_config.model_dump()}"
)
return KafkaProducer(producer_config)
class TaskPublishUnifiedPostAnnotationsManhattan(Task):
DISABLE_RULES = [DisableTaskForNonProd]
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
Metrics.counter("task.publish_unified_post_annotations.count").add(1)
results = ctx.content_categories
if not results:
logger.info("No unified post annotations to publish")
return
post = ctx.payload.post
if not post:
return
grok_response = next(
(
r
for r in results
if r.category == ContentCategoryType.BANGER_INITIAL_SCREEN
),
None,
)
if not grok_response:
return
if grok_response.slop_score is not None:
if grok_response.slop_score == 1:
Metrics.counter(
"task.publish_unified_post_annotations.slop_score_1.count"
).add(1)
elif grok_response.slop_score == 2:
Metrics.counter(
"task.publish_unified_post_annotations.slop_score_2.count"
).add(1)
elif grok_response.slop_score == 3:
Metrics.counter(
"task.publish_unified_post_annotations.slop_score_3.count"
).add(1)
if grok_response.tweet_bool_metadata:
if grok_response.tweet_bool_metadata.isHighQuality:
Metrics.counter(
"task.publish_unified_post_annotations.is_high_quality_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isNsfw:
Metrics.counter(
"task.publish_unified_post_annotations.is_nsfw_true.count"
).add(1)
record_nsfw_detection(post)
if grok_response.tweet_bool_metadata.isGore:
Metrics.counter(
"task.publish_unified_post_annotations.is_gore_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isViolent:
Metrics.counter(
"task.publish_unified_post_annotations.is_violent_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isSpam:
Metrics.counter(
"task.publish_unified_post_annotations.is_spam_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isSoftNsfw:
Metrics.counter(
"task.publish_unified_post_annotations.is_soft_nsfw_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isAdult:
Metrics.counter(
"task.publish_unified_post_annotations.is_adult_true.count"
).add(1)
if grok_response.tags and len(grok_response.tags) > 0:
Metrics.counter(
"task.publish_unified_post_annotations.tags_non_empty.count"
).add(1)
if grok_response.is_image_editable_by_grok:
Metrics.counter(
"task.publish_unified_post_annotations.is_image_editable_by_grok_true.count"
).add(1)
if post.media:
if any(isinstance(m, Video) for m in post.media):
Metrics.counter(
"task.publish_unified_post_annotations.has_video_true.count"
).add(1)
if any(isinstance(m, Image) for m in post.media):
Metrics.counter(
"task.publish_unified_post_annotations.has_image_true.count"
).add(1)
resolved_grok_topics = []
if grok_response.taxonomy_categories and ctx.available_topics:
id_to_name = {}
name_to_category_id = {}
for category in ctx.available_topics:
id_to_name[category.categoryEntityId] = category.categoryName
name_to_category_id[category.categoryName] = category.categoryEntityId
for sub in category.subtopics:
id_to_name[sub.topicEntityId] = sub.topicName
name_to_category_id[sub.topicName] = category.categoryEntityId
topic_id_to_best_score = {}
for grok_topic in grok_response.taxonomy_categories:
topic_id = grok_topic.id
if topic_id in id_to_name:
topic_name = id_to_name[topic_id]
category_id = name_to_category_id[topic_name]
resolved_grok_topic = ContentCategoryScore(
id=topic_id,
name=topic_name,
score=grok_topic.score,
category_id=category_id,
)
logger.info(
f"Validated grok_topic: ID {topic_id} -> '{topic_name}' (category_id: {category_id})"
)
else:
logger.warning(
f"Invalid topic ID from Grok: {topic_id} not found in available topics"
)
Metrics.counter(
"task.publish_unified_post_annotations.invalid_grok_topic.count"
).add(1)
continue
if (
topic_id not in topic_id_to_best_score
or grok_topic.score > topic_id_to_best_score[topic_id].score
):
topic_id_to_best_score[topic_id] = resolved_grok_topic
resolved_grok_topics = list(topic_id_to_best_score.values())
elif grok_response.taxonomy_categories:
logger.warning("No available topics to validate grok_topics")
resolved_grok_topics = []
for topic in resolved_grok_topics:
sanitized_topic_name_for_metric = (
topic.name.lower().replace(" ", "_").replace("&", "and")
)
Metrics.counter(
f"task.publish_unified_post_annotations.topic_{sanitized_topic_name_for_metric}.count"
).add(1)
entities = []
if resolved_grok_topics and len(resolved_grok_topics) > 0:
Metrics.counter(
"task.publish_unified_post_annotations.with_grok_topics.count"
).add(1)
entities = [
EntityWithMetadata(
qualifiedId=QualifiedId(domainId=236, entityId=str(grok_topic.id)),
score=grok_topic.score,
categoryId=QualifiedId(
domainId=236, entityId=str(grok_topic.category_id)
)
if grok_topic.category_id
else None,
)
for grok_topic in resolved_grok_topics
]
else:
Metrics.counter(
"task.publish_unified_post_annotations.with_empty_grok_topics.count"
).add(1)
annotations = UnifiedPostAnnotations(
tweetId=post.id,
entities=entities,
tags=[{"tag": tag, "score": 0.0} for tag in (grok_response.tags or [])],
tweetBoolMetadata=grok_response.tweet_bool_metadata.model_dump()
if grok_response.tweet_bool_metadata
else None,
description=grok_response.summary,
isImageEditableByGrok=grok_response.is_image_editable_by_grok,
slopScore=grok_response.slop_score,
originalOcrText="",
evergreenScore=None,
hasVideo=post.media and any(isinstance(m, Video) for m in post.media),
hasImage=post.media and any(isinstance(m, Image) for m in post.media),
imageDescription=None,
videoDescription=None,
qualityScore=grok_response.score,
hasMinorScore=grok_response.has_minor_score,
hasCard=post.card is not None,
foundMetadata=FoundMetadata(
imageCount=sum(1 for m in post.media if isinstance(m, Image))
if post.media
else 0,
videoCount=sum(1 for m in post.media if isinstance(m, Video))
if post.media
else 0,
cardCount=1 if post.card else 0,
cardV2Count=len(post.cardsV2) if post.cardsV2 else 0,
),
)
await StratoUnifiedPostAnnotations().put(int(post.id), annotations)
Metrics.counter("task.publish_unified_post_annotations.success.count").add(1)
class TaskUpsertTweetBoolMetadataToUnifiedPostAnnotation(Task):
DISABLE_RULES = [DisableTaskForNonProd]
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.count"
).add(1)
results = ctx.content_categories
if not results:
logger.info("No unified post annotations to publish")
return
post = ctx.payload.post
if not post:
return
grok_response = next(
(
r
for r in results
if r.category == ContentCategoryType.POST_SAFETY_SCREEN
),
None,
)
if not grok_response or not grok_response.tweet_bool_metadata:
return
if grok_response.tweet_bool_metadata.isHighQuality:
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_high_quality_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isNsfw:
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_nsfw_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isGore:
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_gore_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isViolent:
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_violent_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isSpam:
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_spam_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isSoftNsfw:
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_soft_nsfw_true.count"
).add(1)
if grok_response.tweet_bool_metadata.isAdult:
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_adult_true.count"
).add(1)
await StratoUpsertTweetBoolMetadataToUnifiedPostAnnotations().put(
int(post.id), grok_response.tweet_bool_metadata.model_dump()
)
Metrics.counter(
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.success.count"
).add(1)
class TaskWriteReplyRankingManhattan(Task):
DISABLE_RULES = [DisableTaskForNonProd]
_strato_grok_reply_spam_action_with_labels = StratoGrokReplySpamActionWithLabels()
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
post = ctx.payload.post
results = ctx.reply_ranking_results
if not post:
return
if not results:
Metrics.counter("task.write_reply_ranking_manhattan.skipped.count").add(
1, attributes={"reason": "no_results"}
)
return
Metrics.counter("task.write_reply_ranking_manhattan.intaken.count").add(1)
try:
await cls._publish_to_reply_ranking_manhattan(post, results)
logger.info(
f"Published reply ranking post to manhattan: {post.id=} {post.user.id=}"
)
except Exception:
Metrics.counter("task.write_reply_ranking_manhattan.failed.count").add(1)
logger.error(
f"Failed to write reply ranking score to manhattan: {traceback.format_exc()}"
)
raise
@classmethod
async def _publish_to_reply_ranking_manhattan(
cls, post: Post, results: list[ReplyScoreResult]
):
logger.info(
f"[_publish_to_reply_ranking_manhattan] checking results: {results}"
)
reasoning = ""
try:
reply_ranking_result = next(r for r in results)
except:
reply_ranking_result = None
score = reply_ranking_result.score if reply_ranking_result else 3.0
reasoning = reply_ranking_result.reason if reply_ranking_result else ""
if post.user:
logger.info(
f"[_publish_to_reply_ranking_manhattan] {reasoning=} {post.id=} {post.user.id=} {score=}"
)
else:
logger.info(
f"Missing user id [_publish_to_reply_ranking_manhattan] {reasoning=} {post.id=} {score=}"
)
if score == 0.0:
action_result = (
await cls._strato_grok_reply_spam_action_with_labels.execute(
int(post.id)
)
)
if action_result and len(action_result.applied_labels) > 0:
logger.info(
f"grokReplySpamActionWithLabels applied labels: debugString='{action_result.debug_string}', appliedLabels={action_result.applied_labels} for post {post.id}"
)
Metrics.counter(
"task.grok_reply_spam_action_with_labels.applied.count"
).add(1)
elif action_result:
logger.info(
f"grokReplySpamActionWithLabels no labels applied: debugString='{action_result.debug_string}' for post {post.id}"
)
Metrics.counter(
"task.grok_reply_spam_action_with_labels.empty.count"
).add(1)
else:
logger.info(f"grokReplySpamActionWithLabels failed for post {post.id}")
Metrics.counter(
"task.grok_reply_spam_action_with_labels.failed.count"
).add(1)
await ReplyRankingScoreStratoLoader.save_reply_ranking_score(
post_id=post.id,
reply_ranking_score=ReplyRankingScore(
score=score, reasoning=reasoning[-500:]
),
)
await ReplyRankingScoreStratoLoader.save_reply_ranking_kafka_v2(
post_id=post.id,
reply_ranking_score_kafka=ReplyRankingScoreKafka(
postId=int(post.id), score=score, reasoning=reasoning[-500:]
),
)
Metrics.counter("task.write_reply_ranking_manhattan.success.count").add(
1, attributes={"column": "reply_ranking"}
)
class TaskWriteReplySpamManhattan(Task):
DISABLE_RULES = [DisableTaskForNonProd]
_strato_grok_reply_spam_action_with_labels = StratoGrokReplySpamActionWithLabels()
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
post = ctx.payload.post
if not post:
return
results = ctx.content_categories
for result in results:
if result.category == ContentCategoryType.SPAM_COMMENT:
if result.positive:
action_result = (
await cls._strato_grok_reply_spam_action_with_labels.execute(
int(post.id)
)
)
if action_result and len(action_result.applied_labels) > 0:
logger.info(
f"grokReplySpamActionWithLabels applied labels: debugString='{action_result.debug_string}', appliedLabels={action_result.applied_labels} for post {post.id}"
)
Metrics.counter(
"task.grok_reply_spam_action_with_labels.applied.count"
).add(1)
elif action_result:
logger.info(
f"grokReplySpamActionWithLabels no labels applied: debugString='{action_result.debug_string}' for post {post.id}"
)
Metrics.counter(
"task.grok_reply_spam_action_with_labels.empty.count"
).add(1)
else:
logger.info(
f"grokReplySpamActionWithLabels failed for post {post.id}"
)
Metrics.counter(
"task.grok_reply_spam_action_with_labels.failed.count"
).add(1)
await ReplySpamStratoLoader.save_spam_reply_annotation(
post.id, result.score, result.positive, ""
)
logger.info(
f"Published reply spam annotation to manhattan: {post.id=} {post.user.id=}"
)

View File

@@ -0,0 +1,31 @@
import logging
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post
from grox.classifiers.content.reply_ranking import ReplyScorer
logger = logging.getLogger(__name__)
class TaskRankReplies(TaskWithPost):
scorer = ReplyScorer()
@classmethod
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
user = post.user
logger.info(
f"[task_rank_replies] {post.id=} "
f"is_pasted={post.is_pasted} "
f"user_agent={post.user_agent!r} "
f"composition_source={post.composition_source!r} "
f"app_attestation_status={post.app_attestation_status!r} "
f"has_risky_user_safety_label={user.has_risky_user_safety_label if user else None} "
f"num_legit_blocks_received_last_24hrs={user.num_legit_blocks_received_last_24hrs if user else None}"
)
res = await cls.scorer.score(post)
ctx.reply_ranking_results.extend(res)

View File

@@ -0,0 +1,191 @@
import logging
from abc import abstractmethod
from typing import Awaitable, override
from cachetools import TTLCache
from grox.tasks.task import Task, TaskContext, TaskStopExecution
from grox.config.config import TaskGeneratorType
from monitor.metrics import Metrics
from grox.data_loaders.data_types import Post, User
logger = logging.getLogger(__name__)
class TaskRateLimit(Task):
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
eligible = await cls._eligible(ctx)
Metrics.counter("task.rate_limit.count").add(
1, attributes={"task_name": cls.get_name(), "passed": eligible}
)
if not eligible:
raise TaskStopExecution()
@classmethod
@abstractmethod
def _eligible(cls, ctx: TaskContext) -> Awaitable[bool]:
raise NotImplementedError()
class TaskRateLimitWithPost(TaskRateLimit):
@override
@classmethod
async def _eligible(cls, ctx: TaskContext) -> bool:
if not ctx.payload.post:
return False
return await cls._eligible_with_post(ctx.payload.post, ctx)
@classmethod
@abstractmethod
def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> Awaitable[bool]:
raise NotImplementedError()
class TaskRateLimitWithUser(TaskRateLimit):
@override
@classmethod
async def _eligible(cls, ctx: TaskContext) -> bool:
if not ctx.payload.user:
return False
return await cls._eligible_with_user(ctx.payload.user, ctx)
@classmethod
@abstractmethod
def _eligible_with_user(cls, user: User, ctx: TaskContext) -> Awaitable[bool]:
raise NotImplementedError()
class TaskRateLimitEmbeddingWithPostSummary(TaskRateLimitWithPost):
POST_CACHE_FOR_MM_EMB_SUMMARY = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
if post_id not in cls.POST_CACHE_FOR_MM_EMB_SUMMARY:
cls.POST_CACHE_FOR_MM_EMB_SUMMARY[post_id] = True
return True
logger.info(f"Post {post_id} already hit rate limit for mm emb with summary")
return False
class TaskRateLimitEmbeddingWithPostSummaryForReply(TaskRateLimitWithPost):
POST_CACHE_FOR_MM_EMB_SUMMARY_REPLY = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
if post_id not in cls.POST_CACHE_FOR_MM_EMB_SUMMARY_REPLY:
cls.POST_CACHE_FOR_MM_EMB_SUMMARY_REPLY[post_id] = True
return True
logger.info(
f"Post {post_id} already hit rate limit for mm emb with summary for reply"
)
return False
class TaskRateLimitEmbeddingV5(TaskRateLimitWithPost):
POST_CACHE_FOR_MM_EMB_V5 = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
if post_id not in cls.POST_CACHE_FOR_MM_EMB_V5:
cls.POST_CACHE_FOR_MM_EMB_V5[post_id] = True
return True
logger.info(f"Post {post_id} already hit rate limit for mm emb v5")
return False
class TaskRateLimitEmbeddingV5ForReply(TaskRateLimitWithPost):
POST_CACHE_FOR_MM_EMB_V5_REPLY = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
if post_id not in cls.POST_CACHE_FOR_MM_EMB_V5_REPLY:
cls.POST_CACHE_FOR_MM_EMB_V5_REPLY[post_id] = True
return True
logger.info(f"Post {post_id} already hit rate limit for mm emb v5 for reply")
return False
class TaskRateLimitBangerAnnotationWithPost(TaskRateLimitWithPost):
POST_CACHE_FOR_BANGER = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
if post_id not in cls.POST_CACHE_FOR_BANGER:
cls.POST_CACHE_FOR_BANGER[post_id] = True
return True
logger.info(f"Post {post_id} already hit rate limit for banger")
return False
class TaskRateLimitReplySpamAnnotationWithPost(TaskRateLimitWithPost):
POST_CACHE_FOR_REPLY_SPAM = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
if post_id not in cls.POST_CACHE_FOR_REPLY_SPAM:
cls.POST_CACHE_FOR_REPLY_SPAM[post_id] = True
return True
logger.info(f"Post {post_id} already hit rate limit for reply spam")
return False
class TaskRateLimitReplyRankingAnnotationWithPost(TaskRateLimitWithPost):
POST_CACHE_FOR_REPLY_RANKING = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
if post_id not in cls.POST_CACHE_FOR_REPLY_RANKING:
cls.POST_CACHE_FOR_REPLY_RANKING[post_id] = True
return True
logger.info(f"Post {post_id} already hit rate limit for reply ranking")
return False
class TaskRateLimitPostSafetyAnnotationWithPost(TaskRateLimitWithPost):
POST_CACHE_FOR_POST_SAFETY = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
if post_id not in cls.POST_CACHE_FOR_POST_SAFETY:
cls.POST_CACHE_FOR_POST_SAFETY[post_id] = True
return True
logger.info(f"Post {post_id} already hit rate limit for post safety")
return False
class TaskRateLimitSafetyPtosAnnotationWithPost(TaskRateLimitWithPost):
POST_CACHE_FOR_SAFETY_PTOS = TTLCache(maxsize=10_000, ttl=60)
POST_CACHE_FOR_SAFETY_PTOS_DELUXE = TTLCache(maxsize=10_000, ttl=60)
@override
@classmethod
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
post_id = post.id
is_deluxe = ctx.payload.task_type == TaskGeneratorType.SAFETY_PTOS_DELUXE
cache = (
cls.POST_CACHE_FOR_SAFETY_PTOS_DELUXE
if is_deluxe
else cls.POST_CACHE_FOR_SAFETY_PTOS
)
label = "safety ptos deluxe" if is_deluxe else "safety ptos"
if post_id not in cache:
cache[post_id] = True
return True
logger.info(f"Post {post_id} already hit rate limit for {label}")
return False

View File

@@ -0,0 +1,61 @@
import logging
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post
from grox.classifiers.content.safety_ptos import SafetyPtosCategoryClassifier
from grox.config.config import ModelName, TaskGeneratorType
logger = logging.getLogger(__name__)
class TaskSafetyPtosCategoryDetection(TaskWithPost):
classifier = SafetyPtosCategoryClassifier(ModelName.VLM_SAFETY)
deluxe_classifier = SafetyPtosCategoryClassifier(
ModelName.VLM_PRIMARY_CRITICAL, deluxe=True
)
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
is_deluxe = ctx.payload.task_type == TaskGeneratorType.SAFETY_PTOS_DELUXE
active_classifier = cls.deluxe_classifier if is_deluxe else cls.classifier
metric_prefix = (
"task.safety_ptos_deluxe_category"
if is_deluxe
else "task.safety_ptos_category"
)
safety_annotations = await active_classifier.classify_post(post)
ctx.safety_annotations = safety_annotations
safety_categories = safety_annotations.violatedPolicies or []
violation_count = len(safety_categories)
has_violations = violation_count > 0
Metrics.counter(f"{metric_prefix}.classified.count").add(1)
if has_violations:
Metrics.counter(f"{metric_prefix}.has_violations.count").add(1)
Metrics.counter(f"{metric_prefix}.violations.count").add(violation_count)
violation_details = []
for violation in safety_categories:
Metrics.counter(f"{metric_prefix}.violations_by_category.count").add(
1, attributes={"category": violation.category.value}
)
violation_details.append(
f"{violation.category.value}({violation.score})"
)
mode = " (deluxe)" if is_deluxe else ""
logger.info(
f"Post {post.id}: Found {violation_count} violations{mode} - Details: {', '.join(violation_details)}"
)
else:
Metrics.counter(f"{metric_prefix}.no_violations.count").add(1)
mode = " (deluxe)" if is_deluxe else ""
logger.info(f"Post {post.id}: No safety violations detected{mode}")
@classmethod
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)

View File

@@ -0,0 +1,89 @@
import logging
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import (
Post,
SafetyPolicyCategory,
SafetyPolicyType,
SafetyPtosViolatedPolicy,
)
from grox.classifiers.content.safety_ptos import SafetyPtosPolicyClassifier
from grox.config.config import TaskGeneratorType
logger = logging.getLogger(__name__)
class TaskSafetyPtosPolicyDetection(TaskWithPost):
violated_policy_classifier = SafetyPtosPolicyClassifier()
deluxe_violated_policy_classifier = SafetyPtosPolicyClassifier(deluxe=True)
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
if not ctx.safety_annotations:
return
is_deluxe = ctx.payload.task_type == TaskGeneratorType.SAFETY_PTOS_DELUXE
active_classifier = (
cls.deluxe_violated_policy_classifier
if is_deluxe
else cls.violated_policy_classifier
)
metric_prefix = (
"task.safety_ptos_deluxe_policy" if is_deluxe else "task.safety_ptos_policy"
)
violations = list(ctx.safety_annotations.violatedPolicies or [])
injected_recheck = None
if is_deluxe:
if not any(
v.category == SafetyPolicyCategory.AdultContent for v in violations
):
injected_recheck = SafetyPtosViolatedPolicy(
category=SafetyPolicyCategory.AdultContent,
reason="adult content recheck",
score=50,
)
violations.append(injected_recheck)
for violation in violations:
violation.safetyPolicy = (
await active_classifier.classify_policy_for_violation(post, violation)
)
if violation.safetyPolicy:
cls._record_policy_metrics(metric_prefix, violation)
if injected_recheck is not None:
policy = injected_recheck.safetyPolicy
if not policy or policy.policyType == SafetyPolicyType.NoViolation:
violations.remove(injected_recheck)
ctx.safety_annotations.violatedPolicies = violations
@classmethod
def _record_policy_metrics(
cls, metric_prefix: str, violation: SafetyPtosViolatedPolicy
) -> None:
Metrics.counter(f"{metric_prefix}.classified_total.count").add(1)
category_key = {
SafetyPolicyCategory.ViolentMedia: "violent_media",
SafetyPolicyCategory.AdultContent: "adult_content",
SafetyPolicyCategory.Spam: "spam",
SafetyPolicyCategory.IllegalAndRegulatedBehaviors: "illegal_and_regulated_behaviors",
SafetyPolicyCategory.HateOrAbuse: "hate_or_abuse",
SafetyPolicyCategory.ViolentSpeech: "violent_speech",
SafetyPolicyCategory.SuicideOrSelfHarm: "suicide_or_self_harm",
}.get(violation.category)
if category_key:
Metrics.counter(
f"{metric_prefix}.classified_{category_key}_violations.count"
).add(1)
Metrics.counter(
f"{metric_prefix}.classified_{category_key}_policy_types.count"
).add(1, attributes={"policy_type": violation.safetyPolicy.policyType.name})
@classmethod
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)

View File

@@ -0,0 +1,57 @@
import logging
from typing import override
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post, ContentCategoryType
from grox.classifiers.content.spam import SpamEapiLowFollowerClassifier
logger = logging.getLogger(__name__)
class TaskSpamDetection(TaskWithPost):
eapi_low_follower_classifier = SpamEapiLowFollowerClassifier()
@classmethod
def get_follower_bucket_string(cls, post: Post) -> str:
if not post.ancestors:
return "invalid"
in_reply_user_follower_count = post.ancestors[-1].user.follower_count or 0
root_user_follower_count = post.ancestors[0].user.follower_count or 0
if in_reply_user_follower_count <= 100 and root_user_follower_count <= 100:
return "lte_100"
elif in_reply_user_follower_count <= 500 and root_user_follower_count <= 500:
return "lte_500"
elif in_reply_user_follower_count <= 1000 and root_user_follower_count <= 1000:
return "lte_1000"
else:
return "gt_1000"
@classmethod
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)
@override
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
res = await cls.eapi_low_follower_classifier.classify(post)
ctx.content_categories.extend(res)
passed = any(
r.positive for r in res if r.category == ContentCategoryType.SPAM_COMMENT
)
follower_bucket_string = cls.get_follower_bucket_string(post)
if passed and follower_bucket_string != "gt_1000":
logger.info(
f"Reply Spam Found for lower than 1000 follower bucket. The post_id is {post.id} and the follower bucket is {follower_bucket_string}"
)
if passed:
Metrics.counter("task.spam_comment_detection.positive.count").add(
1, attributes={"reason": follower_bucket_string}
)
else:
Metrics.counter("task.spam_comment_detection.negative.count").add(
1, attributes={"reason": follower_bucket_string}
)

View File

@@ -0,0 +1,25 @@
import logging
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from grox.tasks.task_load_post_with_not_found_retry import TaskLoadPostWithNotFoundRetry
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext, TaskPayload
from grox.data_loaders.data_types import Post
from grox.summarizer.post_embedding_summarizer import PostEmbeddingSummarizer
logger = logging.getLogger(__name__)
class TaskPostEmbeddingSummarizer(TaskWithPost):
summarizer = PostEmbeddingSummarizer(prompt_file="")
@classmethod
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
res = await cls.summarizer.summarize(post)
assert res is not None
post.summary = res
Metrics.counter("task.post_embedding_summarizer.count").add(1)

View File

@@ -0,0 +1,193 @@
import logging
import time
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
from grox.tasks.disable_rules import DisableTaskForNonMmEmbProd
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import Post
from strato_http.queries.post_multimodal_embedding_sink import (
StratoPostMultimodalEmbeddingSink,
)
from tenacity import retry, wait_chain, wait_fixed, stop_after_attempt
from strato_http.queries.post_multimodal_embedding_mh_searchai import (
StratoPostMultimodalEmbeddingMhSearchAi,
TweetEmbedding,
StratoPostMultimodalEmbeddingMhSearchAiNoCache,
StratoPostMultimodalEmbeddingGrokSummaryMh,
StratoMultiModalEmbeddingTopic,
)
from grox.tasks.task_embedding_pub import (
TaskPublishEmbeddingV4Kafka,
TaskPublishEmbeddingV5Kafka,
)
logger = logging.getLogger(__name__)
class TaskWriteMMEmbeddingSinkBase(TaskWithPost):
model_version: str
DISABLE_RULES = [DisableTaskForNonMmEmbProd]
@classmethod
@retry(stop=stop_after_attempt(3), wait=wait_chain(wait_fixed(1), wait_fixed(2)))
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
return await Task.exec.__wrapped__(cls, ctx)
class TaskWriteMMEmbeddingSinkExperiment(TaskWriteMMEmbeddingSinkBase):
model_version = "v2"
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
start_time = time.perf_counter_ns()
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
assert embedding is not None
query = StratoPostMultimodalEmbeddingMhSearchAi()
await query.put(
int(post.id),
cls.model_version,
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
)
logger.info(
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version})"
)
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
Metrics.histogram(
"task.write_post_embedding_sink_experiment.duration_ms"
).record(duration_ms)
Metrics.counter("task.write_post_embedding_sink_experiment.count").add(1)
class TaskWriteMMEmbeddingSinkV3(TaskWriteMMEmbeddingSinkBase):
model_version = "v3"
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
start_time = time.perf_counter_ns()
summary = post.summary
assert summary is not None
stratoPostMultimodalEmbeddingGrokSummaryMh = (
StratoPostMultimodalEmbeddingGrokSummaryMh()
)
await stratoPostMultimodalEmbeddingGrokSummaryMh.put(
int(post.id), cls.model_version, summary
)
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
assert embedding is not None
query = StratoPostMultimodalEmbeddingMhSearchAi()
await query.put(
int(post.id),
cls.model_version,
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
)
stratoMultiModalEmbeddingTopic = StratoMultiModalEmbeddingTopic()
await stratoMultiModalEmbeddingTopic.insert(
TweetEmbedding(tweetId=int(post.id), embedding1=embedding)
)
logger.info(
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version})"
)
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
Metrics.histogram("task.write_post_embedding_sink_v3.duration_ms").record(
duration_ms
)
Metrics.counter("task.write_post_embedding_sink_v3.count").add(1)
class TaskWriteMMEmbeddingSinkV4(TaskWriteMMEmbeddingSinkBase):
model_version = "v4"
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
start_time = time.perf_counter_ns()
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
assert embedding is not None
query = StratoPostMultimodalEmbeddingMhSearchAiNoCache()
await query.put(
int(post.id),
cls.model_version,
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
)
await TaskPublishEmbeddingV4Kafka._publish_to_kafka(post, embedding)
logger.info(
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version})"
)
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
Metrics.histogram("task.write_post_embedding_sink_v4.duration_ms").record(
duration_ms
)
Metrics.counter("task.write_post_embedding_sink_v4.count").add(1)
class TaskWriteMMEmbeddingSinkV5(TaskWriteMMEmbeddingSinkBase):
model_version = "v5_1"
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
start_time = time.perf_counter_ns()
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
assert embedding is not None
query = StratoPostMultimodalEmbeddingMhSearchAiNoCache()
await query.put(
int(post.id),
cls.model_version,
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
)
await TaskPublishEmbeddingV5Kafka._publish_to_kafka(post, embedding)
logger.info(
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version})"
)
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
Metrics.histogram("task.write_post_embedding_sink_v5.duration_ms").record(
duration_ms
)
Metrics.counter("task.write_post_embedding_sink_v5.count").add(1)
class TaskWriteMMEmbeddingSinkV5SkipKafkaForReplies(TaskWriteMMEmbeddingSinkBase):
model_version = "v5_1"
@classmethod
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
start_time = time.perf_counter_ns()
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
assert embedding is not None
query = StratoPostMultimodalEmbeddingMhSearchAiNoCache()
await query.put(
int(post.id),
cls.model_version,
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
)
is_reply = bool(post.ancestors)
if not is_reply:
await TaskPublishEmbeddingV5Kafka._publish_to_kafka(post, embedding)
else:
Metrics.counter(
"task.write_post_embedding_sink_v5.kafka_skipped_reply.count"
).add(1)
logger.info(
f"Skipping Kafka publish for reply post {post.id} (written to Manhattan only)"
)
logger.info(
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version}, kafka={'yes' if not is_reply else 'no'})"
)
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
Metrics.histogram("task.write_post_embedding_sink_v5.duration_ms").record(
duration_ms
)
Metrics.counter("task.write_post_embedding_sink_v5.count").add(1)

View File

@@ -0,0 +1,342 @@
import logging
import time
from grox.tasks.task import Task
from grox.tasks.disable_rules import DisableTaskForNonPtosProd
from monitor.metrics import Metrics
from grox.schedules.types import TaskContext
from grox.data_loaders.data_types import (
Image,
Video,
SafetyPolicyCategory,
SafetyPolicyType,
)
from strato_http.queries.data_types import (
SafetyPostAnnotations,
SafetyPostAnnotationsResult,
SafetyBoolMetadata,
SafetyPtosViolatedPolicy,
SafetyPolicy,
FoundMetadata,
)
from strato_http.queries.safety_post_annotations_result import (
StratoSafetyPostAnnotationsResultMh,
StratoSafetyPostAnnotationsResultDirectMh,
StratoSafetyPostAnnotationsResultKafka,
)
from strato_http.queries.grok_ptos_action_with_labels import (
StratoGrokPtosActionWithLabels,
)
from strato_http.queries.grok_ptos_delete_labels import StratoGrokPtosDeleteLabels
logger = logging.getLogger(__name__)
class TaskWriteSafetyPostAnnotationsResultSink(Task):
DISABLE_RULES = [DisableTaskForNonPtosProd]
_strato_mh = StratoSafetyPostAnnotationsResultMh()
_strato_direct_mh = StratoSafetyPostAnnotationsResultDirectMh()
_strato_kafka = StratoSafetyPostAnnotationsResultKafka()
_strato_grok_ptos_action_with_labels = StratoGrokPtosActionWithLabels()
_strato_grok_ptos_delete_labels = StratoGrokPtosDeleteLabels()
@staticmethod
def _build_found_metadata(post) -> FoundMetadata:
return FoundMetadata(
imageCount=sum(1 for m in post.media if isinstance(m, Image))
if post.media
else 0,
videoCount=sum(1 for m in post.media if isinstance(m, Video))
if post.media
else 0,
cardV2Count=len(post.cardsV2) if post.cardsV2 else 0,
)
@classmethod
def _compute_bool_metadata_from_violations(
cls, safety_annotations
) -> SafetyBoolMetadata:
is_gore = False
is_nsfw = False
is_soft_nsfw = False
is_spam = False
if safety_annotations and safety_annotations.violatedPolicies:
for violation in safety_annotations.violatedPolicies:
if (
violation.category == SafetyPolicyCategory.ViolentMedia
and violation.safetyPolicy
and violation.safetyPolicy.policyType
!= SafetyPolicyType.NoViolation
):
is_gore = True
Metrics.counter(
"task.write_safety_post_annotations_result_sink.detected_gore.count"
).add(1)
if (
violation.category == SafetyPolicyCategory.AdultContent
and violation.safetyPolicy
and violation.safetyPolicy.policyType
== SafetyPolicyType.AdultContentSexualHard
):
is_nsfw = True
Metrics.counter(
"task.write_safety_post_annotations_result_sink.detected_nsfw.count"
).add(1)
if (
violation.category == SafetyPolicyCategory.AdultContent
and violation.safetyPolicy
and violation.safetyPolicy.policyType
== SafetyPolicyType.AdultContentSexualSoft
):
is_soft_nsfw = True
Metrics.counter(
"task.write_safety_post_annotations_result_sink.detected_soft_nsfw.count"
).add(1)
if (
violation.category == SafetyPolicyCategory.Spam
and violation.safetyPolicy
and violation.safetyPolicy.policyType
!= SafetyPolicyType.NoViolation
):
is_spam = True
Metrics.counter(
"task.write_safety_post_annotations_result_sink.detected_spam.count"
).add(1)
if (
violation.category
== SafetyPolicyCategory.IllegalAndRegulatedBehaviors
and violation.safetyPolicy
and violation.safetyPolicy.policyType
!= SafetyPolicyType.NoViolation
):
is_spam = True
Metrics.counter(
"task.write_safety_post_annotations_result_sink.detected_spam_illegal.count"
).add(1)
return SafetyBoolMetadata(
isGore=is_gore, isNsfw=is_nsfw, isSoftNsfw=is_soft_nsfw, isSpam=is_spam
)
@classmethod
def _merge_bool_metadata(
cls, existing: SafetyBoolMetadata | None, new: SafetyBoolMetadata
) -> SafetyBoolMetadata:
if existing is None:
return new
return SafetyBoolMetadata(
isGore=True if (existing.isGore or new.isGore) else False,
isNsfw=True if (existing.isNsfw or new.isNsfw) else False,
isSoftNsfw=True if (existing.isSoftNsfw or new.isSoftNsfw) else False,
isSpam=True if (existing.isSpam or new.isSpam) else False,
)
@classmethod
async def _exec(cls, ctx: TaskContext) -> None:
Metrics.counter("task.write_safety_post_annotations_result_sink.count").add(1)
post = ctx.payload.post
if not post:
return
safety_annotations = ctx.safety_annotations
if not safety_annotations:
return
post_id = int(post.id)
existing_result = await cls._strato_direct_mh.fetch(post_id)
if existing_result:
Metrics.counter(
"task.write_safety_post_annotations_result_sink.existing_found.count"
).add(1)
else:
Metrics.counter(
"task.write_safety_post_annotations_result_sink.existing_not_found.count"
).add(1)
if safety_annotations.violatedPolicies:
safety_annotations.violatedPolicies.sort(
key=lambda x: x.score or 0, reverse=True
)
task_type_suffix = (
ctx.payload.task_type.value if ctx.payload.task_type else "unknown"
)
identifier = f"{task_type_suffix}"
timestamp_ms = int(time.time() * 1000)
found_metadata = cls._build_found_metadata(post)
new_annotation = SafetyPostAnnotations(
tweetId=post_id,
violatedPolicies=[
policy.model_dump()
for policy in (safety_annotations.violatedPolicies or [])
],
foundMetadata=found_metadata,
identifier=identifier,
timestamp=timestamp_ms,
)
annotations_list = (
list(existing_result.safetyPostAnnotations)
if existing_result and existing_result.safetyPostAnnotations
else []
)
annotations_list.append(new_annotation)
new_bool_metadata = cls._compute_bool_metadata_from_violations(
safety_annotations
)
existing_bool_metadata = (
existing_result.safetyBoolMetadata if existing_result else None
)
merged_bool_metadata = cls._merge_bool_metadata(
existing_bool_metadata, new_bool_metadata
)
violation_details = []
for v in safety_annotations.violatedPolicies:
policy_type = v.safetyPolicy.policyType.name if v.safetyPolicy else "none"
violation_details.append(f"{v.category.value}:{policy_type}")
violations_summary = ", ".join(violation_details)
action_result = await cls._strato_grok_ptos_action_with_labels.execute(
new_annotation
)
if action_result and len(action_result.applied_labels) > 0:
logger.info(
f"grokPtosActionWithLabels applied labels: debugString='{action_result.debug_string}', appliedLabels={action_result.applied_labels} for post {post_id} (result_sink), violations=[{violations_summary}]"
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.grok_ptos_action_with_labels.count"
).add(1)
for label in action_result.applied_labels:
Metrics.counter(
"task.write_safety_post_annotations_result_sink.grok_ptos_action_with_labels.applied_label.count"
).add(1, attributes={"label": label})
elif action_result:
logger.info(
f"grokPtosActionWithLabels did not apply any labels: (debugString='{action_result.debug_string}') for post {post_id} (result_sink), violations=[{violations_summary}] "
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.grok_ptos_action_with_labels.empty.count"
).add(1)
else:
logger.info(
f"grokPtosActionWithLabels failed for post {post_id} (result_sink), violations=[{violations_summary}] "
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.grok_ptos_action_with_labels.failed.count"
).add(1)
ptos_already_nsfw = new_bool_metadata.isNsfw
if ctx.safemodel_sex_nudity.positive and not ptos_already_nsfw:
safemodel_confidence_int = round(ctx.safemodel_sex_nudity.confidence * 100)
safemodel_annotation = SafetyPostAnnotations(
tweetId=post_id,
violatedPolicies=[
SafetyPtosViolatedPolicy(
category=SafetyPolicyCategory.AdultContent,
score=safemodel_confidence_int,
reason="safemodel sex-and-nudity classifier detected adult content",
safetyPolicy=SafetyPolicy(
policyType=SafetyPolicyType.AdultContentSexualHard,
confidenceScore=safemodel_confidence_int,
reason="safemodel sex-and-nudity classifier detected adult content",
),
).model_dump(),
],
foundMetadata=found_metadata,
identifier="safemodel-sex-nudity",
timestamp=timestamp_ms,
)
annotations_list.append(safemodel_annotation)
merged_bool_metadata = cls._merge_bool_metadata(
merged_bool_metadata,
SafetyBoolMetadata(
isGore=False, isNsfw=True, isSoftNsfw=False, isSpam=False
),
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.safemodel_enforced.count"
).add(1)
safemodel_action_result = (
await cls._strato_grok_ptos_action_with_labels.execute(
safemodel_annotation
)
)
if (
safemodel_action_result
and len(safemodel_action_result.applied_labels) > 0
):
logger.info(
f"safemodel enforce: grokPtosActionWithLabels applied labels: debugString='{safemodel_action_result.debug_string}', "
f"appliedLabels={safemodel_action_result.applied_labels} for post {post_id}"
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.safemodel_action.count"
).add(1)
for label in safemodel_action_result.applied_labels:
Metrics.counter(
"task.write_safety_post_annotations_result_sink.safemodel_action.applied_label.count"
).add(1, attributes={"label": label})
elif safemodel_action_result:
logger.info(
f"safemodel enforce: grokPtosActionWithLabels no labels applied (debugString='{safemodel_action_result.debug_string}') for post {post_id}"
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.safemodel_action.empty.count"
).add(1)
else:
logger.info(
f"safemodel enforce: grokPtosActionWithLabels returned None for post {post_id}"
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.safemodel_action.failed.count"
).add(1)
final_result = SafetyPostAnnotationsResult(
tweetId=post_id,
safetyPostAnnotations=annotations_list,
safetyBoolMetadata=merged_bool_metadata,
)
delete_labels_result = await cls._strato_grok_ptos_delete_labels.execute(
final_result
)
if delete_labels_result:
logger.info(
f"grokPtosDeleteLabels returned '{delete_labels_result}' for post {post_id} (result_sink)"
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.grok_ptos_delete_labels.count"
).add(1)
else:
logger.info(
f"grokPtosDeleteLabels returned no result for post {post_id} (result_sink)"
)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.grok_ptos_delete_labels.empty.count"
).add(1)
await cls._strato_mh.put(post_id, final_result)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.mh.success.count"
).add(1)
await cls._strato_kafka.insert(post_id, final_result)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.kafka.success.count"
).add(1)
Metrics.counter(
"task.write_safety_post_annotations_result_sink.success.count"
).add(1)

20
home-mixer/ads/mod.rs Normal file
View File

@@ -0,0 +1,20 @@
mod partition_organic_blender;
mod safe_gap_blender;
pub(crate) mod util;
pub use partition_organic_blender::PartitionOrganicAdsBlender;
pub use safe_gap_blender::SafeGapAdsBlender;
use util::{record_ad_risk_stats, record_post_verdict_stats};
use xai_home_mixer_proto::{FeedItem, ScoredPost};
use xai_recsys_proto::AdIndexInfo;
pub trait AdsBlender: Send + Sync {
fn blend_inner(&self, scored_posts: Vec<ScoredPost>, ads: Vec<AdIndexInfo>) -> Vec<FeedItem>;
fn blend(&self, scored_posts: Vec<ScoredPost>, ads: Vec<AdIndexInfo>) -> Vec<FeedItem> {
record_post_verdict_stats(&scored_posts);
record_ad_risk_stats(&ads);
self.blend_inner(scored_posts, ads)
}
}

View File

@@ -0,0 +1,190 @@
use super::AdsBlender;
use super::util::*;
use crate::params::RESULT_SIZE;
use xai_home_mixer_proto::{FeedItem, ScoredPost, feed_item};
use xai_recsys_proto::AdIndexInfo;
use xai_stats_receiver::global_stats_receiver;
const ENFORCEMENT_METRIC: &str = "PartitionOrganic.enforcement";
pub struct PartitionOrganicAdsBlender;
impl AdsBlender for PartitionOrganicAdsBlender {
fn blend_inner(&self, scored_posts: Vec<ScoredPost>, ads: Vec<AdIndexInfo>) -> Vec<FeedItem> {
blend_impl(scored_posts, ads, MIN_POSTS_FOR_ADS)
}
}
pub(crate) fn blend_impl(
scored_posts: Vec<ScoredPost>,
ads: Vec<AdIndexInfo>,
min_posts: usize,
) -> Vec<FeedItem> {
let n = scored_posts.len();
if ads.is_empty() || n < min_posts {
return posts_to_feed_items(scored_posts);
}
let spacing = compute_spacing(&ads);
let safe_count = scored_posts.iter().filter(|p| !has_avoid(p)).count();
let max_from_safe = safe_count / 2;
let expected_from_spacing = if spacing.requested > 0 {
n.saturating_sub(1) / spacing.requested
} else {
0
};
let actual_ads = ads.len().min(expected_from_spacing).min(max_from_safe);
if actual_ads == 0 {
return posts_to_feed_items(scored_posts);
}
let mut safe: Vec<ScoredPost> = Vec::new();
let mut unsafe_posts: Vec<ScoredPost> = Vec::new();
for post in scored_posts {
if has_avoid(&post) {
unsafe_posts.push(post);
} else {
safe.push(post);
}
}
let num_safe = safe.len();
let group_size = num_safe / actual_ads;
let mut safe_opts: Vec<Option<ScoredPost>> = safe.into_iter().map(Some).collect();
let mut triples: Vec<(AdIndexInfo, ScoredPost, ScoredPost)> = Vec::new();
let mut bsr_drop: u64 = 0;
let mut bsr_ok: u64 = 0;
let mut handle_drop: u64 = 0;
let mut keyword_drop: u64 = 0;
let mut group_idx = 0;
for ad in ads {
if group_idx >= actual_ads {
break;
}
let group_start = group_idx * group_size;
let above_ref = safe_opts[group_start].as_ref();
let below_ref = safe_opts[group_start + 1].as_ref();
if should_drop_bsr_low(&ad, above_ref, below_ref) {
bsr_drop += 1;
continue;
}
if is_bsr_low_ad(&ad) {
bsr_ok += 1;
}
if should_drop_handle(&ad, above_ref, below_ref) {
handle_drop += 1;
continue;
}
if should_drop_keyword(&ad, above_ref, below_ref) {
keyword_drop += 1;
continue;
}
let above = safe_opts[group_start].take().unwrap();
let below = safe_opts[group_start + 1].take().unwrap();
triples.push((ad, above, below));
group_idx += 1;
}
let placed_ads = triples.len();
emit_enforcement_metrics(bsr_drop, bsr_ok, handle_drop, keyword_drop);
if placed_ads == 0 {
let mut all_posts: Vec<ScoredPost> = safe_opts.into_iter().flatten().collect();
all_posts.extend(unsafe_posts);
all_posts.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
return posts_to_feed_items(all_posts);
}
let mut filler: Vec<ScoredPost> =
Vec::with_capacity(num_safe - 2 * placed_ads + unsafe_posts.len());
filler.extend(safe_opts.into_iter().flatten());
filler.extend(unsafe_posts);
filler.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let inter_ad_gaps = placed_ads;
let filler_per_gap = filler.len() / inter_ad_gaps;
let remainder = filler.len() % inter_ad_gaps;
let mut filler_iter = filler.into_iter();
let mut items: Vec<FeedItem> = Vec::with_capacity(n + placed_ads);
for (i, (ad, above, below)) in triples.into_iter().enumerate() {
items.push(FeedItem {
position: 0,
item: Some(feed_item::Item::Post(above)),
});
items.push(FeedItem {
position: 0,
item: Some(feed_item::Item::Ad(ad)),
});
items.push(FeedItem {
position: 0,
item: Some(feed_item::Item::Post(below)),
});
let count = filler_per_gap + if i >= inter_ad_gaps - remainder { 1 } else { 0 };
for _ in 0..count {
if let Some(post) = filler_iter.next() {
items.push(FeedItem {
position: 0,
item: Some(feed_item::Item::Post(post)),
});
}
}
}
items.truncate(RESULT_SIZE);
if matches!(items.last(), Some(item) if matches!(item.item, Some(feed_item::Item::Ad(_)))) {
items.pop();
}
for (i, item) in items.iter_mut().enumerate() {
item.position = i as i32;
}
items
}
fn emit_enforcement_metrics(bsr_drop: u64, bsr_ok: u64, handle_drop: u64, keyword_drop: u64) {
let Some(receiver) = global_stats_receiver() else {
return;
};
if bsr_drop > 0 {
receiver.incr(ENFORCEMENT_METRIC, &[("action", "drop")], bsr_drop);
}
if bsr_ok > 0 {
receiver.incr(ENFORCEMENT_METRIC, &[("action", "ok")], bsr_ok);
}
if handle_drop > 0 {
receiver.incr(
ENFORCEMENT_METRIC,
&[("action", "handle_drop")],
handle_drop,
);
}
if keyword_drop > 0 {
receiver.incr(
ENFORCEMENT_METRIC,
&[("action", "keyword_drop")],
keyword_drop,
);
}
}

View File

@@ -0,0 +1,95 @@
use super::AdsBlender;
use super::util::*;
use xai_home_mixer_proto::{FeedItem, ScoredPost};
use xai_recsys_proto::AdIndexInfo;
pub struct SafeGapAdsBlender;
impl AdsBlender for SafeGapAdsBlender {
fn blend_inner(&self, scored_posts: Vec<ScoredPost>, ads: Vec<AdIndexInfo>) -> Vec<FeedItem> {
blend_impl(scored_posts, ads, MIN_POSTS_FOR_ADS)
}
}
pub(crate) fn blend_impl(
scored_posts: Vec<ScoredPost>,
ads: Vec<AdIndexInfo>,
min_posts: usize,
) -> Vec<FeedItem> {
let n = scored_posts.len();
if ads.is_empty() || n < min_posts {
return posts_to_feed_items(scored_posts);
}
let safe_gaps = find_safe_gaps(&scored_posts);
let spacing = compute_spacing(&ads);
let first_ideal = ads[0].insert_position.max(0) as usize;
let placements = assign_ads_to_gaps(&safe_gaps, ads.len(), &spacing, first_ideal);
interleave_and_finalize(scored_posts, ads, &placements)
}
pub(crate) fn assign_ads_to_gaps(
safe_gaps: &[usize],
num_ads: usize,
spacing: &AdSpacing,
first_ideal: usize,
) -> Vec<usize> {
let mut placements: Vec<usize> = Vec::new();
let mut search_from: usize = 0;
let mut prev_ideal = first_ideal;
for _ in 0..num_ads {
if search_from >= safe_gaps.len() {
break;
}
let (ideal, min) = match placements.last() {
None => (first_ideal, 1),
Some(&last_actual) => {
let ideal = prev_ideal + spacing.requested;
let min = (prev_ideal + spacing.min).max(last_actual + DEFAULT_SPACING.min);
(ideal, min)
}
};
let gap = find_best_gap(&safe_gaps[search_from..], ideal, min);
match gap {
Some((offset, g)) => {
placements.push(g);
search_from += offset + 1;
prev_ideal = ideal;
}
None => break,
}
}
placements
}
pub(crate) fn find_best_gap(gaps: &[usize], ideal: usize, min: usize) -> Option<(usize, usize)> {
let min_offset = gaps.partition_point(|&g| g < min);
if min_offset >= gaps.len() {
return None;
}
let candidates = &gaps[min_offset..];
let ideal_pos = candidates.partition_point(|&g| g < ideal);
let chosen = if ideal_pos >= candidates.len() {
candidates.len() - 1
} else if ideal_pos == 0 {
0
} else {
let below = candidates[ideal_pos - 1];
let above = candidates[ideal_pos];
if ideal - below <= above - ideal {
ideal_pos - 1
} else {
ideal_pos
}
};
Some((min_offset + chosen, candidates[chosen]))
}

228
home-mixer/ads/util.rs Normal file
View File

@@ -0,0 +1,228 @@
use crate::params::RESULT_SIZE;
use std::sync::LazyLock;
use xai_home_mixer_proto::{BrandSafetyVerdict, FeedItem, ScoredPost, feed_item};
use xai_post_text::TweetTokenizer;
use xai_recsys_proto::{AdIndexInfo, BrandSafetyRiskLevel};
use xai_stats_receiver::global_stats_receiver;
static TWEET_TOKENIZER: LazyLock<TweetTokenizer> = LazyLock::new(TweetTokenizer::new);
pub(crate) const MIN_POSTS_FOR_ADS: usize = 5;
pub(crate) const MIN_REQUESTED_GAP: usize = 3;
pub(crate) const DEFAULT_SPACING: AdSpacing = AdSpacing {
requested: 3,
min: 2,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct AdSpacing {
pub(crate) requested: usize,
pub(crate) min: usize,
}
pub(crate) fn has_avoid(post: &ScoredPost) -> bool {
post.brand_safety_verdict() == BrandSafetyVerdict::MediumRisk
}
pub(crate) fn find_safe_gaps(scored_posts: &[ScoredPost]) -> Vec<usize> {
let n = scored_posts.len();
let mut safe = Vec::new();
for g in 1..n {
if has_avoid(&scored_posts[g - 1]) {
continue;
}
if g < n && has_avoid(&scored_posts[g]) {
continue;
}
safe.push(g);
}
safe
}
pub(crate) fn compute_spacing(ads: &[AdIndexInfo]) -> AdSpacing {
if ads.len() < 2 {
return DEFAULT_SPACING;
}
let mut positions: Vec<i32> = ads.iter().take(4).map(|a| a.insert_position).collect();
positions.sort_unstable();
let min_diff = positions
.windows(2)
.map(|w| (w[1] - w[0]) as usize)
.filter(|&d| d > 0)
.min();
match min_diff {
Some(requested) if requested >= MIN_REQUESTED_GAP => AdSpacing {
requested,
min: requested.div_ceil(2),
},
_ => DEFAULT_SPACING,
}
}
pub(crate) fn is_bsr_low_ad(ad: &AdIndexInfo) -> bool {
let risk = ad
.ad_adjacency_control
.as_ref()
.map(|c| c.brand_safety_risk())
.unwrap_or(BrandSafetyRiskLevel::BsrUnknown);
matches!(
risk,
BrandSafetyRiskLevel::BsrLow | BrandSafetyRiskLevel::BsrIas
)
}
pub(crate) fn should_drop_bsr_low(
ad: &AdIndexInfo,
above: Option<&ScoredPost>,
below: Option<&ScoredPost>,
) -> bool {
let risk = ad
.ad_adjacency_control
.as_ref()
.map(|c| c.brand_safety_risk())
.unwrap_or(BrandSafetyRiskLevel::BsrUnknown);
if !matches!(
risk,
BrandSafetyRiskLevel::BsrLow | BrandSafetyRiskLevel::BsrIas
) {
return false;
}
let is_lr = |p: &ScoredPost| p.brand_safety_verdict() == BrandSafetyVerdict::LowRisk;
above.map(is_lr).unwrap_or(false) || below.map(is_lr).unwrap_or(false)
}
pub(crate) fn should_drop_handle(
ad: &AdIndexInfo,
above: Option<&ScoredPost>,
below: Option<&ScoredPost>,
) -> bool {
let handles = match ad.ad_adjacency_control.as_ref() {
Some(ctrl) if !ctrl.handles.is_empty() => &ctrl.handles,
_ => return false,
};
above
.map(|p| handles.contains(&(p.author_id as i64)))
.unwrap_or(false)
|| below
.map(|p| handles.contains(&(p.author_id as i64)))
.unwrap_or(false)
}
pub(crate) fn should_drop_keyword(
ad: &AdIndexInfo,
above: Option<&ScoredPost>,
below: Option<&ScoredPost>,
) -> bool {
let keywords = match ad.ad_adjacency_control.as_ref() {
Some(ctrl) if !ctrl.keywords.is_empty() => &ctrl.keywords,
_ => return false,
};
let tokenizer = &*TWEET_TOKENIZER;
let tokenized_keywords: Vec<_> = keywords
.iter()
.map(|kw| tokenizer.tokenize(kw))
.filter(|seq| !seq.is_empty())
.collect();
if tokenized_keywords.is_empty() {
return false;
}
let text_matches = |p: &ScoredPost| {
if p.tweet_text.is_empty() {
return false;
}
let tweet_tokens = tokenizer.tokenize(&p.tweet_text);
if tweet_tokens.is_empty() {
return false;
}
tokenized_keywords
.iter()
.any(|kw_tokens| tweet_tokens.contains_keyword_sequence(kw_tokens))
};
above.map(text_matches).unwrap_or(false) || below.map(text_matches).unwrap_or(false)
}
pub(crate) fn posts_to_feed_items(scored_posts: Vec<ScoredPost>) -> Vec<FeedItem> {
scored_posts
.into_iter()
.enumerate()
.map(|(i, post)| FeedItem {
position: i as i32,
item: Some(feed_item::Item::Post(post)),
})
.collect()
}
pub(crate) fn interleave_and_finalize(
scored_posts: Vec<ScoredPost>,
ads: Vec<AdIndexInfo>,
placements: &[usize],
) -> Vec<FeedItem> {
let n = scored_posts.len();
let mut items: Vec<FeedItem> = Vec::with_capacity(n + placements.len());
let mut ads_iter = ads.into_iter();
let mut pi = 0;
for (i, post) in scored_posts.into_iter().enumerate() {
if pi < placements.len() && placements[pi] == i {
items.push(FeedItem {
position: 0,
item: Some(feed_item::Item::Ad(ads_iter.next().unwrap())),
});
pi += 1;
}
items.push(FeedItem {
position: 0,
item: Some(feed_item::Item::Post(post)),
});
}
items.truncate(RESULT_SIZE);
if matches!(items.last(), Some(item) if matches!(item.item, Some(feed_item::Item::Ad(_)))) {
items.pop();
}
for (i, item) in items.iter_mut().enumerate() {
item.position = i as i32;
}
items
}
const VERDICT_METRIC: &str = "AdsBlender.post_brand_safety_verdict";
const RISK_METRIC: &str = "AdsBlender.ad_brand_safety_risk";
pub(crate) fn record_post_verdict_stats(posts: &[ScoredPost]) {
let Some(receiver) = global_stats_receiver() else {
return;
};
for post in posts {
let label = post.brand_safety_verdict().as_str_name();
receiver.incr(VERDICT_METRIC, &[("verdict", label)], 1);
}
}
pub(crate) fn record_ad_risk_stats(ads: &[AdIndexInfo]) {
let Some(receiver) = global_stats_receiver() else {
return;
};
for ad in ads {
let risk_level = ad
.ad_adjacency_control
.as_ref()
.map(|c| c.brand_safety_risk())
.unwrap_or(BrandSafetyRiskLevel::BsrUnknown);
receiver.incr(RISK_METRIC, &[("risk", risk_level.as_str_name())], 1);
}
}

View File

@@ -0,0 +1,176 @@
use crate::models::brand_safety::{
BrandSafetyVerdict, botmaker_rule_category, botmaker_rule_id_from, compute_verdict,
truncate_description, worst_verdict,
};
use crate::models::candidate::{PostCandidate, SafetyLabelInfo};
use crate::models::query::ScoredPostsQuery;
use crate::params::*;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::utils::{
MokaCache, TweetAgeExpiry, build_moka_cache_tweet_age,
};
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
use xai_safety_label_store::SafetyLabelStoreClient;
const CACHE_SIZE: u64 = 1_000_000;
#[derive(Clone, Default)]
pub struct CachedBrandSafety {
verdict: BrandSafetyVerdict,
safety_labels: Vec<SafetyLabelInfo>,
}
pub struct AdsBrandSafetyHydrator {
pub client: Arc<dyn SafetyLabelStoreClient>,
pub cache: MokaCache<u64, CachedBrandSafety>,
}
impl AdsBrandSafetyHydrator {
pub fn new(client: Arc<dyn SafetyLabelStoreClient>) -> Self {
let cache = build_moka_cache_tweet_age(
CACHE_SIZE,
TweetAgeExpiry {
age_threshold: Duration::from_secs(5 * 60),
new_tweet_ttl: Duration::from_secs(60),
old_tweet_ttl: Duration::from_secs(60 * 60),
},
);
Self { client, cache }
}
}
#[async_trait]
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for AdsBrandSafetyHydrator {
type CacheKey = u64;
type CacheValue = CachedBrandSafety;
fn enable(&self, query: &ScoredPostsQuery) -> bool {
query.params.get(EnableAdsBrandSafetyHydrator)
&& !query
.decider
.as_ref()
.is_some_and(|d| d.enabled("vf_brand_safety_dark_traffic"))
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
&self.cache
}
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
candidate.tweet_id
}
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
CachedBrandSafety {
verdict: hydrated.brand_safety_verdict.unwrap_or_default(),
safety_labels: hydrated.safety_labels.clone(),
}
}
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
PostCandidate {
brand_safety_verdict: Some(value.verdict),
safety_labels: value.safety_labels,
..Default::default()
}
}
async fn hydrate_from_client(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let mut all_ids: HashSet<u64> = HashSet::new();
for c in candidates {
all_ids.insert(c.retweeted_tweet_id.unwrap_or(c.tweet_id));
if let Some(qt_id) = c.quoted_tweet_id {
all_ids.insert(qt_id);
}
}
let all_ids_vec: Vec<u64> = all_ids.into_iter().collect();
let all_ids_i64: Vec<i64> = all_ids_vec.iter().map(|&id| id as i64).collect();
let mut label_map: HashMap<u64, xai_safety_label_store::types::SafetyLabelMap> =
HashMap::new();
let mut error_map: HashMap<u64, String> = HashMap::new();
match self.client.batch_get_all_labels(&all_ids_i64).await {
Ok(per_key_results) => {
for (&id, result) in all_ids_vec.iter().zip(per_key_results) {
match result {
Ok(labels) => {
label_map.insert(id, labels);
}
Err(e) => {
error_map.insert(id, e.to_string());
}
}
}
}
Err(e) => {
let err_str = e.to_string();
for &id in &all_ids_vec {
error_map.insert(id, err_str.clone());
}
}
}
candidates
.iter()
.map(|c| {
let primary_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id);
if let Some(err) = error_map.get(&primary_id) {
return Err(format!("safety label lookup error for tweet {primary_id}: {err}"));
}
let empty = HashMap::new();
let primary_labels = label_map.get(&primary_id).unwrap_or(&empty);
let mut verdict = compute_verdict(primary_labels, primary_id);
let mut safety_labels: Vec<SafetyLabelInfo> = primary_labels
.iter()
.map(|(k, v)| SafetyLabelInfo {
label_type: *k,
description: v.source.as_deref().map(truncate_description),
source: botmaker_rule_id_from(v)
.map(|id| botmaker_rule_category(id).to_string()),
})
.collect();
if let Some(qt_id) = c.quoted_tweet_id {
if error_map.contains_key(&qt_id) {
verdict = worst_verdict(&verdict, &BrandSafetyVerdict::MediumRisk);
} else {
let qt_labels = label_map.get(&qt_id).unwrap_or(&empty);
verdict = worst_verdict(&verdict, &compute_verdict(qt_labels, qt_id));
safety_labels.extend(qt_labels.iter().map(|(k, v)| {
SafetyLabelInfo {
label_type: *k,
description: v.source.as_deref().map(truncate_description),
source: botmaker_rule_id_from(v)
.map(|id| botmaker_rule_category(id).to_string()),
}
}));
}
}
safety_labels.sort_unstable_by_key(|l| i32::from(l.label_type));
safety_labels.dedup_by(|a, b| a.label_type == b.label_type);
Ok(PostCandidate {
brand_safety_verdict: Some(verdict),
safety_labels,
..Default::default()
})
})
.collect()
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.brand_safety_verdict = hydrated.brand_safety_verdict;
candidate.safety_labels = hydrated.safety_labels;
}
}

View File

@@ -0,0 +1,108 @@
use crate::models::brand_safety::{
BrandSafetyVerdict, botmaker_rule_category, botmaker_rule_id_from, compute_verdict,
truncate_description, worst_verdict,
};
use crate::models::candidate::{PostCandidate, SafetyLabelInfo};
use crate::models::query::ScoredPostsQuery;
use crate::params::*;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_visibility_filtering::vf_safety_labels_client::VfClient;
pub struct AdsBrandSafetyVfHydrator {
pub client: Arc<dyn VfClient>,
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for AdsBrandSafetyVfHydrator {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
query.params.get(EnableAdsBrandSafetyHydrator)
&& query
.decider
.as_ref()
.is_some_and(|d| d.enabled("vf_brand_safety_dark_traffic"))
}
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let mut all_ids: HashSet<u64> = HashSet::new();
for c in candidates {
all_ids.insert(c.retweeted_tweet_id.unwrap_or(c.tweet_id));
if let Some(qt_id) = c.quoted_tweet_id {
all_ids.insert(qt_id);
}
}
let tweet_ids: Vec<u64> = all_ids.into_iter().collect();
let batch = match self.client.get_safety_labels(tweet_ids).await {
Ok(batch) => batch,
Err(e) => {
let err = format!("VF get_safety_labels failed: {e}");
return candidates.iter().map(|_| Err(err.clone())).collect();
}
};
let failed_ids: HashSet<u64> = batch.failures.keys().copied().collect();
let label_map = batch.labels;
candidates
.iter()
.map(|c| {
let primary_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id);
if failed_ids.contains(&primary_id) {
return Err(format!("VF lookup failed for tweet {primary_id}"));
}
let empty = HashMap::new();
let primary_labels = label_map.get(&primary_id).unwrap_or(&empty);
let mut verdict = compute_verdict(primary_labels, primary_id);
let mut safety_labels: Vec<SafetyLabelInfo> = primary_labels
.iter()
.map(|(k, v)| SafetyLabelInfo {
label_type: *k,
description: v.source.as_deref().map(truncate_description),
source: botmaker_rule_id_from(v)
.map(|id| botmaker_rule_category(id).to_string()),
})
.collect();
if let Some(qt_id) = c.quoted_tweet_id {
if failed_ids.contains(&qt_id) {
verdict = worst_verdict(&verdict, &BrandSafetyVerdict::MediumRisk);
} else {
let qt_labels = label_map.get(&qt_id).unwrap_or(&empty);
verdict = worst_verdict(&verdict, &compute_verdict(qt_labels, qt_id));
safety_labels.extend(qt_labels.iter().map(|(k, v)| {
SafetyLabelInfo {
label_type: *k,
description: v.source.as_deref().map(truncate_description),
source: botmaker_rule_id_from(v)
.map(|id| botmaker_rule_category(id).to_string()),
}
}));
}
}
safety_labels.sort_unstable_by_key(|l| i32::from(l.label_type));
safety_labels.dedup_by(|a, b| a.label_type == b.label_type);
Ok(PostCandidate {
brand_safety_verdict: Some(verdict),
safety_labels,
..Default::default()
})
})
.collect()
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.brand_safety_verdict = hydrated.brand_safety_verdict;
candidate.safety_labels = hydrated.safety_labels;
}
}

View File

@@ -0,0 +1,57 @@
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::clients::SocialGraphClientOps;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct BlockedByHydrator {
pub socialgraph_client: Arc<dyn SocialGraphClientOps>,
}
impl BlockedByHydrator {
pub async fn new(socialgraph_client: Arc<dyn SocialGraphClientOps>) -> Self {
Self { socialgraph_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for BlockedByHydrator {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.has_cached_posts
}
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let author_ids: Vec<u64> = candidates.iter().map(|x| x.author_id).collect();
let blocked_by_user_ids = match self
.socialgraph_client
.check_blocked_by(query.user_id, &author_ids)
.await
{
Ok(ids) => ids,
Err(e) => {
let err_msg = e.to_string();
return candidates.iter().map(|_| Err(err_msg.clone())).collect();
}
};
candidates
.iter()
.map(|candidate| {
let author_blocks_viewer = blocked_by_user_ids.contains(&candidate.author_id);
Ok(PostCandidate {
author_blocks_viewer: Some(author_blocks_viewer),
..Default::default()
})
})
.collect()
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.author_blocks_viewer = hydrated.author_blocks_viewer;
}
}

View File

@@ -1,52 +1,108 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
use xai_stats_receiver::global_stats_receiver;
const FOUND_SCOPE: [(&str, &str); 1] = [("hydration", "found")];
const MISSING_SCOPE: [(&str, &str); 1] = [("hydration", "missing")];
pub struct CoreDataCandidateHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
pub cache: MokaCache<u64, CoreDataCacheValue>,
}
impl CoreDataCandidateHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
let cache = default_moka_cache();
Self { tes_client, cache }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
type CacheKey = u64;
type CacheValue = CoreDataCacheValue;
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.has_cached_posts
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
&self.cache
}
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
candidate.tweet_id
}
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
CoreDataCacheValue {
author_id: hydrated.author_id,
retweeted_user_id: hydrated.retweeted_user_id,
retweeted_tweet_id: hydrated.retweeted_tweet_id,
in_reply_to_tweet_id: hydrated.in_reply_to_tweet_id,
tweet_text: hydrated.tweet_text.clone(),
}
}
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
PostCandidate {
author_id: value.author_id,
retweeted_user_id: value.retweeted_user_id,
retweeted_tweet_id: value.retweeted_tweet_id,
in_reply_to_tweet_id: value.in_reply_to_tweet_id,
tweet_text: value.tweet_text,
..Default::default()
}
}
async fn hydrate_from_client(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
) -> Vec<Result<PostCandidate, String>> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let tweet_ids: Vec<u64> = candidates.iter().map(|c| c.tweet_id).collect();
let post_features = client.get_tweet_core_datas(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
let mut hydrated_count = 0usize;
let mut missing_count = 0usize;
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let core_data = post_features.and_then(|x| x.as_ref());
let text = core_data.map(|x| x.text.clone());
let hydrated = PostCandidate {
author_id: core_data.map(|x| x.author_id).unwrap_or_default(),
retweeted_user_id: core_data.and_then(|x| x.source_user_id),
retweeted_tweet_id: core_data.and_then(|x| x.source_tweet_id),
in_reply_to_tweet_id: core_data.and_then(|x| x.in_reply_to_tweet_id),
tweet_text: text.unwrap_or_default(),
..Default::default()
};
hydrated_candidates.push(hydrated);
match post_features {
Some(Ok(Some(core_data))) => {
hydrated_count += 1;
let text = core_data.text.clone();
let hydrated = PostCandidate {
author_id: core_data.author_id,
retweeted_user_id: core_data.source_user_id,
retweeted_tweet_id: core_data.source_tweet_id,
in_reply_to_tweet_id: core_data.in_reply_to_tweet_id,
tweet_text: text,
..Default::default()
};
hydrated_candidates.push(Ok(hydrated));
}
Some(Ok(None)) | None => {
missing_count += 1;
hydrated_candidates.push(Ok(PostCandidate::default()));
}
Some(Err(err)) => {
hydrated_candidates.push(Err(err.to_string()));
}
}
}
Ok(hydrated_candidates)
self.record_hydration_stats(hydrated_count, missing_count);
hydrated_candidates
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
@@ -56,3 +112,22 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
candidate.tweet_text = hydrated.tweet_text;
}
}
#[derive(Clone, Debug)]
pub struct CoreDataCacheValue {
pub author_id: u64,
pub retweeted_user_id: Option<u64>,
pub retweeted_tweet_id: Option<u64>,
pub in_reply_to_tweet_id: Option<u64>,
pub tweet_text: String,
}
impl CoreDataCandidateHydrator {
fn record_hydration_stats(&self, hydrated_count: usize, missing_count: usize) {
if let Some(receiver) = global_stats_receiver() {
let metric_name = format!("{}.hydrate", self.name());
receiver.incr(metric_name.as_str(), &FOUND_SCOPE, hydrated_count as u64);
receiver.incr(metric_name.as_str(), &MISSING_SCOPE, missing_count as u64);
}
}
}

View File

@@ -0,0 +1,113 @@
use crate::clients::tweet_entity_service_client::TESClient;
use crate::models::candidate::{CandidateHelpers, PostCandidate};
use crate::models::query::ScoredPostsQuery;
use crate::params::EnableContextFeatures;
use std::sync::Arc;
use std::time::Duration;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::utils::{
MokaCache, TweetAgeExpiry, build_moka_cache_tweet_age,
};
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
#[derive(Clone, Debug)]
pub struct CachedCounts {
fav_count: Option<i64>,
reply_count: Option<i64>,
repost_count: Option<i64>,
quote_count: Option<i64>,
}
pub struct EngagementCountsHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
cache: MokaCache<u64, CachedCounts>,
}
impl EngagementCountsHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
let cache = build_moka_cache_tweet_age(
1_000_000,
TweetAgeExpiry {
age_threshold: Duration::from_secs(30 * 60),
new_tweet_ttl: Duration::from_secs(5 * 60),
old_tweet_ttl: Duration::from_secs(10 * 60),
},
);
Self { tes_client, cache }
}
}
#[async_trait]
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for EngagementCountsHydrator {
type CacheKey = u64;
type CacheValue = CachedCounts;
fn enable(&self, query: &ScoredPostsQuery) -> bool {
(query.params.get(EnableContextFeatures) || query.is_shadow_traffic)
&& !query.has_cached_posts
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
&self.cache
}
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
candidate.get_original_tweet_id()
}
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
CachedCounts {
fav_count: hydrated.fav_count,
reply_count: hydrated.reply_count,
repost_count: hydrated.repost_count,
quote_count: hydrated.quote_count,
}
}
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
PostCandidate {
fav_count: value.fav_count,
reply_count: value.reply_count,
repost_count: value.repost_count,
quote_count: value.quote_count,
..Default::default()
}
}
async fn hydrate_from_client(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let tweet_ids: Vec<u64> = candidates
.iter()
.map(|c| c.get_original_tweet_id())
.collect();
let counts_results = self.tes_client.get_api_counts(tweet_ids.clone()).await;
tweet_ids
.iter()
.map(|tweet_id| {
let counts = counts_results
.get(tweet_id)
.and_then(|r| r.as_ref().ok())
.and_then(|opt| opt.as_ref());
Ok(PostCandidate {
fav_count: counts.and_then(|c| c.favorite_count),
reply_count: counts.and_then(|c| c.reply_count),
repost_count: counts.and_then(|c| c.retweet_count),
quote_count: counts.and_then(|c| c.quote_count),
..Default::default()
})
})
.collect()
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.fav_count = hydrated.fav_count;
candidate.reply_count = hydrated.reply_count;
candidate.repost_count = hydrated.repost_count;
candidate.quote_count = hydrated.quote_count;
}
}

View File

@@ -0,0 +1,134 @@
use crate::filters::topic_ids_filter::TopicFilteringOverrideMap;
use crate::models::candidate::PostCandidate;
use crate::models::candidate_features::{FilteredTopicsByExperiment, TopicFilteringExperiment};
use crate::models::query::ScoredPostsQuery;
use crate::params::{EnableNewUserTopicFiltering, TopicFilteringId, TopicFilteringOverrides};
use std::collections::HashMap;
use std::sync::Arc;
use tonic::async_trait;
use tracing::warn;
use xai_candidate_pipeline::component_library::clients::StratoClient;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_strato::{StratoResult, StratoValue, decode};
fn decode_topics_pair(
result: &Result<Vec<u8>, Box<dyn std::error::Error>>,
experiment: TopicFilteringExperiment,
need_unfiltered: bool,
) -> (Option<Vec<i64>>, Option<Vec<i64>>) {
match result {
Ok(bytes) if !bytes.is_empty() => {
let decoded: StratoResult<StratoValue<FilteredTopicsByExperiment>> = decode(bytes);
match decoded {
StratoResult::Ok(v) => {
let ft = v.v;
let exp_topics = ft
.as_ref()
.and_then(|ft| ft.topic_ids_for_experiment(experiment).cloned());
let unf_topics = if need_unfiltered {
ft.as_ref().and_then(|ft| {
ft.topic_ids_for_experiment(TopicFilteringExperiment::Unfiltered)
.cloned()
})
} else {
None
};
(exp_topics, unf_topics)
}
StratoResult::Err(_) => (None, None),
}
}
Ok(_) => (None, None),
Err(e) => {
warn!("FilteredTopicsHydrator: strato fetch error: {}", e);
(None, None)
}
}
}
pub struct FilteredTopicsHydrator {
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for FilteredTopicsHydrator {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
query.is_topic_request()
|| query.has_excluded_topics()
|| (query.params.get(EnableNewUserTopicFiltering) && query.has_new_user_topic_ids())
}
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let experiment = if query.is_bulk_topic_request() || query.has_excluded_topics() {
TopicFilteringExperiment::Unfiltered
} else {
let default_experiment =
TopicFilteringExperiment::parse(&query.params.get(TopicFilteringId));
let override_map =
TopicFilteringOverrideMap::parse(&query.params.get(TopicFilteringOverrides));
override_map.resolve(&query.topic_ids, default_experiment)
};
let client = &self.strato_client;
let need_unfiltered = experiment != TopicFilteringExperiment::Unfiltered;
let mut all_ids: Vec<u64> = candidates.iter().map(|c| c.tweet_id).collect();
let retweet_offset = all_ids.len();
for c in candidates {
if let Some(rt_id) = c.retweeted_tweet_id {
all_ids.push(rt_id);
}
}
let all_results = client
.batch_get_filtered_topics_by_experiment(&all_ids)
.await;
let mut retweet_topics: HashMap<u64, Vec<i64>> = HashMap::new();
let mut retweet_unfiltered: HashMap<u64, Vec<i64>> = HashMap::new();
let mut rt_idx = retweet_offset;
for c in candidates {
if let Some(rt_id) = c.retweeted_tweet_id {
let (exp, unf) =
decode_topics_pair(&all_results[rt_idx], experiment, need_unfiltered);
if let Some(topics) = exp {
retweet_topics.insert(rt_id, topics);
}
if let Some(topics) = unf {
retweet_unfiltered.insert(rt_id, topics);
}
rt_idx += 1;
}
}
candidates
.iter()
.enumerate()
.map(|(i, c)| {
let (topics, unf_topics) = if let Some(rt_id) = c.retweeted_tweet_id {
(
retweet_topics.get(&rt_id).cloned(),
retweet_unfiltered.get(&rt_id).cloned(),
)
} else {
decode_topics_pair(&all_results[i], experiment, need_unfiltered)
};
Ok(PostCandidate {
filtered_topic_ids: topics,
unfiltered_topic_ids: unf_topics,
..Default::default()
})
})
.collect()
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.filtered_topic_ids = hydrated.filtered_topic_ids;
candidate.unfiltered_topic_ids = hydrated.unfiltered_topic_ids;
}
}

View File

@@ -0,0 +1,98 @@
use crate::models::candidate::PostCandidate;
use crate::models::in_network_reply::InNetworkReply;
use crate::models::query::ScoredPostsQuery;
use crate::params::{
EnableFollowingRepliedUsersFacepile, FollowingRepliedUsersFacepileMaxPosts,
FollowingRepliedUsersFacepileMinUsers,
};
use std::collections::HashMap;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
const VIEWER_FOLLOWERS_THRESHOLD: i64 = 1000;
pub struct FollowingRepliedUsersHydrator;
impl FollowingRepliedUsersHydrator {
fn build_reply_author_map(replies: &[InNetworkReply]) -> HashMap<u64, Vec<u64>> {
let mut map: HashMap<u64, Vec<u64>> = HashMap::new();
for reply in replies {
map.entry(reply.in_reply_to_tweet_id)
.or_default()
.push(reply.author_id);
}
for authors in map.values_mut() {
authors.sort_unstable();
authors.dedup();
}
map
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for FollowingRepliedUsersHydrator {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
let has_enough_followers = query
.user_features
.follower_count
.is_some_and(|c| c >= VIEWER_FOLLOWERS_THRESHOLD);
has_enough_followers && query.params.get(EnableFollowingRepliedUsersFacepile)
}
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let min_users = query
.params
.get(FollowingRepliedUsersFacepileMinUsers)
.max(0) as usize;
let max_posts = query
.params
.get(FollowingRepliedUsersFacepileMaxPosts)
.max(0) as usize;
let empty = Vec::new();
let replies = query.in_network_replies.get().unwrap_or(&empty);
let reply_author_map = Self::build_reply_author_map(replies);
let mut results = Vec::with_capacity(candidates.len());
let mut selected_count: usize = 0;
for candidate in candidates {
let is_root_tweet = candidate.in_reply_to_tweet_id.is_none();
let authors: Vec<u64> = reply_author_map
.get(&candidate.tweet_id)
.map(|ids| {
ids.iter()
.copied()
.filter(|&aid| aid != candidate.author_id)
.collect()
})
.unwrap_or_default();
let eligible = is_root_tweet && authors.len() >= min_users;
let user_ids = if eligible && selected_count < max_posts {
selected_count += 1;
authors
} else {
vec![]
};
results.push(Ok(PostCandidate {
following_replied_user_ids: user_ids,
..Default::default()
}));
}
results
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.following_replied_user_ids = hydrated.following_replied_user_ids;
}
}

View File

@@ -1,28 +1,67 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::gizmoduck_client::GizmoduckClient;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
pub struct GizmoduckCandidateHydrator {
pub gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
pub cache: MokaCache<GizmoduckCacheKey, GizmoduckCacheValue>,
}
impl GizmoduckCandidateHydrator {
pub async fn new(gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>) -> Self {
Self { gizmoduck_client }
let cache = default_moka_cache();
Self {
gizmoduck_client,
cache,
}
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
type CacheKey = GizmoduckCacheKey;
type CacheValue = GizmoduckCacheValue;
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.has_cached_posts
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
&self.cache
}
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
GizmoduckCacheKey {
author_id: candidate.author_id,
retweeted_user_id: candidate.retweeted_user_id,
}
}
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
GizmoduckCacheValue {
author_followers_count: hydrated.author_followers_count,
author_screen_name: hydrated.author_screen_name.clone(),
retweeted_screen_name: hydrated.retweeted_screen_name.clone(),
}
}
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
PostCandidate {
author_followers_count: value.author_followers_count,
author_screen_name: value.author_screen_name,
retweeted_screen_name: value.retweeted_screen_name,
..Default::default()
}
}
async fn hydrate_from_client(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
) -> Vec<Result<PostCandidate, String>> {
let client = &self.gizmoduck_client;
let author_ids: Vec<_> = candidates.iter().map(|c| c.author_id).collect();
@@ -37,40 +76,54 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
user_ids_to_fetch.dedup();
let users = client.get_users(user_ids_to_fetch).await;
let users = users.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for candidate in candidates {
let user = users
.get(&(candidate.author_id as i64))
.and_then(|user| user.as_ref());
let user_counts = user.and_then(|user| user.user.as_ref().map(|u| &u.counts));
let user_profile = user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
let author_followers_count: Option<i32> =
user_counts.map(|x| x.followers_count).map(|x| x as i32);
let author_screen_name: Option<String> = user_profile.map(|x| x.screen_name.clone());
let user = users.get(&(candidate.author_id as i64));
let user = match user {
Some(Ok(Some(user))) => Ok(Some(user)),
Some(Ok(None)) | None => Ok(None),
Some(Err(err)) => Err(err.to_string()),
};
let retweet_user = candidate
.retweeted_user_id
.and_then(|retweeted_user_id| users.get(&(retweeted_user_id as i64)))
.and_then(|user| user.as_ref());
let retweet_profile =
retweet_user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
let retweeted_screen_name: Option<String> =
retweet_profile.map(|x| x.screen_name.clone());
.and_then(|retweeted_user_id| users.get(&(retweeted_user_id as i64)));
let retweet_user = match retweet_user {
Some(Ok(Some(user))) => Ok(Some(user)),
Some(Ok(None)) | None => Ok(None),
Some(Err(err)) => Err(err.to_string()),
};
let hydrated = PostCandidate {
author_followers_count,
author_screen_name,
retweeted_screen_name,
..Default::default()
let hydrated = match (user, retweet_user) {
(Ok(user), Ok(retweet_user)) => {
let user_counts = user.and_then(|user| user.user.as_ref().map(|u| &u.counts));
let user_profile = user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
let author_followers_count: Option<i32> =
user_counts.map(|x| x.followers_count).map(|x| x as i32);
let author_screen_name: Option<String> =
user_profile.map(|x| x.screen_name.clone());
let retweet_profile =
retweet_user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
let retweeted_screen_name: Option<String> =
retweet_profile.map(|x| x.screen_name.clone());
Ok(PostCandidate {
author_followers_count,
author_screen_name,
retweeted_screen_name,
..Default::default()
})
}
(Err(err), _) | (_, Err(err)) => Err(err),
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
hydrated_candidates
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
@@ -79,3 +132,16 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
candidate.retweeted_screen_name = hydrated.retweeted_screen_name;
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct GizmoduckCacheKey {
pub author_id: u64,
pub retweeted_user_id: Option<u64>,
}
#[derive(Clone, Debug)]
pub struct GizmoduckCacheValue {
pub author_followers_count: Option<i32>,
pub author_screen_name: Option<String>,
pub retweeted_screen_name: Option<String>,
}

View File

@@ -0,0 +1,85 @@
use crate::clients::tweet_entity_service_client::TESClient;
use crate::models::candidate::{CandidateHelpers, PostCandidate};
use crate::models::query::ScoredPostsQuery;
use crate::params::EnableHasMediaHydration;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
pub struct HasMediaHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
pub cache: MokaCache<u64, Option<bool>>,
}
impl HasMediaHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
let cache = default_moka_cache();
Self { tes_client, cache }
}
}
#[async_trait]
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for HasMediaHydrator {
type CacheKey = u64;
type CacheValue = Option<bool>;
fn enable(&self, query: &ScoredPostsQuery) -> bool {
(query.params.get(EnableHasMediaHydration) || query.is_shadow_traffic)
&& !query.has_cached_posts
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
&self.cache
}
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
candidate.get_original_tweet_id()
}
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
hydrated.has_media
}
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
PostCandidate {
has_media: value,
..Default::default()
}
}
async fn hydrate_from_client(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let client = &self.tes_client;
let tweet_ids: Vec<u64> = candidates
.iter()
.map(|c| c.get_original_tweet_id())
.collect();
let has_media_results = client.get_has_media(tweet_ids.clone()).await;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let result = has_media_results.get(&tweet_id);
let hydrated = match result {
Some(Ok(value)) => Ok(PostCandidate {
has_media: *value,
..Default::default()
}),
None => Err(format!("Missing has_media for tweet_id={}", tweet_id)),
Some(Err(err)) => Err(err.to_string()),
};
hydrated_candidates.push(hydrated);
}
hydrated_candidates
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.has_media = hydrated.has_media;
}
}

View File

@@ -1,5 +1,5 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
@@ -8,13 +8,16 @@ pub struct InNetworkCandidateHydrator;
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for InNetworkCandidateHydrator {
#[xai_stats_macro::receive_stats]
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.has_cached_posts
}
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let viewer_id = query.user_id as u64;
) -> Vec<Result<PostCandidate, String>> {
let viewer_id = query.user_id;
let followed_ids: HashSet<u64> = query
.user_features
.followed_user_ids
@@ -23,19 +26,17 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for InNetworkCandidateHydrator {
.map(|id| id as u64)
.collect();
let hydrated_candidates = candidates
candidates
.iter()
.map(|candidate| {
let is_self = candidate.author_id == viewer_id;
let is_in_network = is_self || followed_ids.contains(&candidate.author_id);
PostCandidate {
Ok(PostCandidate {
in_network: Some(is_in_network),
..Default::default()
}
})
})
.collect();
Ok(hydrated_candidates)
.collect()
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {

View File

@@ -0,0 +1,84 @@
use crate::clients::tweet_entity_service_client::TESClient;
use crate::models::candidate::{CandidateHelpers, PostCandidate};
use crate::models::query::ScoredPostsQuery;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
pub struct LanguageCodeHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
pub cache: MokaCache<u64, Option<String>>,
}
impl LanguageCodeHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
let cache = default_moka_cache();
Self { tes_client, cache }
}
}
#[async_trait]
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for LanguageCodeHydrator {
type CacheKey = u64;
type CacheValue = Option<String>;
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.has_cached_posts
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
&self.cache
}
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
candidate.get_original_tweet_id()
}
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
hydrated.language_code.clone()
}
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
PostCandidate {
language_code: value,
..Default::default()
}
}
async fn hydrate_from_client(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let client = &self.tes_client;
let tweet_ids: Vec<u64> = candidates
.iter()
.map(|c| c.get_original_tweet_id())
.collect();
let language_results = client.get_language_code(tweet_ids.clone()).await;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let result = language_results.get(&tweet_id);
let hydrated = match result {
Some(Ok(value)) => Ok(PostCandidate {
language_code: value.clone(),
..Default::default()
}),
None => Err(format!("Missing language_code for tweet_id={}", tweet_id)),
Some(Err(err)) => Err(err.to_string()),
};
hydrated_candidates.push(hydrated);
}
hydrated_candidates
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.language_code = hydrated.language_code;
}
}

View File

@@ -1,6 +1,16 @@
pub mod ads_brand_safety_hydrator;
pub mod ads_brand_safety_vf_hydrator;
pub mod blocked_by_hydrator;
pub mod core_data_candidate_hydrator;
pub mod filtered_topics_hydrator;
pub mod following_replied_users_hydrator;
pub mod gizmoduck_hydrator;
pub mod has_media_hydrator;
pub mod in_network_candidate_hydrator;
pub mod language_code_hydrator;
pub mod mutual_follow_jaccard_hydrator;
pub mod quote_hydrator;
pub mod subscription_hydrator;
pub mod tweet_type_metrics_hydrator;
pub mod vf_candidate_hydrator;
pub mod video_duration_candidate_hydrator;

View File

@@ -0,0 +1,118 @@
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use crate::params::EnableMutualFollowJaccardHydration;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::clients::StratoClient;
use xai_candidate_pipeline::hydrator::Hydrator;
const MIN_HASHES: usize = 256;
pub struct MutualFollowJaccardHydrator {
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
}
fn jaccard_from_minhash(a: &[i64], b: &[i64]) -> f64 {
let len = a.len().min(b.len());
if len == 0 {
return 0.0;
}
let matching = a.iter().zip(b.iter()).filter(|(x, y)| x == y).count();
matching as f64 / len as f64
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for MutualFollowJaccardHydrator {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
query.params.get(EnableMutualFollowJaccardHydration) && query.viewer_minhash.is_some()
}
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let viewer_minhash = match &query.viewer_minhash {
Some(mh) if mh.len() >= MIN_HASHES => mh,
_ => {
return candidates
.iter()
.map(|_| {
Ok(PostCandidate {
mutual_follow_jaccard: None,
..Default::default()
})
})
.collect();
}
};
let unique_author_ids: Vec<i64> = candidates
.iter()
.map(|c| c.author_id as i64)
.collect::<HashSet<_>>()
.into_iter()
.collect();
let results = self
.strato_client
.batch_get_minhash_with_count(&unique_author_ids)
.await;
let mut author_result: HashMap<i64, Result<Option<Vec<i64>>, String>> = HashMap::new();
for (uid, result) in unique_author_ids.iter().zip(results) {
match result {
Ok(Some((minhash, _count))) if minhash.len() >= MIN_HASHES => {
author_result.insert(*uid, Ok(Some(minhash)));
}
Ok(Some((minhash, _))) => {
author_result.insert(
*uid,
Err(format!(
"Invalid minhash length {} (need >= {}) for author_id={}",
minhash.len(),
MIN_HASHES,
uid,
)),
);
}
Ok(None) => {
author_result.insert(*uid, Ok(None));
}
Err(e) => {
author_result.insert(*uid, Err(e.to_string()));
}
}
}
candidates
.iter()
.map(|c| {
let author_id = c.author_id as i64;
match author_result.get(&author_id) {
Some(Ok(Some(author_mh))) => Ok(PostCandidate {
mutual_follow_jaccard: Some(jaccard_from_minhash(
viewer_minhash,
author_mh,
)),
..Default::default()
}),
Some(Ok(None)) => Ok(PostCandidate {
mutual_follow_jaccard: None,
..Default::default()
}),
Some(Err(err)) => Err(err.clone()),
None => Err(format!(
"Missing minhash fetch result for author_id={}",
author_id,
)),
}
})
.collect()
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.mutual_follow_jaccard = hydrated.mutual_follow_jaccard;
}
}

View File

@@ -0,0 +1,176 @@
use crate::clients::tweet_entity_service_client::TESClient;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use crate::params::EnableQuotedVqvDurationCheck;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::clients::SocialGraphClientOps;
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
use xai_candidate_pipeline::hydrator::{CacheStore, Hydrator};
pub struct QuoteHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
pub socialgraph_client: Arc<dyn SocialGraphClientOps>,
pub cache: MokaCache<u64, QuoteCacheValue>,
}
impl QuoteHydrator {
pub async fn new(
tes_client: Arc<dyn TESClient + Send + Sync>,
socialgraph_client: Arc<dyn SocialGraphClientOps>,
) -> Self {
let cache = default_moka_cache();
Self {
tes_client,
socialgraph_client,
cache,
}
}
async fn get_quoted_video_durations(
&self,
quoted_tweet_ids: Vec<u64>,
) -> HashMap<u64, Option<i32>> {
if quoted_tweet_ids.is_empty() {
return HashMap::new();
}
let result = tokio::time::timeout(
std::time::Duration::from_millis(200),
self.tes_client.get_min_video_durations(quoted_tweet_ids),
)
.await;
match result {
Ok(durations) => durations
.into_iter()
.filter_map(|(id, result)| result.ok().map(|d| (id, d.map(|v| v as i32))))
.collect(),
Err(_) => HashMap::new(),
}
}
async fn get_blocked_by(&self, viewer_id: u64, quoted_user_ids: Vec<u64>) -> HashSet<u64> {
if quoted_user_ids.is_empty() {
return HashSet::new();
}
self.socialgraph_client
.check_blocked_by(viewer_id, &quoted_user_ids)
.await
.unwrap_or_default()
}
}
#[derive(Clone, Debug)]
pub struct QuoteCacheValue {
pub quoted_tweet_id: Option<u64>,
pub quoted_user_id: Option<u64>,
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for QuoteHydrator {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.has_cached_posts
}
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let tweet_ids: Vec<u64> = candidates.iter().map(|c| c.tweet_id).collect();
let mut cache_misses: Vec<u64> = Vec::new();
let mut resolved: Vec<(u64, Option<u64>, Option<u64>)> =
Vec::with_capacity(tweet_ids.len());
for &tweet_id in &tweet_ids {
if let Some(cached) = self.cache.get(&tweet_id).await {
resolved.push((tweet_id, cached.quoted_tweet_id, cached.quoted_user_id));
} else {
cache_misses.push(tweet_id);
resolved.push((tweet_id, None, None));
}
}
if !cache_misses.is_empty() {
let quoted_tweets = self
.tes_client
.get_quoted_tweets(cache_misses.clone())
.await;
for entry in resolved.iter_mut() {
let tweet_id = entry.0;
if !cache_misses.contains(&tweet_id) {
continue;
}
let (qt_tweet_id, qt_user_id) = match quoted_tweets.get(&tweet_id) {
Some(Ok(Some(qt))) => (Some(qt.tweet_id), Some(qt.user_id)),
_ => (None, None),
};
entry.1 = qt_tweet_id;
entry.2 = qt_user_id;
self.cache
.insert(
tweet_id,
QuoteCacheValue {
quoted_tweet_id: qt_tweet_id,
quoted_user_id: qt_user_id,
},
)
.await;
}
}
let quoted_user_ids: Vec<u64> = resolved
.iter()
.filter_map(|(_, _, uid)| *uid)
.collect::<HashSet<u64>>()
.into_iter()
.collect();
let fetch_quoted_duration = query.params.get(EnableQuotedVqvDurationCheck);
let quoted_tweet_ids: Vec<u64> = if fetch_quoted_duration {
resolved
.iter()
.filter_map(|(_, qt_id, _)| *qt_id)
.collect::<HashSet<u64>>()
.into_iter()
.collect()
} else {
Vec::new()
};
let (blocked_by, quoted_durations) = tokio::join!(
self.get_blocked_by(query.user_id, quoted_user_ids),
self.get_quoted_video_durations(quoted_tweet_ids),
);
resolved
.iter()
.map(|(_, qt_tweet_id, qt_user_id)| {
let quoted_author_blocks_viewer = qt_user_id
.map(|uid| blocked_by.contains(&uid))
.unwrap_or(false);
let quoted_video_duration_ms = qt_tweet_id
.and_then(|id| quoted_durations.get(&id).copied())
.flatten();
Ok(PostCandidate {
quoted_tweet_id: *qt_tweet_id,
quoted_user_id: *qt_user_id,
quoted_author_blocks_viewer: Some(quoted_author_blocks_viewer),
quoted_video_duration_ms,
..Default::default()
})
})
.collect()
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.quoted_tweet_id = hydrated.quoted_tweet_id;
candidate.quoted_user_id = hydrated.quoted_user_id;
candidate.quoted_author_blocks_viewer = hydrated.quoted_author_blocks_viewer;
candidate.quoted_video_duration_ms = hydrated.quoted_video_duration_ms;
}
}

View File

@@ -1,47 +1,79 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
pub struct SubscriptionHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
pub cache: MokaCache<u64, Option<u64>>,
}
impl SubscriptionHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
let cache = default_moka_cache();
Self { tes_client, cache }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for SubscriptionHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for SubscriptionHydrator {
type CacheKey = u64;
type CacheValue = Option<u64>;
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.has_cached_posts
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
&self.cache
}
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
candidate.tweet_id
}
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
hydrated.subscription_author_id
}
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
PostCandidate {
subscription_author_id: value,
..Default::default()
}
}
async fn hydrate_from_client(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
) -> Vec<Result<PostCandidate, String>> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let tweet_ids: Vec<u64> = candidates.iter().map(|c| c.tweet_id).collect();
let post_features = client.get_subscription_author_ids(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let subscription_author_id = post_features.and_then(|x| *x);
let hydrated = PostCandidate {
subscription_author_id,
..Default::default()
let hydrated = match post_features {
Some(Ok(value)) => Ok(PostCandidate {
subscription_author_id: *value,
..Default::default()
}),
None => Err(format!(
"Missing subscription author id for tweet_id={}",
tweet_id
)),
Some(Err(err)) => Err(err.to_string()),
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
hydrated_candidates
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {

View File

@@ -0,0 +1,180 @@
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use crate::util::tweet_type_metrics::*;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::utils::duration_since_creation_opt;
use xai_candidate_pipeline::hydrator::Hydrator;
const THIRTY_MINUTES_MS: u64 = 30 * 60 * 1000;
const ONE_HOUR_MS: u64 = 60 * 60 * 1000;
const SIX_HOURS_MS: u64 = 6 * 60 * 60 * 1000;
const TWELVE_HOURS_MS: u64 = 12 * 60 * 60 * 1000;
const TWENTY_FOUR_HOURS_MS: u64 = 24 * 60 * 60 * 1000;
pub struct TweetTypeMetricsHydrator;
impl TweetTypeMetricsHydrator {
pub fn new() -> Self {
Self
}
pub fn create_tweet_type_bitset(
candidate: &PostCandidate,
query: &ScoredPostsQuery,
) -> HashSet<usize> {
let mut true_tweet_types = HashSet::new();
true_tweet_types.insert(ANY_CANDIDATE);
if candidate.retweeted_tweet_id.is_some() {
true_tweet_types.insert(RETWEET);
}
if candidate.in_reply_to_tweet_id.is_some() {
true_tweet_types.insert(REPLY);
}
if candidate.subscription_author_id.is_some() {
true_tweet_types.insert(SUBSCRIPTION_POST);
}
if let Some(score) = candidate.score
&& score != 0.0
{
true_tweet_types.insert(FULL_SCORING_SUCCEEDED);
}
if !candidate.ancestors.is_empty() {
true_tweet_types.insert(HAS_ANCESTORS);
}
if candidate.in_network.unwrap_or(true) {
true_tweet_types.insert(IN_NETWORK);
}
if let Some(followers) = candidate.author_followers_count {
let followers_u32 = followers as u32;
if followers_u32 < 100 {
true_tweet_types.insert(AUTHOR_FOLLOWERS_0_100);
}
if (100..1000).contains(&followers_u32) {
true_tweet_types.insert(AUTHOR_FOLLOWERS_100_1K);
}
if (1000..10000).contains(&followers_u32) {
true_tweet_types.insert(AUTHOR_FOLLOWERS_1K_10K);
}
if (10000..100000).contains(&followers_u32) {
true_tweet_types.insert(AUTHOR_FOLLOWERS_10K_100K);
}
if (100000..1000000).contains(&followers_u32) {
true_tweet_types.insert(AUTHOR_FOLLOWERS_100K_1M);
}
if followers_u32 >= 1000000 {
true_tweet_types.insert(AUTHOR_FOLLOWERS_1M_PLUS);
}
}
if candidate.min_video_duration_ms.is_some() {
true_tweet_types.insert(VIDEO);
}
if let Some(duration_ms) = candidate.min_video_duration_ms {
let duration_ms_u32 = duration_ms as u32;
if duration_ms_u32 <= 10000 {
true_tweet_types.insert(VIDEO_LTE_10_SEC);
}
if duration_ms_u32 > 10000 && duration_ms_u32 <= 60000 {
true_tweet_types.insert(VIDEO_BT_10_60_SEC);
}
if duration_ms_u32 > 60000 {
true_tweet_types.insert(VIDEO_GT_60_SEC);
}
}
if let Some(age) = duration_since_creation_opt(candidate.tweet_id) {
let age_ms = age.as_millis() as u64;
if age_ms <= THIRTY_MINUTES_MS {
true_tweet_types.insert(TWEET_AGE_LTE_30_MINUTES);
}
if age_ms <= ONE_HOUR_MS {
true_tweet_types.insert(TWEET_AGE_LTE_1_HOUR);
}
if age_ms <= SIX_HOURS_MS {
true_tweet_types.insert(TWEET_AGE_LTE_6_HOURS);
}
if age_ms <= TWELVE_HOURS_MS {
true_tweet_types.insert(TWEET_AGE_LTE_12_HOURS);
}
if age_ms >= TWENTY_FOUR_HOURS_MS {
true_tweet_types.insert(TWEET_AGE_GTE_24_HOURS);
}
}
let served_size = query.served_ids.len();
if served_size == 0 {
true_tweet_types.insert(EMPTY_REQUEST);
}
if served_size < 3 {
true_tweet_types.insert(NEAR_EMPTY);
}
if served_size < 20 {
true_tweet_types.insert(SERVED_SIZE_LESS_THAN_20);
}
if served_size < 10 {
true_tweet_types.insert(SERVED_SIZE_LESS_THAN_10);
}
if served_size < 5 {
true_tweet_types.insert(SERVED_SIZE_LESS_THAN_5);
}
true_tweet_types
}
pub fn bitset_to_bytes(bits: &HashSet<usize>) -> Vec<u8> {
if bits.is_empty() {
return Vec::new();
}
let max_bit = bits.iter().max().copied().unwrap_or(0);
let num_bytes = (max_bit / 8) + 1;
let mut bytes = vec![0u8; num_bytes];
for &bit_index in bits {
let byte_index = bit_index / 8;
let bit_offset = bit_index % 8;
bytes[byte_index] |= 1u8 << bit_offset;
}
bytes
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for TweetTypeMetricsHydrator {
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Vec<Result<PostCandidate, String>> {
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for candidate in candidates {
let true_tweet_types = Self::create_tweet_type_bitset(candidate, query);
let tweet_type_metrics = Some(Self::bitset_to_bytes(&true_tweet_types));
let hydrated = PostCandidate {
tweet_type_metrics,
..Default::default()
};
hydrated_candidates.push(Ok(hydrated));
}
hydrated_candidates
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.tweet_type_metrics = hydrated.tweet_type_metrics;
}
}

View File

@@ -1,5 +1,6 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use anyhow::Result;
use futures::future::join;
use std::collections::HashMap;
use std::sync::Arc;
@@ -7,7 +8,7 @@ use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_twittercontext_proto::GetTwitterContextViewer;
use xai_twittercontext_proto::TwitterContextViewer;
use xai_visibility_filtering::models::FilteredReason;
use xai_visibility_filtering::models::{Action, FilteredReason};
use xai_visibility_filtering::vf_client::SafetyLevel;
use xai_visibility_filtering::vf_client::SafetyLevel::{TimelineHome, TimelineHomeRecommendations};
use xai_visibility_filtering::vf_client::VisibilityFilteringClient;
@@ -23,44 +24,55 @@ impl VFCandidateHydrator {
async fn fetch_vf_results(
client: &Arc<dyn VisibilityFilteringClient + Send + Sync>,
tweet_ids: Vec<i64>,
tweet_ids: Vec<u64>,
safety_level: SafetyLevel,
for_user_id: i64,
for_user_id: u64,
context: Option<TwitterContextViewer>,
) -> Result<HashMap<i64, Option<FilteredReason>>, String> {
) -> HashMap<u64, Result<Option<FilteredReason>>> {
if tweet_ids.is_empty() {
return Ok(HashMap::new());
return HashMap::new();
}
client
.get_result(tweet_ids, safety_level, for_user_id, context)
.await
.map_err(|e| e.to_string())
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for VFCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
) -> Vec<Result<PostCandidate, String>> {
let context = query.get_viewer();
let user_id = query.user_id;
let client = &self.vf_client;
let mut in_network_ids = Vec::new();
let mut oon_ids = Vec::new();
let mut in_network_ids: Vec<u64> = Vec::new();
let mut oon_ids: Vec<u64> = Vec::new();
for candidate in candidates.iter() {
if candidate.in_network.unwrap_or(false) {
in_network_ids.push(candidate.tweet_id);
} else {
oon_ids.push(candidate.tweet_id);
}
for &ancestor_id in &candidate.ancestors {
oon_ids.push(ancestor_id);
}
if let Some(quoted_id) = candidate.quoted_tweet_id {
oon_ids.push(quoted_id);
}
if let Some(retweeted_id) = candidate.retweeted_tweet_id {
in_network_ids.push(retweeted_id);
}
}
oon_ids.sort_unstable();
oon_ids.dedup();
let in_network_future = Self::fetch_vf_results(
client,
in_network_ids,
@@ -78,24 +90,73 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for VFCandidateHydrator {
);
let (in_network_result, oon_result) = join(in_network_future, oon_future).await;
let mut result: HashMap<i64, Option<FilteredReason>> = HashMap::new();
result.extend(in_network_result?);
result.extend(oon_result?);
let mut all_results: HashMap<u64, Result<Option<FilteredReason>>> = HashMap::new();
all_results.extend(in_network_result);
all_results.extend(oon_result);
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for candidate in candidates {
let visibility_reason = result.get(&candidate.tweet_id);
let visibility_reason = visibility_reason.unwrap_or(&None);
let hydrated = PostCandidate {
visibility_reason: visibility_reason.clone(),
..Default::default()
let primary_result = all_results.get(&candidate.tweet_id);
let visibility_reason = match primary_result {
Some(Ok(Some(reason))) => Some(reason.clone()),
_ => None,
};
let drop_ancillary = should_drop_ancillary(candidate, &all_results);
let hydrated = match primary_result {
Some(Err(err)) => Err(err.to_string()),
_ => Ok(PostCandidate {
visibility_reason,
drop_ancillary_posts: Some(drop_ancillary),
..Default::default()
}),
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
hydrated_candidates
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.visibility_reason = hydrated.visibility_reason;
candidate.drop_ancillary_posts = hydrated.drop_ancillary_posts;
}
}
fn should_drop_ancillary(
candidate: &PostCandidate,
vf_results: &HashMap<u64, Result<Option<FilteredReason>>>,
) -> bool {
for &ancestor_id in &candidate.ancestors {
if let Some(Ok(Some(reason))) = vf_results.get(&ancestor_id)
&& should_drop_reason(reason)
{
return true;
}
}
if let Some(quoted_id) = candidate.quoted_tweet_id
&& let Some(Ok(Some(reason))) = vf_results.get(&quoted_id)
&& should_drop_reason(reason)
{
return true;
}
if let Some(retweeted_id) = candidate.retweeted_tweet_id
&& let Some(Ok(Some(reason))) = vf_results.get(&retweeted_id)
&& should_drop_reason(reason)
{
return true;
}
false
}
fn should_drop_reason(reason: &FilteredReason) -> bool {
match reason {
FilteredReason::SafetyResult(safety_result) => {
matches!(safety_result.action, Action::Drop(_))
}
_ => true,
}
}

View File

@@ -1,62 +1,85 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::candidate_features::MediaInfo;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use crate::models::candidate::{CandidateHelpers, PostCandidate};
use crate::models::query::ScoredPostsQuery;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
pub struct VideoDurationCandidateHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
pub cache: MokaCache<u64, Option<i32>>,
}
impl VideoDurationCandidateHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
let cache = default_moka_cache();
Self { tes_client, cache }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for VideoDurationCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for VideoDurationCandidateHydrator {
type CacheKey = u64;
type CacheValue = Option<i32>;
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.has_cached_posts
}
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
&self.cache
}
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
candidate.get_original_tweet_id()
}
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
hydrated.min_video_duration_ms
}
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
PostCandidate {
min_video_duration_ms: value,
..Default::default()
}
}
async fn hydrate_from_client(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
) -> Vec<Result<PostCandidate, String>> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let tweet_ids: Vec<u64> = candidates
.iter()
.map(|c| c.get_original_tweet_id())
.collect();
let post_features = client.get_tweet_media_entities(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let durations = client.get_min_video_durations(tweet_ids.clone()).await;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let media_entities = post_features.and_then(|x| x.as_ref());
let video_duration_ms = media_entities.and_then(|entities| {
entities.iter().find_map(|entity| {
if let Some(MediaInfo::VideoInfo(video_info)) = &entity.media_info {
Some(video_info.duration_millis)
} else {
None
}
})
});
let hydrated = PostCandidate {
video_duration_ms,
..Default::default()
let hydrated = match durations.get(&tweet_id) {
Some(Ok(min_video_duration_ms)) => Ok(PostCandidate {
min_video_duration_ms: min_video_duration_ms.map(|v| v as i32),
..Default::default()
}),
None => Err(format!(
"Missing min video duration for tweet_id={}",
tweet_id
)),
Some(Err(err)) => Err(err.to_string()),
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
hydrated_candidates
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.video_duration_ms = hydrated.video_duration_ms;
candidate.min_video_duration_ms = hydrated.min_video_duration_ms;
}
}

View File

@@ -0,0 +1,278 @@
use crate::clients::ad_index_client::{AdIndexClient, MockAdIndexClient, ProdAdIndexClient};
use crate::clients::kafka_publisher_client::{KafkaPublisherClient, MockKafkaPublisherClient};
use crate::clients::past_request_timestamps_client::{
MockPastRequestTimestampsClient, PastRequestTimestampsClient, ProdPastRequestTimestampsClient,
};
use crate::clients::prompts_client::{MockPromptsClient, ProdPromptsClient, PromptsClient};
use crate::clients::served_history_client::{
MockServedHistoryClient, ProdServedHistoryClient, ServedHistoryClient,
};
use crate::clients::tweet_entity_service_client::{MockTESClient, ProdTESClient, TESClient};
use crate::clients::who_to_follow_client::{
MockWhoToFollowClient, ProdWhoToFollowClient, WhoToFollowClient,
};
use crate::models::query::ScoredPostsQuery;
use crate::params;
use crate::query_hydrators::past_request_timestamps_query_hydrator::PastRequestTimestampsQueryHydrator;
use crate::query_hydrators::served_history_query_hydrator::ServedHistoryQueryHydrator;
use crate::scored_posts_server::ScoredPostsServer;
use crate::selectors::BlenderSelector;
use crate::side_effects::ads_injection_logging_side_effect::AdsInjectionLoggingSideEffect;
use crate::side_effects::client_events_kafka_side_effect::ClientEventsKafkaSideEffect;
use crate::side_effects::for_you_response_stats_side_effect::ForYouResponseStatsSideEffect;
use crate::side_effects::publish_seen_ids_to_kafka_side_effect::PublishSeenIdsToKafkaSideEffect;
use crate::side_effects::served_candidates_kafka_side_effect::ServedCandidatesKafkaSideEffect;
use crate::side_effects::truncate_served_history_side_effect::TruncateServedHistorySideEffect;
use crate::side_effects::update_past_request_timestamps_side_effect::UpdatePastRequestTimestampsSideEffect;
use crate::side_effects::update_served_history_side_effect::UpdateServedHistorySideEffect;
use crate::sources::ads_source::AdsSource;
use crate::sources::prompts_source::PromptsSource;
use crate::sources::push_to_home_source::PushToHomeSource;
use crate::sources::scored_posts_source::ScoredPostsSource;
use crate::sources::who_to_follow_source::WhoToFollowSource;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
use xai_candidate_pipeline::component_library::clients::{
MockReplyMixerClient, ProdReplyMixerClient, ReplyMixerClient,
};
use xai_candidate_pipeline::filter::Filter;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
use xai_candidate_pipeline::scorer::Scorer;
use xai_candidate_pipeline::selector::Selector;
use xai_candidate_pipeline::side_effect::SideEffect;
use xai_candidate_pipeline::source::Source;
use xai_home_mixer_proto::FeedItem;
pub struct ForYouCandidatePipeline {
query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>>,
sources: Vec<Box<dyn Source<ScoredPostsQuery, FeedItem>>>,
selector: BlenderSelector,
side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, FeedItem>>>>,
}
impl ForYouCandidatePipeline {
pub async fn new(scored_posts_server: Arc<ScoredPostsServer>, datacenter: &str) -> Self {
let (
ad_index_client,
ads_injection_logging,
publish_seen_ids,
served_candidates,
client_events,
served_history_client,
who_to_follow_client,
prompts_client,
past_request_timestamps_client,
tes_client,
reply_mixer_client,
) = tokio::join!(
async {
Arc::new(
ProdAdIndexClient::new(datacenter)
.await
.expect("Failed to create AdIndex client"),
) as Arc<dyn AdIndexClient + Send + Sync>
},
AdsInjectionLoggingSideEffect::prod(),
PublishSeenIdsToKafkaSideEffect::prod(),
ServedCandidatesKafkaSideEffect::prod(),
ClientEventsKafkaSideEffect::prod(),
async {
Arc::new(
ProdServedHistoryClient::new(datacenter)
.await
.expect("Failed to create ServedHistoryClient"),
) as Arc<dyn ServedHistoryClient>
},
async {
Arc::new(
ProdWhoToFollowClient::new(datacenter)
.await
.expect("Failed to create WhoToFollowClient"),
) as Arc<dyn WhoToFollowClient + Send + Sync>
},
async {
Arc::new(
ProdPromptsClient::new(datacenter)
.await
.expect("Failed to create PromptsClient"),
) as Arc<dyn PromptsClient + Send + Sync>
},
async {
Arc::new(
ProdPastRequestTimestampsClient::new(datacenter)
.await
.expect("Failed to create PastRequestTimestampsClient"),
) as Arc<dyn PastRequestTimestampsClient>
},
async {
Arc::new(
ProdTESClient::new(None, datacenter)
.await
.expect("Failed to create TES client"),
) as Arc<dyn TESClient + Send + Sync>
},
async {
Arc::new(
ProdReplyMixerClient::new(datacenter)
.await
.expect("Failed to create ReplyMixer client"),
) as Arc<dyn ReplyMixerClient>
},
);
Self::build(
scored_posts_server,
ad_index_client,
ads_injection_logging,
publish_seen_ids,
served_candidates,
client_events,
served_history_client,
who_to_follow_client,
prompts_client,
past_request_timestamps_client,
tes_client,
reply_mixer_client,
)
}
fn build(
scored_posts_server: Arc<ScoredPostsServer>,
ad_index_client: Arc<dyn AdIndexClient + Send + Sync>,
ads_injection_logging: AdsInjectionLoggingSideEffect,
publish_seen_ids: PublishSeenIdsToKafkaSideEffect,
served_candidates: ServedCandidatesKafkaSideEffect,
client_events: ClientEventsKafkaSideEffect,
served_history_client: Arc<dyn ServedHistoryClient>,
who_to_follow_client: Arc<dyn WhoToFollowClient + Send + Sync>,
prompts_client: Arc<dyn PromptsClient + Send + Sync>,
past_request_timestamps_client: Arc<dyn PastRequestTimestampsClient>,
tes_client: Arc<dyn TESClient + Send + Sync>,
reply_mixer_client: Arc<dyn ReplyMixerClient>,
) -> Self {
let query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>> = vec![
Box::new(ServedHistoryQueryHydrator::from_client(Arc::clone(
&served_history_client,
))),
Box::new(PastRequestTimestampsQueryHydrator::new(Arc::clone(
&past_request_timestamps_client,
))),
];
let sources: Vec<Box<dyn Source<ScoredPostsQuery, FeedItem>>> = vec![
Box::new(ScoredPostsSource {
scored_posts_server,
}),
Box::new(AdsSource { ad_index_client }),
Box::new(WhoToFollowSource {
who_to_follow_client,
}),
Box::new(PromptsSource { prompts_client }),
Box::new(PushToHomeSource {
tes_client,
reply_mixer_client,
}),
];
let selector = BlenderSelector::new();
let side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, FeedItem>>>> =
Arc::new(vec![
Box::new(ads_injection_logging),
Box::new(publish_seen_ids),
Box::new(served_candidates),
Box::new(client_events),
Box::new(ForYouResponseStatsSideEffect),
Box::new(UpdatePastRequestTimestampsSideEffect::new(
past_request_timestamps_client,
)),
Box::new(UpdateServedHistorySideEffect::new(Arc::clone(
&served_history_client,
))),
Box::new(TruncateServedHistorySideEffect::new(served_history_client)),
]);
Self {
query_hydrators,
sources,
selector,
side_effects,
}
}
pub async fn mock(scored_posts_server: Arc<ScoredPostsServer>) -> Self {
let ad_index_client: Arc<dyn AdIndexClient + Send + Sync> = Arc::new(MockAdIndexClient);
let mock_kafka = Arc::new(MockKafkaPublisherClient) as Arc<dyn KafkaPublisherClient>;
let ads_injection = AdsInjectionLoggingSideEffect::new(Arc::clone(&mock_kafka));
let publish_seen_ids = PublishSeenIdsToKafkaSideEffect::new(Arc::clone(&mock_kafka));
let served_candidates = ServedCandidatesKafkaSideEffect::new(Arc::clone(&mock_kafka));
let client_events = ClientEventsKafkaSideEffect::new(mock_kafka);
let served_history_client: Arc<dyn ServedHistoryClient> = Arc::new(MockServedHistoryClient);
let who_to_follow_client: Arc<dyn WhoToFollowClient + Send + Sync> =
Arc::new(MockWhoToFollowClient);
let prompts_client: Arc<dyn PromptsClient + Send + Sync> = Arc::new(MockPromptsClient);
let past_request_timestamps_client: Arc<dyn PastRequestTimestampsClient> =
Arc::new(MockPastRequestTimestampsClient);
let tes_client: Arc<dyn TESClient + Send + Sync> = Arc::new(MockTESClient::default());
let reply_mixer_client: Arc<dyn ReplyMixerClient> = Arc::new(MockReplyMixerClient);
Self::build(
scored_posts_server,
ad_index_client,
ads_injection,
publish_seen_ids,
served_candidates,
client_events,
served_history_client,
who_to_follow_client,
prompts_client,
past_request_timestamps_client,
tes_client,
reply_mixer_client,
)
}
}
#[async_trait]
impl CandidatePipeline<ScoredPostsQuery, FeedItem> for ForYouCandidatePipeline {
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<ScoredPostsQuery>>] {
&self.query_hydrators
}
fn sources(&self) -> &[Box<dyn Source<ScoredPostsQuery, FeedItem>>] {
&self.sources
}
fn hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, FeedItem>>] {
&[]
}
fn filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, FeedItem>>] {
&[]
}
fn scorers(&self) -> &[Box<dyn Scorer<ScoredPostsQuery, FeedItem>>] {
&[]
}
fn selector(&self) -> &dyn Selector<ScoredPostsQuery, FeedItem> {
&self.selector
}
fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, FeedItem>>] {
&[]
}
fn post_selection_filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, FeedItem>>] {
&[]
}
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, FeedItem>>>> {
Arc::clone(&self.side_effects)
}
fn result_size(&self) -> usize {
params::FOR_YOU_MAX_RESULT_SIZE
}
}

View File

@@ -1,5 +1,2 @@
pub mod candidate;
pub mod candidate_features;
pub mod for_you_candidate_pipeline;
pub mod phoenix_candidate_pipeline;
pub mod query;
pub mod query_features;

View File

@@ -1,51 +1,130 @@
use crate::candidate_hydrators::ads_brand_safety_hydrator::AdsBrandSafetyHydrator;
use crate::candidate_hydrators::ads_brand_safety_vf_hydrator::AdsBrandSafetyVfHydrator;
use crate::candidate_hydrators::blocked_by_hydrator::BlockedByHydrator;
use crate::candidate_hydrators::core_data_candidate_hydrator::CoreDataCandidateHydrator;
use crate::candidate_hydrators::filtered_topics_hydrator::FilteredTopicsHydrator;
use crate::candidate_hydrators::following_replied_users_hydrator::FollowingRepliedUsersHydrator;
use crate::candidate_hydrators::gizmoduck_hydrator::GizmoduckCandidateHydrator;
use crate::candidate_hydrators::has_media_hydrator::HasMediaHydrator;
use crate::candidate_hydrators::in_network_candidate_hydrator::InNetworkCandidateHydrator;
use crate::candidate_hydrators::language_code_hydrator::LanguageCodeHydrator;
use crate::candidate_hydrators::mutual_follow_jaccard_hydrator::MutualFollowJaccardHydrator;
use crate::candidate_hydrators::quote_hydrator::QuoteHydrator;
use crate::candidate_hydrators::subscription_hydrator::SubscriptionHydrator;
use crate::candidate_hydrators::tweet_type_metrics_hydrator::TweetTypeMetricsHydrator;
use crate::candidate_hydrators::vf_candidate_hydrator::VFCandidateHydrator;
use crate::candidate_hydrators::video_duration_candidate_hydrator::VideoDurationCandidateHydrator;
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::gizmoduck_client::{GizmoduckClient, ProdGizmoduckClient};
use crate::clients::phoenix_prediction_client::{
PhoenixPredictionClient, ProdPhoenixPredictionClient,
use crate::clients::followed_grok_topics_store_client::{
FollowedGrokTopicsStoreClient, MockFollowedGrokTopicsStoreClient,
ProdFollowedGrokTopicsStoreClient,
};
use crate::clients::phoenix_retrieval_client::{
PhoenixRetrievalClient, ProdPhoenixRetrievalClient,
use crate::clients::followed_starter_packs_store_client::{
FollowedStarterPacksStoreClient, MockFollowedStarterPacksStoreClient,
ProdFollowedStarterPacksStoreClient,
};
use crate::clients::gender_prediction_client::{
GenderPredictionGrpcClient, MockGenderPredictionGrpcClient, ProdGenderPredictionGrpcClient,
};
use crate::clients::gizmoduck_client::{GizmoduckClient, MockGizmoduckClient, ProdGizmoduckClient};
use crate::clients::impressed_posts_client::ImpressedPostsClient;
use crate::clients::kafka_publisher_client::{
KafkaCluster, KafkaPublisherClient, MockKafkaPublisherClient, PHOENIX_SCORES_TOPIC,
ProdKafkaPublisherClient, RERANKING_TOPIC,
};
use crate::clients::s2s::{S2S_CHAIN_PATH, S2S_CRT_PATH, S2S_KEY_PATH};
use crate::clients::socialgraph_client::SocialGraphClient;
use crate::clients::strato_client::{ProdStratoClient, StratoClient};
use crate::clients::thunder_client::ThunderClient;
use crate::clients::tweet_entity_service_client::{ProdTESClient, TESClient};
use crate::clients::uas_fetcher::UserActionSequenceFetcher;
use crate::clients::tweet_entity_service_client::{MockTESClient, ProdTESClient, TESClient};
use crate::clients::user_action_aggregation_client::{
MockUserActionAggregationClient, ProdUserActionAggregationClient, UserActionAggregationClient,
};
use crate::clients::user_demographics_client::{
MockUserDemographicsClient, ProdUserDemographicsClient, UserDemographicsClient,
};
use crate::clients::user_inferred_gender_store_client::{
MockUserInferredGenderStoreClient, ProdUserInferredGenderStoreClient,
UserInferredGenderStoreClient,
};
use crate::clients::vm_ranker_client::{MockVMRankerClient, ProdVMRankerClient, VMRankerClient};
use crate::filters::age_filter::AgeFilter;
use crate::filters::ancillary_vf_filter::AncillaryVFFilter;
use crate::filters::author_socialgraph_filter::AuthorSocialgraphFilter;
use crate::filters::core_data_hydration_filter::CoreDataHydrationFilter;
use crate::filters::dedup_conversation_filter::DedupConversationFilter;
use crate::filters::drop_duplicates_filter::DropDuplicatesFilter;
use crate::filters::ineligible_subscription_filter::IneligibleSubscriptionFilter;
use crate::filters::muted_keyword_filter::MutedKeywordFilter;
use crate::filters::new_user_topic_ids_filter::NewUserTopicIdsFilter;
use crate::filters::previously_seen_posts_backup_filter::PreviouslySeenPostsBackupFilter;
use crate::filters::previously_seen_posts_filter::PreviouslySeenPostsFilter;
use crate::filters::previously_served_posts_filter::PreviouslyServedPostsFilter;
use crate::filters::retweet_deduplication_filter::RetweetDeduplicationFilter;
use crate::filters::self_tweet_filter::SelfTweetFilter;
use crate::filters::topic_ids_filter::TopicIdsFilter;
use crate::filters::vf_filter::VFFilter;
use crate::filters::video_filter::VideoFilter;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use crate::params;
use crate::query_hydrators::user_action_seq_query_hydrator::UserActionSeqQueryHydrator;
use crate::query_hydrators::user_features_query_hydrator::UserFeaturesQueryHydrator;
use crate::scorers::author_diversity_scorer::AuthorDiversityScorer;
use crate::scorers::oon_scorer::OONScorer;
use crate::query_hydrators::blocked_user_ids_query_hydrator::BlockedUserIdsQueryHydrator;
use crate::query_hydrators::cached_posts_query_hydrator::CachedPostsQueryHydrator;
use crate::query_hydrators::followed_grok_topics_query_hydrator::FollowedGrokTopicsQueryHydrator;
use crate::query_hydrators::followed_starter_packs_query_hydrator::FollowedStarterPacksQueryHydrator;
use crate::query_hydrators::followed_user_ids_query_hydrator::FollowedUserIdsQueryHydrator;
use crate::query_hydrators::ip_query_hydrator::IpQueryHydrator;
use crate::query_hydrators::impressed_posts_query_hydrator::ImpressedPostsQueryHydrator;
use crate::query_hydrators::impression_bloom_filter_query_hydrator::ImpressionBloomFilterQueryHydrator;
use crate::query_hydrators::inferred_grok_topics_query_hydrator::InferredGrokTopicsQueryHydrator;
use crate::query_hydrators::muted_user_ids_query_hydrator::MutedUserIdsQueryHydrator;
use crate::query_hydrators::mutual_follow_query_hydrator::MutualFollowQueryHydrator;
use crate::query_hydrators::retrieval_sequence_query_hydrator::RetrievalSequenceQueryHydrator;
use crate::query_hydrators::scoring_sequence_query_hydrator::ScoringSequenceQueryHydrator;
use crate::query_hydrators::subscribed_user_ids_query_hydrator::SubscribedUserIdsQueryHydrator;
use crate::query_hydrators::user_demographics_query_hydrator::UserDemographicsQueryHydrator;
use crate::query_hydrators::user_inferred_gender_query_hydrator::UserInferredGenderQueryHydrator;
use crate::scorers::phoenix_scorer::PhoenixScorer;
use crate::scorers::weighted_scorer::WeightedScorer;
use crate::scorers::ranking_scorer::RankingScorer;
use crate::scorers::vm_ranker::VMRanker;
use crate::selectors::TopKScoreSelector;
use crate::side_effects::cache_request_info_side_effect::CacheRequestInfoSideEffect;
use crate::side_effects::mutual_follow_stats_side_effect::MutualFollowStatsSideEffect;
use crate::side_effects::phoenix_experiments_side_effect::PhoenixExperimentsSideEffect;
use crate::side_effects::phoenix_request_cache_side_effect::PhoenixRequestCacheSideEffect;
use crate::side_effects::redis_post_candidate_cache_side_effect::RedisPostCandidateCacheSideEffect;
use crate::side_effects::reranking_kafka_side_effect::RerankingKafkaSideEffect;
use crate::side_effects::scored_stats_side_effect::ScoredStatsSideEffect;
use crate::sources::cached_posts_source::CachedPostsSource;
use crate::sources::phoenix_moe_source::PhoenixMOESource;
use crate::sources::phoenix_source::PhoenixSource;
use crate::sources::phoenix_topics_source::PhoenixTopicsSource;
use crate::sources::thunder_source::ThunderSource;
use crate::sources::tweet_mixer_source::TweetMixerSource;
use xai_candidate_pipeline::component_library::clients::{
MockTweetMixerClient, ProdTweetMixerClient, TweetMixerClient,
};
use std::sync::Arc;
use std::time::Duration;
use tonic::async_trait;
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
use xai_candidate_pipeline::component_library::clients::ThunderClient;
use xai_candidate_pipeline::component_library::clients::egress_prediction_client::EgressPhoenixPredictionClient;
use xai_candidate_pipeline::component_library::clients::phoenix_prediction_client::{
MockPredictClient, PhoenixPredictionClient, ProdPhoenixPredictionClient,
};
use xai_candidate_pipeline::component_library::clients::phoenix_retrieval_client::{
MockRetrievalClient, PhoenixRetrievalClient, PhoenixRetrievalCluster,
ProdPhoenixRetrievalClient,
};
use xai_candidate_pipeline::component_library::clients::redis_client::{
MockRedisClient, RedisClient,
};
use xai_candidate_pipeline::component_library::clients::{
ImpressionBloomFilterClient, MockImpressionBloomFilterClient, ProdImpressionBloomFilterClient,
};
use xai_candidate_pipeline::component_library::clients::{
MockSocialGraphClient, SocialGraphClient, SocialGraphClientOps,
};
use xai_candidate_pipeline::component_library::clients::{
MockStratoClient, ProdStratoClient, StratoClient,
};
use xai_candidate_pipeline::filter::Filter;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
@@ -53,9 +132,13 @@ use xai_candidate_pipeline::scorer::Scorer;
use xai_candidate_pipeline::selector::Selector;
use xai_candidate_pipeline::side_effect::SideEffect;
use xai_candidate_pipeline::source::Source;
use xai_geo_ip::GeoIpLocationClient;
use xai_redis_client::{XdsRedisClient, XdsRedisConfig};
use xai_visibility_filtering::vf_client::{
ProdVisibilityFilteringClient, VisibilityFilteringClient,
MockVisibilityFilteringClient, ProdVisibilityFilteringClient, VisibilityFilteringClient,
};
use xai_visibility_filtering::vf_safety_labels_client::{MockVfClient, ProdVfClient, VfClient};
use xai_x_rpc::wily_lookup_service::ShardCoordinate;
pub struct PhoenixCandidatePipeline {
query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>>,
@@ -70,42 +153,124 @@ pub struct PhoenixCandidatePipeline {
}
impl PhoenixCandidatePipeline {
async fn build_with_clients(
uas_fetcher: Arc<UserActionSequenceFetcher>,
pub(crate) async fn build_with_clients(
user_action_aggregation_client: Arc<dyn UserActionAggregationClient + Send + Sync>,
phoenix_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
egress_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
thunder_client: Arc<ThunderClient>,
strato_client: Arc<dyn StratoClient + Send + Sync>,
tweet_mixer_client: Arc<dyn TweetMixerClient>,
tes_client: Arc<dyn TESClient + Send + Sync>,
gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>,
redis_client: Arc<dyn RedisClient + Send + Sync>,
phoenix_kafka_client: Arc<dyn KafkaPublisherClient>,
reranking_kafka_client: Arc<dyn KafkaPublisherClient>,
socialgraph_client: Arc<dyn SocialGraphClientOps>,
vm_ranker_client: Arc<dyn VMRankerClient>,
safety_label_client: Arc<dyn xai_safety_label_store::SafetyLabelStoreClient>,
vf_safety_labels_client: Arc<dyn VfClient>,
phoenix_request_cache_redis_atla_client: Arc<dyn RedisClient + Send + Sync>,
phoenix_request_cache_redis_pdxa_client: Arc<dyn RedisClient + Send + Sync>,
impression_bloom_filter_client: Arc<dyn ImpressionBloomFilterClient>,
ip_client: Arc<GeoIpLocationClient>,
user_demographics_client: Arc<dyn UserDemographicsClient>,
user_inferred_gender_store_client: Arc<dyn UserInferredGenderStoreClient>,
user_inferred_gender_grpc_client: Arc<dyn GenderPredictionGrpcClient>,
impressed_posts_client: Arc<dyn ImpressedPostsClient>,
followed_grok_topics_client: Arc<dyn FollowedGrokTopicsStoreClient>,
followed_starter_packs_client: Arc<dyn FollowedStarterPacksStoreClient>,
) -> PhoenixCandidatePipeline {
// Query Hydrators
let query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>> = vec![
Box::new(UserActionSeqQueryHydrator::new(uas_fetcher)),
Box::new(UserFeaturesQueryHydrator {
Box::new(ScoringSequenceQueryHydrator::new(
user_action_aggregation_client.clone(),
)),
Box::new(RetrievalSequenceQueryHydrator::new(
user_action_aggregation_client,
)),
Box::new(BlockedUserIdsQueryHydrator {
socialgraph_client: socialgraph_client.clone(),
}),
Box::new(MutedUserIdsQueryHydrator {
socialgraph_client: socialgraph_client.clone(),
}),
Box::new(FollowedUserIdsQueryHydrator {
socialgraph_client: socialgraph_client.clone(),
}),
Box::new(SubscribedUserIdsQueryHydrator {
socialgraph_client: socialgraph_client.clone(),
}),
Box::new(CachedPostsQueryHydrator {
redis_client: redis_client.clone(),
}),
Box::new(MutualFollowQueryHydrator {
strato_client: strato_client.clone(),
}),
Box::new(UserDemographicsQueryHydrator {
client: user_demographics_client,
}),
Box::new(FollowedGrokTopicsQueryHydrator::new(
followed_grok_topics_client,
)),
Box::new(FollowedStarterPacksQueryHydrator::new(
followed_starter_packs_client,
)),
Box::new(InferredGrokTopicsQueryHydrator {
strato_client: strato_client.clone(),
}),
Box::new(ImpressionBloomFilterQueryHydrator {
client: impression_bloom_filter_client,
}),
Box::new(IpQueryHydrator {
client: ip_client,
}),
Box::new(UserInferredGenderQueryHydrator::new(
user_inferred_gender_store_client,
user_inferred_gender_grpc_client,
)),
];
// Sources
let _impressed_posts_hydrator = ImpressedPostsQueryHydrator {
client: impressed_posts_client,
};
let phoenix_source = Box::new(PhoenixSource {
phoenix_retrieval_client: phoenix_retrieval_client.clone(),
});
let phoenix_topics_source = Box::new(PhoenixTopicsSource {
phoenix_retrieval_client: phoenix_retrieval_client.clone(),
});
let phoenix_moe_source = Box::new(PhoenixMOESource {
phoenix_retrieval_client,
});
let thunder_source = Box::new(ThunderSource { thunder_client });
let sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>> =
vec![phoenix_source, thunder_source];
let tweet_mixer_source = Box::new(TweetMixerSource { tweet_mixer_client });
let cached_posts_source = Box::new(CachedPostsSource);
let sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>> = vec![
thunder_source,
tweet_mixer_source,
phoenix_source,
phoenix_topics_source,
phoenix_moe_source,
cached_posts_source,
];
// Hydrators
let hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> = vec![
Box::new(InNetworkCandidateHydrator),
Box::new(CoreDataCandidateHydrator::new(tes_client.clone()).await),
Box::new(QuoteHydrator::new(tes_client.clone(), socialgraph_client.clone()).await),
Box::new(VideoDurationCandidateHydrator::new(tes_client.clone()).await),
Box::new(HasMediaHydrator::new(tes_client.clone()).await),
Box::new(SubscriptionHydrator::new(tes_client.clone()).await),
Box::new(GizmoduckCandidateHydrator::new(gizmoduck_client).await),
Box::new(BlockedByHydrator::new(socialgraph_client).await),
Box::new(FilteredTopicsHydrator {
strato_client: strato_client.clone(),
}),
Box::new(LanguageCodeHydrator::new(tes_client.clone()).await),
];
// Filters
let filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> = vec![
Box::new(DropDuplicatesFilter),
Box::new(CoreDataHydrationFilter),
@@ -114,37 +279,63 @@ impl PhoenixCandidatePipeline {
Box::new(RetweetDeduplicationFilter),
Box::new(IneligibleSubscriptionFilter),
Box::new(PreviouslySeenPostsFilter),
Box::new(PreviouslySeenPostsBackupFilter),
Box::new(PreviouslyServedPostsFilter),
Box::new(MutedKeywordFilter::new()),
Box::new(AuthorSocialgraphFilter),
Box::new(VideoFilter),
Box::new(TopicIdsFilter),
Box::new(NewUserTopicIdsFilter),
];
// Scorers
let phoenix_scorer = Box::new(PhoenixScorer { phoenix_client });
let weighted_scorer = Box::new(WeightedScorer);
let author_diversity_scorer = Box::new(AuthorDiversityScorer::default());
let oon_scorer = Box::new(OONScorer);
let scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>> = vec![
phoenix_scorer,
weighted_scorer,
author_diversity_scorer,
oon_scorer,
];
let phoenix_scorer = Box::new(PhoenixScorer {
phoenix_client: phoenix_client.clone(),
egress_client: Arc::clone(&egress_client),
});
let ranking_scorer = Box::new(RankingScorer);
let vm_ranker = Box::new(VMRanker {
client: vm_ranker_client,
});
let scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>> =
vec![phoenix_scorer, ranking_scorer, vm_ranker];
// Selector
let selector = TopKScoreSelector;
// Post-selection hydrators
let post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> =
vec![Box::new(VFCandidateHydrator::new(vf_client.clone()).await)];
let post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> = vec![
Box::new(VFCandidateHydrator::new(vf_client.clone()).await),
Box::new(AdsBrandSafetyHydrator::new(safety_label_client)),
Box::new(AdsBrandSafetyVfHydrator {
client: vf_safety_labels_client,
}),
Box::new(TweetTypeMetricsHydrator::new()),
Box::new(FollowingRepliedUsersHydrator),
Box::new(MutualFollowJaccardHydrator {
strato_client: strato_client.clone(),
}),
];
// Post-selection filters
let post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> =
vec![Box::new(VFFilter), Box::new(DedupConversationFilter)];
let post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> = vec![
Box::new(VFFilter),
Box::new(AncillaryVFFilter),
Box::new(DedupConversationFilter),
];
// Side Effects
let side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>> =
Arc::new(vec![Box::new(CacheRequestInfoSideEffect { strato_client })]);
Arc::new(vec![
Box::new(PhoenixExperimentsSideEffect::new(
phoenix_client,
egress_client,
phoenix_kafka_client,
)),
Box::new(RerankingKafkaSideEffect::new(reranking_kafka_client)),
Box::new(RedisPostCandidateCacheSideEffect::new(redis_client)),
Box::new(ScoredStatsSideEffect),
Box::new(MutualFollowStatsSideEffect),
Box::new(PhoenixRequestCacheSideEffect::new(
phoenix_request_cache_redis_atla_client,
phoenix_request_cache_redis_pdxa_client,
)),
]);
PhoenixCandidatePipeline {
query_hydrators,
@@ -159,54 +350,379 @@ impl PhoenixCandidatePipeline {
}
}
pub async fn prod() -> PhoenixCandidatePipeline {
let uas_fetcher =
Arc::new(UserActionSequenceFetcher::new().expect("Failed to create UAS fetcher"));
let _sgs_client = Arc::new(SocialGraphClient::new());
let phoenix_client = Arc::new(
ProdPhoenixPredictionClient::new()
.await
.expect("Failed to create Phoenix prediction client"),
);
let phoenix_retrieval_client = Arc::new(
ProdPhoenixRetrievalClient::new()
.await
.expect("Failed to create Phoenix retrieval client"),
);
let thunder_client = Arc::new(ThunderClient::new().await);
let strato_client = Arc::new(
ProdStratoClient::new()
.await
.expect("Failed to create Strato client"),
);
let tes_client = Arc::new(
ProdTESClient::new()
.await
.expect("Failed to create TES client"),
);
let gizmoduck_client = Arc::new(
ProdGizmoduckClient::new()
.await
.expect("Failed to create Gizmoduck client"),
);
let vf_client = Arc::new(
ProdVisibilityFilteringClient::new(
S2S_CHAIN_PATH.clone(),
S2S_CRT_PATH.clone(),
S2S_KEY_PATH.clone()
)
.await
.expect("Failed to create VF client"),
pub async fn prod(
shard_coordinate: Option<ShardCoordinate>,
datacenter: &str,
) -> PhoenixCandidatePipeline {
let local_cache_eds = String::new();
let atla_phoenix_cache_eds = "";
let pdxa_phoenix_cache_eds = "";
let (
flock_socialgraph_client,
user_action_aggregation_client,
phoenix_client,
egress_client,
phoenix_retrieval_client,
thunder_client,
strato_client,
tweet_mixer_client,
tes_client,
gizmoduck_client,
vf_client,
redis_client,
phoenix_request_cache_redis_atla_client,
phoenix_request_cache_redis_pdxa_client,
phoenix_kafka_client,
reranking_kafka_client,
vm_ranker_client,
safety_label_client,
vf_safety_labels_client,
impression_bloom_filter_client,
ip_client,
user_demographics_client,
user_inferred_gender_store_client,
user_inferred_gender_grpc_client,
impressed_posts_client,
followed_grok_topics_client,
followed_starter_packs_client,
) = tokio::join!(
async {
Arc::new(
SocialGraphClient::new(
datacenter,
&S2S_CHAIN_PATH,
&S2S_CRT_PATH,
&S2S_KEY_PATH,
)
.await
.expect("Failed to create flock SocialGraphClient"),
) as Arc<dyn SocialGraphClientOps>
},
async {
Arc::new(
ProdUserActionAggregationClient::new()
.await
.expect("Failed to create User Action Aggregation client"),
) as Arc<dyn UserActionAggregationClient + Send + Sync>
},
async {
Arc::new(
ProdPhoenixPredictionClient::new()
.await
.expect("Failed to create Phoenix prediction client"),
) as Arc<dyn PhoenixPredictionClient + Send + Sync>
},
async {
Arc::new(
EgressPhoenixPredictionClient::connect()
.await
.expect("Failed to connect to egress sidecar"),
) as Arc<dyn PhoenixPredictionClient + Send + Sync>
},
async {
Arc::new(
ProdPhoenixRetrievalClient::new(Some((
PhoenixRetrievalCluster::Experiment1Fou,
PhoenixRetrievalCluster::Experiment1Lap7,
)))
.await
.expect("Failed to create Phoenix retrieval client"),
) as Arc<dyn PhoenixRetrievalClient + Send + Sync>
},
async { Arc::new(ThunderClient::new().await) },
async {
Arc::new(
ProdStratoClient::new(shard_coordinate, datacenter)
.await
.expect("Failed to create Strato client"),
) as Arc<dyn StratoClient + Send + Sync>
},
async {
Arc::new(
ProdTweetMixerClient::new(datacenter)
.await
.expect("Failed to create TweetMixer client"),
) as Arc<dyn TweetMixerClient>
},
async {
Arc::new(
ProdTESClient::new(shard_coordinate, datacenter)
.await
.expect("Failed to create TES client"),
) as Arc<dyn TESClient + Send + Sync>
},
async {
Arc::new(
ProdGizmoduckClient::new(
shard_coordinate,
datacenter,
Some("home-mixer.prod".to_string()),
)
.await
.expect("Failed to create Gizmoduck client"),
) as Arc<dyn GizmoduckClient + Send + Sync>
},
async {
Arc::new(
ProdVisibilityFilteringClient::new(
S2S_CHAIN_PATH.clone(),
S2S_CRT_PATH.clone(),
S2S_KEY_PATH.clone(),
"home-mixer.prod".to_string(),
datacenter.to_string(),
)
.await
.expect("Failed to create VF client"),
) as Arc<dyn VisibilityFilteringClient + Send + Sync>
},
async {
Arc::new(
XdsRedisClient::new(XdsRedisConfig {
eds_resource_name: local_cache_eds.clone(),
})
.await
.expect("Failed to create xDS Redis client for local cache"),
) as Arc<dyn RedisClient + Send + Sync>
},
async {
Arc::new(
XdsRedisClient::new(XdsRedisConfig {
eds_resource_name: atla_phoenix_cache_eds.into(),
})
.await
.expect("Failed to create xDS Redis client for atla phoenix cache"),
) as Arc<dyn RedisClient + Send + Sync>
},
async {
Arc::new(
XdsRedisClient::new(XdsRedisConfig {
eds_resource_name: pdxa_phoenix_cache_eds.into(),
})
.await
.expect("Failed to create xDS Redis client for pdxa phoenix cache"),
) as Arc<dyn RedisClient + Send + Sync>
},
async {
Arc::new(
ProdKafkaPublisherClient::new(PHOENIX_SCORES_TOPIC, KafkaCluster::Aiml).await,
) as Arc<dyn KafkaPublisherClient>
},
async {
Arc::new(
ProdKafkaPublisherClient::new(RERANKING_TOPIC, KafkaCluster::Phoenix).await,
) as Arc<dyn KafkaPublisherClient>
},
async {
Arc::new(
ProdVMRankerClient::new()
.await
.expect("Failed to create VMRanker client"),
) as Arc<dyn VMRankerClient>
},
async {
let s2s = xai_manhattan::s2s::S2sConfig {
client_cert_path: S2S_CRT_PATH.clone(),
client_key_path: S2S_KEY_PATH.clone(),
ca_cert_path: S2S_CHAIN_PATH.clone(),
};
Arc::new(
xai_safety_label_store::ProdSafetyLabelStoreClient::new(datacenter, s2s)
.await
.expect("Failed to create SafetyLabelStore client"),
) as Arc<dyn xai_safety_label_store::SafetyLabelStoreClient>
},
async {
Arc::new(
ProdVfClient::new(datacenter)
.await
.expect("Failed to create VF SafetyLabels client")
.with_timeout_ms(500)
.with_max_batch_size(150),
) as Arc<dyn VfClient>
},
async {
Arc::new(
ProdImpressionBloomFilterClient::new(datacenter)
.await
.expect("Failed to create ImpressionBloomFilter client"),
) as Arc<dyn ImpressionBloomFilterClient>
},
async {
Arc::new(
GeoIpLocationClient::new(
&S2S_CHAIN_PATH,
&S2S_CRT_PATH,
&S2S_KEY_PATH,
datacenter,
)
.await
.expect("Failed to create GeoIpLocationClient"),
)
},
async {
let s2s = xai_manhattan::s2s::S2sConfig {
client_cert_path: S2S_CRT_PATH.clone(),
client_key_path: S2S_KEY_PATH.clone(),
ca_cert_path: S2S_CHAIN_PATH.clone(),
};
Arc::new(
ProdUserDemographicsClient::new(datacenter, s2s)
.await
.expect("Failed to create UserDemographics client"),
) as Arc<dyn UserDemographicsClient>
},
async {
let s2s = xai_manhattan::s2s::S2sConfig {
client_cert_path: S2S_CRT_PATH.clone(),
client_key_path: S2S_KEY_PATH.clone(),
ca_cert_path: S2S_CHAIN_PATH.clone(),
};
Arc::new(
ProdUserInferredGenderStoreClient::new(datacenter, s2s)
.await
.expect("Failed to create UserInferredGenderStore client"),
) as Arc<dyn UserInferredGenderStoreClient>
},
async {
Arc::new(
ProdGenderPredictionGrpcClient::new()
.await
.expect("Failed to create GenderPredictionGrpcClient"),
) as Arc<dyn GenderPredictionGrpcClient>
},
async {
Arc::new(
crate::clients::impressed_posts_client::ProdImpressedPostsClient::new(
datacenter,
)
.await
.expect("Failed to create ImpressedPosts client"),
) as Arc<dyn ImpressedPostsClient>
},
async {
let s2s = xai_manhattan::s2s::S2sConfig {
client_cert_path: S2S_CRT_PATH.clone(),
client_key_path: S2S_KEY_PATH.clone(),
ca_cert_path: S2S_CHAIN_PATH.clone(),
};
Arc::new(
ProdFollowedGrokTopicsStoreClient::new(datacenter, s2s)
.await
.expect("Failed to create FollowedGrokTopicsStore client"),
) as Arc<dyn FollowedGrokTopicsStoreClient>
},
async {
let s2s = xai_manhattan::s2s::S2sConfig {
client_cert_path: S2S_CRT_PATH.clone(),
client_key_path: S2S_KEY_PATH.clone(),
ca_cert_path: S2S_CHAIN_PATH.clone(),
};
Arc::new(
ProdFollowedStarterPacksStoreClient::new(datacenter, s2s)
.await
.expect("Failed to create FollowedStarterPacksStore client"),
) as Arc<dyn FollowedStarterPacksStoreClient>
},
);
PhoenixCandidatePipeline::build_with_clients(
uas_fetcher,
user_action_aggregation_client,
phoenix_client,
egress_client,
phoenix_retrieval_client,
thunder_client,
strato_client,
tweet_mixer_client,
tes_client,
gizmoduck_client,
vf_client,
redis_client,
phoenix_kafka_client,
reranking_kafka_client,
flock_socialgraph_client,
vm_ranker_client,
safety_label_client,
vf_safety_labels_client,
phoenix_request_cache_redis_atla_client,
phoenix_request_cache_redis_pdxa_client,
impression_bloom_filter_client,
ip_client,
user_demographics_client,
user_inferred_gender_store_client,
user_inferred_gender_grpc_client,
impressed_posts_client,
followed_grok_topics_client,
followed_starter_packs_client,
)
.await
}
pub async fn mock() -> PhoenixCandidatePipeline {
let user_action_aggregation_client = Arc::new(MockUserActionAggregationClient);
let phoenix_client = Arc::new(MockPredictClient);
let phoenix_retrieval_client = Arc::new(MockRetrievalClient);
let thunder_client = Arc::new(ThunderClient::mock());
let strato_client = Arc::new(MockStratoClient::default());
let tweet_mixer_client: Arc<dyn TweetMixerClient> = Arc::new(MockTweetMixerClient);
let tes_client = Arc::new(MockTESClient::default());
let gizmoduck_client = Arc::new(MockGizmoduckClient::default());
let vf_client = Arc::new(MockVisibilityFilteringClient);
let redis_client = Arc::new(MockRedisClient::default());
let kafka_client: Arc<dyn KafkaPublisherClient> = Arc::new(MockKafkaPublisherClient);
let reranking_kafka_client: Arc<dyn KafkaPublisherClient> =
Arc::new(MockKafkaPublisherClient);
let mock_socialgraph: Arc<dyn SocialGraphClientOps> = Arc::new(MockSocialGraphClient);
let vm_ranker_client: Arc<dyn VMRankerClient> = Arc::new(MockVMRankerClient);
let safety_label_client: Arc<dyn xai_safety_label_store::SafetyLabelStoreClient> =
Arc::new(xai_safety_label_store::MockSafetyLabelStoreClient);
let vf_safety_labels_client: Arc<dyn VfClient> = Arc::new(MockVfClient);
let phoenix_request_cache_redis_atla_client = Arc::new(MockRedisClient::default());
let phoenix_request_cache_redis_pdxa_client: Arc<dyn RedisClient + Send + Sync> =
Arc::new(MockRedisClient::default());
let impression_bloom_filter_client: Arc<dyn ImpressionBloomFilterClient> =
Arc::new(MockImpressionBloomFilterClient::default());
let ip_client = Arc::new(GeoIpLocationClient::mock());
let user_demographics_client: Arc<dyn UserDemographicsClient> =
Arc::new(MockUserDemographicsClient);
let user_inferred_gender_store_client: Arc<dyn UserInferredGenderStoreClient> =
Arc::new(MockUserInferredGenderStoreClient);
let user_inferred_gender_grpc_client: Arc<dyn GenderPredictionGrpcClient> =
Arc::new(MockGenderPredictionGrpcClient);
let impressed_posts_client: Arc<dyn ImpressedPostsClient> =
Arc::new(crate::clients::impressed_posts_client::MockImpressedPostsClient::default());
let followed_grok_topics_client: Arc<dyn FollowedGrokTopicsStoreClient> =
Arc::new(MockFollowedGrokTopicsStoreClient);
let followed_starter_packs_client: Arc<dyn FollowedStarterPacksStoreClient> =
Arc::new(MockFollowedStarterPacksStoreClient);
PhoenixCandidatePipeline::build_with_clients(
user_action_aggregation_client,
phoenix_client.clone(),
phoenix_client,
phoenix_retrieval_client,
thunder_client,
strato_client,
tweet_mixer_client,
tes_client,
gizmoduck_client,
vf_client,
redis_client,
kafka_client,
reranking_kafka_client,
mock_socialgraph,
vm_ranker_client,
safety_label_client,
vf_safety_labels_client,
phoenix_request_cache_redis_atla_client,
phoenix_request_cache_redis_pdxa_client,
impression_bloom_filter_client,
ip_client,
user_demographics_client,
user_inferred_gender_store_client,
user_inferred_gender_grpc_client,
impressed_posts_client,
followed_grok_topics_client,
followed_starter_packs_client,
)
.await
}
@@ -253,3 +769,4 @@ impl CandidatePipeline<ScoredPostsQuery, PostCandidate> for PhoenixCandidatePipe
params::RESULT_SIZE
}
}

View File

@@ -1,8 +1,7 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::util::snowflake;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use std::time::Duration;
use tonic::async_trait;
use xai_candidate_pipeline::component_library::utils::duration_since_creation_opt;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filter that removes tweets older than a specified duration.
@@ -15,24 +14,23 @@ impl AgeFilter {
Self { max_age }
}
fn is_within_age(&self, tweet_id: i64) -> bool {
snowflake::duration_since_creation_opt(tweet_id)
fn is_within_age(&self, tweet_id: u64) -> bool {
duration_since_creation_opt(tweet_id)
.map(|age| age <= self.max_age)
.unwrap_or(false)
}
}
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for AgeFilter {
async fn filter(
fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
) -> FilterResult<PostCandidate> {
let (kept, removed): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| self.is_within_age(c.tweet_id));
Ok(FilterResult { kept, removed })
FilterResult { kept, removed }
}
}

View File

@@ -0,0 +1,19 @@
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
pub struct AncillaryVFFilter;
impl Filter<ScoredPostsQuery, PostCandidate> for AncillaryVFFilter {
fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> FilterResult<PostCandidate> {
let (removed, kept): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| c.drop_ancillary_posts.unwrap_or(false));
FilterResult { kept, removed }
}
}

View File

@@ -1,27 +1,26 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use std::collections::HashSet;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
// Remove candidates that are blocked or muted by the viewer
pub struct AuthorSocialgraphFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for AuthorSocialgraphFilter {
async fn filter(
fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let viewer_blocked_user_ids = query.user_features.blocked_user_ids.clone();
let viewer_muted_user_ids = query.user_features.muted_user_ids.clone();
if viewer_blocked_user_ids.is_empty() && viewer_muted_user_ids.is_empty() {
return Ok(FilterResult {
kept: candidates,
removed: Vec::new(),
});
}
) -> FilterResult<PostCandidate> {
let viewer_blocked_user_ids: HashSet<i64> = query
.user_features
.blocked_user_ids
.iter()
.copied()
.collect();
let viewer_muted_user_ids: HashSet<i64> =
query.user_features.muted_user_ids.iter().copied().collect();
let mut kept: Vec<PostCandidate> = Vec::new();
let mut removed: Vec<PostCandidate> = Vec::new();
@@ -30,13 +29,33 @@ impl Filter<ScoredPostsQuery, PostCandidate> for AuthorSocialgraphFilter {
let author_id = candidate.author_id as i64;
let muted = viewer_muted_user_ids.contains(&author_id);
let blocked = viewer_blocked_user_ids.contains(&author_id);
if muted || blocked {
let author_blocks_viewer = candidate.author_blocks_viewer.unwrap_or(false);
let quoted_author_blocks_viewer =
candidate.quoted_author_blocks_viewer.unwrap_or(false);
let viewer_blocks_quoted_author = candidate
.quoted_user_id
.map(|uid| viewer_blocked_user_ids.contains(&(uid as i64)))
.unwrap_or(false);
let viewer_blocks_retweeted_user = candidate
.retweeted_user_id
.map(|uid| viewer_blocked_user_ids.contains(&(uid as i64)))
.unwrap_or(false);
if muted
|| blocked
|| author_blocks_viewer
|| quoted_author_blocks_viewer
|| viewer_blocks_quoted_author
|| viewer_blocks_retweeted_user
{
removed.push(candidate);
} else {
kept.push(candidate);
}
}
Ok(FilterResult { kept, removed })
FilterResult { kept, removed }
}
}

View File

@@ -1,20 +1,16 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
pub struct CoreDataHydrationFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for CoreDataHydrationFilter {
async fn filter(
fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (kept, removed) = candidates
.into_iter()
.partition(|c| c.author_id != 0 && !c.tweet_text.trim().is_empty());
Ok(FilterResult { kept, removed })
) -> FilterResult<PostCandidate> {
let (kept, removed) = candidates.into_iter().partition(|c| c.author_id != 0);
FilterResult { kept, removed }
}
}

View File

@@ -1,19 +1,16 @@
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::models::candidate::PostCandidate;
use crate::models::query::ScoredPostsQuery;
use std::collections::HashMap;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Keeps only the highest-scored candidate per branch of a conversation tree
pub struct DedupConversationFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for DedupConversationFilter {
async fn filter(
fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
) -> FilterResult<PostCandidate> {
let mut kept: Vec<PostCandidate> = Vec::new();
let mut removed: Vec<PostCandidate> = Vec::new();
let mut best_per_convo: HashMap<u64, (usize, f64)> = HashMap::new();
@@ -37,7 +34,7 @@ impl Filter<ScoredPostsQuery, PostCandidate> for DedupConversationFilter {
}
}
Ok(FilterResult { kept, removed })
FilterResult { kept, removed }
}
}
@@ -47,5 +44,5 @@ fn get_conversation_id(candidate: &PostCandidate) -> u64 {
.iter()
.copied()
.min()
.unwrap_or(candidate.tweet_id as u64)
.unwrap_or(candidate.tweet_id)
}

Some files were not shown because too many files have changed in this diff Show More