mirror of
https://github.com/xai-org/x-algorithm.git
synced 2026-06-20 10:52:15 +08:00
Open-source X Recommendation Algorithm
This commit is contained in:
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
21
README.md
21
README.md
@@ -6,6 +6,7 @@ This repository contains the core recommendation system powering the "For You" f
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Updates — May 15th, 2026](#updates--may-15th-2026)
|
||||
- [Overview](#overview)
|
||||
- [System Architecture](#system-architecture)
|
||||
- [Components](#components)
|
||||
@@ -22,6 +23,26 @@ This repository contains the core recommendation system powering the "For You" f
|
||||
|
||||
---
|
||||
|
||||
## Updates — May 15th, 2026
|
||||
|
||||
This release updates the For You algorithm code, including a runnable end-to-end inference pipeline alongside new components for content understanding, ads, and candidate sourcing.
|
||||
|
||||
1. **End-to-end inference pipeline:** A new [`phoenix/run_pipeline.py`](phoenix/run_pipeline.py) replaces the separate `run_ranker.py` and `run_retrieval.py` scripts with a single entry point that runs **retrieval → ranking** from exported checkpoints, mirroring how the two stages are composed in production.
|
||||
|
||||
2. **Pre-trained model artifacts:** A pre-trained mini Phoenix model (256-dim embeddings, 4 attention heads, 2 transformer layers) is now packaged as a ~3 GB archive distributed via Git LFS, enabling out-of-the-box inference without training your own model first.
|
||||
|
||||
3. **Grox content-understanding pipeline:** A new [`grox/`](grox/) service is included, providing classifiers, embedders, and a task-execution engine for content understanding workloads such as spam detection, post-category classification, and PTOS policy enforcement.
|
||||
|
||||
4. **Ads blending system:** Includes a new [`home-mixer/ads/`](home-mixer/ads/) module that handles ad injection and positioning within the feed, including brand-safety tracking that respects sensitive content boundaries.
|
||||
|
||||
5. **Query hydrators:** Home mixer now hydrates user context including followed topics, starter packs, impression bloom filters, IP, mutual follow graphs, and served history.
|
||||
|
||||
6. **Candidate hydrators:** Additional hydrators for engagement counts, brand safety signals, language codes, media detection, quote post expansion, mutual follow scores, and more.
|
||||
|
||||
7. **Candidate sources:** Adds sources for ads, who to follow, Phoenix MoE, Phoenix topics, prompts, and updates Thunder/Phoenix ones.
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The For You feed algorithm retrieves, ranks, and filters posts from two sources:
|
||||
|
||||
@@ -2,23 +2,39 @@ use crate::filter::Filter;
|
||||
use crate::hydrator::Hydrator;
|
||||
use crate::query_hydrator::QueryHydrator;
|
||||
use crate::scorer::Scorer;
|
||||
use crate::selector::SelectResult;
|
||||
use crate::selector::Selector;
|
||||
use crate::side_effect::{SideEffect, SideEffectInput};
|
||||
use crate::source::Source;
|
||||
use crate::util;
|
||||
use futures::future::join_all;
|
||||
use log::{error, info, warn};
|
||||
use std::any::type_name_of_val;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tonic::async_trait;
|
||||
use tracing::{Span, field::Empty, info};
|
||||
use xai_stats_receiver::{HistogramBuckets, global_stats_receiver};
|
||||
|
||||
const FINAL_RESULT_SIZE_SCOPE: [(&str, &str); 1] = [("requests", "result_size")];
|
||||
const FINAL_RESULT_EMPTY_SCOPE: [(&str, &str); 1] = [("requests", "result_empty")];
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum PipelineStage {
|
||||
QueryHydrator,
|
||||
DependentQueryHydrator,
|
||||
Source,
|
||||
Hydrator,
|
||||
PostSelectionHydrator,
|
||||
Filter,
|
||||
PostSelectionFilter,
|
||||
Scorer,
|
||||
Selector,
|
||||
SideEffect,
|
||||
}
|
||||
|
||||
pub struct PipelineComponents {
|
||||
pub stage: PipelineStage,
|
||||
pub components: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct PipelineResult<Q, C> {
|
||||
@@ -28,18 +44,36 @@ pub struct PipelineResult<Q, C> {
|
||||
pub query: Arc<Q>,
|
||||
}
|
||||
|
||||
/// Provides a stable request identifier for logging/tracing.
|
||||
pub trait HasRequestId {
|
||||
fn request_id(&self) -> &str;
|
||||
impl<Q: Default, C> PipelineResult<Q, C> {
|
||||
/// Create an empty result with a default query. Useful for short-circuiting
|
||||
/// requests (e.g. test users) without running the pipeline.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
retrieved_candidates: vec![],
|
||||
filtered_candidates: vec![],
|
||||
selected_candidates: vec![],
|
||||
query: Arc::new(Q::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub trait PipelineQuery: Clone + Send + Sync + 'static {
|
||||
fn params(&self) -> &xai_feature_switches::Params;
|
||||
fn decider(&self) -> Option<&xai_decider::Decider>;
|
||||
}
|
||||
|
||||
pub trait PipelineCandidate: Clone + Send + Sync + 'static {}
|
||||
impl<T> PipelineCandidate for T where T: Clone + Send + Sync + 'static {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait CandidatePipeline<Q, C>: Send + Sync
|
||||
where
|
||||
Q: HasRequestId + Clone + Send + Sync + 'static,
|
||||
C: Clone + Send + Sync + 'static,
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
{
|
||||
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<Q>>];
|
||||
fn dependent_query_hydrators(&self) -> &[Box<dyn QueryHydrator<Q>>] {
|
||||
&[]
|
||||
}
|
||||
fn sources(&self) -> &[Box<dyn Source<Q, C>>];
|
||||
fn hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
|
||||
fn filters(&self) -> &[Box<dyn Filter<Q, C>>];
|
||||
@@ -49,37 +83,48 @@ where
|
||||
fn post_selection_filters(&self) -> &[Box<dyn Filter<Q, C>>];
|
||||
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<Q, C>>>>;
|
||||
fn result_size(&self) -> usize;
|
||||
fn finalize(&self, _query: &Q, _candidates: &mut Vec<C>) {}
|
||||
|
||||
#[xai_stats_macro::receive_stats(latency=Bucket500To2500)]
|
||||
async fn execute(&self, query: Q) -> PipelineResult<Q, C> {
|
||||
let hydrated_query = self.hydrate_query(query).await;
|
||||
let hydrated_query = self.hydrate_dependent_query(hydrated_query).await;
|
||||
|
||||
let candidates = self.fetch_candidates(&hydrated_query).await;
|
||||
|
||||
let hydrated_candidates = self.hydrate(&hydrated_query, candidates).await;
|
||||
|
||||
let (kept_candidates, mut filtered_candidates) = self
|
||||
.filter(&hydrated_query, hydrated_candidates.clone())
|
||||
.await;
|
||||
let (kept_candidates, mut filtered_candidates) =
|
||||
self.filter(&hydrated_query, hydrated_candidates.clone());
|
||||
|
||||
let scored_candidates = self.score(&hydrated_query, kept_candidates).await;
|
||||
|
||||
let selected_candidates = self.select(&hydrated_query, scored_candidates);
|
||||
let SelectResult {
|
||||
selected: selected_candidates,
|
||||
non_selected: mut non_selected_candidates,
|
||||
} = self.select(&hydrated_query, scored_candidates);
|
||||
|
||||
let post_selection_hydrated_candidates = self
|
||||
.hydrate_post_selection(&hydrated_query, selected_candidates)
|
||||
.await;
|
||||
|
||||
let (mut final_candidates, post_selection_filtered_candidates) = self
|
||||
.filter_post_selection(&hydrated_query, post_selection_hydrated_candidates)
|
||||
.await;
|
||||
let (mut final_candidates, post_selection_filtered_candidates) =
|
||||
self.filter_post_selection(&hydrated_query, post_selection_hydrated_candidates);
|
||||
filtered_candidates.extend(post_selection_filtered_candidates);
|
||||
|
||||
final_candidates.truncate(self.result_size());
|
||||
let truncated_candidates =
|
||||
final_candidates.split_off(self.result_size().min(final_candidates.len()));
|
||||
non_selected_candidates.extend(truncated_candidates);
|
||||
|
||||
self.finalize(&hydrated_query, &mut final_candidates);
|
||||
|
||||
self.stat_result_size(&final_candidates);
|
||||
|
||||
let arc_hydrated_query = Arc::new(hydrated_query);
|
||||
let input = Arc::new(SideEffectInput {
|
||||
query: arc_hydrated_query.clone(),
|
||||
selected_candidates: final_candidates.clone(),
|
||||
non_selected_candidates, // candidates are moved so we don't need to clone them
|
||||
});
|
||||
self.run_side_effects(input);
|
||||
|
||||
@@ -91,78 +136,158 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Return all configured components grouped by stage.
|
||||
fn components(&self) -> Vec<PipelineComponents> {
|
||||
fn stage<T: ?Sized>(
|
||||
stage: PipelineStage,
|
||||
items: &[Box<T>],
|
||||
name: impl Fn(&T) -> &str,
|
||||
) -> PipelineComponents {
|
||||
PipelineComponents {
|
||||
stage,
|
||||
components: items
|
||||
.iter()
|
||||
.map(|item| name(item.as_ref()).to_string())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
vec![
|
||||
stage(PipelineStage::QueryHydrator, self.query_hydrators(), |h| {
|
||||
h.name()
|
||||
}),
|
||||
stage(
|
||||
PipelineStage::DependentQueryHydrator,
|
||||
self.dependent_query_hydrators(),
|
||||
|h| h.name(),
|
||||
),
|
||||
stage(PipelineStage::Source, self.sources(), |s| s.name()),
|
||||
stage(PipelineStage::Hydrator, self.hydrators(), |h| h.name()),
|
||||
stage(PipelineStage::Filter, self.filters(), |f| f.name()),
|
||||
stage(PipelineStage::Scorer, self.scorers(), |s| s.name()),
|
||||
PipelineComponents {
|
||||
stage: PipelineStage::Selector,
|
||||
components: vec![self.selector().name().to_string()],
|
||||
},
|
||||
stage(
|
||||
PipelineStage::PostSelectionHydrator,
|
||||
self.post_selection_hydrators(),
|
||||
|h| h.name(),
|
||||
),
|
||||
stage(
|
||||
PipelineStage::PostSelectionFilter,
|
||||
self.post_selection_filters(),
|
||||
|f| f.name(),
|
||||
),
|
||||
stage(
|
||||
PipelineStage::SideEffect,
|
||||
self.side_effects().as_ref(),
|
||||
|s| s.name(),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
util::short_type_name(type_name_of_val(self))
|
||||
}
|
||||
|
||||
// -------------------------- Pipeline Execution --------------------------
|
||||
|
||||
/// Run all query hydrators in parallel and merge results into the query.
|
||||
#[tracing::instrument(skip_all, name = "query_hydrators", fields(
|
||||
total_count = Empty,
|
||||
enabled_count = Empty,
|
||||
disabled = Empty,
|
||||
))]
|
||||
async fn hydrate_query(&self, query: Q) -> Q {
|
||||
let request_id = query.request_id().to_string();
|
||||
let hydrators: Vec<_> = self
|
||||
.query_hydrators()
|
||||
.iter()
|
||||
.filter(|h| h.enable(&query))
|
||||
.collect();
|
||||
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(&query));
|
||||
let start = Instant::now();
|
||||
let all = self.query_hydrators();
|
||||
Self::record_enabled_components(all.iter(), |h| h.enable(&query), |h| h.name());
|
||||
let hydrators: Vec<_> = all.iter().filter(|h| h.enable(&query)).collect();
|
||||
let hydrate_futures = hydrators.iter().map(|h| h.run(&query));
|
||||
let results = join_all(hydrate_futures).await;
|
||||
|
||||
let mut hydrated_query = query;
|
||||
for (hydrator, result) in hydrators.iter().zip(results) {
|
||||
match result {
|
||||
Ok(hydrated) => {
|
||||
hydrator.update(&mut hydrated_query, hydrated);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"request_id={} stage={:?} component={} failed: {}",
|
||||
request_id,
|
||||
PipelineStage::QueryHydrator,
|
||||
hydrator.name(),
|
||||
err
|
||||
);
|
||||
}
|
||||
if let Ok(hydrated) = result {
|
||||
hydrator.update(&mut hydrated_query, hydrated);
|
||||
}
|
||||
}
|
||||
self.log_stage(start);
|
||||
hydrated_query
|
||||
}
|
||||
|
||||
/// Run dependent query hydrators in parallel and merge results into the query.
|
||||
///
|
||||
/// This stage runs **after** [`hydrate_query`], so the incoming query
|
||||
/// already has all initial features populated.
|
||||
#[tracing::instrument(skip_all, name = "dependent_query_hydrators", fields(
|
||||
total_count = Empty,
|
||||
enabled_count = Empty,
|
||||
disabled = Empty,
|
||||
))]
|
||||
async fn hydrate_dependent_query(&self, query: Q) -> Q {
|
||||
let all = self.dependent_query_hydrators();
|
||||
if all.is_empty() {
|
||||
return query;
|
||||
}
|
||||
let start = Instant::now();
|
||||
Self::record_enabled_components(all.iter(), |h| h.enable(&query), |h| h.name());
|
||||
let hydrators: Vec<_> = all.iter().filter(|h| h.enable(&query)).collect();
|
||||
let hydrate_futures = hydrators.iter().map(|h| h.run(&query));
|
||||
let results = join_all(hydrate_futures).await;
|
||||
|
||||
let mut hydrated_query = query;
|
||||
for (hydrator, result) in hydrators.iter().zip(results) {
|
||||
if let Ok(hydrated) = result {
|
||||
hydrator.update(&mut hydrated_query, hydrated);
|
||||
}
|
||||
}
|
||||
self.log_stage(start);
|
||||
hydrated_query
|
||||
}
|
||||
|
||||
/// Run all candidate sources in parallel and collect results.
|
||||
#[tracing::instrument(skip_all, name = "sources", fields(
|
||||
total_count = Empty,
|
||||
enabled_count = Empty,
|
||||
disabled = Empty,
|
||||
candidate_count = Empty,
|
||||
))]
|
||||
async fn fetch_candidates(&self, query: &Q) -> Vec<C> {
|
||||
let request_id = query.request_id().to_string();
|
||||
let sources: Vec<_> = self.sources().iter().filter(|s| s.enable(query)).collect();
|
||||
let source_futures = sources.iter().map(|s| s.get_candidates(query));
|
||||
let start = Instant::now();
|
||||
let all = self.sources();
|
||||
Self::record_enabled_components(all.iter(), |s| s.enable(query), |s| s.name());
|
||||
let sources: Vec<_> = all.iter().filter(|s| s.enable(query)).collect();
|
||||
let source_futures = sources.iter().map(|s| s.run(query));
|
||||
let results = join_all(source_futures).await;
|
||||
|
||||
let mut collected = Vec::new();
|
||||
for (source, result) in sources.iter().zip(results) {
|
||||
match result {
|
||||
Ok(mut candidates) => {
|
||||
info!(
|
||||
"request_id={} stage={:?} component={} fetched {} candidates",
|
||||
request_id,
|
||||
PipelineStage::Source,
|
||||
source.name(),
|
||||
candidates.len()
|
||||
);
|
||||
collected.append(&mut candidates);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"request_id={} stage={:?} component={} failed: {}",
|
||||
request_id,
|
||||
PipelineStage::Source,
|
||||
source.name(),
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
for mut candidates in results.into_iter().flatten() {
|
||||
collected.append(&mut candidates);
|
||||
}
|
||||
Span::current().record("candidate_count", collected.len());
|
||||
self.log_stage_size(start, collected.len());
|
||||
collected
|
||||
}
|
||||
|
||||
/// Run all candidate hydrators in parallel and merge results into candidates.
|
||||
#[tracing::instrument(skip_all, name = "hydrators", fields(
|
||||
total_count = Empty,
|
||||
enabled_count = Empty,
|
||||
disabled = Empty,
|
||||
))]
|
||||
async fn hydrate(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
|
||||
self.run_hydrators(query, candidates, self.hydrators(), PipelineStage::Hydrator)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Run post-selection candidate hydrators in parallel and merge results into candidates.
|
||||
#[tracing::instrument(skip_all, name = "post_selection_hydrators", fields(
|
||||
total_count = Empty,
|
||||
enabled_count = Empty,
|
||||
disabled = Empty,
|
||||
))]
|
||||
async fn hydrate_post_selection(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
|
||||
self.run_hydrators(
|
||||
query,
|
||||
@@ -179,139 +304,114 @@ where
|
||||
query: &Q,
|
||||
mut candidates: Vec<C>,
|
||||
hydrators: &[Box<dyn Hydrator<Q, C>>],
|
||||
stage: PipelineStage,
|
||||
_stage: PipelineStage,
|
||||
) -> Vec<C> {
|
||||
let request_id = query.request_id().to_string();
|
||||
let start = Instant::now();
|
||||
Self::record_enabled_components(hydrators.iter(), |h| h.enable(query), |h| h.name());
|
||||
let hydrators: Vec<_> = hydrators.iter().filter(|h| h.enable(query)).collect();
|
||||
let expected_len = candidates.len();
|
||||
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(query, &candidates));
|
||||
let hydrate_futures = hydrators.iter().map(|h| h.run(query, &candidates));
|
||||
let results = join_all(hydrate_futures).await;
|
||||
for (hydrator, result) in hydrators.iter().zip(results) {
|
||||
match result {
|
||||
Ok(hydrated) => {
|
||||
if hydrated.len() == expected_len {
|
||||
hydrator.update_all(&mut candidates, hydrated);
|
||||
} else {
|
||||
warn!(
|
||||
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
|
||||
request_id,
|
||||
stage,
|
||||
hydrator.name(),
|
||||
expected_len,
|
||||
hydrated.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"request_id={} stage={:?} component={} failed: {}",
|
||||
request_id,
|
||||
stage,
|
||||
hydrator.name(),
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
hydrator.update_all(&mut candidates, result);
|
||||
}
|
||||
self.log_stage_size(start, candidates.len());
|
||||
candidates
|
||||
}
|
||||
|
||||
/// Run all filters sequentially. Each filter partitions candidates into kept and removed.
|
||||
async fn filter(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
|
||||
#[tracing::instrument(skip_all, name = "filters", fields(
|
||||
total_count = Empty,
|
||||
enabled_count = Empty,
|
||||
disabled = Empty,
|
||||
input_count = candidates.len(),
|
||||
kept_count = Empty,
|
||||
removed_count = Empty,
|
||||
filter_rate = Empty,
|
||||
))]
|
||||
fn filter(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
|
||||
self.run_filters(query, candidates, self.filters(), PipelineStage::Filter)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Run post-scoring filters sequentially on already-scored candidates.
|
||||
async fn filter_post_selection(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
|
||||
#[tracing::instrument(skip_all, name = "post_selection_filters", fields(
|
||||
total_count = Empty,
|
||||
enabled_count = Empty,
|
||||
disabled = Empty,
|
||||
input_count = candidates.len(),
|
||||
kept_count = Empty,
|
||||
removed_count = Empty,
|
||||
filter_rate = Empty,
|
||||
))]
|
||||
fn filter_post_selection(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
|
||||
self.run_filters(
|
||||
query,
|
||||
candidates,
|
||||
self.post_selection_filters(),
|
||||
PipelineStage::PostSelectionFilter,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Shared helper to run filters sequentially from a provided filter list.
|
||||
async fn run_filters(
|
||||
fn run_filters(
|
||||
&self,
|
||||
query: &Q,
|
||||
mut candidates: Vec<C>,
|
||||
filters: &[Box<dyn Filter<Q, C>>],
|
||||
stage: PipelineStage,
|
||||
_stage: PipelineStage,
|
||||
) -> (Vec<C>, Vec<C>) {
|
||||
let request_id = query.request_id().to_string();
|
||||
Self::record_enabled_components(filters.iter(), |f| f.enable(query), |f| f.name());
|
||||
let mut all_removed = Vec::new();
|
||||
let mut removed_per_filter: Vec<(String, usize)> = Vec::new();
|
||||
for filter in filters.iter().filter(|f| f.enable(query)) {
|
||||
let backup = candidates.clone();
|
||||
match filter.filter(query, candidates).await {
|
||||
Ok(result) => {
|
||||
candidates = result.kept;
|
||||
all_removed.extend(result.removed);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"request_id={} stage={:?} component={} failed: {}",
|
||||
request_id,
|
||||
stage,
|
||||
filter.name(),
|
||||
err
|
||||
);
|
||||
candidates = backup;
|
||||
}
|
||||
let result = filter.run(query, candidates);
|
||||
if !result.removed.is_empty() {
|
||||
removed_per_filter.push((filter.name().to_string(), result.removed.len()));
|
||||
}
|
||||
candidates = result.kept;
|
||||
all_removed.extend(result.removed);
|
||||
}
|
||||
info!(
|
||||
"request_id={} stage={:?} kept {}, removed {}",
|
||||
request_id,
|
||||
stage,
|
||||
candidates.len(),
|
||||
all_removed.len()
|
||||
);
|
||||
let kept = candidates.len();
|
||||
let removed = all_removed.len();
|
||||
let total = kept + removed;
|
||||
let rate = if total > 0 {
|
||||
removed as f64 / total as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
Span::current().record("kept_count", kept);
|
||||
Span::current().record("removed_count", removed);
|
||||
Span::current().record("filter_rate", format!("{:.3}", rate).as_str());
|
||||
self.log_filters(kept, removed, &removed_per_filter);
|
||||
(candidates, all_removed)
|
||||
}
|
||||
|
||||
/// Run all scorers sequentially and apply their results to candidates.
|
||||
#[tracing::instrument(skip_all, name = "scorers", fields(
|
||||
total_count = Empty,
|
||||
enabled_count = Empty,
|
||||
disabled = Empty,
|
||||
))]
|
||||
async fn score(&self, query: &Q, mut candidates: Vec<C>) -> Vec<C> {
|
||||
let request_id = query.request_id().to_string();
|
||||
let expected_len = candidates.len();
|
||||
for scorer in self.scorers().iter().filter(|s| s.enable(query)) {
|
||||
match scorer.score(query, &candidates).await {
|
||||
Ok(scored) => {
|
||||
if scored.len() == expected_len {
|
||||
scorer.update_all(&mut candidates, scored);
|
||||
} else {
|
||||
warn!(
|
||||
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
|
||||
request_id,
|
||||
PipelineStage::Scorer,
|
||||
scorer.name(),
|
||||
expected_len,
|
||||
scored.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"request_id={} stage={:?} component={} failed: {}",
|
||||
request_id,
|
||||
PipelineStage::Scorer,
|
||||
scorer.name(),
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
let start = Instant::now();
|
||||
let all = self.scorers();
|
||||
Self::record_enabled_components(all.iter(), |s| s.enable(query), |s| s.name());
|
||||
for scorer in all.iter().filter(|s| s.enable(query)) {
|
||||
let scored = scorer.run(query, &candidates).await;
|
||||
scorer.update_all(&mut candidates, scored);
|
||||
}
|
||||
self.log_stage_size(start, candidates.len());
|
||||
candidates
|
||||
}
|
||||
|
||||
/// Select (sort/truncate) candidates using the configured selector
|
||||
fn select(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
|
||||
fn select(&self, query: &Q, candidates: Vec<C>) -> SelectResult<C> {
|
||||
if self.selector().enable(query) {
|
||||
self.selector().select(query, candidates)
|
||||
self.selector().run(query, candidates)
|
||||
} else {
|
||||
candidates
|
||||
SelectResult {
|
||||
selected: candidates,
|
||||
non_selected: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,4 +426,68 @@ where
|
||||
let _ = join_all(futures).await;
|
||||
});
|
||||
}
|
||||
|
||||
// -------------------------- Helpers --------------------------
|
||||
|
||||
/// Iterate components, applying `is_enabled` to each, and record
|
||||
/// `total_count`, `enabled_count`, and (if any are disabled) the
|
||||
/// comma-joined names of disabled components on the current tracing span.
|
||||
fn record_enabled_components<'a, T: 'a>(
|
||||
items: impl Iterator<Item = &'a T>,
|
||||
is_enabled: impl Fn(&T) -> bool,
|
||||
get_name: impl Fn(&T) -> &str,
|
||||
) {
|
||||
let mut total = 0usize;
|
||||
let mut disabled: Vec<&str> = Vec::new();
|
||||
for item in items {
|
||||
total += 1;
|
||||
if !is_enabled(item) {
|
||||
disabled.push(get_name(item));
|
||||
}
|
||||
}
|
||||
let span = Span::current();
|
||||
span.record("total_count", total);
|
||||
span.record("enabled_count", total - disabled.len());
|
||||
if !disabled.is_empty() {
|
||||
span.record("disabled", disabled.join(",").as_str());
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------- Logging and Stats --------------------------
|
||||
|
||||
fn log_stage(&self, start: Instant) {
|
||||
info!("latency_ms={}", start.elapsed().as_millis());
|
||||
}
|
||||
|
||||
fn log_stage_size(&self, start: Instant, size: usize) {
|
||||
info!("latency_ms={} size={}", start.elapsed().as_millis(), size);
|
||||
}
|
||||
|
||||
fn log_filters(&self, kept: usize, removed: usize, removed_per_filter: &[(String, usize)]) {
|
||||
let removed_summary = removed_per_filter
|
||||
.iter()
|
||||
.map(|(name, removed)| format!("{}={}", name, removed))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
info!(
|
||||
"kept {}, removed {} removed_per_filter [{}]",
|
||||
kept, removed, removed_summary,
|
||||
);
|
||||
}
|
||||
|
||||
fn stat_result_size(&self, final_candidates: &[C]) {
|
||||
if let Some(receiver) = global_stats_receiver() {
|
||||
let response_size = final_candidates.len();
|
||||
let metric_name = format!("{}.execute", self.name());
|
||||
receiver.observe(
|
||||
metric_name.as_str(),
|
||||
&FINAL_RESULT_SIZE_SCOPE,
|
||||
response_size as f64,
|
||||
HistogramBuckets::Bucket0To50,
|
||||
);
|
||||
if response_size == 0 {
|
||||
receiver.incr(metric_name.as_str(), &FINAL_RESULT_EMPTY_SCOPE, 1u64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
use std::any::{Any, type_name_of_val};
|
||||
use tonic::async_trait;
|
||||
|
||||
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
|
||||
use crate::util;
|
||||
use std::any::{Any, type_name_of_val};
|
||||
use tracing::{Span, field::Empty};
|
||||
use xai_stats_receiver::global_stats_receiver;
|
||||
|
||||
const KEPT_SCOPE: [(&str, &str); 1] = [("requests", "kept")];
|
||||
const REMOVED_SCOPE: [(&str, &str); 1] = [("requests", "removed")];
|
||||
|
||||
pub struct FilterResult<C> {
|
||||
pub kept: Vec<C>,
|
||||
@@ -9,24 +13,58 @@ pub struct FilterResult<C> {
|
||||
}
|
||||
|
||||
/// Filters run sequentially and partition candidates into kept and removed sets
|
||||
#[async_trait]
|
||||
pub trait Filter<Q, C>: Any + Send + Sync
|
||||
where
|
||||
Q: Clone + Send + Sync + 'static,
|
||||
C: Clone + Send + Sync + 'static,
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
{
|
||||
/// Decide if this filter should run for the given query
|
||||
fn enable(&self, _query: &Q) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[xai_stats_macro::receive_stats(latency=Bucket0To50)]
|
||||
#[tracing::instrument(skip_all, name = "filter", fields(
|
||||
name = self.name(),
|
||||
input_count = candidates.len(),
|
||||
kept_count = Empty,
|
||||
removed_count = Empty,
|
||||
filter_rate = Empty,
|
||||
))]
|
||||
fn run(&self, query: &Q, candidates: Vec<C>) -> FilterResult<C> {
|
||||
let result = self.filter(query, candidates);
|
||||
let total = result.kept.len() + result.removed.len();
|
||||
let rate = if total > 0 {
|
||||
result.removed.len() as f64 / total as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let span = Span::current();
|
||||
span.record("kept_count", result.kept.len());
|
||||
span.record("removed_count", result.removed.len());
|
||||
span.record("filter_rate", format!("{:.3}", rate).as_str());
|
||||
self.stat(&result);
|
||||
result
|
||||
}
|
||||
|
||||
/// Filter candidates by evaluating each against some criteria.
|
||||
/// Returns a FilterResult containing kept candidates (which continue to the next stage)
|
||||
/// and removed candidates (which are excluded from further processing).
|
||||
async fn filter(&self, query: &Q, candidates: Vec<C>) -> Result<FilterResult<C>, String>;
|
||||
fn filter(&self, query: &Q, candidates: Vec<C>) -> FilterResult<C>;
|
||||
|
||||
/// Returns a stable name for logging/metrics.
|
||||
fn name(&self) -> &'static str {
|
||||
util::short_type_name(type_name_of_val(self))
|
||||
}
|
||||
|
||||
fn stat(&self, result: &FilterResult<C>) {
|
||||
if let Some(receiver) = global_stats_receiver() {
|
||||
let metric_name = format!("{}.run", self.name());
|
||||
receiver.incr(metric_name.as_str(), &KEPT_SCOPE, result.kept.len() as u64);
|
||||
receiver.incr(
|
||||
metric_name.as_str(),
|
||||
&REMOVED_SCOPE,
|
||||
result.removed.len() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
|
||||
use crate::util;
|
||||
use std::any::{Any, type_name_of_val};
|
||||
use std::hash::Hash;
|
||||
use tonic::async_trait;
|
||||
use tracing::warn;
|
||||
use xai_stats_receiver::global_stats_receiver;
|
||||
|
||||
// Hydrators run in parallel and update candidate fields
|
||||
#[async_trait]
|
||||
pub trait Hydrator<Q, C>: Any + Send + Sync
|
||||
where
|
||||
Q: Clone + Send + Sync + 'static,
|
||||
C: Clone + Send + Sync + 'static,
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
{
|
||||
/// Decide if this hydrator should run for the given query
|
||||
fn enable(&self, _query: &Q) -> bool {
|
||||
@@ -19,17 +23,41 @@ where
|
||||
///
|
||||
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
|
||||
/// Dropping candidates in a hydrator is not allowed - use a filter stage instead.
|
||||
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
|
||||
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>>;
|
||||
|
||||
#[xai_stats_macro::receive_stats(latency=Bucket50To500, size=Bucket500To2500)]
|
||||
#[tracing::instrument(skip_all, name = "hydrator", fields(name = self.name()))]
|
||||
async fn run(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>> {
|
||||
let hydrated = self.hydrate(query, candidates).await;
|
||||
let expected_len = candidates.len();
|
||||
if hydrated.len() == expected_len {
|
||||
hydrated
|
||||
} else {
|
||||
let message = format!(
|
||||
"Hydrator length_mismatch expected={} got={}",
|
||||
expected_len,
|
||||
hydrated.len()
|
||||
);
|
||||
warn!(
|
||||
"Skipped: length_mismatch expected={} got={}",
|
||||
expected_len,
|
||||
hydrated.len()
|
||||
);
|
||||
vec![Err(message); expected_len]
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a single candidate with the hydrated fields.
|
||||
/// Only the fields this hydrator is responsible for should be copied.
|
||||
fn update(&self, candidate: &mut C, hydrated: C);
|
||||
|
||||
/// Update all candidates with the hydrated fields from `hydrated`.
|
||||
/// Update all successfully hydrated candidates with the fields from `hydrated`.
|
||||
/// Default implementation iterates and calls `update` for each pair.
|
||||
fn update_all(&self, candidates: &mut [C], hydrated: Vec<C>) {
|
||||
for (c, h) in candidates.iter_mut().zip(hydrated) {
|
||||
self.update(c, h);
|
||||
fn update_all(&self, candidates: &mut [C], hydrated: Vec<Result<C, String>>) {
|
||||
for (candidate, hydrated) in candidates.iter_mut().zip(hydrated) {
|
||||
if let Ok(hydrated) = hydrated {
|
||||
self.update(candidate, hydrated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,3 +65,125 @@ where
|
||||
util::short_type_name(type_name_of_val(self))
|
||||
}
|
||||
}
|
||||
|
||||
const CACHE_HIT_SCOPE: [(&str, &str); 1] = [("requests", "cache_hit")];
|
||||
const CACHE_MISS_SCOPE: [(&str, &str); 1] = [("requests", "cache_miss")];
|
||||
|
||||
#[async_trait]
|
||||
pub trait CacheStore<K, V>: Send + Sync {
|
||||
async fn get(&self, key: &K) -> Option<V>;
|
||||
async fn insert(&self, key: K, value: V);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait CachedHydrator<Q, C>: Any + Send + Sync
|
||||
where
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
{
|
||||
type CacheKey: Eq + Hash + Send + Sync + 'static;
|
||||
type CacheValue: Clone + Send + Sync + 'static;
|
||||
|
||||
fn enable(&self, _query: &Q) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue>;
|
||||
fn cache_key(&self, candidate: &C) -> Self::CacheKey;
|
||||
fn cache_value(&self, hydrated: &C) -> Self::CacheValue;
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> C;
|
||||
async fn hydrate_from_client(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>>;
|
||||
|
||||
fn update(&self, candidate: &mut C, hydrated: C);
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
util::short_type_name(type_name_of_val(self))
|
||||
}
|
||||
|
||||
fn stat_cache(&self, cache_hits: usize, cache_misses: usize) {
|
||||
if let Some(receiver) = global_stats_receiver() {
|
||||
let metric_name = format!("{}.cache", self.name());
|
||||
if cache_hits > 0 {
|
||||
receiver.incr(metric_name.as_str(), &CACHE_HIT_SCOPE, cache_hits as u64);
|
||||
}
|
||||
if cache_misses > 0 {
|
||||
receiver.incr(metric_name.as_str(), &CACHE_MISS_SCOPE, cache_misses as u64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<Q, C, T> Hydrator<Q, C> for T
|
||||
where
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
T: CachedHydrator<Q, C> + ?Sized,
|
||||
{
|
||||
fn enable(&self, query: &Q) -> bool {
|
||||
CachedHydrator::enable(self, query)
|
||||
}
|
||||
|
||||
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>> {
|
||||
let mut results = vec![None; candidates.len()];
|
||||
let mut missing_candidates = Vec::new();
|
||||
let mut missing_keys = Vec::new();
|
||||
let mut missing_indices = Vec::new();
|
||||
let mut cache_hits = 0usize;
|
||||
let mut cache_misses = 0usize;
|
||||
|
||||
for (index, candidate) in candidates.iter().enumerate() {
|
||||
let key = self.cache_key(candidate);
|
||||
match self.cache_store().get(&key).await {
|
||||
Some(value) => {
|
||||
results[index] = Some(Ok(self.hydrate_from_cache(value)));
|
||||
cache_hits += 1;
|
||||
}
|
||||
None => {
|
||||
missing_candidates.push(candidate.clone());
|
||||
missing_keys.push(key);
|
||||
missing_indices.push(index);
|
||||
cache_misses += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.stat_cache(cache_hits, cache_misses);
|
||||
|
||||
if !missing_candidates.is_empty() {
|
||||
let hydrated_missing = self.hydrate_from_client(query, &missing_candidates).await;
|
||||
if hydrated_missing.len() != missing_candidates.len() {
|
||||
let message = format!(
|
||||
"CachedHydrator length_mismatch expected={} got={}",
|
||||
missing_candidates.len(),
|
||||
hydrated_missing.len()
|
||||
);
|
||||
return vec![Err(message); candidates.len()];
|
||||
}
|
||||
|
||||
for ((index, key), hydrated) in missing_indices
|
||||
.into_iter()
|
||||
.zip(missing_keys.into_iter())
|
||||
.zip(hydrated_missing.into_iter())
|
||||
{
|
||||
if let Ok(ref hydrated_candidate) = hydrated {
|
||||
let value = self.cache_value(hydrated_candidate);
|
||||
self.cache_store().insert(key, value).await;
|
||||
}
|
||||
results[index] = Some(hydrated);
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
.into_iter()
|
||||
.map(|result| {
|
||||
result.unwrap_or_else(|| Err("Missing hydration result for candidate".to_string()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut C, hydrated: C) {
|
||||
CachedHydrator::update(self, candidate, hydrated);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,32 @@
|
||||
use std::any::{Any, type_name_of_val};
|
||||
use tonic::async_trait;
|
||||
|
||||
use crate::candidate_pipeline::PipelineQuery;
|
||||
use crate::util;
|
||||
use tracing::error;
|
||||
|
||||
#[async_trait]
|
||||
pub trait QueryHydrator<Q>: Any + Send + Sync
|
||||
where
|
||||
Q: Clone + Send + Sync + 'static,
|
||||
Q: PipelineQuery,
|
||||
{
|
||||
/// Decide if this query hydrator should run for the given query
|
||||
fn enable(&self, _query: &Q) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[xai_stats_macro::receive_stats]
|
||||
#[tracing::instrument(skip_all, name = "query_hydrator", fields(name = self.name()))]
|
||||
async fn run(&self, query: &Q) -> Result<Q, String> {
|
||||
match self.hydrate(query).await {
|
||||
Ok(hydrated) => Ok(hydrated),
|
||||
Err(err) => {
|
||||
error!("Failed: {}", err);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hydrate the query by performing async operations.
|
||||
/// Returns a new query with this hydrator's fields populated.
|
||||
async fn hydrate(&self, query: &Q) -> Result<Q, String>;
|
||||
|
||||
@@ -1,35 +1,61 @@
|
||||
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
|
||||
use crate::util;
|
||||
use std::any::type_name_of_val;
|
||||
use tonic::async_trait;
|
||||
use tracing::warn;
|
||||
|
||||
/// Scorers update candidate fields (like a score field) and run sequentially
|
||||
#[async_trait]
|
||||
pub trait Scorer<Q, C>: Send + Sync
|
||||
where
|
||||
Q: Clone + Send + Sync + 'static,
|
||||
C: Clone + Send + Sync + 'static,
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
{
|
||||
/// Decide if this scorer should run for the given query
|
||||
fn enable(&self, _query: &Q) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[xai_stats_macro::receive_stats]
|
||||
#[tracing::instrument(skip_all, name = "scorer", fields(name = self.name()))]
|
||||
async fn run(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>> {
|
||||
let scored = self.score(query, candidates).await;
|
||||
let expected_len = candidates.len();
|
||||
if scored.len() == expected_len {
|
||||
scored
|
||||
} else {
|
||||
let message = format!(
|
||||
"Scorer length_mismatch expected={} got={}",
|
||||
expected_len,
|
||||
scored.len()
|
||||
);
|
||||
warn!(
|
||||
"Skipped: length_mismatch expected={} got={}",
|
||||
expected_len,
|
||||
scored.len()
|
||||
);
|
||||
vec![Err(message); expected_len]
|
||||
}
|
||||
}
|
||||
|
||||
/// Score candidates by performing async operations.
|
||||
/// Returns candidates with this scorer's fields populated.
|
||||
///
|
||||
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
|
||||
/// Dropping candidates in a scorer is not allowed - use a filter stage instead.
|
||||
async fn score(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
|
||||
/// Dropping candidates in a hydrator is not allowed - use a filter stage instead.
|
||||
async fn score(&self, query: &Q, candidates: &[C]) -> Vec<Result<C, String>>;
|
||||
|
||||
/// Update a single candidate with the scored fields.
|
||||
/// Only the fields this scorer is responsible for should be copied.
|
||||
fn update(&self, candidate: &mut C, scored: C);
|
||||
|
||||
/// Update all candidates with the scored fields from `scored`.
|
||||
/// Update all successfully scored candidates with the fields from `scored`.
|
||||
/// Default implementation iterates and calls `update` for each pair.
|
||||
fn update_all(&self, candidates: &mut [C], scored: Vec<C>) {
|
||||
for (c, s) in candidates.iter_mut().zip(scored) {
|
||||
self.update(c, s);
|
||||
fn update_all(&self, candidates: &mut [C], scored: Vec<Result<C, String>>) {
|
||||
for (candidate, scored) in candidates.iter_mut().zip(scored) {
|
||||
if let Ok(scored) = scored {
|
||||
self.update(candidate, scored);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,25 +1,65 @@
|
||||
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
|
||||
use crate::util;
|
||||
use std::any::type_name_of_val;
|
||||
use tracing::{Span, field::Empty};
|
||||
|
||||
pub struct SelectResult<C> {
|
||||
pub selected: Vec<C>,
|
||||
pub non_selected: Vec<C>,
|
||||
}
|
||||
|
||||
impl<C> SelectResult<C> {
|
||||
pub fn len(&self) -> usize {
|
||||
self.selected.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.selected.is_empty() && self.non_selected.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Selector<Q, C>: Send + Sync
|
||||
where
|
||||
Q: Clone + Send + Sync + 'static,
|
||||
C: Clone + Send + Sync + 'static,
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
{
|
||||
/// Default selection: sort and truncate based on provided configs
|
||||
fn select(&self, _query: &Q, candidates: Vec<C>) -> Vec<C> {
|
||||
let mut sorted = self.sort(candidates);
|
||||
if let Some(limit) = self.size() {
|
||||
sorted.truncate(limit);
|
||||
}
|
||||
sorted
|
||||
}
|
||||
|
||||
/// Decide if this selector should run for the given query
|
||||
fn enable(&self, _query: &Q) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[xai_stats_macro::receive_stats(latency=Bucket0To50, size=Bucket0To50)]
|
||||
#[tracing::instrument(skip_all, name = "selector", fields(
|
||||
name = self.name(),
|
||||
input_count = candidates.len(),
|
||||
selected_count = Empty,
|
||||
non_selected_count = Empty,
|
||||
))]
|
||||
fn run(&self, query: &Q, candidates: Vec<C>) -> SelectResult<C> {
|
||||
let result = self.select(query, candidates);
|
||||
let span = Span::current();
|
||||
span.record("selected_count", result.selected.len());
|
||||
span.record("non_selected_count", result.non_selected.len());
|
||||
result
|
||||
}
|
||||
|
||||
// Returns (selected, non_selected).
|
||||
fn select(&self, _query: &Q, candidates: Vec<C>) -> SelectResult<C> {
|
||||
let mut sorted = self.sort(candidates);
|
||||
if let Some(limit) = self.size() {
|
||||
let non_selected = sorted.split_off(limit.min(sorted.len()));
|
||||
SelectResult {
|
||||
selected: sorted,
|
||||
non_selected,
|
||||
}
|
||||
} else {
|
||||
SelectResult {
|
||||
selected: sorted,
|
||||
non_selected: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the score from a candidate to use for sorting.
|
||||
fn score(&self, candidate: &C) -> f64;
|
||||
|
||||
|
||||
@@ -1,27 +1,35 @@
|
||||
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
|
||||
use crate::util;
|
||||
use std::any::type_name_of_val;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
|
||||
// A side-effect is an action run that doesn't affect the pipeline result from being returned
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SideEffectInput<Q, C> {
|
||||
pub query: Arc<Q>,
|
||||
pub selected_candidates: Vec<C>,
|
||||
pub non_selected_candidates: Vec<C>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait SideEffect<Q, C>: Send + Sync
|
||||
where
|
||||
Q: Clone + Send + Sync + 'static,
|
||||
C: Clone + Send + Sync + 'static,
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
{
|
||||
/// Decide if this side-effect should be run
|
||||
fn enable(&self, _query: Arc<Q>) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn run(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String>;
|
||||
#[xai_stats_macro::receive_stats]
|
||||
async fn run(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String> {
|
||||
self.side_effect(input).await
|
||||
}
|
||||
|
||||
async fn side_effect(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String>;
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
util::short_type_name(type_name_of_val(self))
|
||||
|
||||
@@ -1,20 +1,37 @@
|
||||
use std::any::{Any, type_name_of_val};
|
||||
use tonic::async_trait;
|
||||
|
||||
use crate::candidate_pipeline::{PipelineCandidate, PipelineQuery};
|
||||
use crate::util;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Source<Q, C>: Any + Send + Sync
|
||||
where
|
||||
Q: Clone + Send + Sync + 'static,
|
||||
C: Clone + Send + Sync + 'static,
|
||||
Q: PipelineQuery,
|
||||
C: PipelineCandidate,
|
||||
{
|
||||
/// Decide if this source should run for the given query
|
||||
fn enable(&self, _query: &Q) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn get_candidates(&self, query: &Q) -> Result<Vec<C>, String>;
|
||||
#[xai_stats_macro::receive_stats(size=Bucket500To1000)]
|
||||
#[tracing::instrument(skip_all, name = "source", fields(name = self.name()))]
|
||||
async fn run(&self, query: &Q) -> Result<Vec<C>, String> {
|
||||
match self.source(query).await {
|
||||
Ok(candidates) => {
|
||||
info!("Fetched {} candidates", candidates.len());
|
||||
Ok(candidates)
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Failed: {}", err);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn source(&self, query: &Q) -> Result<Vec<C>, String>;
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
util::short_type_name(type_name_of_val(self))
|
||||
|
||||
3
candidate-pipeline/util.rs
Normal file
3
candidate-pipeline/util.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub fn short_type_name(full: &'static str) -> &'static str {
|
||||
full.rsplit("::").next().unwrap_or(full)
|
||||
}
|
||||
0
grox/__init__.py
Normal file
0
grox/__init__.py
Normal file
160
grox/classifiers/content/banger_initial_screen.py
Normal file
160
grox/classifiers/content/banger_initial_screen.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import re
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
|
||||
from grox.data_loaders.media_loader import MediaLoader
|
||||
from grox.data_loaders.strato_loader import TweetStratoLoader
|
||||
from grox.lm.post import PostRenderer
|
||||
from grox.lm.user import UserRenderer
|
||||
from grox.lm.convo import Role, Message, Conversation
|
||||
from grox.config.config import ModelName, grox_config
|
||||
from grok_sampler.config import GrokModelConfig
|
||||
from grox.prompts.template import BangerMiniVlmScreenScore
|
||||
from grok_sampler.vision_sampler import VisionSampler
|
||||
from grox.data_loaders.data_types import (
|
||||
Post,
|
||||
ContentCategoryType,
|
||||
ContentCategoryResult,
|
||||
TweetBoolMetadata,
|
||||
ContentCategoryScore,
|
||||
)
|
||||
from grox.classifiers.content.classifier import ContentClassifier
|
||||
from monitor.metrics import Metrics
|
||||
from pydantic import BaseModel
|
||||
from strato_http.queries.grok_topics import StratoGrokTopics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BangerInitialScreenResult(BaseModel):
|
||||
quality_score: float
|
||||
description: str
|
||||
tags: list[str]
|
||||
taxonomy_categories: list[dict] | None = None
|
||||
tweet_bool_metadata: TweetBoolMetadata | None = None
|
||||
is_image_editable_by_grok: bool | None = None
|
||||
slop_score: int | None = None
|
||||
has_minor_score: float | None = None
|
||||
|
||||
|
||||
class BangerInitialScreenClassifier(ContentClassifier):
|
||||
result_pattern = re.compile(r"(.*)<json>(.*)</json>", re.DOTALL)
|
||||
|
||||
def __init__(self):
|
||||
vlm_config = grox_config.get_model(ModelName.VLM_PRIMARY)
|
||||
vlm_config.temperature = 0.000001
|
||||
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
|
||||
super().__init__(
|
||||
categories=[
|
||||
ContentCategoryType.BANGER_INITIAL_SCREEN,
|
||||
ContentCategoryType.GROK_RANKER,
|
||||
],
|
||||
llm=vlm,
|
||||
)
|
||||
self._topics = None
|
||||
|
||||
@staticmethod
|
||||
def build_convo(post: Post, topics: list | None = None) -> Conversation:
|
||||
convo = Conversation(conversation_id=uuid.uuid4().hex)
|
||||
convo.messages.append(
|
||||
Message(
|
||||
role=Role.SYSTEM,
|
||||
content=[BangerMiniVlmScreenScore().render(params={"topics": topics})],
|
||||
)
|
||||
)
|
||||
|
||||
user_msg = Message(role=Role.USER, content=[])
|
||||
user_msg.content.extend(UserRenderer.render(post.user))
|
||||
user_msg.content.extend(PostRenderer.render(post))
|
||||
user_msg.content.append(
|
||||
f"\n\nAnalyze the post {post.id} and provide the requested JSON object for the post."
|
||||
)
|
||||
convo.messages.append(user_msg)
|
||||
|
||||
convo.messages.append(Message(role=Role.ASSISTANT, content=[""], separator=""))
|
||||
return convo
|
||||
|
||||
async def classify(
|
||||
self, post: Post, topics: list | None = None
|
||||
) -> list[ContentCategoryResult]:
|
||||
self._topics = topics
|
||||
return await super().classify(post)
|
||||
|
||||
async def _classify_for_eval(self, post: Post) -> str:
|
||||
self._topics = None
|
||||
convo = await self._to_convo(post)
|
||||
logger.info(f"Banger initial screen conversation for post {post.id}")
|
||||
result = await self.llm.sample(
|
||||
convo.interleave(), conversation_id=convo.conversation_id
|
||||
)
|
||||
logger.info(f"Banger initial screen result for post {post.id}: {result}")
|
||||
return result
|
||||
|
||||
async def _to_convo(self, post: Post) -> Conversation:
|
||||
return self.build_convo(post, topics=self._topics)
|
||||
|
||||
async def _sample(self, convo: Conversation) -> str:
|
||||
return await self.llm.sample(
|
||||
convo.interleave(), conversation_id=convo.conversation_id
|
||||
)
|
||||
|
||||
async def _parse(self, post: Post, output: str) -> list[ContentCategoryResult]:
|
||||
match = self.result_pattern.search(output)
|
||||
if match:
|
||||
reasoning = match.group(1).strip()
|
||||
logger.info(
|
||||
f"Banger initial screen result reasoning for post {post.id}: {reasoning}"
|
||||
)
|
||||
result = BangerInitialScreenResult.model_validate_json(
|
||||
match.group(2).strip()
|
||||
)
|
||||
score = result.quality_score
|
||||
Metrics.histogram(
|
||||
"banger_initial_screen_score",
|
||||
explicit_bucket_boundaries_advisory=[
|
||||
0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1,
|
||||
],
|
||||
).record(score)
|
||||
banger_initial_positive = score >= 0.4
|
||||
|
||||
taxonomy_categories = []
|
||||
if result.taxonomy_categories:
|
||||
for tc in result.taxonomy_categories:
|
||||
taxonomy_categories.append(
|
||||
ContentCategoryScore(
|
||||
id=tc["id"],
|
||||
name=tc["name"],
|
||||
score=tc["score"],
|
||||
category_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
ContentCategoryResult(
|
||||
category=cat,
|
||||
positive=banger_initial_positive,
|
||||
score=score,
|
||||
reason=reasoning,
|
||||
summary=result.description,
|
||||
tags=result.tags,
|
||||
taxonomy_categories=taxonomy_categories,
|
||||
tweet_bool_metadata=result.tweet_bool_metadata,
|
||||
is_image_editable_by_grok=result.is_image_editable_by_grok,
|
||||
slop_score=result.slop_score,
|
||||
has_minor_score=result.has_minor_score,
|
||||
)
|
||||
for cat in self.categories
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
98
grox/classifiers/content/classifier.py
Normal file
98
grox/classifiers/content/classifier.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from grok_sampler.llm import LiteLLM
|
||||
from grox.data_loaders.data_types import (
|
||||
ContentCategoryResult,
|
||||
ContentCategoryType,
|
||||
Post,
|
||||
)
|
||||
from grox.lm.convo import Conversation
|
||||
from monitor.metrics import Metrics
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentClassifier(ABC):
|
||||
def __init__(self, categories: list[ContentCategoryType], llm: LiteLLM):
|
||||
self.categories = categories
|
||||
self.llm = llm
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self.llm.model_config.model_name
|
||||
|
||||
async def classify(self, post: Post) -> list[ContentCategoryResult]:
|
||||
logger.info(
|
||||
f"[{self.__class__.__name__}] started processing content classify request: {post.id}"
|
||||
)
|
||||
for category in self.categories:
|
||||
Metrics.counter(f"content.classification.request.count").add(
|
||||
1, attributes={"category": category.value.lower()}
|
||||
)
|
||||
for category in self.categories:
|
||||
Metrics.counter("content.classification.intake.count").add(
|
||||
1, attributes={"category": category.value.lower()}
|
||||
)
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
res = await self._classify(post)
|
||||
except Exception:
|
||||
for category in self.categories:
|
||||
Metrics.counter(f"content.classification.error.count").add(
|
||||
1, attributes={"category": category.value.lower()}
|
||||
)
|
||||
logger.error(
|
||||
f"[{self.__class__.__name__}] error processing content classify request: {post.id} {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
for category in self.categories:
|
||||
Metrics.counter(f"content.classification.success.count").add(
|
||||
1, attributes={"category": category.value.lower()}
|
||||
)
|
||||
end = time.perf_counter()
|
||||
logger.info(
|
||||
f"[{self.__class__.__name__}] finished processing content classify request: {post.id} in {end - start:.2f} seconds"
|
||||
)
|
||||
Metrics.histogram(f"content.classification.latency").record(
|
||||
end - start, attributes={"class": self.__class__.__name__}
|
||||
)
|
||||
self._post_process_for_logging(res, start, end)
|
||||
return res
|
||||
|
||||
def _post_process_for_logging(
|
||||
self, res: list[BaseModel], start_time, end_time
|
||||
) -> None:
|
||||
for category in self.categories:
|
||||
Metrics.histogram(f"content.classification.latency").record(
|
||||
end_time - start_time, attributes={"category": category.value.lower()}
|
||||
)
|
||||
|
||||
for result in res:
|
||||
Metrics.counter(f"content.classification.result.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"category": result.category.value.lower(),
|
||||
"positive": str(result.positive),
|
||||
},
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def _to_convo(self, post: Post) -> Conversation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _sample(self, convo: Conversation) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _parse(self, post: Post, output: str) -> list[ContentCategoryResult]:
|
||||
pass
|
||||
|
||||
async def _classify(self, post: Post) -> list[ContentCategoryResult]:
|
||||
convo = await self._to_convo(post)
|
||||
output = await self._sample(convo)
|
||||
return await self._parse(post, output)
|
||||
92
grox/classifiers/content/post_safety_screen_deluxe.py
Normal file
92
grox/classifiers/content/post_safety_screen_deluxe.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import re
|
||||
import uuid
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
|
||||
from grox.data_loaders.media_loader import MediaLoader
|
||||
from grox.data_loaders.strato_loader import TweetStratoLoader
|
||||
from grox.lm.post import PostRenderer
|
||||
from grox.lm.user import UserRenderer
|
||||
from grox.lm.convo import Role, Message, Conversation
|
||||
from grox.config.config import ModelName, grox_config
|
||||
from grok_sampler.config import GrokModelConfig
|
||||
from grox.prompts.template import PostSafetyDeluxe
|
||||
from grok_sampler.vision_sampler import VisionSampler
|
||||
from grox.data_loaders.data_types import (
|
||||
Post,
|
||||
ContentCategoryType,
|
||||
ContentCategoryResult,
|
||||
TweetBoolMetadata,
|
||||
)
|
||||
from grox.classifiers.content.classifier import ContentClassifier
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostSafetyScreenResult(BaseModel):
|
||||
tweet_bool_metadata: TweetBoolMetadata
|
||||
|
||||
|
||||
class PostSafetyDeluxeClassifier(ContentClassifier):
|
||||
result_pattern = re.compile(r"(.*)<json>(.*)</json>", re.DOTALL)
|
||||
|
||||
def __init__(self):
|
||||
vlm_config = grox_config.get_model(ModelName.VLM_PRIMARY_CRITICAL)
|
||||
vlm_config.temperature = 0.000001
|
||||
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
|
||||
super().__init__(
|
||||
categories=[
|
||||
ContentCategoryType.POST_SAFETY_SCREEN,
|
||||
],
|
||||
llm=vlm,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_convo(post: Post) -> Conversation:
|
||||
convo = Conversation(conversation_id=uuid.uuid4().hex)
|
||||
convo.messages.append(
|
||||
Message(role=Role.SYSTEM, content=[PostSafetyDeluxe().render()])
|
||||
)
|
||||
|
||||
user_msg = Message(role=Role.USER, content=[])
|
||||
user_msg.content.extend(UserRenderer.render(post.user))
|
||||
user_msg.content.extend(PostRenderer.render(post))
|
||||
user_msg.content.append(
|
||||
f"\n\nAnalyze the post {post.id} and provide the requested JSON object for the post."
|
||||
)
|
||||
convo.messages.append(user_msg)
|
||||
|
||||
convo.messages.append(Message(role=Role.ASSISTANT, content=[]))
|
||||
return convo
|
||||
|
||||
async def _to_convo(self, post: Post) -> Conversation:
|
||||
return self.build_convo(post)
|
||||
|
||||
async def _sample(self, convo: Conversation) -> str:
|
||||
return await self.llm.sample(
|
||||
convo.interleave(), conversation_id=convo.conversation_id
|
||||
)
|
||||
|
||||
async def _parse(self, post: Post, output: str) -> list[ContentCategoryResult]:
|
||||
match = self.result_pattern.search(output)
|
||||
if match:
|
||||
reasoning = match.group(1).strip()
|
||||
logger.info(
|
||||
f"Post Safety Screen reasoning for post {post.id} : {reasoning}"
|
||||
)
|
||||
result = PostSafetyScreenResult.model_validate_json(match.group(2).strip())
|
||||
logger.info(
|
||||
f"Post Safety Screen result for post {post.id} : {result}"
|
||||
)
|
||||
return [
|
||||
ContentCategoryResult(
|
||||
category=ContentCategoryType.POST_SAFETY_SCREEN,
|
||||
positive=False,
|
||||
score=0.0,
|
||||
tweet_bool_metadata=result.tweet_bool_metadata,
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
158
grox/classifiers/content/reply_ranking.py
Normal file
158
grox/classifiers/content/reply_ranking.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
|
||||
import json_repair
|
||||
from grox.data_loaders.media_loader import MediaLoader
|
||||
from grox.lm.convo import Role, Message, Conversation
|
||||
from grox.lm.thread import ThreadRenderer
|
||||
from grox.config.config import ModelName, grox_config
|
||||
from grok_sampler.config import GrokModelConfig
|
||||
from grox.prompts.template import ReplyScoringSystem
|
||||
from grok_sampler.vision_sampler import VisionSampler
|
||||
from grox.data_loaders.data_types import (
|
||||
Post,
|
||||
ReplyScoreResult,
|
||||
)
|
||||
from monitor.metrics import Metrics
|
||||
from pydantic import ValidationError
|
||||
from grox.data_loaders.strato_loader import TweetStratoLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplyScorer:
|
||||
model_name = "GROK"
|
||||
|
||||
def __init__(self):
|
||||
vlm_config = grox_config.get_model(ModelName.VLM_MINI_CRITICAL)
|
||||
vlm_config.temperature = 0.000001
|
||||
self.vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
|
||||
|
||||
vlm_fallback_config = grox_config.get_model(
|
||||
ModelName.VLM_PRIMARY_CRITICAL
|
||||
)
|
||||
vlm_fallback_config.temperature = 0.000001
|
||||
self.vlm_fallback = VisionSampler(
|
||||
GrokModelConfig(**vlm_fallback_config.model_dump())
|
||||
)
|
||||
|
||||
async def score(self, post: Post) -> list[ReplyScoreResult]:
|
||||
convo = await self._to_convo(post)
|
||||
result = await self._sample(convo, post)
|
||||
parsed = await self._parse(result)
|
||||
Metrics.histogram(
|
||||
"ranked_replies_scores",
|
||||
explicit_bucket_boundaries_advisory=[0.0, 1.0, 2.0, 3.0],
|
||||
).record(parsed[0].score)
|
||||
|
||||
return parsed
|
||||
|
||||
async def _to_convo(self, post: Post, non_reasoning: bool = False) -> Conversation:
|
||||
convo = Conversation(conversation_id=uuid.uuid4().hex)
|
||||
system_prompt = ReplyScoringSystem().render(
|
||||
params={"large_account_follower_threshold": ""}
|
||||
)
|
||||
if non_reasoning:
|
||||
system_prompt = "" + system_prompt
|
||||
convo.messages.append(Message(role=Role.SYSTEM, content=[system_prompt]))
|
||||
convo.messages.append(
|
||||
ThreadRenderer.render(
|
||||
post, role=Role.HUMAN, separator="\n\n", include_signals=True
|
||||
)
|
||||
)
|
||||
if non_reasoning:
|
||||
convo.messages.append(
|
||||
Message(role=Role.ASSISTANT, content=[""], separator="")
|
||||
)
|
||||
else:
|
||||
convo.messages.append(Message(role=Role.ASSISTANT, content=[]))
|
||||
return convo
|
||||
|
||||
async def _sample(self, convo: Conversation, post: Post) -> str:
|
||||
output = await self.vlm.sample(
|
||||
convo.interleave(),
|
||||
conversation_id=convo.conversation_id,
|
||||
json_schema=json.dumps(ReplyScoreResult.model_json_schema()),
|
||||
)
|
||||
match = re.search(r"\{.*\}", output, re.DOTALL)
|
||||
if not (match and "score" in match.group(0)):
|
||||
fallback_convo = await self._to_convo(post, non_reasoning=True)
|
||||
output = await self.vlm_fallback.sample(
|
||||
fallback_convo.interleave(),
|
||||
conversation_id=fallback_convo.conversation_id,
|
||||
json_schema=json.dumps(ReplyScoreResult.model_json_schema()),
|
||||
)
|
||||
return output
|
||||
|
||||
async def _clean_output(self, output: str) -> str:
|
||||
if output.endswith("<|eos|>"):
|
||||
output = output.removesuffix("<|eos|>")
|
||||
output = output.strip()
|
||||
if output.startswith("```json"):
|
||||
output = output[7:]
|
||||
elif output.startswith("```"):
|
||||
output = output[3:]
|
||||
if output.endswith("```"):
|
||||
output = output[:-3]
|
||||
output = output.strip()
|
||||
return output
|
||||
|
||||
async def _parse(self, output: str) -> list[ReplyScoreResult]:
|
||||
score = None
|
||||
reason = ""
|
||||
|
||||
match = re.search(r"\{.*\}", output, re.DOTALL)
|
||||
if match and "score" in match.group(0):
|
||||
raw_result = match.group(0).strip()
|
||||
else:
|
||||
raw_result = output
|
||||
|
||||
cleaned_result = await self._clean_output(raw_result)
|
||||
|
||||
try:
|
||||
result = ReplyScoreResult.model_validate_json(cleaned_result)
|
||||
score = result.score
|
||||
reason = result.reason
|
||||
except (ValidationError, ValueError):
|
||||
try:
|
||||
repaired = json_repair.repair_json(cleaned_result, return_objects=True)
|
||||
if isinstance(repaired, dict) and "score" in repaired:
|
||||
result = ReplyScoreResult.model_validate(repaired)
|
||||
score = result.score
|
||||
reason = result.reason
|
||||
Metrics.counter("task.reply_ranker.json_repaired.count").add(1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if score is None:
|
||||
score_match = re.search(r'"score":\s*([\d.]+)', cleaned_result)
|
||||
if score_match:
|
||||
try:
|
||||
score = float(score_match.group(1).strip())
|
||||
except ValueError:
|
||||
score = None
|
||||
Metrics.counter("task.reply_ranker.invalid.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "reply_ranking",
|
||||
"reason": "invalid_score_format",
|
||||
},
|
||||
)
|
||||
if not reason:
|
||||
reason_match = re.search(
|
||||
r'"reason":\s*"((?:[^"\\]|\\.)*)"', cleaned_result, re.DOTALL
|
||||
)
|
||||
if reason_match:
|
||||
reason = reason_match.group(1)
|
||||
|
||||
if not score and score != 0:
|
||||
logger.error(f"Invalid output format: {output}")
|
||||
Metrics.counter("task.reply_ranker.invalid.count").add(
|
||||
1, attributes={"filter": "reply_ranking", "reason": "invalid_format"}
|
||||
)
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
return [ReplyScoreResult(score=score, reason=reason)]
|
||||
288
grox/classifiers/content/safety_ptos.py
Normal file
288
grox/classifiers/content/safety_ptos.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
|
||||
from grox.data_loaders.media_loader import MediaLoader
|
||||
from grox.data_loaders.strato_loader import TweetStratoLoader
|
||||
from grox.lm.convo import Role, Message, Conversation
|
||||
from grox.lm.post import PostRenderer
|
||||
from grox.lm.user import UserRenderer
|
||||
from grox.config.config import ModelName, grox_config
|
||||
from grok_sampler.config import GrokModelConfig, EapiModelConfig
|
||||
from grox.prompts.template import (
|
||||
SafetyPtos,
|
||||
ViolentMediaPolicy,
|
||||
AdultContentPolicy,
|
||||
SpamPolicy,
|
||||
IllegalAndRegulatedBehaviorsPolicy,
|
||||
HateOrAbusePolicy,
|
||||
ViolentSpeechPolicy,
|
||||
SuicideOrSelfHarmPolicy,
|
||||
)
|
||||
from grok_sampler.vision_sampler import VisionSampler
|
||||
from grok_sampler.eapi_sampler import EapiSampler
|
||||
from grox.data_loaders.data_types import (
|
||||
Post,
|
||||
SafetyPostAnnotations,
|
||||
ContentCategoryResult,
|
||||
ContentCategoryType,
|
||||
SafetyPtosViolatedPolicy,
|
||||
SafetyPolicy,
|
||||
SafetyPolicyCategory,
|
||||
)
|
||||
from grox.classifiers.content.classifier import ContentClassifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_THINKING_RESTRICTION_LINES = {
|
||||
"",
|
||||
"",
|
||||
}
|
||||
|
||||
|
||||
def _strip_thinking_restrictions(text: str) -> str:
|
||||
lines = text.splitlines(keepends=True)
|
||||
return "".join(
|
||||
line for line in lines if line.strip() not in _THINKING_RESTRICTION_LINES
|
||||
).lstrip("\n")
|
||||
|
||||
|
||||
def _render_safety_ptos_for_reasoning() -> str:
|
||||
return _strip_thinking_restrictions(SafetyPtos().render())
|
||||
|
||||
|
||||
class SafetyPtosCategoryClassifier(ContentClassifier):
|
||||
result_pattern = re.compile(r"(.*)<json>(.*)</json>", re.DOTALL)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: ModelName = ModelName.VLM_SAFETY,
|
||||
deluxe: bool = False,
|
||||
):
|
||||
self.deluxe = deluxe
|
||||
vlm_config = grox_config.get_model(model_name)
|
||||
vlm_config.temperature = 0.000001
|
||||
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
|
||||
super().__init__(categories=[ContentCategoryType.SAFETY_PTOS], llm=vlm)
|
||||
|
||||
def build_convo(self, post: Post) -> Conversation:
|
||||
convo = Conversation(conversation_id=uuid.uuid4().hex)
|
||||
|
||||
if self.deluxe:
|
||||
convo.messages.append(
|
||||
Message(role=Role.SYSTEM, content=[_render_safety_ptos_for_reasoning()])
|
||||
)
|
||||
else:
|
||||
convo.messages.append(
|
||||
Message(role=Role.SYSTEM, content=[SafetyPtos().render()])
|
||||
)
|
||||
|
||||
user_msg = Message(role=Role.USER, content=[])
|
||||
user_msg.content.extend(UserRenderer.render(post.user))
|
||||
user_msg.content.extend(PostRenderer.render(post, include_reply_to=True))
|
||||
user_msg.content.append(
|
||||
f"\n\nAnalyze the post {post.id} and provide the requested JSON object for the post."
|
||||
)
|
||||
convo.messages.append(user_msg)
|
||||
|
||||
if self.deluxe:
|
||||
convo.messages.append(Message(role=Role.ASSISTANT, content=[]))
|
||||
else:
|
||||
convo.messages.append(
|
||||
Message(role=Role.ASSISTANT, content=[""], separator="")
|
||||
)
|
||||
|
||||
return convo
|
||||
|
||||
async def classify_post(self, post: Post) -> SafetyPostAnnotations:
|
||||
convo = await self._to_convo(post)
|
||||
result = await self._sample(convo, post)
|
||||
mode = "deluxe" if self.deluxe else "standard"
|
||||
logger.info(
|
||||
f"safety ptos category classifier ({mode}) result for {post.id}: {result}"
|
||||
)
|
||||
|
||||
match = self.result_pattern.search(result)
|
||||
if match:
|
||||
json_str = match.group(2).strip()
|
||||
return SafetyPostAnnotations.model_validate_json(json_str)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid output for safety ptos category classifier ({mode}): {result}"
|
||||
)
|
||||
|
||||
async def _to_convo(self, post: Post) -> Conversation:
|
||||
return self.build_convo(post)
|
||||
|
||||
async def _sample(self, convo: Conversation, post: Post = None) -> str:
|
||||
return await self.llm.sample(
|
||||
convo.interleave(), conversation_id=convo.conversation_id
|
||||
)
|
||||
|
||||
async def _parse(self, post: Post, output: str) -> List[ContentCategoryResult]:
|
||||
match = self.result_pattern.search(output)
|
||||
if match:
|
||||
return [
|
||||
ContentCategoryResult(
|
||||
category=ContentCategoryType.SAFETY_PTOS, positive=True, score=0.0
|
||||
)
|
||||
]
|
||||
else:
|
||||
mode = "deluxe" if self.deluxe else "standard"
|
||||
raise ValueError(
|
||||
f"Invalid parsing for safety ptos category classifier ({mode}): {output}"
|
||||
)
|
||||
|
||||
|
||||
class SafetyPtosPolicyClassifier(ContentClassifier):
|
||||
result_pattern = re.compile(r"(.*)<json>(.*)</json>", re.DOTALL)
|
||||
|
||||
def __init__(self, deluxe: bool = False):
|
||||
self.deluxe = deluxe
|
||||
|
||||
vlm_config = grox_config.get_model(ModelName.VLM_PRIMARY_CRITICAL)
|
||||
vlm_config.temperature = 0.000001
|
||||
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
|
||||
super().__init__(categories=[ContentCategoryType.SAFETY_PTOS], llm=vlm)
|
||||
|
||||
if deluxe:
|
||||
eapi_config_reasoning = grox_config.get_eapi_model(
|
||||
ModelName.EAPI_REASONING_INTERNAL
|
||||
)
|
||||
self.eapi_reasoning = EapiSampler(
|
||||
EapiModelConfig(**eapi_config_reasoning.model_dump())
|
||||
)
|
||||
|
||||
eapi_config_reasoning_x_algo = grox_config.get_eapi_model(
|
||||
ModelName.EAPI_REASONING
|
||||
)
|
||||
self.eapi_reasoning_x_algo = EapiSampler(
|
||||
EapiModelConfig(**eapi_config_reasoning_x_algo.model_dump())
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_policy_prompt(violation: SafetyPtosViolatedPolicy) -> str:
|
||||
if violation.category == SafetyPolicyCategory.ViolentMedia:
|
||||
return ViolentMediaPolicy().render()
|
||||
elif violation.category == SafetyPolicyCategory.AdultContent:
|
||||
return AdultContentPolicy().render()
|
||||
elif violation.category == SafetyPolicyCategory.Spam:
|
||||
return SpamPolicy().render()
|
||||
elif violation.category == SafetyPolicyCategory.IllegalAndRegulatedBehaviors:
|
||||
return IllegalAndRegulatedBehaviorsPolicy().render()
|
||||
elif violation.category == SafetyPolicyCategory.HateOrAbuse:
|
||||
return HateOrAbusePolicy().render()
|
||||
elif violation.category == SafetyPolicyCategory.ViolentSpeech:
|
||||
return ViolentSpeechPolicy().render()
|
||||
elif violation.category == SafetyPolicyCategory.SuicideOrSelfHarm:
|
||||
return SuicideOrSelfHarmPolicy().render()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No policy prompt available for category: {violation.category.value}"
|
||||
)
|
||||
|
||||
def build_convo(
|
||||
self, post: Post, violation: SafetyPtosViolatedPolicy
|
||||
) -> Conversation:
|
||||
content = self._get_policy_prompt(violation)
|
||||
if self.deluxe:
|
||||
content = _strip_thinking_restrictions(content)
|
||||
|
||||
convo = Conversation(conversation_id=uuid.uuid4().hex)
|
||||
convo.messages.append(Message(role=Role.SYSTEM, content=[content]))
|
||||
|
||||
user_msg = Message(role=Role.USER, content=[])
|
||||
user_msg.content.extend(UserRenderer.render(post.user))
|
||||
user_msg.content.extend(PostRenderer.render(post, include_reply_to=True))
|
||||
user_msg.content.append(
|
||||
f"\n\nAnalyze the post {post.id} for the specific safety policy violation category: {violation.category.value}"
|
||||
)
|
||||
user_msg.content.append(
|
||||
f"\n\nProvide the requested JSON object for the specific safety policy type."
|
||||
)
|
||||
convo.messages.append(user_msg)
|
||||
|
||||
if self.deluxe:
|
||||
convo.messages.append(Message(role=Role.ASSISTANT, content=[]))
|
||||
else:
|
||||
convo.messages.append(
|
||||
Message(role=Role.ASSISTANT, content=[""], separator="")
|
||||
)
|
||||
|
||||
return convo
|
||||
|
||||
SUPPORTED_POLICY_CATEGORIES = {
|
||||
SafetyPolicyCategory.ViolentMedia,
|
||||
SafetyPolicyCategory.AdultContent,
|
||||
SafetyPolicyCategory.Spam,
|
||||
SafetyPolicyCategory.IllegalAndRegulatedBehaviors,
|
||||
SafetyPolicyCategory.HateOrAbuse,
|
||||
SafetyPolicyCategory.ViolentSpeech,
|
||||
SafetyPolicyCategory.SuicideOrSelfHarm,
|
||||
}
|
||||
|
||||
DELUXE_4_2_CATEGORIES = {
|
||||
SafetyPolicyCategory.AdultContent,
|
||||
SafetyPolicyCategory.ViolentMedia,
|
||||
}
|
||||
|
||||
async def classify_policy_for_violation(
|
||||
self, post: Post, violation: SafetyPtosViolatedPolicy
|
||||
) -> SafetyPolicy | None:
|
||||
|
||||
if violation.category not in self.SUPPORTED_POLICY_CATEGORIES:
|
||||
return None
|
||||
|
||||
convo = await self._to_convo(post, violation)
|
||||
|
||||
if (
|
||||
and self.deluxe
|
||||
and violation.category in self.DELUXE_4_2_CATEGORIES
|
||||
):
|
||||
mode = "deluxe-4.2"
|
||||
result = await self._sample_4_2(convo, post)
|
||||
else:
|
||||
mode = "deluxe" if self.deluxe else "standard"
|
||||
result = await self._sample(convo, post)
|
||||
|
||||
logger.info(
|
||||
f"safety ptos policy classifier ({mode}) result for post {post.id}, violation {violation.category}: {result}"
|
||||
)
|
||||
|
||||
match = self.result_pattern.search(result)
|
||||
if match:
|
||||
json_str = match.group(2).strip()
|
||||
return SafetyPolicy.model_validate_json(json_str)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid output for safety ptos policy ({mode}): {result}"
|
||||
)
|
||||
|
||||
async def _to_convo(
|
||||
self, post: Post, violation: SafetyPtosViolatedPolicy
|
||||
) -> Conversation:
|
||||
return self.build_convo(post, violation)
|
||||
|
||||
async def _sample_4_2(self, convo: Conversation, post: Post) -> str:
|
||||
try:
|
||||
return await self.eapi_reasoning_x_algo.sample(
|
||||
convo.interleaveToEapi(), conversation_id=convo.conversation_id
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Failed to call 4.2 reasoning, error: {traceback.format_exc()}"
|
||||
)
|
||||
return await self.llm.sample(
|
||||
convo.interleave(), conversation_id=convo.conversation_id
|
||||
)
|
||||
|
||||
async def _sample(self, convo: Conversation, post: Post = None) -> str:
|
||||
return await self.llm.sample(
|
||||
convo.interleave(), conversation_id=convo.conversation_id
|
||||
)
|
||||
|
||||
async def _parse(self, post: Post, output: str) -> List[ContentCategoryResult]:
|
||||
return []
|
||||
104
grox/classifiers/content/spam.py
Normal file
104
grox/classifiers/content/spam.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from pydantic import ValidationError
|
||||
|
||||
from grox.lm.convo import Role, Message, Conversation
|
||||
from grox.lm.thread import ThreadRenderer
|
||||
from grox.config.config import ModelName, grox_config
|
||||
from grok_sampler.config import GrokModelConfig
|
||||
from grox.prompts.template import SpamSystemLowFollower
|
||||
from grok_sampler.vision_sampler import VisionSampler
|
||||
from grox.data_loaders.data_types import (
|
||||
Post,
|
||||
ContentCategoryType,
|
||||
ContentCategoryResult,
|
||||
SpamSampleResult,
|
||||
)
|
||||
from grox.classifiers.content.classifier import ContentClassifier
|
||||
from grox.data_loaders.strato_loader import TweetStratoLoader
|
||||
from grox.data_loaders.media_loader import MediaLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpamEapiLowFollowerClassifier(ContentClassifier):
|
||||
def __init__(self, model_name: ModelName = ModelName.VLM_PRIMARY):
|
||||
vlm_config = grox_config.get_model(model_name)
|
||||
vlm_config.temperature = 0.000001
|
||||
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
|
||||
super().__init__(categories=[ContentCategoryType.SPAM_COMMENT], llm=vlm)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return "grox"
|
||||
|
||||
async def _classify(self, post: Post) -> list[ContentCategoryResult]:
|
||||
convo = await self._to_convo(post)
|
||||
result = await self._sample(convo)
|
||||
parsed = await self._parse(post, result)
|
||||
filtered_parsed = [
|
||||
res for res in parsed if res.category == ContentCategoryType.SPAM_COMMENT
|
||||
]
|
||||
assert len(filtered_parsed) == 1
|
||||
return filtered_parsed
|
||||
|
||||
async def _to_convo(self, post: Post) -> Conversation:
|
||||
convo = Conversation(conversation_id=uuid.uuid4().hex)
|
||||
convo.messages.append(
|
||||
Message(role=Role.SYSTEM, content=[SpamSystemLowFollower().render()])
|
||||
)
|
||||
convo.messages.append(
|
||||
ThreadRenderer.render(post, role=Role.HUMAN, separator="\n\n")
|
||||
)
|
||||
return convo
|
||||
|
||||
async def _sample(self, convo: Conversation) -> str:
|
||||
return await self.llm.sample(
|
||||
convo.interleave(), conversation_id=convo.conversation_id
|
||||
)
|
||||
|
||||
async def _clean_output(self, output: str) -> str:
|
||||
if output.endswith("<|eos|>"):
|
||||
output = output.removesuffix("<|eos|>")
|
||||
output = output.strip()
|
||||
if output.startswith("```json"):
|
||||
output = output[7:]
|
||||
elif output.startswith("```"):
|
||||
output = output[3:]
|
||||
if output.endswith("```"):
|
||||
output = output[:-3]
|
||||
output = output.strip()
|
||||
return output
|
||||
|
||||
async def _parse(self, post: Post, output: str) -> list[ContentCategoryResult]:
|
||||
decision = None
|
||||
summary = ""
|
||||
|
||||
cleaned_result = await self._clean_output(output)
|
||||
try:
|
||||
result = SpamSampleResult.model_validate_json(cleaned_result)
|
||||
decision = result.decision
|
||||
summary = result.reason
|
||||
except ValidationError:
|
||||
match = re.search(r'"decision":\s*"(.*?)"', cleaned_result)
|
||||
if match:
|
||||
decision = match.group(1).strip()
|
||||
|
||||
if not decision:
|
||||
raise ValueError(f"Invalid output format: {output}")
|
||||
|
||||
is_spam = decision == "spam"
|
||||
score = 1.0 if is_spam else 0.0
|
||||
|
||||
if is_spam:
|
||||
logger.info(f"Spam found for low follower user: {post.id}")
|
||||
|
||||
return [
|
||||
ContentCategoryResult(
|
||||
category=ContentCategoryType.SPAM_COMMENT,
|
||||
positive=is_spam,
|
||||
score=score,
|
||||
summary=summary,
|
||||
)
|
||||
]
|
||||
394
grox/data_loaders/asr_processor.py
Normal file
394
grox/data_loaders/asr_processor.py
Normal file
@@ -0,0 +1,394 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from multiprocessing import Event, Process, Queue
|
||||
from multiprocessing.synchronize import Event as MultiprocessingEvent
|
||||
from queue import Empty
|
||||
|
||||
import aiohttp
|
||||
from cachetools import TTLCache
|
||||
from pydantic import BaseModel
|
||||
|
||||
from grox.config.config import grox_config
|
||||
from grox.schedules.init import init_proc
|
||||
from monitor.logging import Logging
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _ASRRequest(BaseModel):
|
||||
post_id: str
|
||||
video_url: str
|
||||
max_audio_duration_s: float
|
||||
|
||||
|
||||
class _ASRResult(BaseModel):
|
||||
post_id: str
|
||||
transcript: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def _extract_wav_from_url(
|
||||
video_url: str, max_duration_s: float | None = None
|
||||
) -> bytes | None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wav_path = os.path.join(tmpdir, "audio.wav")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-timeout",
|
||||
"60000000",
|
||||
"-rw_timeout",
|
||||
"60000000",
|
||||
"-reconnect",
|
||||
"1",
|
||||
"-reconnect_streamed",
|
||||
"1",
|
||||
"-reconnect_delay_max",
|
||||
"5",
|
||||
"-i",
|
||||
video_url,
|
||||
"-vn",
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
"16000",
|
||||
"-ac",
|
||||
"1",
|
||||
]
|
||||
if max_duration_s is not None and max_duration_s > 0:
|
||||
cmd += ["-t", str(max_duration_s)]
|
||||
cmd.append(wav_path)
|
||||
result = subprocess.run(cmd, capture_output=True, timeout=180)
|
||||
if result.returncode != 0:
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
raise subprocess.CalledProcessError(
|
||||
result.returncode, cmd, result.stdout, result.stderr
|
||||
)
|
||||
with open(wav_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _clean_asr(raw: str) -> str:
|
||||
if "<asr_text>" in raw:
|
||||
raw = raw.split("<asr_text>", 1)[1]
|
||||
if "</asr_text>" in raw:
|
||||
raw = raw.split("</asr_text>", 1)[0]
|
||||
return raw.strip()
|
||||
|
||||
|
||||
class _ASRWorker:
|
||||
def __init__(
|
||||
self, task_queue: Queue, resp_queue: Queue, shutdown_event: MultiprocessingEvent
|
||||
):
|
||||
self._task_queue: Queue[tuple[_ASRRequest, dict[str, str]]] = task_queue
|
||||
self._resp_queue: Queue[_ASRResult] = resp_queue
|
||||
self._shutdown_event: MultiprocessingEvent = shutdown_event
|
||||
|
||||
async def _transcribe(self, request: _ASRRequest) -> str | None:
|
||||
asr_config = grox_config.asr
|
||||
|
||||
t_start = time.monotonic()
|
||||
loop = asyncio.get_event_loop()
|
||||
wav_bytes = await loop.run_in_executor(
|
||||
None, _extract_wav_from_url, request.video_url, request.max_audio_duration_s
|
||||
)
|
||||
t_extract = time.monotonic() - t_start
|
||||
Metrics.histogram("asr_proc.extract_duration_s").record(t_extract)
|
||||
if wav_bytes is None:
|
||||
return None
|
||||
Metrics.histogram("asr_proc.audio_bytes").record(len(wav_bytes))
|
||||
logger.debug(
|
||||
f"Extracted audio in {t_extract:.2f}s, size={len(wav_bytes)} bytes"
|
||||
)
|
||||
|
||||
t_start = time.monotonic()
|
||||
b64_audio = base64.b64encode(wav_bytes).decode()
|
||||
body = {
|
||||
"model": "default",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {"url": f"data:audio/wav;base64,{b64_audio}"},
|
||||
},
|
||||
{"type": "text", "text": "Transcribe this audio."},
|
||||
],
|
||||
}
|
||||
],
|
||||
"temperature": asr_config.temperature,
|
||||
"max_tokens": asr_config.max_tokens,
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
f"{asr_config.endpoint}/v1/chat/completions",
|
||||
json=body,
|
||||
timeout=aiohttp.ClientTimeout(total=asr_config.timeout),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
raw_transcript = data["choices"][0]["message"]["content"].strip()
|
||||
transcript = _clean_asr(raw_transcript)
|
||||
|
||||
t_transcribe = time.monotonic() - t_start
|
||||
Metrics.histogram("asr_proc.transcribe_duration_s").record(t_transcribe)
|
||||
Metrics.histogram("asr_proc.transcript_chars").record(len(transcript))
|
||||
|
||||
if "usage" in data:
|
||||
usage = data["usage"]
|
||||
if "prompt_tokens" in usage:
|
||||
Metrics.histogram("asr_proc.prompt_tokens").record(
|
||||
usage["prompt_tokens"]
|
||||
)
|
||||
if "completion_tokens" in usage:
|
||||
Metrics.histogram("asr_proc.completion_tokens").record(
|
||||
usage["completion_tokens"]
|
||||
)
|
||||
if "total_tokens" in usage:
|
||||
Metrics.histogram("asr_proc.total_tokens").record(
|
||||
usage["total_tokens"]
|
||||
)
|
||||
|
||||
return transcript
|
||||
else:
|
||||
error_text = await resp.text()
|
||||
raise Exception(
|
||||
f"ASR request failed with status {resp.status}: {error_text}"
|
||||
)
|
||||
|
||||
async def _process(self, request: _ASRRequest, ctx: dict[str, str]) -> None:
|
||||
attributes = {"pid": str(os.getpid())}
|
||||
with Metrics.tracer("asr_proc").start_as_current_span("asr.process"):
|
||||
Logging.set_context(**ctx)
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
Metrics.counter("asr_proc.total.count").add(1, attributes=attributes)
|
||||
transcript = await self._transcribe(request)
|
||||
if transcript is None:
|
||||
logger.debug(
|
||||
f"Video has no audio stream for post {request.post_id}, skipping ASR"
|
||||
)
|
||||
Metrics.counter("asr_proc.skip.count").add(
|
||||
1, attributes={**attributes, "reason": "no_audio_stream"}
|
||||
)
|
||||
self._resp_queue.put(
|
||||
_ASRResult(post_id=request.post_id, error="no_audio_stream")
|
||||
)
|
||||
else:
|
||||
Metrics.counter("asr_proc.success.count").add(
|
||||
1, attributes=attributes
|
||||
)
|
||||
self._resp_queue.put(
|
||||
_ASRResult(post_id=request.post_id, transcript=transcript)
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
f"FFmpeg timeout extracting audio for post {request.post_id}"
|
||||
)
|
||||
Metrics.counter("asr_proc.error.count").add(
|
||||
1, attributes={**attributes, "reason": "ffmpeg_timeout"}
|
||||
)
|
||||
self._resp_queue.put(
|
||||
_ASRResult(post_id=request.post_id, error="ffmpeg_timeout")
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = e.stderr.decode() if e.stderr else str(e)
|
||||
logger.warning(f"FFmpeg error for post {request.post_id}: {error_msg}")
|
||||
Metrics.counter("asr_proc.error.count").add(
|
||||
1, attributes={**attributes, "reason": "ffmpeg_error"}
|
||||
)
|
||||
self._resp_queue.put(
|
||||
_ASRResult(
|
||||
post_id=request.post_id, error=f"ffmpeg_error: {error_msg}"
|
||||
)
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"ASR request timed out for post {request.post_id}")
|
||||
Metrics.counter("asr_proc.error.count").add(
|
||||
1, attributes={**attributes, "reason": "asr_timeout"}
|
||||
)
|
||||
self._resp_queue.put(
|
||||
_ASRResult(post_id=request.post_id, error="asr_timeout")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"ASR processing failed for post {request.post_id}: {traceback.format_exc()}"
|
||||
)
|
||||
Metrics.counter("asr_proc.error.count").add(
|
||||
1, attributes={**attributes, "reason": "unknown"}
|
||||
)
|
||||
self._resp_queue.put(_ASRResult(post_id=request.post_id, error=str(e)))
|
||||
finally:
|
||||
end = time.perf_counter()
|
||||
Metrics.histogram("asr_proc.duration").record(end - start)
|
||||
|
||||
async def _init_run(self) -> None:
|
||||
await init_proc("asr_proc")
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
async def _run(self) -> None:
|
||||
logger.info("starting ASR worker process loop")
|
||||
pending: set[asyncio.Task] = set()
|
||||
while not self._is_shutdown() or not self._task_queue.empty():
|
||||
try:
|
||||
request, ctx = self._task_queue.get(block=False)
|
||||
except Empty:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
try:
|
||||
task = asyncio.create_task(self._process(request, ctx))
|
||||
pending.add(task)
|
||||
task.add_done_callback(pending.discard)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"error processing ASR request {request.post_id}: {traceback.format_exc()}"
|
||||
)
|
||||
if pending:
|
||||
logger.info(f"ASR worker draining {len(pending)} in-flight tasks")
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
logger.warning("ASR worker process loop done")
|
||||
|
||||
def run(self) -> None:
|
||||
async def wrapper():
|
||||
await self._init_run()
|
||||
try:
|
||||
await self._run()
|
||||
finally:
|
||||
await self._session.close()
|
||||
|
||||
asyncio.run(wrapper())
|
||||
|
||||
def _start_loop(self) -> Process:
|
||||
process = Process(target=self.run)
|
||||
process.start()
|
||||
return process
|
||||
|
||||
def start(self) -> list[Process]:
|
||||
return [self._start_loop() for _ in range(grox_config.asr.max_workers)]
|
||||
|
||||
def _is_shutdown(self) -> bool:
|
||||
try:
|
||||
return self._shutdown_event.is_set()
|
||||
except BrokenPipeError:
|
||||
logger.error("Broken pipe error, assuming shutdown")
|
||||
return True
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error checking shutdown event, assuming shutdown: {traceback.format_exc()}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class ASRProcessor:
|
||||
_task_queue: Queue = Queue()
|
||||
_resp_queue: Queue = Queue()
|
||||
_shutdown_event = Event()
|
||||
_inflights: dict[str, asyncio.Future[str | None]] = {}
|
||||
_workers: list[Process] = []
|
||||
_initialized = False
|
||||
_result_task: asyncio.Task | None = None
|
||||
_cache: TTLCache = TTLCache(maxsize=1_000, ttl=300)
|
||||
|
||||
@classmethod
|
||||
async def process(
|
||||
cls, post_id: str, video_url: str, max_audio_duration_s: float | None = None
|
||||
) -> str | None:
|
||||
if not cls._initialized:
|
||||
raise RuntimeError("ASR processor not initialized")
|
||||
|
||||
cached = cls._cache.get(post_id)
|
||||
if cached is not None:
|
||||
Metrics.counter("asr_proc.cache_hit.count").add(1)
|
||||
return cached
|
||||
|
||||
if max_audio_duration_s is None:
|
||||
max_audio_duration_s = grox_config.asr.max_audio_duration_s
|
||||
|
||||
future = cls._submit(post_id, video_url, max_audio_duration_s)
|
||||
transcript = await future
|
||||
|
||||
if transcript is not None:
|
||||
cls._cache[post_id] = transcript
|
||||
|
||||
return transcript
|
||||
|
||||
@classmethod
|
||||
def _submit(
|
||||
cls, post_id: str, video_url: str, max_audio_duration_s: float
|
||||
) -> asyncio.Future[str | None]:
|
||||
if post_id in cls._inflights:
|
||||
return cls._inflights[post_id]
|
||||
|
||||
request = _ASRRequest(
|
||||
post_id=post_id,
|
||||
video_url=video_url,
|
||||
max_audio_duration_s=max_audio_duration_s,
|
||||
)
|
||||
cls._task_queue.put((request, Logging.get_context()))
|
||||
|
||||
future: asyncio.Future[str | None] = asyncio.get_running_loop().create_future()
|
||||
cls._inflights[post_id] = future
|
||||
return future
|
||||
|
||||
@classmethod
|
||||
async def _result_loop(cls) -> None:
|
||||
logger.info("ASR processor result loop started")
|
||||
while not cls._shutdown_event.is_set() or cls._inflights:
|
||||
try:
|
||||
result: _ASRResult = cls._resp_queue.get(block=False)
|
||||
future = cls._inflights.pop(result.post_id, None)
|
||||
if not future:
|
||||
logger.warning(f"no future found for post {result.post_id}")
|
||||
continue
|
||||
if result.error:
|
||||
if result.error == "no_audio_stream":
|
||||
logger.debug(
|
||||
f"ASR skipped for post {result.post_id}: no audio stream"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"ASR failed for post {result.post_id}: {result.error}"
|
||||
)
|
||||
future.set_result(None)
|
||||
else:
|
||||
future.set_result(result.transcript)
|
||||
except Empty:
|
||||
await asyncio.sleep(0.01)
|
||||
except Exception:
|
||||
logger.error(f"Error processing ASR result: {traceback.format_exc()}")
|
||||
logger.warning("ASR processor result loop done")
|
||||
|
||||
@classmethod
|
||||
def start(cls) -> None:
|
||||
if cls._initialized:
|
||||
logger.warning("ASR processor already initialized")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"starting ASR processor with {grox_config.asr.max_workers} workers"
|
||||
)
|
||||
cls._workers = _ASRWorker(
|
||||
cls._task_queue, cls._resp_queue, cls._shutdown_event
|
||||
).start()
|
||||
cls._result_task = asyncio.create_task(cls._result_loop())
|
||||
cls._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def stop(cls, timeout: float = 5) -> None:
|
||||
logger.warning("stopping ASR processor")
|
||||
cls._shutdown_event.set()
|
||||
for worker in cls._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout)
|
||||
if cls._result_task and not cls._result_task.done():
|
||||
cls._result_task.cancel()
|
||||
logger.warning("ASR processor stopped")
|
||||
232
grox/data_loaders/kafka_loader.py
Normal file
232
grox/data_loaders/kafka_loader.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
from abc import abstractmethod
|
||||
from typing import override
|
||||
from collections.abc import AsyncGenerator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from aiokafka import TopicPartition
|
||||
from kafka_cli.config import KafkaMessage
|
||||
from grox.config.config import KafkaTopicName, grox_config
|
||||
from kafka_cli.consumer import KafkaConsumer
|
||||
from grox.data_loaders.data_types import Post, User, GroxContentAnalysis
|
||||
from grox.data_loaders.message_queue_loader import (
|
||||
MessageQueueLoader,
|
||||
MessageQueuePayload,
|
||||
)
|
||||
from monitor.metrics import Metrics
|
||||
from thrifts.serdes import SerDesError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MAX_WORKING_THREADS = 12
|
||||
|
||||
|
||||
class _Payload(MessageQueuePayload):
|
||||
tp: TopicPartition
|
||||
offset: int
|
||||
|
||||
|
||||
class KafkaLoader(MessageQueueLoader):
|
||||
def __init__(self, topic_name: KafkaTopicName):
|
||||
super().__init__()
|
||||
self._initialized = False
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self.topic_name = topic_name
|
||||
self.loader_config = grox_config.get_kafka_loader_topic(topic_name)
|
||||
self.consumer_config = grox_config.get_kafka_consumer_topic(topic_name)
|
||||
self.consumer = KafkaConsumer(self.consumer_config)
|
||||
self.loaded_messages: dict[str, tuple[TopicPartition, int]] = {}
|
||||
self.queue: asyncio.Queue[MessageQueuePayload] = asyncio.Queue()
|
||||
self._prefetcher_task: asyncio.Task | None = None
|
||||
|
||||
def _is_shutdown(self) -> bool:
|
||||
try:
|
||||
return self._shutdown_event.is_set()
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error checking if KafkaLoader is shutdown: {traceback.format_exc()}"
|
||||
)
|
||||
return True
|
||||
|
||||
async def start(self):
|
||||
logger.info(f"Initializing KafkaLoader, topic: {self.topic_name}")
|
||||
self._initialized = True
|
||||
await self.consumer.start()
|
||||
self._prefetcher_task = asyncio.create_task(self._prefetcher())
|
||||
self._initialized = True
|
||||
logger.info(f"KafkaLoader initialized, topic: {self.topic_name}")
|
||||
|
||||
async def stop(self):
|
||||
logger.warning(f"Stopping KafkaLoader, topic: {self.topic_name}")
|
||||
self._shutdown_event.set()
|
||||
try:
|
||||
if self._prefetcher_task:
|
||||
await asyncio.wait_for(self._prefetcher_task, 5)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Waiting prefetcher to stop timed out, topic: {self.topic_name}"
|
||||
)
|
||||
await self.consumer.stop()
|
||||
logger.warning(f"KafkaLoader stopped, topic: {self.topic_name}")
|
||||
|
||||
async def poll(self) -> AsyncGenerator[MessageQueuePayload | None, None]:
|
||||
while not self._shutdown_event.is_set() or not self.queue.empty():
|
||||
try:
|
||||
yield self.queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
logger.debug(
|
||||
f"Queue is empty, waiting for prefetcher to fill, topic: {self.topic_name}"
|
||||
)
|
||||
yield None
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error polling from kafka, topic: {self.topic_name}, error: {traceback.format_exc()}"
|
||||
)
|
||||
yield None
|
||||
|
||||
async def ack(self, mid: str, success: bool = True):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
|
||||
pass
|
||||
|
||||
def _process_messages(self, messages: list[KafkaMessage]) -> list[_Payload]:
|
||||
group_size = max(1, len(messages) // MAX_WORKING_THREADS)
|
||||
message_groups = [
|
||||
messages[i : i + group_size] for i in range(0, len(messages), group_size)
|
||||
]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=MAX_WORKING_THREADS) as executor:
|
||||
payloads = []
|
||||
for result in executor.map(self._messages_to_payloads, message_groups):
|
||||
payloads.extend(result)
|
||||
return payloads
|
||||
|
||||
async def _prefetcher(self) -> None:
|
||||
logger.info(f"Starting prefetcher, topic: {self.topic_name}")
|
||||
prefetching_threshold = self.loader_config.prefetching_threshold
|
||||
prefetching_batch_size = self.loader_config.prefetching_batch_size
|
||||
while not self._is_shutdown():
|
||||
if self.queue.qsize() < prefetching_threshold:
|
||||
logger.debug(
|
||||
f"Inventory low at {self.queue.qsize()}, prefetching {prefetching_batch_size} messages, topic: {self.topic_name}"
|
||||
)
|
||||
try:
|
||||
messages = await self.consumer.poll(prefetching_batch_size)
|
||||
try:
|
||||
payloads = self._process_messages(messages)
|
||||
except SerDesError:
|
||||
logger.error(
|
||||
f"Error processing messages, error: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.queue.put(
|
||||
MessageQueuePayload(
|
||||
mid=payload.mid,
|
||||
user=payload.user,
|
||||
post=payload.post,
|
||||
user_context=payload.user_context,
|
||||
grox_content_analysis=payload.grox_content_analysis,
|
||||
deadline_ts_secs=payload.deadline_ts_secs,
|
||||
)
|
||||
)
|
||||
for payload in payloads
|
||||
]
|
||||
)
|
||||
logger.debug(
|
||||
f"Prefetched {prefetching_batch_size} messages, inventory now at {self.queue.qsize()}, topic: {self.topic_name}"
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error prefetching messages, error: {traceback.format_exc()}"
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
logger.warning("Prefetcher stopped")
|
||||
|
||||
|
||||
class KafkaPostLoader(KafkaLoader):
|
||||
def __init__(self, topic_name: KafkaTopicName):
|
||||
super().__init__(topic_name)
|
||||
|
||||
@override
|
||||
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
|
||||
return [
|
||||
_Payload(
|
||||
mid=uuid.uuid4().hex,
|
||||
post=Post.from_thrift_content_understanding_metadata(message.value),
|
||||
tp=message.tp,
|
||||
offset=message.offset,
|
||||
deadline_ts_secs=int(time.time())
|
||||
+ self.loader_config.task_deadline_secs,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
|
||||
class KafkaPostEmbeddingRequestLoader(KafkaLoader):
|
||||
def __init__(self, topic_name: KafkaTopicName):
|
||||
super().__init__(topic_name)
|
||||
|
||||
@override
|
||||
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
|
||||
return [
|
||||
_Payload(
|
||||
mid=uuid.uuid4().hex,
|
||||
post=Post.from_thrift_post_embedding_request(message.value),
|
||||
tp=message.tp,
|
||||
offset=message.offset,
|
||||
deadline_ts_secs=int(time.time())
|
||||
+ self.loader_config.task_deadline_secs,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
|
||||
class KafkaGroxContentAnalysisLoader(KafkaLoader):
|
||||
def __init__(self, topic_name: KafkaTopicName):
|
||||
super().__init__(topic_name)
|
||||
|
||||
@override
|
||||
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
|
||||
return [
|
||||
_Payload(
|
||||
mid=uuid.uuid4().hex,
|
||||
grox_content_analysis=GroxContentAnalysis.from_thrift_content_understanding_metadata(
|
||||
message.value
|
||||
),
|
||||
tp=message.tp,
|
||||
offset=message.offset,
|
||||
deadline_ts_secs=int(time.time())
|
||||
+ self.loader_config.task_deadline_secs,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
|
||||
class KafkaTweetEmbeddingLoader(KafkaLoader):
|
||||
def __init__(self, topic_name: KafkaTopicName):
|
||||
super().__init__(topic_name)
|
||||
|
||||
@override
|
||||
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
|
||||
return [
|
||||
_Payload(
|
||||
mid=uuid.uuid4().hex,
|
||||
post=Post.from_thrift_tweet_embedding(message.value),
|
||||
tp=message.tp,
|
||||
offset=message.offset,
|
||||
deadline_ts_secs=int(time.time())
|
||||
+ self.loader_config.task_deadline_secs,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
39
grox/data_loaders/message_queue_loader.py
Normal file
39
grox/data_loaders/message_queue_loader.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from grox.data_loaders.data_types import Post, User, UserContext, GroxContentAnalysis
|
||||
from collections.abc import AsyncGenerator
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageQueuePayload(BaseModel):
|
||||
mid: str
|
||||
post: Post | None = None
|
||||
user: User | None = None
|
||||
user_context: UserContext | None = None
|
||||
grox_content_analysis: GroxContentAnalysis | None = None
|
||||
|
||||
deadline_ts_secs: int
|
||||
|
||||
|
||||
class MessageQueueLoader(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def poll(self) -> AsyncGenerator[MessageQueuePayload | None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def ack(self, mid: str, success: bool = True):
|
||||
pass
|
||||
154
grox/data_loaders/strato_loader.py
Normal file
154
grox/data_loaders/strato_loader.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from grox.data_loaders.data_types import Post, User
|
||||
from strato_http.queries.data_types import (
|
||||
ReplyRankingScore,
|
||||
ReplyRankingScoreKafka,
|
||||
)
|
||||
from strato_http.queries.content_understanding_author_metadata import (
|
||||
StratoContentUnderstandingAuthorMetadata,
|
||||
)
|
||||
from strato_http.queries.content_understanding_post_quote_metadata import (
|
||||
StratoContentUnderstandingPostQuoteMetadata,
|
||||
)
|
||||
from strato_http.queries.content_understanding_metadata_v2 import (
|
||||
StratoContentUnderstandingMetadataV2,
|
||||
)
|
||||
from strato_http.queries.reply_ranking_score import StratoReplyRankingScore
|
||||
from strato_http.queries.reply_spam_annotation import StratoReplySpamAnnotation
|
||||
from strato_http.queries.reply_ranking_score_kafka_v2 import (
|
||||
StratoReplyRankingScoreV2Kafka,
|
||||
)
|
||||
from strato_http.queries.safety_label import StratoSafetyLabel
|
||||
from strato_http.queries.user_recent_posts import StratoUserRecentPosts
|
||||
from grox.data_loaders.mappers.post_mapper import PostMapper
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TweetStratoLoader:
|
||||
content_understanding_metadata_strato = StratoContentUnderstandingMetadataV2()
|
||||
content_understanding_post_quote_metadata_strato = (
|
||||
StratoContentUnderstandingPostQuoteMetadata()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def load_post(
|
||||
cls, tweet_id: str, include_ancestors: bool = True
|
||||
) -> Post | None:
|
||||
if include_ancestors:
|
||||
content_understanding_metadata = (
|
||||
await cls.content_understanding_metadata_strato.fetch(int(tweet_id))
|
||||
)
|
||||
if content_understanding_metadata:
|
||||
post = PostMapper.from_strato_content_understanding_metadata(
|
||||
content_understanding_metadata
|
||||
)
|
||||
return post
|
||||
else:
|
||||
post_with_quote_metadata = (
|
||||
await cls.content_understanding_post_quote_metadata_strato.fetch(
|
||||
int(tweet_id)
|
||||
)
|
||||
)
|
||||
if post_with_quote_metadata:
|
||||
post = PostMapper.from_strato_post_with_quote_metadata(
|
||||
post_with_quote_metadata
|
||||
)
|
||||
return post
|
||||
return None
|
||||
|
||||
|
||||
class UserStratoLoader:
|
||||
strato = StratoContentUnderstandingAuthorMetadata()
|
||||
|
||||
@classmethod
|
||||
async def load_user(cls, user_id: int) -> User | None:
|
||||
strato_user = await cls.strato.fetch(user_id)
|
||||
if not strato_user:
|
||||
logger.warning(f"failed to hydrate user with {user_id=}, not found")
|
||||
return None
|
||||
return PostMapper._from_strato_user_metadata_to_user(strato_user)
|
||||
|
||||
|
||||
class ReplyRankingScoreStratoLoader:
|
||||
strato = StratoReplyRankingScore()
|
||||
reply_ranking_v2_kafka_strato = StratoReplyRankingScoreV2Kafka()
|
||||
|
||||
@classmethod
|
||||
async def save_reply_ranking_score(
|
||||
cls, post_id: str, reply_ranking_score: ReplyRankingScore
|
||||
):
|
||||
await cls.strato.put(int(post_id), reply_ranking_score)
|
||||
|
||||
@classmethod
|
||||
async def save_reply_ranking_kafka_v2(
|
||||
cls, post_id: str, reply_ranking_score_kafka: ReplyRankingScoreKafka
|
||||
):
|
||||
await cls.reply_ranking_v2_kafka_strato.insert(
|
||||
int(post_id), reply_ranking_score_kafka
|
||||
)
|
||||
|
||||
|
||||
class ReplySpamStratoLoader:
|
||||
strato = StratoReplySpamAnnotation()
|
||||
|
||||
@classmethod
|
||||
async def save_spam_reply_annotation(
|
||||
cls, post_id: str, score: float, positive: bool, reason: str
|
||||
):
|
||||
await cls.strato.put(int(post_id), score, positive, reason)
|
||||
|
||||
|
||||
class UserRecentPostsLoader:
|
||||
recent_posts_strato = StratoUserRecentPosts()
|
||||
post_hydrator = StratoContentUnderstandingPostQuoteMetadata()
|
||||
safety_label = StratoSafetyLabel()
|
||||
|
||||
@classmethod
|
||||
async def load(cls, user_id: int, limit: int = 10) -> list[Post]:
|
||||
res = await cls.recent_posts_strato.fetch(
|
||||
user_id, limit=limit, max_per_type=limit
|
||||
)
|
||||
if not res or "v" not in res:
|
||||
logger.warning(f"No recent posts found for {user_id=}")
|
||||
return []
|
||||
|
||||
post_ids: list[int] = []
|
||||
for _post_type, posts in res["v"]:
|
||||
for post in posts:
|
||||
if _post_type == "TypeRetweet":
|
||||
if "inReactionToPostId" in post:
|
||||
post_ids.append(post["inReactionToPostId"])
|
||||
else:
|
||||
if "postId" in post:
|
||||
post_ids.append(post["postId"])
|
||||
|
||||
if not post_ids:
|
||||
return []
|
||||
|
||||
tasks = [cls.post_hydrator.fetch(post_id) for post_id in post_ids]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
hydrated: list[Post] = []
|
||||
for post_id, result in zip(post_ids, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(
|
||||
f"Failed to hydrate recent post {post_id} for {user_id=}: {result}"
|
||||
)
|
||||
continue
|
||||
if result is None:
|
||||
continue
|
||||
try:
|
||||
hydrated.append(PostMapper.from_strato_post_with_quote_metadata(result))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to map recent post {post_id} for {user_id=}", exc_info=True
|
||||
)
|
||||
|
||||
for post in hydrated:
|
||||
post.safety_labels = await cls.safety_label.scan(post.id)
|
||||
|
||||
return hydrated
|
||||
370
grox/dispatcher.py
Normal file
370
grox/dispatcher.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import multiprocessing
|
||||
from queue import Empty, Queue
|
||||
from threading import Event
|
||||
from multiprocessing import Process
|
||||
|
||||
from tenacity import retry, wait_incrementing, stop_after_attempt
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
from grox.config.config import TaskGeneratorType, grox_config
|
||||
from grox.schedules.init import init_proc
|
||||
from grox.schedules.types import TaskResult, TaskPayload
|
||||
from grox.schedules.context import ScheduleContext
|
||||
from grox.generators.task_generator import TaskGenerator, PriorityTaskGenerator
|
||||
from grox.generators.stream_generator import (
|
||||
PostStreamTaskGenerator,
|
||||
PostStreamRecoveryTaskGenerator,
|
||||
PostStreamTestTaskGenerator,
|
||||
PostStreamDelayedTaskGenerator,
|
||||
PostSafetyStreamTaskGenerator,
|
||||
ReplyRankingRecoveryTaskGenerator,
|
||||
PostEmbeddingRequestWithSummaryStreamTaskGenerator,
|
||||
PostEmbeddingRequestWithSummaryRecoveryStreamTaskGenerator,
|
||||
MinTractionPostStreamForGroxTaskGenerator,
|
||||
MinTractionPostStreamForGroxPtosTaskGenerator,
|
||||
PostEmbeddingV5StreamTaskGenerator,
|
||||
PostEmbeddingV5ForReplyStreamTaskGenerator,
|
||||
MinTractionPostStreamForGroxMultiModalTaskGenerator,
|
||||
PostEmbeddingRequestWithSummaryForReplyRecoveryStreamTaskGenerator,
|
||||
SafetyPtosRecoveryStreamTaskGenerator,
|
||||
SafetyPtosDeluxeStreamTaskGenerator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
def __init__(self, context: ScheduleContext):
|
||||
self.config = grox_config.dispatcher
|
||||
self.context = context
|
||||
self._task_queue: Queue[TaskPayload] = self.context["task_queue"]
|
||||
self._resp_queue: Queue[TaskResult] = self.context["resp_queue"]
|
||||
self._shutdown_event: Event = self.context["shutdown_event"]
|
||||
self._queue_connection_shutdown_event: Event = self.context[
|
||||
"queue_connection_shutdown_event"
|
||||
]
|
||||
self._process = None
|
||||
|
||||
def _is_shutdown(self) -> bool:
|
||||
try:
|
||||
return self._shutdown_event.is_set()
|
||||
except BrokenPipeError:
|
||||
logger.error("Broken pipe error, assuming shutdown")
|
||||
return True
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error checking shutdown event, assuming shutdown: {traceback.format_exc()}"
|
||||
)
|
||||
return True
|
||||
|
||||
def _is_queue_connection_shutdown(self) -> bool:
|
||||
try:
|
||||
return self._queue_connection_shutdown_event.is_set()
|
||||
except BrokenPipeError:
|
||||
logger.error("Broken pipe error, assuming queue connection shutdown")
|
||||
return True
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error checking shutdown event, assuming queue connection shutdown: {traceback.format_exc()}"
|
||||
)
|
||||
return True
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_incrementing(start=1, increment=3, max=9)
|
||||
)
|
||||
async def _init_run(self):
|
||||
await init_proc("dispatcher")
|
||||
self._in_flights: set[str] = set()
|
||||
self._task_generator = self._get_task_generators()
|
||||
await self._task_generator.start()
|
||||
|
||||
def _get_task_generators(self) -> TaskGenerator:
|
||||
generators: list[tuple[TaskGenerator, int]] = []
|
||||
for task_generator_config in self.config.task_generators:
|
||||
match task_generator_config.type:
|
||||
case TaskGeneratorType.POST_STREAM:
|
||||
generators.append(
|
||||
(
|
||||
PostStreamTaskGenerator(task_generator_config.max_qps),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_STREAM_RECOVERY:
|
||||
generators.append(
|
||||
(
|
||||
PostStreamRecoveryTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_STREAM_TEST:
|
||||
generators.append(
|
||||
(
|
||||
PostStreamTestTaskGenerator(task_generator_config.max_qps),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_STREAM_DELAYED:
|
||||
generators.append(
|
||||
(
|
||||
PostStreamDelayedTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_SAFETY_STREAM:
|
||||
generators.append(
|
||||
(
|
||||
PostSafetyStreamTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX:
|
||||
generators.append(
|
||||
(
|
||||
MinTractionPostStreamForGroxTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX_PTOS:
|
||||
generators.append(
|
||||
(
|
||||
MinTractionPostStreamForGroxPtosTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX_MULTI_MODAL:
|
||||
generators.append(
|
||||
(
|
||||
MinTractionPostStreamForGroxMultiModalTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY_FOR_REPLY_RECOVERY:
|
||||
generators.append(
|
||||
(
|
||||
PostEmbeddingRequestWithSummaryForReplyRecoveryStreamTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.SAFETY_PTOS_RECOVERY:
|
||||
generators.append(
|
||||
(
|
||||
SafetyPtosRecoveryStreamTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.SAFETY_PTOS_DELUXE:
|
||||
generators.append(
|
||||
(
|
||||
SafetyPtosDeluxeStreamTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_EMBEDDING_V5_STREAM:
|
||||
generators.append(
|
||||
(
|
||||
PostEmbeddingV5StreamTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_EMBEDDING_V5_FOR_REPLY_STREAM:
|
||||
generators.append(
|
||||
(
|
||||
PostEmbeddingV5ForReplyStreamTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.REPLY_RANKING_RECOVERY:
|
||||
generators.append(
|
||||
(
|
||||
ReplyRankingRecoveryTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY:
|
||||
generators.append(
|
||||
(
|
||||
PostEmbeddingRequestWithSummaryStreamTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY_RECOVERY:
|
||||
generators.append(
|
||||
(
|
||||
PostEmbeddingRequestWithSummaryRecoveryStreamTaskGenerator(
|
||||
task_generator_config.max_qps
|
||||
),
|
||||
task_generator_config.weight,
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid task generator type: {task_generator_config.type}"
|
||||
)
|
||||
return PriorityTaskGenerator(generators)
|
||||
|
||||
async def _submit_task(self, task_payload: TaskPayload) -> None:
|
||||
inflight_gauge = Metrics.gauge("dispatcher.inflight.count")
|
||||
self._in_flights.add(task_payload.payload_id)
|
||||
inflight_gauge.set(len(self._in_flights))
|
||||
Metrics.counter("dispatcher.task.sent.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"task_type": task_payload.task_type.value
|
||||
if task_payload.task_type
|
||||
else "none"
|
||||
},
|
||||
)
|
||||
self._task_queue.put(task_payload)
|
||||
logger.debug(
|
||||
f"Submitted task: {task_payload.payload_id}, queue size: {self._task_queue.qsize()}"
|
||||
)
|
||||
|
||||
async def _fill_loop(self):
|
||||
logger.info("Starting fill loop")
|
||||
while not self._is_shutdown():
|
||||
try:
|
||||
async for task_payload in self._task_generator.poll():
|
||||
if task_payload is None:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
while len(self._in_flights) >= self.config.max_in_flight:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
await self._submit_task(task_payload)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error polling from task queues: {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def _poll_result(self) -> TaskResult | None:
|
||||
try:
|
||||
res = self._resp_queue.get(block=False)
|
||||
logger.debug(f"Dispatcher received result: {res.task.payload_id}")
|
||||
Metrics.counter("dispatcher.result.received.count").add(1)
|
||||
return res
|
||||
except Empty:
|
||||
return None
|
||||
except BrokenPipeError:
|
||||
logger.error("Broken pipe error, shutting down")
|
||||
return None
|
||||
except Exception:
|
||||
logger.error(f"failed to poll result: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def _result_loop(self) -> None:
|
||||
logger.info("Starting result loop")
|
||||
max_attempts = self.config.max_attempts
|
||||
inflight_gauge = Metrics.gauge("dispatcher.inflight.count")
|
||||
while not self._is_shutdown() or self._in_flights:
|
||||
try:
|
||||
result = await self._poll_result()
|
||||
if result is None:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
task = result.task
|
||||
if result.success:
|
||||
Metrics.counter("dispatcher.result.success.count").add(1)
|
||||
if task.payload_id in self._in_flights:
|
||||
self._in_flights.remove(task.payload_id)
|
||||
inflight_gauge.set(len(self._in_flights))
|
||||
await self._task_generator.ack(result)
|
||||
else:
|
||||
if task.attempt < max_attempts:
|
||||
Metrics.counter("dispatcher.result.failed.count").add(1)
|
||||
logger.warning(
|
||||
f"Task {task.payload_id} failed, retrying... (attempt {task.attempt})"
|
||||
)
|
||||
task.attempt += 1
|
||||
await self._submit_task(task)
|
||||
else:
|
||||
if task.payload_id in self._in_flights:
|
||||
self._in_flights.remove(task.payload_id)
|
||||
inflight_gauge.set(len(self._in_flights))
|
||||
logger.error(
|
||||
f"Task {task.payload_id} failed after {max_attempts} attempts, error is {result.error}"
|
||||
)
|
||||
origin = self._task_generator.identify_task_origin(result)
|
||||
Metrics.counter("dispatcher.result.failed.final.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"origin": origin.value if origin else "unknown"
|
||||
},
|
||||
)
|
||||
if origin is None:
|
||||
logger.warning(
|
||||
f"No origin found for task {task.payload_id}, skipping ack"
|
||||
)
|
||||
else:
|
||||
await self._task_generator.ack(result)
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
except BrokenPipeError:
|
||||
logger.error("Broken pipe error, shutting down")
|
||||
break
|
||||
logger.warning("Result loop finished")
|
||||
|
||||
async def _wait_for_queue_connection_shutdown(self):
|
||||
while not self._is_queue_connection_shutdown():
|
||||
await asyncio.sleep(1)
|
||||
logger.warning("Shutdowning task generators")
|
||||
await self._task_generator.stop()
|
||||
logger.warning("Task generators stopped")
|
||||
|
||||
async def _run(self, started_event: Event):
|
||||
await self._init_run()
|
||||
started_event.set()
|
||||
await asyncio.gather(
|
||||
self._fill_loop(),
|
||||
self._result_loop(),
|
||||
self._wait_for_queue_connection_shutdown(),
|
||||
)
|
||||
|
||||
def run(self, started_event: Event):
|
||||
asyncio.run(self._run(started_event))
|
||||
|
||||
async def start(self):
|
||||
logger.info("Starting Grox dispatcher...")
|
||||
started_event = multiprocessing.Event()
|
||||
self._process = Process(
|
||||
target=self.run, args=(started_event,), name="grox-dispatcher"
|
||||
)
|
||||
self._process.start()
|
||||
started_event.wait()
|
||||
logger.info("Grox dispatcher started")
|
||||
|
||||
async def stop(self):
|
||||
logger.warning("Stopping Grox dispatcher...")
|
||||
if self._process and self._process.is_alive():
|
||||
self._process.join(self.config.graceful_shutdown_timeout)
|
||||
else:
|
||||
logger.warning("Dispatcher process is not alive, skipping join")
|
||||
logger.warning("Dispatcher stopped")
|
||||
287
grox/embedder/multimodal_post_embedder_v2.py
Normal file
287
grox/embedder/multimodal_post_embedder_v2.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
import numpy as np
|
||||
from grox.lm.post import LitePostRenderer, MMEmbedPostRenderer, EvalPostRenderer
|
||||
from grox.lm.convo import Image as ConvoImage, Video as ConvoVideo, Content
|
||||
from embed.embed_cli import XaiEmbeddingClient
|
||||
from monitor.metrics import Metrics
|
||||
from grox.config.config import ModelName, grox_config
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.data_loaders.media_loader import MediaLoader
|
||||
from grox.data_loaders.strato_loader import TweetStratoLoader
|
||||
from grox.data_loaders.media_description_loader import MediaDescriptionLoader
|
||||
from strato_http.queries.post_multimodal_embedding_mh_searchai import (
|
||||
StratoContentUnderstandingUnifiedPostAnnotations,
|
||||
)
|
||||
from grox.config.config import EmbeddingModelConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultimodalPostEmbedderV2:
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "qwen3",
|
||||
use_grok_summary: bool = False,
|
||||
renderer_version="mmembed_summary",
|
||||
use_media_descriptions: bool = False,
|
||||
use_post_context_summary: bool = False,
|
||||
use_grok_summary_path: str = "",
|
||||
custom_endpoint: str = "",
|
||||
instruction: str = "",
|
||||
):
|
||||
embed_config = grox_config.get_embedding_model(ModelName.EMBED_PRIMARY)
|
||||
video_embed_config = grox_config.get_embedding_model(ModelName.EMBED_PRIMARY_VIDEO)
|
||||
embed_config.text_max_len = 8192
|
||||
video_embed_config.text_max_len = 8192
|
||||
qwen_3_embed_06b_config = grox_config.get_embedding_model(
|
||||
ModelName.EMBED_SMALL
|
||||
)
|
||||
qwen_3_embed_8b_config = grox_config.get_embedding_model(
|
||||
ModelName.EMBED_LARGE
|
||||
)
|
||||
recsys_v4_embed_config = grox_config.get_embedding_model(
|
||||
ModelName.RECSYS_EMBED_V4
|
||||
)
|
||||
qwen_3_embed_06b_config.text_max_len = 4096
|
||||
qwen_3_embed_8b_config.text_max_len = 4096
|
||||
recsys_v4_embed_config.text_max_len = 4096
|
||||
if custom_endpoint:
|
||||
custom_embed_config = EmbeddingModelConfig(
|
||||
model_name="custom", endpoint=custom_endpoint, text_max_len=4096
|
||||
)
|
||||
self._custom_embed_client = XaiEmbeddingClient(config=custom_embed_config)
|
||||
self.use_custom_embed = True if custom_endpoint else False
|
||||
self.renderer_version = renderer_version
|
||||
self._client = XaiEmbeddingClient(config=embed_config)
|
||||
self._video_client = XaiEmbeddingClient(config=video_embed_config)
|
||||
self._qwen_3_embed_06b_client = XaiEmbeddingClient(
|
||||
config=qwen_3_embed_06b_config
|
||||
)
|
||||
self._qwen_3_embed_8b_client = XaiEmbeddingClient(config=qwen_3_embed_8b_config)
|
||||
self._recsys_v4_embed_client = XaiEmbeddingClient(config=recsys_v4_embed_config)
|
||||
self.model = model
|
||||
self.use_grok_summary = use_grok_summary
|
||||
self.use_media_descriptions = use_media_descriptions
|
||||
self.use_grok_summary_versioned = False
|
||||
self.instruction = instruction
|
||||
self.use_post_context_summary = use_post_context_summary
|
||||
|
||||
if use_grok_summary_path:
|
||||
assert os.path.exists(use_grok_summary_path), (
|
||||
f"Grok summary path {use_grok_summary_path} does not exist"
|
||||
)
|
||||
assert use_grok_summary_path.endswith(".jsonl"), (
|
||||
f"Grok summary path {use_grok_summary_path} is not a jsonl file"
|
||||
)
|
||||
self.grok_summary_versioned: dict[str, str] = {}
|
||||
self.use_grok_summary_versioned = True
|
||||
with open(use_grok_summary_path, "r") as f:
|
||||
for line in f:
|
||||
json_line = json.loads(line)
|
||||
self.grok_summary_versioned[str(json_line["post_id"]).strip()] = (
|
||||
json_line["summary"]
|
||||
)
|
||||
|
||||
def _get_client(
|
||||
self, num_text: int, num_image: int, num_video: int
|
||||
) -> XaiEmbeddingClient:
|
||||
if self.use_custom_embed:
|
||||
return self._custom_embed_client
|
||||
if self.model == "qwen3":
|
||||
return self._qwen_3_embed_06b_client
|
||||
if self.model == "qwen3_8b":
|
||||
return self._qwen_3_embed_8b_client
|
||||
if self.model == "v4":
|
||||
return self._recsys_v4_embed_client
|
||||
|
||||
if num_video > 0:
|
||||
logger.info(
|
||||
f"Using video client for post with {num_text} text, {num_image} images, and {num_video} videos"
|
||||
)
|
||||
return self._video_client
|
||||
return self._client
|
||||
|
||||
def document_original(
|
||||
self, content: list[Content]
|
||||
) -> tuple[list[tuple[str, str | bytes]], int, int, int]:
|
||||
def get_convo_video_instruction(video: ConvoVideo) -> str:
|
||||
res = [f"The video lasts for {video.total_duration:.2f} seconds."]
|
||||
bucket_times = [i * video.duration for i in range(len(video.frames))]
|
||||
res.append(
|
||||
f"The following frames are sampled at every {video.duration:.2f} second interval."
|
||||
)
|
||||
for i, frame in enumerate(video.frames):
|
||||
subtitle = (
|
||||
video.subtitles[i]
|
||||
if video.subtitles and i < len(video.subtitles)
|
||||
else None
|
||||
)
|
||||
subtitle_str = (
|
||||
f"with subtitle: {subtitle}" if subtitle else "(no subtitles)"
|
||||
)
|
||||
res.append(f"At {bucket_times[i]:.2f} seconds, {subtitle_str}.")
|
||||
res.append("The frames are listed below:")
|
||||
return " ".join(res)
|
||||
|
||||
document = []
|
||||
num_text = 0
|
||||
num_image = 0
|
||||
num_video = 0
|
||||
new_text_part = ""
|
||||
for c in content:
|
||||
if isinstance(c, ConvoImage):
|
||||
document.append(("text", f"Image: \n"))
|
||||
document.append(("image", c.content))
|
||||
num_image += 1
|
||||
elif isinstance(c, ConvoVideo):
|
||||
if c.combined_video_bytes:
|
||||
new_text_part += get_convo_video_instruction(c)
|
||||
document.append(("video", c.combined_video_bytes))
|
||||
num_video += 1
|
||||
elif isinstance(c, str):
|
||||
new_text_part += c
|
||||
num_text += 1
|
||||
new_text_part = new_text_part.strip()
|
||||
document.append(
|
||||
("text", "")
|
||||
)
|
||||
document.append(("text", new_text_part))
|
||||
return document, num_text, num_image, num_video
|
||||
|
||||
def document_v1(
|
||||
self, content: list[Content]
|
||||
) -> tuple[list[tuple[str, str | bytes]], int, int, int]:
|
||||
def video_frames(
|
||||
video: ConvoVideo, index: int
|
||||
) -> list[tuple[str, str | bytes]]:
|
||||
res: list[tuple[str, str | bytes]] = []
|
||||
for i, frame in enumerate(video.frames):
|
||||
res.append(("image", frame))
|
||||
if (
|
||||
video.subtitles
|
||||
and i < len(video.subtitles)
|
||||
and video.subtitles[i] is not None
|
||||
and video.subtitles[i].strip() != ""
|
||||
):
|
||||
res.append(("text", "subtitle: " + video.subtitles[i] + " "))
|
||||
return res
|
||||
|
||||
document = []
|
||||
num_text = 0
|
||||
num_image = 0
|
||||
num_video = 0
|
||||
|
||||
for c in content:
|
||||
if isinstance(c, str):
|
||||
document.append(("text", c.strip()))
|
||||
num_text += 1
|
||||
else:
|
||||
if isinstance(c, ConvoImage):
|
||||
document.append(("image", c.content))
|
||||
num_image += 1
|
||||
elif isinstance(c, ConvoVideo):
|
||||
document.extend(video_frames(c, num_video))
|
||||
num_video += 1
|
||||
|
||||
return document, num_text, num_image, num_video
|
||||
|
||||
async def _create_embeddings_for_post(
|
||||
self,
|
||||
content: list[Content],
|
||||
is_query: bool = False,
|
||||
document_version: str = "v1",
|
||||
) -> tuple[list[tuple[str, str | bytes]], np.ndarray]:
|
||||
if document_version == "default":
|
||||
document_fn = self.document_original
|
||||
elif document_version == "v1":
|
||||
document_fn = self.document_v1
|
||||
else:
|
||||
raise ValueError(f"document_version not found: {document_version}")
|
||||
|
||||
document, num_text, num_image, num_video = document_fn(content)
|
||||
logger.info(
|
||||
f"creating embeddings for post with {num_text} text, {num_image} images, and {num_video} videos"
|
||||
)
|
||||
client = self._get_client(num_text, num_image, num_video)
|
||||
return document, await client.create_embeddings_async(
|
||||
[document], is_query=is_query
|
||||
)
|
||||
|
||||
async def hydrate_grok_post_summary(self, post: Post):
|
||||
query = StratoContentUnderstandingUnifiedPostAnnotations()
|
||||
res = await query.fetch(int(post.id))
|
||||
if res:
|
||||
description = res["annotations"]["description"]
|
||||
post.summary = description
|
||||
|
||||
def get_detailed_instruct(self, instruction: str) -> str:
|
||||
return f"Instruct: {instruction}\nQuery: Please embed the following post:"
|
||||
|
||||
def _get_document_fn(self, document_version: str):
|
||||
if document_version == "default":
|
||||
return self.document_original
|
||||
if document_version == "v1":
|
||||
return self.document_v1
|
||||
raise ValueError(f"document_version not found: {document_version}")
|
||||
|
||||
async def embed_texts_batch(
|
||||
self, texts: list[str], is_query: bool = True, document_version: str = "v1"
|
||||
) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
document_fn = self._get_document_fn(document_version)
|
||||
documents: list[list[tuple[str, str | bytes]]] = []
|
||||
for text in texts:
|
||||
document, _, _, _ = document_fn([text])
|
||||
documents.append(document)
|
||||
client = self._get_client(num_text=1, num_image=0, num_video=0)
|
||||
embeddings = await client.create_embeddings_async(documents, is_query=is_query)
|
||||
return [embedding.flatten().tolist() for embedding in embeddings]
|
||||
|
||||
async def embed(
|
||||
self, post: Post, is_query: bool = False, document_version: str = "v1"
|
||||
) -> tuple[list[tuple[str, str | bytes]], list[float]]:
|
||||
if self.instruction:
|
||||
content: list[Content] = [self.get_detailed_instruct(self.instruction)]
|
||||
|
||||
if self.renderer_version == "lite":
|
||||
content = LitePostRenderer.render_for_embedding(post)
|
||||
elif self.renderer_version == "eval":
|
||||
content = EvalPostRenderer.render_for_embedding(post)
|
||||
elif self.renderer_version == "mmembed_summary":
|
||||
content = await MMEmbedPostRenderer.render_for_embedding(
|
||||
post, use_grok_summary=self.use_grok_summary
|
||||
)
|
||||
|
||||
if self.use_grok_summary_versioned:
|
||||
if str(post.id) in self.grok_summary_versioned:
|
||||
content.append(
|
||||
f"\nPost summary and description: {self.grok_summary_versioned[str(post.id)]}"
|
||||
)
|
||||
|
||||
if self.use_grok_summary and not self.use_grok_summary_versioned:
|
||||
await self.hydrate_grok_post_summary(post)
|
||||
content.append(f"\nPost summary and description: {post.summary}")
|
||||
|
||||
if self.use_post_context_summary:
|
||||
content.append(f"\nPost summary and description: {post.summary}")
|
||||
|
||||
if self.use_media_descriptions:
|
||||
await MediaDescriptionLoader.hydrate_media_descriptions(post)
|
||||
content.append(
|
||||
f"\nThe post has these associated media descriptions: \n{post.media_descriptions}"
|
||||
)
|
||||
|
||||
start_time = time.perf_counter_ns()
|
||||
|
||||
document, embedding = await self._create_embeddings_for_post(
|
||||
content, is_query, document_version
|
||||
)
|
||||
|
||||
duration_ms = (time.perf_counter_ns() - start_time) / 1_000_000
|
||||
logger.info(f"Embedding finished in {duration_ms:.2f} ms")
|
||||
Metrics.histogram("post_embedding_duration_ms").record(duration_ms)
|
||||
return document, embedding.flatten().tolist()
|
||||
120
grox/embedder/multimodal_post_embedder_v5.py
Normal file
120
grox/embedder/multimodal_post_embedder_v5.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from embed.embed_http import ChatTemplate, XaiEmbeddingClientHttp
|
||||
from embed.embed_http import EmbeddingModelConfig as HttpModelConfig
|
||||
from grox.config.config import ModelName, grox_config
|
||||
from grox.data_loaders.data_types import Post, Video
|
||||
from grox.lm.post_v5 import V5EmbedPostRenderer
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = ""
|
||||
TRUNCATE_DIM = 1024
|
||||
|
||||
|
||||
class MultimodalPostEmbedderV5:
|
||||
@staticmethod
|
||||
def has_video(post: Post) -> bool:
|
||||
if post.media:
|
||||
for m in post.media:
|
||||
if isinstance(m, Video):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
max_images: int | None = None,
|
||||
):
|
||||
self.truncate_dim = TRUNCATE_DIM
|
||||
self.max_images = max_images
|
||||
embed_config = grox_config.get_embedding_model(ModelName.RECSYS_EMBED_V5)
|
||||
http_config = HttpModelConfig(
|
||||
model_name=embed_config.model_name,
|
||||
endpoint=embed_config.endpoint,
|
||||
text_max_len=4096,
|
||||
timeout_seconds=60.0,
|
||||
)
|
||||
chat_template = ChatTemplate(system_prompt=system_prompt)
|
||||
self._client = XaiEmbeddingClientHttp(
|
||||
config=http_config, chat_template=chat_template
|
||||
)
|
||||
|
||||
def _maybe_truncate(self, embedding: np.ndarray) -> list[float]:
|
||||
if self.truncate_dim > 0 and len(embedding) > self.truncate_dim:
|
||||
emb = embedding[: self.truncate_dim]
|
||||
norm = np.linalg.norm(emb)
|
||||
if norm > 0:
|
||||
emb = emb / norm
|
||||
return emb.tolist()
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
embedding = embedding / norm
|
||||
return embedding.tolist()
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
post: Post,
|
||||
transcript: str | None = None,
|
||||
is_query: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[list[tuple[str, str | bytes]], list[float]]:
|
||||
total_start = time.perf_counter()
|
||||
|
||||
render_start = time.perf_counter()
|
||||
text_with_pads, images = V5EmbedPostRenderer.render_for_embedding(
|
||||
post, max_images=self.max_images
|
||||
)
|
||||
render_duration_ms = (time.perf_counter() - render_start) * 1000
|
||||
Metrics.histogram("post_embedding_v5.render_duration_ms").record(
|
||||
render_duration_ms
|
||||
)
|
||||
|
||||
if transcript:
|
||||
text_with_pads += f"\nTranscript: {transcript}"
|
||||
|
||||
document: list[tuple[str, str | bytes]] = [("text", text_with_pads)]
|
||||
for img in images:
|
||||
document.append(("image", img))
|
||||
|
||||
if not text_with_pads and not images:
|
||||
logger.warning(f"Post {post.id} has no text or media content")
|
||||
return document, []
|
||||
|
||||
encode_start = time.perf_counter()
|
||||
embedding = await self._client.encode_with_embedded_pads_async(
|
||||
text_with_pads, images if images else None
|
||||
)
|
||||
encode_duration_ms = (time.perf_counter() - encode_start) * 1000
|
||||
Metrics.histogram("post_embedding_v5.encode_duration_ms").record(
|
||||
encode_duration_ms
|
||||
)
|
||||
|
||||
truncate_start = time.perf_counter()
|
||||
result = self._maybe_truncate(embedding)
|
||||
truncate_duration_ms = (time.perf_counter() - truncate_start) * 1000
|
||||
Metrics.histogram("post_embedding_v5.truncate_duration_ms").record(
|
||||
truncate_duration_ms
|
||||
)
|
||||
|
||||
total_duration_ms = (time.perf_counter() - total_start) * 1000
|
||||
Metrics.histogram("post_embedding_v5.total_duration_ms").record(
|
||||
total_duration_ms
|
||||
)
|
||||
Metrics.counter("post_embedding_v5.image_count").add(len(images))
|
||||
|
||||
total_image_bytes = sum(len(img) for img in images) if images else 0
|
||||
Metrics.histogram("post_embedding_v5.image_payload_bytes").record(
|
||||
total_image_bytes
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Embedding V5 post={post.id}: total={total_duration_ms:.1f}ms "
|
||||
f"(render={render_duration_ms:.1f}ms, encode={encode_duration_ms:.1f}ms, truncate={truncate_duration_ms:.1f}ms), "
|
||||
f"images={len(images)}, image_bytes={total_image_bytes:,}, text_len={len(text_with_pads)}, has_transcript={transcript is not None}"
|
||||
)
|
||||
|
||||
return document, result
|
||||
137
grox/engine.py
Normal file
137
grox/engine.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import multiprocessing
|
||||
from queue import Empty, Queue
|
||||
from threading import Event
|
||||
from multiprocessing import Process
|
||||
|
||||
from monitor.logging import Logging
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
from grox.config.config import grox_config
|
||||
from grox.schedules.init import init_proc
|
||||
from grox.schedules.types import TaskResult, TaskPayload
|
||||
from grox.plans.plan_master import PlanMaster
|
||||
from grox.schedules.context import ScheduleContext
|
||||
from grox.data_loaders.media_processor import MediaProcessor
|
||||
from grox.data_loaders.asr_processor import ASRProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(self, context: ScheduleContext):
|
||||
self.config = grox_config.engine
|
||||
self.context = context
|
||||
self._task_queue: Queue[TaskPayload] = self.context["task_queue"]
|
||||
self._resp_queue: Queue[TaskResult] = self.context["resp_queue"]
|
||||
self._shutdown_event: Event = self.context["shutdown_event"]
|
||||
self._process = None
|
||||
|
||||
def _is_shutdown(self) -> bool:
|
||||
try:
|
||||
return self._shutdown_event.is_set()
|
||||
except BrokenPipeError:
|
||||
logger.error("Broken pipe error, assuming shutdown")
|
||||
return True
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error checking shutdown event, assuming shutdown: {traceback.format_exc()}"
|
||||
)
|
||||
return True
|
||||
|
||||
async def _init_run(self):
|
||||
await init_proc("engine")
|
||||
MediaProcessor.start()
|
||||
ASRProcessor.start()
|
||||
|
||||
async def _process_task(self, task: TaskPayload) -> TaskResult:
|
||||
logger.debug(f"engine started processing task")
|
||||
start = time.perf_counter()
|
||||
res = await PlanMaster.exec(task)
|
||||
end = time.perf_counter()
|
||||
logger.debug(f"engine finished processing task in {end - start:.2f} seconds")
|
||||
Metrics.histogram("engine.task.processing_time").record(end - start)
|
||||
return res
|
||||
|
||||
async def _run_task(self, task: TaskPayload):
|
||||
start = time.perf_counter()
|
||||
with Metrics.tracer("engine").start_as_current_span("task.root"):
|
||||
Logging.set_context(task=task.payload_id)
|
||||
if task.post:
|
||||
Logging.set_context(post=task.post.id)
|
||||
if task.user:
|
||||
Logging.set_context(user=task.user.id)
|
||||
if task.user_context:
|
||||
Logging.set_context(user=task.user_context.user.id)
|
||||
try:
|
||||
res = await self._process_task(task)
|
||||
self._resp_queue.put(res)
|
||||
Metrics.counter("engine.task.success.count").add(1)
|
||||
except Exception as e:
|
||||
logger.error(f"failed to process task, error: {traceback.format_exc()}")
|
||||
self._resp_queue.put(
|
||||
TaskResult(
|
||||
task=task,
|
||||
success=False,
|
||||
error=str(e),
|
||||
task_finished_at=start,
|
||||
task_started_at=time.perf_counter(),
|
||||
)
|
||||
)
|
||||
Metrics.counter("engine.task.failed.count").add(1)
|
||||
|
||||
async def _poll_task(self) -> TaskPayload | None:
|
||||
logger.debug(f"engine polling task, queue size: {self._task_queue.qsize()}")
|
||||
try:
|
||||
task = self._task_queue.get(block=False)
|
||||
logger.debug(f"engine received task: {task.payload_id}")
|
||||
Metrics.counter("engine.task.received.count").add(1)
|
||||
return task
|
||||
except Empty:
|
||||
logger.debug("engine polling task returned None")
|
||||
return None
|
||||
except BrokenPipeError:
|
||||
logger.error("Broken pipe error, shutting down")
|
||||
return None
|
||||
except Exception:
|
||||
logger.error(f"failed to poll task: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def _run(self, started_event: Event):
|
||||
await self._init_run()
|
||||
started_event.set()
|
||||
while not self._is_shutdown() or not self._task_queue.empty():
|
||||
task = await self._poll_task()
|
||||
if task is None:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
asyncio.create_task(self._run_task(task))
|
||||
logger.warning("engine stopped")
|
||||
|
||||
def run(self, started_event: Event):
|
||||
asyncio.run(self._run(started_event))
|
||||
os._exit(0)
|
||||
|
||||
async def start(self):
|
||||
logger.info("Starting Grox engine...")
|
||||
started_event = multiprocessing.Event()
|
||||
self._process = Process(
|
||||
target=self.run, args=(started_event,), name="grox-engine"
|
||||
)
|
||||
self._process.start()
|
||||
started_event.wait()
|
||||
logger.info("Grox engine started")
|
||||
|
||||
async def stop(self):
|
||||
logger.warning("Stopping Grox engine...")
|
||||
if self._process and self._process.is_alive():
|
||||
self._process.join(self.config.graceful_shutdown_timeout)
|
||||
else:
|
||||
logger.warning("Engine process is not alive, skipping join")
|
||||
await MediaProcessor.stop()
|
||||
await ASRProcessor.stop()
|
||||
logger.warning("Engine stopped")
|
||||
222
grox/generators/stream_generator.py
Normal file
222
grox/generators/stream_generator.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import abc
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from grox.config.config import KafkaTopicName, TaskGeneratorType
|
||||
from grox.schedules.types import TaskResult, TaskPayload, TaskEligibility
|
||||
from grox.data_loaders.kafka_loader import (
|
||||
KafkaAdPostLoader,
|
||||
KafkaPostLoader,
|
||||
KafkaPostEmbeddingRequestLoader,
|
||||
)
|
||||
from grox.generators.task_generator import TaskGenerator
|
||||
from grox.data_loaders.message_queue_loader import MessageQueueLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamTaskGenerator(TaskGenerator, metaclass=abc.ABCMeta):
|
||||
ELIGIBILITIES_TO_INJECT: set[TaskEligibility]
|
||||
|
||||
def __init__(self, max_qps: int | None):
|
||||
super().__init__(max_qps)
|
||||
self._loader = self._get_loader()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_loader(self) -> MessageQueueLoader:
|
||||
pass
|
||||
|
||||
async def start(self) -> None:
|
||||
logger.info("Starting StreamTaskGenerator")
|
||||
await self._loader.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
logger.info("Stopping StreamTaskGenerator")
|
||||
await super().stop()
|
||||
await self._loader.stop()
|
||||
|
||||
async def _poll(self) -> AsyncGenerator[TaskPayload | None, None]:
|
||||
async for payload in self._loader.poll():
|
||||
if not payload:
|
||||
yield None
|
||||
continue
|
||||
yield TaskPayload(
|
||||
payload_id=payload.mid,
|
||||
post=payload.post,
|
||||
user=payload.user,
|
||||
user_context=payload.user_context,
|
||||
deadline_ts_secs=payload.deadline_ts_secs,
|
||||
task_type=self.TASK_GENERATOR_TYPE,
|
||||
eligibilities=self.ELIGIBILITIES_TO_INJECT.copy(),
|
||||
grox_content_analysis=payload.grox_content_analysis,
|
||||
)
|
||||
|
||||
async def ack(self, result: TaskResult):
|
||||
await self._loader.ack(result.task.payload_id, result.success)
|
||||
|
||||
|
||||
class PostStreamTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_STREAM
|
||||
ELIGIBILITIES_TO_INJECT = {
|
||||
TaskEligibility.SPAM_COMMENT,
|
||||
TaskEligibility.REPLY_RANKING,
|
||||
}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS
|
||||
)
|
||||
|
||||
|
||||
class MinTractionPostStreamForGroxTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.BANGER_INITIAL_SCREEN}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_MIN_TRACTION_FOR_GROX
|
||||
)
|
||||
|
||||
|
||||
class MinTractionPostStreamForGroxPtosTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX_PTOS
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.SAFETY_PTOS}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_MIN_TRACTION_FOR_GROX_PTOS
|
||||
)
|
||||
|
||||
|
||||
class MinTractionPostStreamForGroxMultiModalTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = (
|
||||
TaskGeneratorType.POST_MIN_TRACTION_STREAM_FOR_GROX_MULTI_MODAL
|
||||
)
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_EMBEDDING_WITH_SUMMARY_FOR_REPLY}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_MIN_TRACTION_FOR_GROX_MULTI_MODAL
|
||||
)
|
||||
|
||||
|
||||
class PostEmbeddingRequestWithSummaryForReplyRecoveryStreamTaskGenerator(
|
||||
StreamTaskGenerator
|
||||
):
|
||||
TASK_GENERATOR_TYPE = (
|
||||
TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY_FOR_REPLY_RECOVERY
|
||||
)
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_EMBEDDING_WITH_SUMMARY_FOR_REPLY}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_REQUESTS_WITH_SUMMARY_FOR_REPLY_RECOVERY
|
||||
)
|
||||
|
||||
|
||||
class PostStreamRecoveryTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_STREAM_RECOVERY
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.BANGER_INITIAL_SCREEN}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_RECOVERY
|
||||
)
|
||||
|
||||
|
||||
class SafetyPtosRecoveryStreamTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.SAFETY_PTOS_RECOVERY
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.SAFETY_PTOS}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(KafkaTopicName.SAFETY_PTOS_RECOVERY)
|
||||
|
||||
|
||||
class SafetyPtosDeluxeStreamTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.SAFETY_PTOS_DELUXE
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.SAFETY_PTOS}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(KafkaTopicName.SAFETY_PTOS_DELUXE)
|
||||
|
||||
|
||||
class PostStreamTestTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_STREAM_TEST
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.BANGER_INITIAL_SCREEN}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_TEST
|
||||
)
|
||||
|
||||
|
||||
class PostSafetyStreamTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_SAFETY_STREAM
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_SAFETY}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_POPULAR
|
||||
)
|
||||
|
||||
|
||||
class PostStreamDelayedTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_STREAM_DELAYED
|
||||
ELIGIBILITIES_TO_INJECT = {}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_DELAYED
|
||||
)
|
||||
|
||||
|
||||
class ReplyRankingRecoveryTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.REPLY_RANKING_RECOVERY
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.REPLY_RANKING}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(KafkaTopicName.REPLY_RANKING_RECOVERY)
|
||||
|
||||
|
||||
class PostEmbeddingRequestWithSummaryStreamTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_EMBEDDING_WITH_SUMMARY}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_REQUESTS_WITH_SUMMARY
|
||||
)
|
||||
|
||||
|
||||
class PostEmbeddingRequestWithSummaryRecoveryStreamTaskGenerator(StreamTaskGenerator):
|
||||
TASK_GENERATOR_TYPE = (
|
||||
TaskGeneratorType.POST_EMBEDDING_REQUEST_STREAM_WITH_SUMMARY_RECOVERY
|
||||
)
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.POST_EMBEDDING_WITH_SUMMARY}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_REQUESTS_WITH_SUMMARY_RECOVERY
|
||||
)
|
||||
|
||||
|
||||
class PostEmbeddingV5StreamTaskGenerator(StreamTaskGenerator):
|
||||
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_EMBEDDING_V5_STREAM
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.MM_EMB_V5}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_REQUESTS_WITH_SUMMARY
|
||||
)
|
||||
|
||||
|
||||
class PostEmbeddingV5ForReplyStreamTaskGenerator(StreamTaskGenerator):
|
||||
|
||||
TASK_GENERATOR_TYPE = TaskGeneratorType.POST_EMBEDDING_V5_FOR_REPLY_STREAM
|
||||
ELIGIBILITIES_TO_INJECT = {TaskEligibility.MM_EMB_V5_FOR_REPLY}
|
||||
|
||||
def _get_loader(self):
|
||||
return KafkaPostLoader(
|
||||
KafkaTopicName.CONTENT_UNDERSTANDING_REALTIME_UNIFIED_POSTS_MIN_TRACTION_FOR_GROX_MULTI_MODAL
|
||||
)
|
||||
153
grox/generators/task_generator.py
Normal file
153
grox/generators/task_generator.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import traceback
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from grox.config.config import TaskGeneratorType
|
||||
from grox.schedules.types import TaskResult, TaskPayload
|
||||
from limits import RateLimitItemPerSecond, storage, strategies
|
||||
from typing import AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
limiter = strategies.FixedWindowRateLimiter(storage.MemoryStorage())
|
||||
|
||||
|
||||
class TaskGenerator(ABC):
|
||||
TASK_GENERATOR_TYPE: TaskGeneratorType | None = None
|
||||
|
||||
def __init__(self, max_qps: int | None):
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._limiter_key = self.__class__.__name__
|
||||
self._limit = RateLimitItemPerSecond(max_qps, 1) if max_qps else None
|
||||
|
||||
def is_shutdown(self) -> bool:
|
||||
try:
|
||||
return self._shutdown_event.is_set()
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error checking if task generator is shutdown: {traceback.format_exc()}"
|
||||
)
|
||||
return True
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
logger.info(f"Stopping task generator {self.__class__.__name__}")
|
||||
self._shutdown_event.set()
|
||||
|
||||
async def poll(self) -> AsyncGenerator[TaskPayload | None, None]:
|
||||
async for payload in self._poll():
|
||||
if not payload:
|
||||
yield None
|
||||
continue
|
||||
if self._limit:
|
||||
while not limiter.test(self._limit, self._limiter_key):
|
||||
yield None
|
||||
await asyncio.sleep(0.01)
|
||||
limiter.hit(self._limit, self._limiter_key)
|
||||
yield payload
|
||||
|
||||
@abstractmethod
|
||||
def _poll(self) -> AsyncGenerator[TaskPayload | None, None]:
|
||||
pass
|
||||
|
||||
async def ack(self, result: TaskResult):
|
||||
pass
|
||||
|
||||
def identify_task_origin(self, result: TaskResult) -> TaskGeneratorType | None:
|
||||
return self.TASK_GENERATOR_TYPE
|
||||
|
||||
|
||||
class PriorityTaskGenerator(TaskGenerator):
|
||||
def __init__(self, generators: list[tuple[TaskGenerator, int]]):
|
||||
if not generators:
|
||||
raise ValueError("No generators provided")
|
||||
if any(weight <= 0 for _, weight in generators):
|
||||
raise ValueError("All weights must be positive")
|
||||
super().__init__(None)
|
||||
self._generators: dict[str, TaskGenerator] = {}
|
||||
self._weights: dict[str, int] = {}
|
||||
for i, (gen, weight) in enumerate(generators):
|
||||
label = f"GEN_{i}"
|
||||
self._generators[label] = gen
|
||||
self._weights[label] = weight
|
||||
self._result_cache: dict[str, str] = {}
|
||||
logger.info(
|
||||
f"Initialized priority task generator with {list(zip(self._generators.keys(), [gen.__class__.__name__ for gen in self._generators.values()], self._weights.values(), strict=True))}"
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
logger.info(f"Starting priority task generators")
|
||||
await asyncio.gather(*[gen.start() for gen in self._generators.values()])
|
||||
self._streams = {label: gen.poll() for label, gen in self._generators.items()}
|
||||
logger.info(f"Priority task generators started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
logger.warning(f"Stopping priority task generators")
|
||||
await asyncio.gather(*[gen.stop() for gen in self._generators.values()])
|
||||
await super().stop()
|
||||
logger.warning(f"Priority task generators stopped")
|
||||
|
||||
async def _poll(self) -> AsyncGenerator[TaskPayload | None, None]:
|
||||
if not self._streams:
|
||||
raise RuntimeError("Task generators not started")
|
||||
while self._weights:
|
||||
_weights = self._weights.copy()
|
||||
polled = False
|
||||
while _weights:
|
||||
labels = list(_weights.keys())
|
||||
weights = list(_weights.values())
|
||||
labels = random.choices(labels, weights, k=1)
|
||||
label = labels[0]
|
||||
stream = self._streams[label]
|
||||
try:
|
||||
payload = await anext(stream)
|
||||
if payload:
|
||||
self._result_cache[payload.payload_id] = label
|
||||
yield payload
|
||||
polled = True
|
||||
break
|
||||
else:
|
||||
del _weights[label]
|
||||
except StopAsyncIteration:
|
||||
logger.warning(
|
||||
f"Task generator {label} exhausted, removing from pool"
|
||||
)
|
||||
if label in _weights:
|
||||
del _weights[label]
|
||||
if label in self._weights:
|
||||
del self._weights[label]
|
||||
continue
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error polling task generator {label}: {traceback.format_exc()}"
|
||||
)
|
||||
if label in _weights:
|
||||
del _weights[label]
|
||||
continue
|
||||
if not polled:
|
||||
yield None
|
||||
|
||||
async def ack(self, result: TaskResult):
|
||||
logger.debug(f"Acknowledging task {result.task.payload_id}")
|
||||
label = self._result_cache.pop(result.task.payload_id, None)
|
||||
if not label:
|
||||
logger.warning(
|
||||
f"No label found for task {result.task.payload_id}, skipping ack"
|
||||
)
|
||||
return
|
||||
gen = self._generators[label]
|
||||
await gen.ack(result)
|
||||
|
||||
def identify_task_origin(self, result: TaskResult) -> TaskGeneratorType | None:
|
||||
logger.debug(f"Identifying task origin for {result.task.payload_id}")
|
||||
label = self._result_cache.get(result.task.payload_id)
|
||||
if not label:
|
||||
logger.warning(
|
||||
f"No label found for task {result.task.payload_id}, cannot identify origin"
|
||||
)
|
||||
return None
|
||||
gen = self._generators[label]
|
||||
return gen.identify_task_origin(result)
|
||||
46
grox/lib/stream.py
Normal file
46
grox/lib/stream.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import AsyncIterator, AsyncGenerator, TypeVar
|
||||
from asyncio import Queue, create_task
|
||||
from enum import Enum
|
||||
|
||||
import logging
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StreamStatus(Enum):
|
||||
STOP = "Stop"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def parallel_merge(*streams: AsyncIterator[T]) -> AsyncGenerator[T, None]:
|
||||
if not streams:
|
||||
return
|
||||
queue: Queue[T | StreamStatus | Exception] = Queue()
|
||||
|
||||
async def enqueue(ait: AsyncIterator[T]):
|
||||
try:
|
||||
async for item in ait:
|
||||
await queue.put(item)
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except Exception as e:
|
||||
await queue.put(e)
|
||||
finally:
|
||||
await queue.put(StreamStatus.STOP)
|
||||
|
||||
_enq_tasks = [create_task(enqueue(s)) for s in streams]
|
||||
|
||||
nstreams_done = 0
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item == StreamStatus.STOP:
|
||||
nstreams_done += 1
|
||||
elif isinstance(item, Exception):
|
||||
raise item
|
||||
else:
|
||||
yield item
|
||||
queue.task_done()
|
||||
if nstreams_done == len(streams):
|
||||
break
|
||||
6
grox/lib/utils.py
Normal file
6
grox/lib/utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
def camel_to_snake(s: str) -> str:
|
||||
return "".join(["_" + c.lower() if c.isupper() else c for c in s]).lstrip("_")
|
||||
|
||||
|
||||
def snake_to_camel(s: str) -> str:
|
||||
return "".join(word.capitalize() for word in s.split("_"))
|
||||
54
grox/main.py
Normal file
54
grox/main.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import signal
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from grox.engine import Engine
|
||||
from grox.service import GrpcServer
|
||||
from grox.dispatcher import Dispatcher
|
||||
from grox.config.config import grox_config
|
||||
from grox.schedules.init import init_proc
|
||||
from grox.schedules.context import (
|
||||
cleanup,
|
||||
new_context,
|
||||
shutdown_context,
|
||||
queue_connection_shutdown_context,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
shutdown = asyncio.Event()
|
||||
|
||||
|
||||
async def serve():
|
||||
await init_proc("main")
|
||||
logger.info(f"Starting grox server...")
|
||||
context = new_context()
|
||||
engine = Engine(context)
|
||||
dispatcher = Dispatcher(context)
|
||||
grpc_server = GrpcServer(context)
|
||||
|
||||
await engine.start()
|
||||
await dispatcher.start()
|
||||
await grpc_server.start()
|
||||
|
||||
logger.info("Grox server started")
|
||||
event_loop = asyncio.get_running_loop()
|
||||
event_loop.add_signal_handler(signal.SIGINT, lambda: shutdown.set())
|
||||
event_loop.add_signal_handler(signal.SIGTERM, lambda: shutdown.set())
|
||||
|
||||
await shutdown.wait()
|
||||
logger.warning("Grox server shutting down...")
|
||||
queue_connection_shutdown_context(context)
|
||||
await asyncio.sleep(300)
|
||||
|
||||
shutdown_context(context)
|
||||
await asyncio.gather(
|
||||
grpc_server.stop(),
|
||||
dispatcher.stop(),
|
||||
engine.stop(),
|
||||
)
|
||||
cleanup()
|
||||
logger.warning("Grox server stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(serve())
|
||||
106
grox/plans/plan.py
Normal file
106
grox/plans/plan.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from functools import cache
|
||||
|
||||
from grox.lib.utils import camel_to_snake
|
||||
from grox.tasks.task import Task, TaskResultCategory
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskResult, TaskContext, TaskPayload, TaskEligibility
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Plan(ABC):
|
||||
TASKS: dict[str, type[Task]] = {}
|
||||
TASK_DEPENDENCIES: dict[str, set[str]] = {}
|
||||
REQUIRED_ELIGIBILITY: TaskEligibility
|
||||
|
||||
def __init__(self):
|
||||
self.deps = set([d for deps in self.TASK_DEPENDENCIES.values() for d in deps])
|
||||
if any(t not in self.TASKS for t in self.deps) or any(
|
||||
t not in self.TASKS for t in self.TASK_DEPENDENCIES.keys()
|
||||
):
|
||||
raise ValueError("Not every task in TASK_DEPENDENCIES is defined in TASKS")
|
||||
|
||||
async def execute(self, task: TaskPayload) -> TaskResult | None:
|
||||
if not self._eligible(task):
|
||||
return None
|
||||
Metrics.counter("plan.execute.count").add(
|
||||
1, attributes={"plan_name": self.get_name()}
|
||||
)
|
||||
logger.debug(f"Creating execution plan for graph: {self.TASK_DEPENDENCIES}")
|
||||
loop = asyncio.get_running_loop()
|
||||
dependencies = {task: loop.create_future() for task in self.deps}
|
||||
start = time.perf_counter()
|
||||
ctx = TaskContext(task)
|
||||
try:
|
||||
await asyncio.gather(
|
||||
*[self._execute_task(t, ctx, dependencies) for t in self.TASKS.keys()]
|
||||
)
|
||||
Metrics.counter("plan.execute.success.count").add(
|
||||
1, attributes={"plan_name": self.get_name()}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing plan: {traceback.format_exc()}")
|
||||
ctx.errors.append(e)
|
||||
Metrics.counter("plan.execute.failed.count").add(
|
||||
1, attributes={"plan_name": self.get_name()}
|
||||
)
|
||||
finally:
|
||||
duration = time.perf_counter() - start
|
||||
Metrics.histogram("plan.execute.duration").record(
|
||||
duration, attributes={"plan_name": self.get_name()}
|
||||
)
|
||||
for fut in dependencies.values():
|
||||
try:
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error canceling dependency future: {traceback.format_exc()}"
|
||||
)
|
||||
dependencies.clear()
|
||||
return TaskResult(
|
||||
task=task,
|
||||
content_categories=[c.model_copy() for c in ctx.content_categories],
|
||||
task_started_at=ctx.start_time,
|
||||
task_finished_at=time.perf_counter(),
|
||||
multimodal_post_embedding=ctx.multimodal_post_embedding,
|
||||
reason=ctx.reason,
|
||||
success=len(ctx.errors) == 0,
|
||||
error="\n".join([str(e) for e in ctx.errors]),
|
||||
)
|
||||
|
||||
def _eligible(self, ctx: TaskPayload) -> bool:
|
||||
return self.REQUIRED_ELIGIBILITY in ctx.eligibilities
|
||||
|
||||
async def _execute_task(
|
||||
self, task_name: str, ctx: TaskContext, dependencies: dict[str, asyncio.Future]
|
||||
):
|
||||
logger.debug(f"Waiting for task to become ready: {task_name}")
|
||||
task = self.TASKS[task_name]
|
||||
deps = self.TASK_DEPENDENCIES.get(task_name, set())
|
||||
dep_futures = [dependencies[d] for d in deps]
|
||||
dep_results = await asyncio.gather(*dep_futures)
|
||||
task_future = dependencies.get(task_name, None)
|
||||
if any(r == TaskResultCategory.SKIPPED for r in dep_results):
|
||||
if task_future is not None:
|
||||
task_future.set_result(TaskResultCategory.SKIPPED)
|
||||
return
|
||||
logger.debug(f"Started executing task: {task_name}")
|
||||
try:
|
||||
res = await task.exec(ctx)
|
||||
except Exception as e:
|
||||
if task_future is not None:
|
||||
task_future.set_exception(e)
|
||||
raise e
|
||||
if task_future is not None:
|
||||
task_future.set_result(res)
|
||||
logger.debug(f"Finished executing task: {task_name}")
|
||||
|
||||
@cache
|
||||
def get_name(self) -> str:
|
||||
return camel_to_snake(self.__class__.__name__)
|
||||
37
grox/plans/plan_initial_banger.py
Normal file
37
grox/plans/plan_initial_banger.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.tasks.task_pub import (
|
||||
TaskPublishKafka,
|
||||
TaskPublishUnifiedPostAnnotationsManhattan,
|
||||
)
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_media import TaskMediaHydrationBanger
|
||||
from grox.tasks.task_filters import TaskInitialBangerFilter
|
||||
from grox.tasks.task_banger_screen import TaskBangerScreen
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitBangerAnnotationWithPost
|
||||
from grox.tasks.task_grok_upa_action_with_labels import TaskGrokUpaActionWithLabels
|
||||
|
||||
|
||||
class PlanInitialBanger(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.BANGER_INITIAL_SCREEN
|
||||
|
||||
TASKS = {
|
||||
"task_initial_banger_filter": TaskInitialBangerFilter,
|
||||
"task_banger_annotation_rate_limit": TaskRateLimitBangerAnnotationWithPost,
|
||||
"task_media_hydration": TaskMediaHydrationBanger,
|
||||
"task_banger_screen_initial": TaskBangerScreen,
|
||||
"task_grok_upa_action_with_labels": TaskGrokUpaActionWithLabels,
|
||||
"task_publish_unified_post_annotations_manhattan": TaskPublishUnifiedPostAnnotationsManhattan,
|
||||
"task_publish_kafka": TaskPublishKafka,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_initial_banger_filter": set(),
|
||||
"task_banger_annotation_rate_limit": {"task_initial_banger_filter"},
|
||||
"task_media_hydration": {"task_banger_annotation_rate_limit"},
|
||||
"task_banger_screen_initial": {"task_media_hydration"},
|
||||
"task_grok_upa_action_with_labels": {"task_banger_screen_initial"},
|
||||
"task_publish_unified_post_annotations_manhattan": {
|
||||
"task_banger_screen_initial"
|
||||
},
|
||||
"task_publish_kafka": {"task_publish_unified_post_annotations_manhattan"},
|
||||
}
|
||||
62
grox/plans/plan_master.py
Normal file
62
grox/plans/plan_master.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import asyncio
|
||||
|
||||
from grox.plans.plan import Plan
|
||||
from grox.schedules.types import TaskResult, TaskPayload
|
||||
from grox.plans.plan_spam_comment import PlanSpamComment
|
||||
from grox.plans.plan_initial_banger import PlanInitialBanger
|
||||
from grox.plans.plan_post_embedding_with_summary import PlanPostEmbeddingWithSummary
|
||||
from grox.plans.plan_post_embedding_v5 import PlanPostEmbeddingV5
|
||||
from grox.plans.plan_post_embedding_v5_for_reply import PlanPostEmbeddingV5ForReply
|
||||
from grox.plans.plan_post_embedding_with_summary_for_reply import (
|
||||
PlanPostEmbeddingWithSummaryForReply,
|
||||
)
|
||||
from grox.plans.plan_post_safety import PlanPostSafety
|
||||
from grox.plans.plan_reply_ranking import PlanReplyRanking
|
||||
from grox.plans.plan_safety_ptos import PlanSafetyPtos
|
||||
|
||||
|
||||
class PlanMaster:
|
||||
ALL_PLANS: list[Plan] = [
|
||||
PlanInitialBanger(),
|
||||
PlanPostSafety(),
|
||||
PlanSpamComment(),
|
||||
PlanPostEmbeddingWithSummary(),
|
||||
PlanPostEmbeddingWithSummaryForReply(),
|
||||
PlanPostEmbeddingV5(),
|
||||
PlanPostEmbeddingV5ForReply(),
|
||||
PlanReplyRanking(),
|
||||
PlanSafetyPtos(),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def exec(cls, task: TaskPayload) -> TaskResult:
|
||||
results = await asyncio.gather(*[p.execute(task) for p in cls.ALL_PLANS])
|
||||
result = cls.merge_results(task, [r for r in results if r is not None])
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def merge_results(cls, task: TaskPayload, results: list[TaskResult]) -> TaskResult:
|
||||
multimodal_post_embedding = [
|
||||
r.multimodal_post_embedding
|
||||
for r in results
|
||||
if r.multimodal_post_embedding is not None
|
||||
]
|
||||
if multimodal_post_embedding:
|
||||
multimodal_post_embedding = multimodal_post_embedding[0]
|
||||
else:
|
||||
multimodal_post_embedding = None
|
||||
|
||||
return TaskResult(
|
||||
task=task,
|
||||
content_categories=[
|
||||
c.model_copy() for r in results for c in r.content_categories
|
||||
],
|
||||
task_started_at=min(r.task_started_at for r in results),
|
||||
task_finished_at=max(r.task_finished_at for r in results),
|
||||
multimodal_post_embedding=multimodal_post_embedding,
|
||||
reason="\n".join([r.reason for r in results if r.reason]),
|
||||
success=all(r.success for r in results),
|
||||
error="\n".join(
|
||||
[r.error or "unknown error" for r in results if not r.success]
|
||||
),
|
||||
)
|
||||
29
grox/plans/plan_post_embedding_v5.py
Normal file
29
grox/plans/plan_post_embedding_v5.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_media import TaskMediaHydration
|
||||
from grox.tasks.task_multimodal_post_embedding import TaskMultimodalPostEmbeddingV5
|
||||
from grox.tasks.task_write_mm_embedding_sink import (
|
||||
TaskWriteMMEmbeddingSinkV5SkipKafkaForReplies,
|
||||
)
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitEmbeddingV5
|
||||
from grox.tasks.task_asr import TaskASRTranscription
|
||||
|
||||
|
||||
class PlanPostEmbeddingV5(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.MM_EMB_V5
|
||||
|
||||
TASKS = {
|
||||
"task_post_embedding_rate_limit": TaskRateLimitEmbeddingV5,
|
||||
"task_media_hydration": TaskMediaHydration,
|
||||
"task_asr_transcription": TaskASRTranscription,
|
||||
"task_multimodal_post_embedding_v5": TaskMultimodalPostEmbeddingV5,
|
||||
"task_write_post_embedding_sink_v5": TaskWriteMMEmbeddingSinkV5SkipKafkaForReplies,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_post_embedding_rate_limit": set(),
|
||||
"task_media_hydration": {"task_post_embedding_rate_limit"},
|
||||
"task_asr_transcription": {"task_media_hydration"},
|
||||
"task_multimodal_post_embedding_v5": {"task_asr_transcription"},
|
||||
"task_write_post_embedding_sink_v5": {"task_multimodal_post_embedding_v5"},
|
||||
}
|
||||
32
grox/plans/plan_post_embedding_v5_for_reply.py
Normal file
32
grox/plans/plan_post_embedding_v5_for_reply.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_filters import TaskPostEmbeddingWithSummaryForReplyFilter
|
||||
from grox.tasks.task_media import TaskMediaHydration
|
||||
from grox.tasks.task_multimodal_post_embedding import TaskMultimodalPostEmbeddingV5
|
||||
from grox.tasks.task_write_mm_embedding_sink import TaskWriteMMEmbeddingSinkV5
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitEmbeddingV5ForReply
|
||||
from grox.tasks.task_asr import TaskASRTranscription
|
||||
|
||||
|
||||
class PlanPostEmbeddingV5ForReply(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.MM_EMB_V5_FOR_REPLY
|
||||
|
||||
TASKS = {
|
||||
"task_post_embedding_rate_limit_v5_for_reply": TaskRateLimitEmbeddingV5ForReply,
|
||||
"task_post_embedding_filter_for_reply": TaskPostEmbeddingWithSummaryForReplyFilter,
|
||||
"task_media_hydration": TaskMediaHydration,
|
||||
"task_asr_transcription": TaskASRTranscription,
|
||||
"task_multimodal_post_embedding_v5": TaskMultimodalPostEmbeddingV5,
|
||||
"task_write_post_embedding_sink_v5": TaskWriteMMEmbeddingSinkV5,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_post_embedding_rate_limit_v5_for_reply": set(),
|
||||
"task_post_embedding_filter_for_reply": {
|
||||
"task_post_embedding_rate_limit_v5_for_reply"
|
||||
},
|
||||
"task_media_hydration": {"task_post_embedding_filter_for_reply"},
|
||||
"task_asr_transcription": {"task_media_hydration"},
|
||||
"task_multimodal_post_embedding_v5": {"task_asr_transcription"},
|
||||
"task_write_post_embedding_sink_v5": {"task_multimodal_post_embedding_v5"},
|
||||
}
|
||||
38
grox/plans/plan_post_embedding_with_summary.py
Normal file
38
grox/plans/plan_post_embedding_with_summary.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_filters import TaskPostEmbeddingWithSummaryFilter
|
||||
from grox.tasks.task_media import TaskMediaHydration
|
||||
from grox.tasks.task_multimodal_post_embedding import (
|
||||
TaskMultimodalPostEmbeddingWithSummary,
|
||||
)
|
||||
from grox.tasks.task_write_mm_embedding_sink import TaskWriteMMEmbeddingSinkV3
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitEmbeddingWithPostSummary
|
||||
from grox.tasks.task_summarizer_for_post_embedding import TaskPostEmbeddingSummarizer
|
||||
|
||||
|
||||
class PlanPostEmbeddingWithSummary(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.POST_EMBEDDING_WITH_SUMMARY
|
||||
|
||||
TASKS = {
|
||||
"task_post_embedding_rate_limit_summary": TaskRateLimitEmbeddingWithPostSummary,
|
||||
"task_post_embedding_with_summary_filter": TaskPostEmbeddingWithSummaryFilter,
|
||||
"task_media_hydration": TaskMediaHydration,
|
||||
"task_post_embedding_summarizer": TaskPostEmbeddingSummarizer,
|
||||
"task_multimodal_post_embedding_with_summary": TaskMultimodalPostEmbeddingWithSummary,
|
||||
"task_write_post_embedding_sink_v3": TaskWriteMMEmbeddingSinkV3,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_post_embedding_rate_limit_summary": set(),
|
||||
"task_post_embedding_with_summary_filter": {
|
||||
"task_post_embedding_rate_limit_summary"
|
||||
},
|
||||
"task_media_hydration": {"task_post_embedding_with_summary_filter"},
|
||||
"task_post_embedding_summarizer": {"task_media_hydration"},
|
||||
"task_multimodal_post_embedding_with_summary": {
|
||||
"task_post_embedding_summarizer"
|
||||
},
|
||||
"task_write_post_embedding_sink_v3": {
|
||||
"task_multimodal_post_embedding_with_summary"
|
||||
},
|
||||
}
|
||||
38
grox/plans/plan_post_embedding_with_summary_for_reply.py
Normal file
38
grox/plans/plan_post_embedding_with_summary_for_reply.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_filters import TaskPostEmbeddingWithSummaryForReplyFilter
|
||||
from grox.tasks.task_media import TaskMediaHydration
|
||||
from grox.tasks.task_multimodal_post_embedding import (
|
||||
TaskMultimodalPostEmbeddingWithSummary,
|
||||
)
|
||||
from grox.tasks.task_write_mm_embedding_sink import TaskWriteMMEmbeddingSinkV3
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitEmbeddingWithPostSummaryForReply
|
||||
from grox.tasks.task_summarizer_for_post_embedding import TaskPostEmbeddingSummarizer
|
||||
|
||||
|
||||
class PlanPostEmbeddingWithSummaryForReply(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.POST_EMBEDDING_WITH_SUMMARY_FOR_REPLY
|
||||
|
||||
TASKS = {
|
||||
"task_post_embedding_rate_limit_summary_for_reply": TaskRateLimitEmbeddingWithPostSummaryForReply,
|
||||
"task_post_embedding_with_summary_filter_for_reply": TaskPostEmbeddingWithSummaryForReplyFilter,
|
||||
"task_media_hydration": TaskMediaHydration,
|
||||
"task_post_embedding_summarizer": TaskPostEmbeddingSummarizer,
|
||||
"task_multimodal_post_embedding_with_summary": TaskMultimodalPostEmbeddingWithSummary,
|
||||
"task_write_post_embedding_sink_v3": TaskWriteMMEmbeddingSinkV3,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_post_embedding_rate_limit_summary_for_reply": set(),
|
||||
"task_post_embedding_with_summary_filter_for_reply": {
|
||||
"task_post_embedding_rate_limit_summary_for_reply"
|
||||
},
|
||||
"task_media_hydration": {"task_post_embedding_with_summary_filter_for_reply"},
|
||||
"task_post_embedding_summarizer": {"task_media_hydration"},
|
||||
"task_multimodal_post_embedding_with_summary": {
|
||||
"task_post_embedding_summarizer"
|
||||
},
|
||||
"task_write_post_embedding_sink_v3": {
|
||||
"task_multimodal_post_embedding_with_summary"
|
||||
},
|
||||
}
|
||||
32
grox/plans/plan_post_safety.py
Normal file
32
grox/plans/plan_post_safety.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.tasks.task_pub import TaskUpsertTweetBoolMetadataToUnifiedPostAnnotation
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_filters import TaskPostSafetyDeluxeFilter
|
||||
from grox.tasks.task_media import TaskMediaHydrationBanger
|
||||
from grox.tasks.task_post_safety_screen_deluxe import TaskPostSafetyScreenDeluxe
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitPostSafetyAnnotationWithPost
|
||||
from grox.tasks.task_grok_upa_action_with_labels import TaskGrokUpaActionWithLabels
|
||||
|
||||
|
||||
class PlanPostSafety(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.POST_SAFETY
|
||||
|
||||
TASKS = {
|
||||
"task_post_safety_deluxe_filter": TaskPostSafetyDeluxeFilter,
|
||||
"task_post_safety_annotation_rate_limit": TaskRateLimitPostSafetyAnnotationWithPost,
|
||||
"task_media_hydration": TaskMediaHydrationBanger,
|
||||
"task_post_safety_screen_deluxe": TaskPostSafetyScreenDeluxe,
|
||||
"task_grok_upa_action_with_labels": TaskGrokUpaActionWithLabels,
|
||||
"task_upsert_tweet_bool_metadata_to_unified_post_annotations_manhattan": TaskUpsertTweetBoolMetadataToUnifiedPostAnnotation,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_post_safety_deluxe_filter": set(),
|
||||
"task_post_safety_annotation_rate_limit": {"task_post_safety_deluxe_filter"},
|
||||
"task_media_hydration": {"task_post_safety_annotation_rate_limit"},
|
||||
"task_post_safety_screen_deluxe": {"task_media_hydration"},
|
||||
"task_grok_upa_action_with_labels": {"task_post_safety_screen_deluxe"},
|
||||
"task_upsert_tweet_bool_metadata_to_unified_post_annotations_manhattan": {
|
||||
"task_post_safety_screen_deluxe"
|
||||
},
|
||||
}
|
||||
27
grox/plans/plan_reply_ranking.py
Normal file
27
grox/plans/plan_reply_ranking.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.tasks.task_pub import TaskWriteReplyRankingManhattan
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_media import TaskMediaHydration
|
||||
from grox.tasks.task_filters import TaskReplyRankingFilter
|
||||
from grox.tasks.task_rank_replies import TaskRankReplies
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitReplyRankingAnnotationWithPost
|
||||
|
||||
|
||||
class PlanReplyRanking(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.REPLY_RANKING
|
||||
|
||||
TASKS = {
|
||||
"task_reply_ranking_filter": TaskReplyRankingFilter,
|
||||
"task_reply_ranking_annotation_rate_limit": TaskRateLimitReplyRankingAnnotationWithPost,
|
||||
"task_media_hydration": TaskMediaHydration,
|
||||
"task_rank_replies": TaskRankReplies,
|
||||
"task_write_reply_ranking_manhattan": TaskWriteReplyRankingManhattan,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_reply_ranking_filter": set(),
|
||||
"task_reply_ranking_annotation_rate_limit": {"task_reply_ranking_filter"},
|
||||
"task_media_hydration": {"task_reply_ranking_annotation_rate_limit"},
|
||||
"task_rank_replies": {"task_media_hydration"},
|
||||
"task_write_reply_ranking_manhattan": {"task_rank_replies"},
|
||||
}
|
||||
32
grox/plans/plan_safety_ptos.py
Normal file
32
grox/plans/plan_safety_ptos.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_media import TaskMediaHydration
|
||||
from grox.tasks.task_filters import TaskSafetyPtosFilter
|
||||
from grox.tasks.task_safety_ptos_category import TaskSafetyPtosCategoryDetection
|
||||
from grox.tasks.task_safety_ptos_policy import TaskSafetyPtosPolicyDetection
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitSafetyPtosAnnotationWithPost
|
||||
from grox.tasks.task_write_safety_post_annotations_result_sink import (
|
||||
TaskWriteSafetyPostAnnotationsResultSink,
|
||||
)
|
||||
|
||||
|
||||
class PlanSafetyPtos(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.SAFETY_PTOS
|
||||
|
||||
TASKS = {
|
||||
"task_safety_ptos_filter": TaskSafetyPtosFilter,
|
||||
"task_safety_ptos_annotation_rate_limit": TaskRateLimitSafetyPtosAnnotationWithPost,
|
||||
"task_media_hydration": TaskMediaHydration,
|
||||
"task_safety_ptos_category_detection": TaskSafetyPtosCategoryDetection,
|
||||
"task_safety_ptos_policy_detection": TaskSafetyPtosPolicyDetection,
|
||||
"task_write_safety_post_annotations_result_sink": TaskWriteSafetyPostAnnotationsResultSink,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_safety_ptos_filter": {},
|
||||
"task_safety_ptos_annotation_rate_limit": {"task_safety_ptos_filter"},
|
||||
"task_media_hydration": {"task_safety_ptos_annotation_rate_limit"},
|
||||
"task_safety_ptos_category_detection": {"task_media_hydration"},
|
||||
"task_safety_ptos_policy_detection": {"task_safety_ptos_category_detection"},
|
||||
"task_write_safety_post_annotations_result_sink": {"task_safety_ptos_policy_detection"},
|
||||
}
|
||||
29
grox/plans/plan_spam_comment.py
Normal file
29
grox/plans/plan_spam_comment.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from grox.plans.plan import Plan
|
||||
from grox.tasks.task_pub import TaskPublishKafka, TaskWriteReplySpamManhattan
|
||||
from grox.schedules.types import TaskEligibility
|
||||
from grox.tasks.task_media import TaskMediaHydration
|
||||
from grox.tasks.task_filters import TaskSpamFilter
|
||||
from grox.tasks.task_spam_detection import TaskSpamDetection
|
||||
from grox.tasks.task_rate_limit import TaskRateLimitReplySpamAnnotationWithPost
|
||||
|
||||
|
||||
class PlanSpamComment(Plan):
|
||||
REQUIRED_ELIGIBILITY = TaskEligibility.SPAM_COMMENT
|
||||
|
||||
TASKS = {
|
||||
"task_spam_filter": TaskSpamFilter,
|
||||
"task_reply_spam_annotation_rate_limit": TaskRateLimitReplySpamAnnotationWithPost,
|
||||
"task_media_hydration": TaskMediaHydration,
|
||||
"task_spam_detection": TaskSpamDetection,
|
||||
"task_publish_reply_spam_mh": TaskWriteReplySpamManhattan,
|
||||
"task_publish_kafka": TaskPublishKafka,
|
||||
}
|
||||
|
||||
TASK_DEPENDENCIES = {
|
||||
"task_spam_filter": set(),
|
||||
"task_reply_spam_annotation_rate_limit": {"task_spam_filter"},
|
||||
"task_media_hydration": {"task_reply_spam_annotation_rate_limit"},
|
||||
"task_spam_detection": {"task_media_hydration"},
|
||||
"task_publish_reply_spam_mh": {"task_spam_detection"},
|
||||
"task_publish_kafka": {"task_spam_detection"},
|
||||
}
|
||||
47
grox/schedules/context.py
Normal file
47
grox/schedules/context.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import signal
|
||||
import logging
|
||||
from typing import Any
|
||||
from multiprocessing import Manager
|
||||
from multiprocessing.managers import DictProxy, SyncManager
|
||||
|
||||
type ScheduleContext = DictProxy[str, Any]
|
||||
logger = logging.getLogger(__name__)
|
||||
_manager: SyncManager | None = None
|
||||
|
||||
|
||||
def get_manager() -> SyncManager:
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = Manager()
|
||||
return _manager
|
||||
|
||||
|
||||
def new_context() -> ScheduleContext:
|
||||
manager = get_manager()
|
||||
return manager.dict(
|
||||
task_queue=manager.Queue(),
|
||||
resp_queue=manager.Queue(),
|
||||
live_task_queue=manager.Queue(),
|
||||
live_resp_queue=manager.Queue(),
|
||||
shutdown_event=manager.Event(),
|
||||
queue_connection_shutdown_event=manager.Event(),
|
||||
)
|
||||
|
||||
|
||||
def prevent_default() -> None:
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||
|
||||
|
||||
def shutdown_context(context: ScheduleContext) -> None:
|
||||
context["shutdown_event"].set()
|
||||
|
||||
|
||||
def queue_connection_shutdown_context(context: ScheduleContext) -> None:
|
||||
context["queue_connection_shutdown_event"].set()
|
||||
|
||||
|
||||
def cleanup() -> None:
|
||||
if _manager is not None:
|
||||
logger.warning("Shutting down context manager")
|
||||
_manager.shutdown()
|
||||
35
grox/schedules/init.py
Normal file
35
grox/schedules/init.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import setproctitle
|
||||
|
||||
from grox.config.config import grox_config
|
||||
from grox.schedules.context import prevent_default
|
||||
from monitor.logging import Logging
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_proc(proc_name: str):
|
||||
prevent_default()
|
||||
Logging.config(grox_config.logging)
|
||||
Metrics.init(proc_name, grox_config.metrics)
|
||||
setproctitle.setproctitle(proc_name)
|
||||
logger.info(f"Changed process title to {proc_name}")
|
||||
asyncio.create_task(periodic_gc(proc_name))
|
||||
|
||||
|
||||
async def periodic_gc(proc_name: str):
|
||||
while True:
|
||||
seconds = grox_config.periodic_gc.interval + random.randint(
|
||||
0, int(grox_config.periodic_gc.jitter)
|
||||
)
|
||||
await asyncio.sleep(seconds)
|
||||
logger.info(f"Running periodic GC for {proc_name}")
|
||||
start = time.perf_counter()
|
||||
gc.collect()
|
||||
end = time.perf_counter()
|
||||
logger.info(f"GC for {proc_name} took {end - start:.2f} seconds")
|
||||
73
grox/schedules/types.py
Normal file
73
grox/schedules/types.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from grox.config.config import TaskGeneratorType
|
||||
from grox.data_loaders.data_types import (
|
||||
Post,
|
||||
User,
|
||||
UserContext,
|
||||
ContentCategoryResult,
|
||||
GroxContentAnalysis,
|
||||
ReplyScoreResult,
|
||||
SafetyPostAnnotations,
|
||||
)
|
||||
from grox.classifiers.content.classifier_data_collection import (
|
||||
ClassifierDataCollectionResult,
|
||||
)
|
||||
|
||||
|
||||
class TaskEligibility(str, Enum):
|
||||
SPAM_COMMENT = "spam_comment"
|
||||
BANGER_INITIAL_SCREEN = "banger_initial_screen"
|
||||
POST_EMBEDDING_WITH_SUMMARY = "post_embedding_with_summary"
|
||||
POST_EMBEDDING_WITH_SUMMARY_FOR_REPLY = "post_embedding_with_summary_for_reply"
|
||||
MM_EMB_V4 = "mm_emb_v4"
|
||||
MM_EMB_V5 = "mm_emb_v5"
|
||||
MM_EMB_V5_FOR_REPLY = "mm_emb_v5_for_reply"
|
||||
REPLY_RANKING = "reply_ranking"
|
||||
SAFETY_PTOS = "safety_ptos"
|
||||
POST_SAFETY = "post_safety"
|
||||
|
||||
|
||||
class TaskPayload(BaseModel):
|
||||
payload_id: str
|
||||
post: Post | None = None
|
||||
user: User | None = None
|
||||
user_context: UserContext | None = None
|
||||
attempt: int = 0
|
||||
eligibilities: set[TaskEligibility] = Field(default_factory=set)
|
||||
deadline_ts_secs: int | None = None
|
||||
task_type: TaskGeneratorType | None = None
|
||||
grox_content_analysis: GroxContentAnalysis | None = None
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
task: TaskPayload
|
||||
task_started_at: float
|
||||
task_finished_at: float = Field(default_factory=time.perf_counter)
|
||||
content_categories: list[ContentCategoryResult] = Field(default_factory=list)
|
||||
multimodal_post_embedding: list[float] | None = None
|
||||
reason: str = ""
|
||||
success: bool = Field(default=True)
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TaskContext:
|
||||
def __init__(self, task: TaskPayload):
|
||||
self.payload = task
|
||||
self.eligibilities: set[TaskEligibility] = task.eligibilities.copy()
|
||||
self.content_categories: list[ContentCategoryResult] = []
|
||||
self.summary: str = ""
|
||||
self.multimodal_post_embedding: list[float] | None = None
|
||||
self.multimodal_post_embedding_dict: dict[str, list[float]] = {}
|
||||
self.reply_ranking_results: list[ReplyScoreResult] = []
|
||||
self.reason: str = ""
|
||||
self.available_topics: list | None = None
|
||||
self.start_time = time.perf_counter()
|
||||
self.errors: list[Exception] = []
|
||||
self.safety_annotations: SafetyPostAnnotations | None = None
|
||||
|
||||
|
||||
class TaskError(Exception):
|
||||
pass
|
||||
44
grox/summarizer/eapi_summarizer.py
Normal file
44
grox/summarizer/eapi_summarizer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from grok_sampler.eapi_sampler import EapiSampler
|
||||
from grox.config.config import EapiModelConfig
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import TypeVar, Generic, Any
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class EapiSummarizer(ABC, Generic[T]):
|
||||
def __init__(self, eapi_config: EapiModelConfig, eapi_sampler: EapiSampler):
|
||||
self.eapi_config = eapi_config
|
||||
self.eapi_sampler = eapi_sampler
|
||||
|
||||
async def summarize(self, input: T) -> Any:
|
||||
logger.info(f"[{self.__class__.__name__}] started processing summarize request")
|
||||
Metrics.counter("summarize.request.count").add(1)
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
res = await self._summarize(input)
|
||||
except Exception:
|
||||
Metrics.counter("summarize.error.count").add(1)
|
||||
logger.error(
|
||||
f"[{self.__class__.__name__}] error processing summarize request: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
Metrics.counter("summarize.success.count").add(1)
|
||||
end = time.perf_counter()
|
||||
logger.info(
|
||||
f"[{self.__class__.__name__}] finished processing summarize request in {end - start:.2f} seconds"
|
||||
)
|
||||
Metrics.histogram("summarize.latency.seconds").record(end - start)
|
||||
return res
|
||||
|
||||
@abstractmethod
|
||||
async def _summarize(self, input: T) -> Any:
|
||||
pass
|
||||
55
grox/summarizer/post_embedding_summarizer.py
Normal file
55
grox/summarizer/post_embedding_summarizer.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from grox.lm.post import PostRenderer
|
||||
from grox.lm.user import UserRenderer
|
||||
from grox.config.config import grox_config, ModelName
|
||||
from grox.summarizer.summarizer import Summarizer
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.lm.convo import Conversation, Message, Role
|
||||
from grok_sampler.config import GrokModelConfig
|
||||
from grok_sampler.vision_sampler import VisionSampler
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostEmbeddingSummarizer(Summarizer):
|
||||
def __init__(self, prompt_file: str):
|
||||
vlm_config = grox_config.get_model(ModelName.VLM_MINI_CRITICAL)
|
||||
vlm = VisionSampler(GrokModelConfig(**vlm_config.model_dump()))
|
||||
self.prompt_file: str = prompt_file
|
||||
if not os.path.exists(self.prompt_file):
|
||||
raise FileNotFoundError(f"Prompt file {self.prompt_file} not found")
|
||||
super().__init__(vlm_config, vlm)
|
||||
|
||||
async def _summarize(self, post: Post) -> str:
|
||||
convo = await self._render_vlm_conversation(post)
|
||||
result = await self.vlm.sample(
|
||||
convo.interleave(), conversation_id=convo.conversation_id
|
||||
)
|
||||
result_section = result.split("<description>")[1].split("</description>")[0]
|
||||
return result_section
|
||||
|
||||
async def _render_vlm_conversation(
|
||||
self, post: Post, disable_thinking: bool = True
|
||||
) -> Conversation:
|
||||
convo = Conversation(conversation_id=uuid.uuid4().hex)
|
||||
prompt = ""
|
||||
with open(self.prompt_file, "r") as f:
|
||||
prompt = f.read()
|
||||
convo.messages.append(Message(role=Role.SYSTEM, content=[prompt]))
|
||||
convo.messages.append(await self._build_task_message(post))
|
||||
if disable_thinking:
|
||||
convo.messages.append(
|
||||
Message(role=Role.ASSISTANT, content=[""], separator="")
|
||||
)
|
||||
else:
|
||||
convo.messages.append(Message(role=Role.ASSISTANT))
|
||||
return convo
|
||||
|
||||
async def _build_task_message(self, post: Post) -> Message:
|
||||
msg: Message = Message(role=Role.USER, content=[])
|
||||
msg.content.extend(UserRenderer.render(post.user))
|
||||
msg.content.extend(PostRenderer.render(post))
|
||||
return msg
|
||||
44
grox/summarizer/summarizer.py
Normal file
44
grox/summarizer/summarizer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from grok_sampler.vision_sampler import VisionSampler
|
||||
from grox.config.config import ModelConfig
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import TypeVar, Generic, Any
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Summarizer(ABC, Generic[T]):
|
||||
def __init__(self, model_config: ModelConfig, vlm: VisionSampler):
|
||||
self.model_config = model_config
|
||||
self.vlm = vlm
|
||||
|
||||
async def summarize(self, input: T) -> Any:
|
||||
logger.info(f"[{self.__class__.__name__}] started processing summarize request")
|
||||
Metrics.counter("summarize.request.count").add(1)
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
res = await self._summarize(input)
|
||||
except Exception:
|
||||
Metrics.counter("summarize.error.count").add(1)
|
||||
logger.error(
|
||||
f"[{self.__class__.__name__}] error processing summarize request: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
Metrics.counter("summarize.success.count").add(1)
|
||||
end = time.perf_counter()
|
||||
logger.info(
|
||||
f"[{self.__class__.__name__}] finished processing summarize request in {end - start:.2f} seconds"
|
||||
)
|
||||
Metrics.histogram("summarize.latency.seconds").record(end - start)
|
||||
return res
|
||||
|
||||
@abstractmethod
|
||||
async def _summarize(self, input: T) -> Any:
|
||||
pass
|
||||
56
grox/tasks/disable_rules.py
Normal file
56
grox/tasks/disable_rules.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from abc import ABC
|
||||
|
||||
from grox.config.env import is_dev, is_prod, is_local, is_mm_emb_prod, is_ptos_prod
|
||||
from grox.schedules.types import TaskContext
|
||||
|
||||
|
||||
class DisableTaskRule(ABC):
|
||||
DISABLE_REASON: str = ""
|
||||
|
||||
@classmethod
|
||||
def should_disable(cls, ctx: TaskContext) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def disable_reason(cls) -> str | None:
|
||||
return cls.DISABLE_REASON
|
||||
|
||||
|
||||
class DisableTaskForLocal(DisableTaskRule):
|
||||
DISABLE_REASON = "Task is disabled for local mode"
|
||||
|
||||
@classmethod
|
||||
def should_disable(cls, ctx: TaskContext) -> bool:
|
||||
return is_local
|
||||
|
||||
|
||||
class DisableTaskForDev(DisableTaskRule):
|
||||
DISABLE_REASON = "Task is disabled for dev mode"
|
||||
|
||||
@classmethod
|
||||
def should_disable(cls, ctx: TaskContext) -> bool:
|
||||
return is_dev
|
||||
|
||||
|
||||
class DisableTaskForNonProd(DisableTaskRule):
|
||||
DISABLE_REASON = "Task is disabled for non-prod mode"
|
||||
|
||||
@classmethod
|
||||
def should_disable(cls, ctx: TaskContext) -> bool:
|
||||
return not is_prod
|
||||
|
||||
|
||||
class DisableTaskForNonMmEmbProd(DisableTaskRule):
|
||||
DISABLE_REASON = "Task is disabled for non-mm-emb-prod mode"
|
||||
|
||||
@classmethod
|
||||
def should_disable(cls, ctx: TaskContext) -> bool:
|
||||
return not is_mm_emb_prod
|
||||
|
||||
|
||||
class DisableTaskForNonPtosProd(DisableTaskRule):
|
||||
DISABLE_REASON = "Task is disabled for non-ptos-prod mode"
|
||||
|
||||
@classmethod
|
||||
def should_disable(cls, ctx: TaskContext) -> bool:
|
||||
return not is_ptos_prod
|
||||
150
grox/tasks/task.py
Normal file
150
grox/tasks/task.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import logging
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import cache
|
||||
|
||||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from grox.lib.utils import camel_to_snake
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.tasks.disable_rules import DisableTaskRule
|
||||
from grox.data_loaders.data_types import GroxContentAnalysis, Post, User, UserContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskResultCategory(str, Enum):
|
||||
SUCCESS = "success"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class TaskStopExecution(Exception):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Task(ABC):
|
||||
DISABLE_RULES: list[type[DisableTaskRule]] = []
|
||||
|
||||
@classmethod
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
metrics_attributes = {"task_name": cls.get_name()}
|
||||
Metrics.counter("task.exec.count").add(1, attributes=metrics_attributes)
|
||||
logger.debug(f"[{cls.get_name()}] starting task")
|
||||
if cls.should_skip(ctx):
|
||||
Metrics.counter("task.exec.skipped.count").add(
|
||||
1, attributes=metrics_attributes
|
||||
)
|
||||
logger.info(f"[{cls.get_name()}] skipping task")
|
||||
return TaskResultCategory.SKIPPED
|
||||
Metrics.counter("task.exec.intaken.count").add(1, attributes=metrics_attributes)
|
||||
try:
|
||||
await cls._exec(ctx)
|
||||
except TaskStopExecution:
|
||||
Metrics.counter("task.exec.skipped.count").add(
|
||||
1, attributes=metrics_attributes
|
||||
)
|
||||
logger.info(f"[{cls.get_name()}] skipping task")
|
||||
return TaskResultCategory.SKIPPED
|
||||
except Exception:
|
||||
Metrics.counter("task.exec.failed.count").add(
|
||||
1, attributes=metrics_attributes
|
||||
)
|
||||
logger.error(
|
||||
f"[{cls.get_name()}] failed to execute task with error: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
Metrics.counter("task.exec.success.count").add(1, attributes=metrics_attributes)
|
||||
return TaskResultCategory.SUCCESS
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def should_skip(cls, ctx: TaskContext) -> bool:
|
||||
if cls.should_disable(ctx):
|
||||
return True
|
||||
return cls._should_skip(ctx)
|
||||
|
||||
@classmethod
|
||||
def _should_skip(cls, ctx: TaskContext) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def should_disable(cls, ctx: TaskContext) -> bool:
|
||||
for rule in cls.DISABLE_RULES:
|
||||
if rule.should_disable(ctx):
|
||||
logger.debug(
|
||||
f"[{cls.get_name()}] skipping task because {rule.disable_reason()}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def get_name(cls) -> str:
|
||||
return camel_to_snake(cls.__name__)
|
||||
|
||||
|
||||
class TaskWithUser(Task):
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
user = ctx.payload.user
|
||||
if not user:
|
||||
raise TaskStopExecution("No user for task")
|
||||
await cls._exec_with_user(ctx, user)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def _exec_with_user(cls, ctx: TaskContext, user: User) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskWithUserContext(Task):
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
user_context = ctx.payload.user_context
|
||||
if not user_context:
|
||||
raise TaskStopExecution("No user context for task")
|
||||
await cls._exec_with_user_context(ctx, user_context)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def _exec_with_user_context(
|
||||
cls, ctx: TaskContext, user_context: UserContext
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskWithPost(Task):
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
post = ctx.payload.post
|
||||
if not post:
|
||||
raise TaskStopExecution("No post for task")
|
||||
await cls._exec_with_post(ctx, post)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskWithContentAnalysis(Task):
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
content_analysis: GroxContentAnalysis | None = ctx.payload.grox_content_analysis
|
||||
if not content_analysis:
|
||||
raise TaskStopExecution("No content analysis for task")
|
||||
await cls._exec_with_content_analysis(ctx, content_analysis)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def _exec_with_content_analysis(
|
||||
cls, ctx: TaskContext, content_analysis: GroxContentAnalysis
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
74
grox/tasks/task_asr.py
Normal file
74
grox/tasks/task_asr.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import logging
|
||||
|
||||
from grox.data_loaders.asr_processor import ASRProcessor
|
||||
from grox.data_loaders.data_types import Post, Video
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.tasks.task import TaskWithPost
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_video_url(video: Video) -> tuple[str | None, bool]:
|
||||
if video.animatedGifInfo:
|
||||
v = video.animatedGifInfo.get_best_variant()
|
||||
if v and v.url:
|
||||
return v.url, True
|
||||
if video.videoInfo:
|
||||
v = video.videoInfo.get_best_variant()
|
||||
if v and v.url:
|
||||
return v.url, False
|
||||
if video.url:
|
||||
return video.url, False
|
||||
return None, False
|
||||
|
||||
|
||||
class TaskASRTranscription(TaskWithPost):
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
Metrics.counter("task.asr_transcription.total.count").add(1)
|
||||
|
||||
if not post.media:
|
||||
logger.debug(f"Post {post.id} has no media, skipping ASR")
|
||||
Metrics.counter("task.asr_transcription.skipped.count").add(
|
||||
1, attributes={"reason": "no_media"}
|
||||
)
|
||||
return
|
||||
|
||||
videos = [m for m in post.media if isinstance(m, Video)]
|
||||
if not videos:
|
||||
logger.debug(f"Post {post.id} has no video, skipping ASR")
|
||||
Metrics.counter("task.asr_transcription.skipped.count").add(
|
||||
1, attributes={"reason": "no_video"}
|
||||
)
|
||||
return
|
||||
|
||||
Metrics.counter("task.asr_transcription.has_video.count").add(1)
|
||||
|
||||
for video in videos:
|
||||
video_url, is_animated_gif = _get_video_url(video)
|
||||
if not video_url:
|
||||
continue
|
||||
|
||||
if is_animated_gif:
|
||||
logger.debug(
|
||||
f"Post {post.id} video {video.id} is an animated GIF, skipping ASR (no audio)"
|
||||
)
|
||||
Metrics.counter("task.asr_transcription.skipped.count").add(
|
||||
1, attributes={"reason": "animated_gif"}
|
||||
)
|
||||
continue
|
||||
|
||||
transcript = await ASRProcessor.process(post.id, video_url)
|
||||
|
||||
if transcript:
|
||||
if video.convo_video is not None:
|
||||
video.convo_video.asr_transcript = transcript
|
||||
Metrics.counter("task.asr_transcription.success.count").add(1)
|
||||
logger.info(
|
||||
f"ASR completed for post {post.id} video {video.id}, transcript_len={len(transcript)}"
|
||||
)
|
||||
else:
|
||||
Metrics.counter("task.asr_transcription.failed.count").add(
|
||||
1, attributes={"reason": "processor_error"}
|
||||
)
|
||||
79
grox/tasks/task_banger_screen.py
Normal file
79
grox/tasks/task_banger_screen.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import override
|
||||
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post, ContentCategoryType
|
||||
from grox.classifiers.content.banger_initial_screen import BangerInitialScreenClassifier
|
||||
from strato_http.queries.grok_topics import StratoGrokTopics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LOG_EVERY_N = 10000
|
||||
CACHE_TTL_SECONDS = 3600
|
||||
|
||||
|
||||
class TaskBangerScreen(TaskWithPost):
|
||||
classifier = BangerInitialScreenClassifier()
|
||||
_cached_topics = None
|
||||
_cache_timestamp = None
|
||||
|
||||
@classmethod
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
Metrics.counter("task.banger_initial_screen.total.count").add(1)
|
||||
|
||||
if cls._cached_topics is None or (
|
||||
cls._cache_timestamp is not None
|
||||
and time.time() - cls._cache_timestamp > CACHE_TTL_SECONDS
|
||||
):
|
||||
logger.info("Fetching grok topics for cache")
|
||||
Metrics.counter("task.banger_initial_screen.grok_cache.new_load.count").add(
|
||||
1
|
||||
)
|
||||
query = StratoGrokTopics()
|
||||
fetched_topics = await query.fetch()
|
||||
if fetched_topics:
|
||||
cls._cached_topics = fetched_topics
|
||||
cls._cache_timestamp = time.time()
|
||||
logger.info(f"Cached {len(cls._cached_topics)} categories with topics")
|
||||
Metrics.counter(
|
||||
"task.banger_initial_screen.grok_cache.new_load.success.count"
|
||||
).add(1)
|
||||
else:
|
||||
logger.warning("Failed to fetch grok topics")
|
||||
Metrics.counter(
|
||||
"task.banger_initial_screen.grok_cache.new_load.failure.count"
|
||||
).add(1)
|
||||
else:
|
||||
Metrics.counter(
|
||||
"task.banger_initial_screen.grok_cache.cache_hit.count"
|
||||
).add(1)
|
||||
|
||||
if cls._cached_topics and len(cls._cached_topics) > 0:
|
||||
Metrics.counter("task.banger_initial_screen.with_cached_topics.count").add(
|
||||
1
|
||||
)
|
||||
else:
|
||||
Metrics.counter(
|
||||
"task.banger_initial_screen.without_cached_topics.count"
|
||||
).add(1)
|
||||
|
||||
res = await cls.classifier.classify(post, topics=cls._cached_topics)
|
||||
ctx.content_categories.extend(res)
|
||||
ctx.available_topics = cls._cached_topics
|
||||
passed = any(
|
||||
r.positive
|
||||
for r in res
|
||||
if r.category == ContentCategoryType.BANGER_INITIAL_SCREEN
|
||||
)
|
||||
if passed:
|
||||
Metrics.counter("task.banger_initial_screen.passed.count").add(1)
|
||||
else:
|
||||
Metrics.counter("task.banger_initial_screen.failed.count").add(1)
|
||||
79
grox/tasks/task_embedding_pub.py
Normal file
79
grox/tasks/task_embedding_pub.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from functools import cache
|
||||
|
||||
from thrifts.gen.twitter.strato.columns.content_understanding.content_understanding.ttypes import (
|
||||
SimpleTweetEmbedding,
|
||||
)
|
||||
from thrifts.serdes import Serializer
|
||||
from grox.tasks.task import TaskWithPost
|
||||
from grox.tasks.disable_rules import DisableTaskForNonMmEmbProd
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.config.config import grox_config, KafkaTopicName
|
||||
from kafka_cli.producer import ScramKafkaProducer
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskPublishEmbeddingKafka(TaskWithPost):
|
||||
DISABLE_RULES = [DisableTaskForNonMmEmbProd]
|
||||
KAFKA_TOPIC_NAME: KafkaTopicName
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
embedding = ctx.multimodal_post_embedding
|
||||
if embedding is None:
|
||||
Metrics.counter("task.publish_embedding_kafka.skipped.count").add(
|
||||
1, attributes={"reason": "no_embedding"}
|
||||
)
|
||||
logger.info(f"No embedding available for post {post.id}, skipping")
|
||||
return
|
||||
|
||||
Metrics.counter("task.publish_embedding_kafka.intaken.count").add(1)
|
||||
try:
|
||||
await cls._publish_to_kafka(post, embedding)
|
||||
Metrics.counter("task.publish_embedding_kafka.success.count").add(1)
|
||||
if post.created_at:
|
||||
latency = time.time() - post.created_at.timestamp()
|
||||
Metrics.histogram("task.publish_embedding_kafka.e2e_latency").record(
|
||||
latency
|
||||
)
|
||||
except Exception:
|
||||
Metrics.counter("task.publish_embedding_kafka.failed.count").add(1)
|
||||
logger.error(
|
||||
f"Failed to publish embedding to Kafka: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _publish_to_kafka(cls, post: Post, embedding: list[float]) -> None:
|
||||
tweet_embedding = SimpleTweetEmbedding(
|
||||
tweetId=int(post.id),
|
||||
embedding1=embedding,
|
||||
)
|
||||
serialized_bytes = Serializer.serialize(tweet_embedding)
|
||||
await cls._get_kafka_producer().send(id=post.id, value=serialized_bytes)
|
||||
logger.info(
|
||||
f"Published embedding for post {post.id} to {cls.KAFKA_TOPIC_NAME.value}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def _get_kafka_producer(cls) -> ScramKafkaProducer:
|
||||
producer_config = grox_config.get_kafka_producer_topic(cls.KAFKA_TOPIC_NAME)
|
||||
logger.info(
|
||||
f"Creating embedding kafka producer with config: {producer_config.model_dump()}"
|
||||
)
|
||||
return ScramKafkaProducer(producer_config)
|
||||
|
||||
|
||||
class TaskPublishEmbeddingV4Kafka(TaskPublishEmbeddingKafka):
|
||||
KAFKA_TOPIC_NAME = KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_V4
|
||||
|
||||
|
||||
class TaskPublishEmbeddingV5Kafka(TaskPublishEmbeddingKafka):
|
||||
KAFKA_TOPIC_NAME = KafkaTopicName.GROX_MULTIMODAL_EMBEDDING_V5
|
||||
370
grox/tasks/task_filters.py
Normal file
370
grox/tasks/task_filters.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import override
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from grox.config.config import grox_config, TaskGeneratorType
|
||||
from grox.tasks.task import Task, TaskStopExecution
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskFilter(Task):
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
if not await cls._eligible(ctx):
|
||||
raise TaskStopExecution()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def _eligible(cls, ctx: TaskContext) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskFilterWithUser(TaskFilter):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible(cls, ctx: TaskContext) -> bool:
|
||||
if not ctx.payload.user:
|
||||
return False
|
||||
return await cls._eligible_with_user(ctx.payload.user, ctx)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def _eligible_with_user(cls, user: User, ctx: TaskContext) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskFilterWithPost(TaskFilter):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible(cls, ctx: TaskContext) -> bool:
|
||||
if not ctx.payload.post:
|
||||
return False
|
||||
return await cls._eligible_with_post(ctx.payload.post, ctx)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskSpamFilter(TaskFilterWithPost):
|
||||
FOLLOWER_COUNT_THRESHOLD_FOR_SPAM_DETECTION = ""
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
if not post.ancestors:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "spam_detection", "reason": "not_reply"}
|
||||
)
|
||||
return False
|
||||
if not post.user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "spam_detection", "reason": "no_user"}
|
||||
)
|
||||
return False
|
||||
if post.user.id == 0:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={"filter": "spam_detection", "reason": "is_system_account"},
|
||||
)
|
||||
return False
|
||||
if post.user.id == 0:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={"filter": "spam_detection", "reason": "is_system_account"},
|
||||
)
|
||||
return False
|
||||
if not post.ancestors[-1].user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "spam_detection",
|
||||
"reason": "previous_post_no_user",
|
||||
},
|
||||
)
|
||||
return False
|
||||
if post.user.id == post.ancestors[-1].user.id:
|
||||
logger.info(
|
||||
f"Skipping reply spam since the replier is same as reply target post {post.id}"
|
||||
)
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "spam_detection", "reason": "same_user_reply"}
|
||||
)
|
||||
return False
|
||||
if not post.ancestors[0].user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={"filter": "spam_detection", "reason": "root_post_no_user"},
|
||||
)
|
||||
return False
|
||||
if post.user.id == post.ancestors[0].user.id:
|
||||
logger.info(
|
||||
f"Skipping reply spam since the replier is same as reply root post {post.id}"
|
||||
)
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "spam_detection",
|
||||
"reason": "same_user_reply_as_root",
|
||||
},
|
||||
)
|
||||
return False
|
||||
in_reply_user_follower_count = post.ancestors[-1].user.follower_count or 0
|
||||
root_user_follower_count = post.ancestors[0].user.follower_count or 0
|
||||
if (
|
||||
in_reply_user_follower_count
|
||||
> cls.FOLLOWER_COUNT_THRESHOLD_FOR_SPAM_DETECTION
|
||||
or root_user_follower_count
|
||||
> cls.FOLLOWER_COUNT_THRESHOLD_FOR_SPAM_DETECTION
|
||||
):
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "spam_detection",
|
||||
"reason": "reply_ranking_target",
|
||||
},
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class TaskReplyRankingFilter(TaskFilterWithPost):
|
||||
FOLLOWER_COUNT_THRESHOLD_FOR_REPLY_RANKING = ""
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
if not post.ancestors:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "reply_ranking", "reason": "not_reply"}
|
||||
)
|
||||
return False
|
||||
if not post.user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "reply_ranking", "reason": "no_user"}
|
||||
)
|
||||
return False
|
||||
if not post.ancestors[-1].user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "reply_ranking",
|
||||
"reason": "previous_post_no_user",
|
||||
},
|
||||
)
|
||||
return False
|
||||
if not post.ancestors[0].user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "reply_ranking", "reason": "root_post_no_user"}
|
||||
)
|
||||
return False
|
||||
in_reply_user_follower_count = post.ancestors[-1].user.follower_count or 0
|
||||
root_user_follower_count = post.ancestors[0].user.follower_count or 0
|
||||
if (
|
||||
in_reply_user_follower_count
|
||||
<= cls.FOLLOWER_COUNT_THRESHOLD_FOR_REPLY_RANKING
|
||||
and root_user_follower_count
|
||||
<= cls.FOLLOWER_COUNT_THRESHOLD_FOR_REPLY_RANKING
|
||||
):
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "reply_ranking", "reason": "low_blast_radius"}
|
||||
)
|
||||
return False
|
||||
if post.user.id == post.ancestors[-1].user.id:
|
||||
logger.info(
|
||||
f"Skipping reply ranking since the replier is same as reply target post {post.id}"
|
||||
)
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "reply_ranking", "reason": "same_user_reply"}
|
||||
)
|
||||
return False
|
||||
if post.user.id == post.ancestors[0].user.id:
|
||||
logger.info(
|
||||
f"Skipping reply ranking since the replier is same as reply root post {post.id}"
|
||||
)
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "reply_ranking",
|
||||
"reason": "same_user_reply_as_root",
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
Metrics.counter("task.reply_ranking.eligible.count").add(1)
|
||||
return True
|
||||
|
||||
|
||||
class TaskPostEmbeddingWithSummaryFilter(TaskFilterWithPost):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
if post.ancestors:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "post_embedding_with_summary",
|
||||
"reason": "is_reply",
|
||||
},
|
||||
)
|
||||
return False
|
||||
if not post.user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "post_embedding_with_summary",
|
||||
"reason": "no_user",
|
||||
},
|
||||
)
|
||||
return False
|
||||
if post.user.is_protected:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "post_embedding_with_summary",
|
||||
"reason": "private_account",
|
||||
},
|
||||
)
|
||||
return False
|
||||
Metrics.counter("task.post_embedding_with_summary.eligible.count").add(1)
|
||||
return True
|
||||
|
||||
|
||||
class TaskPostEmbeddingWithSummaryForReplyFilter(TaskFilterWithPost):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
if not post.ancestors:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "post_embedding_with_summary_for_reply",
|
||||
"reason": "is_original",
|
||||
},
|
||||
)
|
||||
return False
|
||||
if not post.user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "post_embedding_with_summary_for_reply",
|
||||
"reason": "no_user",
|
||||
},
|
||||
)
|
||||
return False
|
||||
in_reply_user_protected = post.ancestors[-1].user.is_protected
|
||||
root_user_protected = post.ancestors[0].user.is_protected
|
||||
if in_reply_user_protected or root_user_protected or post.user.is_protected:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={
|
||||
"filter": "post_embedding_with_summary_for_reply",
|
||||
"reason": "private_account",
|
||||
},
|
||||
)
|
||||
return False
|
||||
Metrics.counter(
|
||||
"task.post_embedding_with_summary_for_reply.eligible.count"
|
||||
).add(1)
|
||||
return True
|
||||
|
||||
|
||||
class TaskSafetyPtosFilter(TaskFilterWithPost):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
is_deluxe = ctx.payload.task_type == TaskGeneratorType.SAFETY_PTOS_DELUXE
|
||||
filter_name = (
|
||||
"safety_ptos_deluxe_detection" if is_deluxe else "safety_ptos_detection"
|
||||
)
|
||||
|
||||
Metrics.counter(f"task.{filter_name}.request.count").add(1)
|
||||
if not post.user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": filter_name, "reason": "no_user"}
|
||||
)
|
||||
return False
|
||||
|
||||
Metrics.counter(f"task.{filter_name}.eligible.count").add(1)
|
||||
return True
|
||||
|
||||
|
||||
class TaskPostSafetyDeluxeFilter(TaskFilterWithPost):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
if not post.user:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "post_safety_deluxe", "reason": "no_user"}
|
||||
)
|
||||
return False
|
||||
|
||||
if post.ancestors:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "post_safety_deluxe", "reason": "reply"}
|
||||
)
|
||||
logger.info(f"Skipping post {post.id} because it is a reply")
|
||||
return False
|
||||
|
||||
filter_reason = cls._get_hardcoded_filter_reason(post)
|
||||
if filter_reason:
|
||||
logger.info(
|
||||
f"Skipping upa deluxe {post.id} because it is hit by hardcoded filters, reason: {filter_reason}"
|
||||
)
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "post_safety_deluxe", "reason": filter_reason}
|
||||
)
|
||||
return False
|
||||
|
||||
Metrics.counter("task.post_safety_deluxe.eligible.count").add(1)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _get_hardcoded_filter_reason(cls, post: Post) -> str | None:
|
||||
if not post.user:
|
||||
return None
|
||||
if post.user.is_protected:
|
||||
return "private_account"
|
||||
return None
|
||||
|
||||
|
||||
class TaskInitialBangerFilter(TaskFilterWithPost):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
Metrics.counter("task.initial_banger_filter.count").add(1)
|
||||
if not post.user:
|
||||
return False
|
||||
if post.ancestors:
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1, attributes={"filter": "content_understanding", "reason": "reply"}
|
||||
)
|
||||
logger.info(f"Skipping post {post.id} because it is a reply")
|
||||
return False
|
||||
filter_reason = cls._get_hardcoded_filter_reason(post)
|
||||
if filter_reason:
|
||||
logger.info(
|
||||
f"Skipping post {post.id} because it is hit by hardcoded filters, reason: {filter_reason}"
|
||||
)
|
||||
Metrics.counter("task.filter.skipped.count").add(
|
||||
1,
|
||||
attributes={"filter": "content_understanding", "reason": filter_reason},
|
||||
)
|
||||
return False
|
||||
logger.info(f"Post {post.id} is eligible for initial banger")
|
||||
Metrics.counter("task.initial_banger_filter.eligible.count").add(1)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _get_hardcoded_filter_reason(cls, post: Post) -> str | None:
|
||||
if not post.user:
|
||||
return None
|
||||
if post.user.is_protected:
|
||||
return "private_account"
|
||||
return None
|
||||
57
grox/tasks/task_grok_upa_action_with_labels.py
Normal file
57
grox/tasks/task_grok_upa_action_with_labels.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import logging
|
||||
|
||||
from grox.tasks.task import Task
|
||||
from grox.tasks.disable_rules import DisableTaskForNonProd
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from strato_http.queries.grok_upa_action_with_labels import (
|
||||
StratoGrokUpaActionWithLabels,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskGrokUpaActionWithLabels(Task):
|
||||
DISABLE_RULES = [DisableTaskForNonProd]
|
||||
|
||||
_strato_grok_upa_action_with_labels = StratoGrokUpaActionWithLabels()
|
||||
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
Metrics.counter("task.grok_upa_action_with_labels.count").add(1)
|
||||
|
||||
post = ctx.payload.post
|
||||
if not post:
|
||||
return
|
||||
|
||||
results = ctx.content_categories
|
||||
if not results:
|
||||
return
|
||||
|
||||
grok_response = next((r for r in results if r.tweet_bool_metadata), None)
|
||||
if not grok_response or not grok_response.tweet_bool_metadata:
|
||||
return
|
||||
|
||||
tweet_id = int(post.id)
|
||||
tweet_bool_metadata = grok_response.tweet_bool_metadata.model_dump()
|
||||
|
||||
action_result = await cls._strato_grok_upa_action_with_labels.execute(
|
||||
tweet_id, tweet_bool_metadata
|
||||
)
|
||||
if action_result and len(action_result.applied_labels) > 0:
|
||||
logger.info(
|
||||
f"grokUpaActionWithLabels applied labels: debugString='{action_result.debug_string}', appliedLabels={action_result.applied_labels} for post {tweet_id}"
|
||||
)
|
||||
Metrics.counter("task.grok_upa_action_with_labels.applied.count").add(1)
|
||||
for label in action_result.applied_labels:
|
||||
Metrics.counter(
|
||||
"task.grok_upa_action_with_labels.applied_label.count"
|
||||
).add(1, attributes={"label": label})
|
||||
elif action_result:
|
||||
logger.info(
|
||||
f"grokUpaActionWithLabels no labels applied: debugString='{action_result.debug_string}' for post {tweet_id}"
|
||||
)
|
||||
Metrics.counter("task.grok_upa_action_with_labels.empty.count").add(1)
|
||||
else:
|
||||
logger.info(f"grokUpaActionWithLabels failed for post {tweet_id}")
|
||||
Metrics.counter("task.grok_upa_action_with_labels.failed.count").add(1)
|
||||
35
grox/tasks/task_load_post_with_not_found_retry.py
Normal file
35
grox/tasks/task_load_post_with_not_found_retry.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import time
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.data_loaders.strato_loader import TweetStratoLoader
|
||||
from grox.schedules.types import TaskPayload
|
||||
from grox.tasks.task import TaskStopExecution
|
||||
from monitor.metrics import Metrics
|
||||
from tenacity import retry, wait_chain, wait_fixed, stop_after_attempt
|
||||
|
||||
|
||||
class TaskLoadPostWithNotFoundRetry(TaskWithPost):
|
||||
@classmethod
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_chain(wait_fixed(1), wait_fixed(1)))
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
start_time = time.perf_counter_ns()
|
||||
loaded_post = await TweetStratoLoader.load_post(post.id)
|
||||
if loaded_post is None:
|
||||
task_type = (
|
||||
ctx.payload.task_type.value
|
||||
if ctx.payload and ctx.payload.task_type
|
||||
else "None"
|
||||
)
|
||||
if "recovery" in task_type:
|
||||
raise TaskStopExecution(f"Post not found: {post.id}")
|
||||
else:
|
||||
raise ValueError(f"Post not found: {post.id}")
|
||||
ctx.payload.post = loaded_post
|
||||
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
|
||||
Metrics.histogram("task.embedding_load_post.duration_ms").record(duration_ms)
|
||||
27
grox/tasks/task_load_post_with_summary.py
Normal file
27
grox/tasks/task_load_post_with_summary.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from grox.tasks.task import TaskWithPost
|
||||
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.data_loaders.strato_loader import TweetStratoLoader
|
||||
from grox.tasks.task import TaskStopExecution
|
||||
from strato_http.queries.post_multimodal_embedding_mh_searchai import (
|
||||
StratoPostMultimodalEmbeddingGrokSummaryMh,
|
||||
)
|
||||
|
||||
|
||||
class TaskLoadPostWithSummary(TaskWithPost):
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
loaded_post = await TweetStratoLoader.load_post(post.id)
|
||||
if loaded_post is None:
|
||||
raise TaskStopExecution(f"Post not found: {post.id}")
|
||||
stratoPostMultimodalEmbeddingGrokSummaryMh = (
|
||||
StratoPostMultimodalEmbeddingGrokSummaryMh()
|
||||
)
|
||||
summary = await stratoPostMultimodalEmbeddingGrokSummaryMh.fetch(
|
||||
int(post.id), "v3"
|
||||
)
|
||||
if summary is None:
|
||||
raise TaskStopExecution(f"Summary not found: {post.id}")
|
||||
loaded_post.summary = summary
|
||||
ctx.payload.post = loaded_post
|
||||
23
grox/tasks/task_load_user_recent_posts.py
Normal file
23
grox/tasks/task_load_user_recent_posts.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import logging
|
||||
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.data_loaders.strato_loader import UserRecentPostsLoader
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.tasks.task import TaskWithPost, TaskStopExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskLoadUserRecentPosts(TaskWithPost):
|
||||
RECENT_POSTS_LIMIT = ""
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
if not post.user or not post.user.id:
|
||||
raise TaskStopExecution("Post has no author user to load recent posts for")
|
||||
|
||||
recent_posts = await UserRecentPostsLoader.load(
|
||||
post.user.id, limit=cls.RECENT_POSTS_LIMIT
|
||||
)
|
||||
post.user.recent_posts = recent_posts
|
||||
logger.info(f"Loaded {len(recent_posts)} recent posts for user {post.user.id}")
|
||||
21
grox/tasks/task_media.py
Normal file
21
grox/tasks/task_media.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from grox.tasks.task import TaskWithPost
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.data_loaders.media_processor import MediaProcessor
|
||||
from monitor.metrics import Metrics
|
||||
|
||||
|
||||
class TaskMediaHydration(TaskWithPost):
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
Metrics.counter("task.media_hydration.total.count").add(1)
|
||||
await MediaProcessor.process(post, video_duration_limit_minutes=360)
|
||||
Metrics.counter("task.media_hydration.passed.count").add(1)
|
||||
|
||||
|
||||
class TaskMediaHydrationBanger(TaskWithPost):
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
Metrics.counter("task.media_hydration_banger.total.count").add(1)
|
||||
await MediaProcessor.process(post, video_duration_limit_minutes=360)
|
||||
Metrics.counter("task.media_hydration_banger.passed.count").add(1)
|
||||
86
grox/tasks/task_multimodal_post_embedding.py
Normal file
86
grox/tasks/task_multimodal_post_embedding.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from grox.tasks.task import TaskWithPost
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext, TaskEligibility
|
||||
from grox.data_loaders.data_types import Post, Video
|
||||
from grox.embedder.multimodal_post_embedder_v2 import MultimodalPostEmbedderV2
|
||||
from grox.embedder.multimodal_post_embedder_v5 import MultimodalPostEmbedderV5
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskMultimodalPostEmbeddingWithSummary(TaskWithPost):
|
||||
embedder = MultimodalPostEmbedderV2(
|
||||
model="qwen3", renderer_version="lite", use_post_context_summary=True
|
||||
)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
_, embedding = await cls.embedder.embed(post)
|
||||
ctx.multimodal_post_embedding_dict["v3"] = embedding
|
||||
logger.info(
|
||||
f"TaskMultimodalPostEmbeddingWithSummary Embedding Added, length: {len(embedding)}"
|
||||
)
|
||||
Metrics.counter("task.multimodal_post_embedding_with_summary.count").add(1)
|
||||
|
||||
|
||||
class TaskMultimodalPostEmbeddingRecsysV4(TaskWithPost):
|
||||
embedder = MultimodalPostEmbedderV2(
|
||||
model="v4",
|
||||
renderer_version="lite",
|
||||
use_post_context_summary=True,
|
||||
)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
try:
|
||||
_, embedding = await cls.embedder.embed(post)
|
||||
except Exception as e:
|
||||
Metrics.counter("task.multimodal_post_embedding_recsys_v4.error").add(1)
|
||||
logger.warning(
|
||||
f"TaskMultimodalPostEmbeddingRecsysV4 failed for post {post.id}: {e}"
|
||||
)
|
||||
return
|
||||
ctx.multimodal_post_embedding_dict["v4"] = embedding
|
||||
logger.info(
|
||||
f"TaskMultimodalPostEmbeddingRecsysV4 Embedding Added, length: {len(embedding)}"
|
||||
)
|
||||
Metrics.counter("task.multimodal_post_embedding_recsys_v4.count").add(1)
|
||||
|
||||
|
||||
class TaskMultimodalPostEmbeddingV5(TaskWithPost):
|
||||
embedder = MultimodalPostEmbedderV5()
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
try:
|
||||
transcripts = []
|
||||
if post.media:
|
||||
for m in post.media:
|
||||
if (
|
||||
isinstance(m, Video)
|
||||
and m.convo_video
|
||||
and m.convo_video.asr_transcript
|
||||
):
|
||||
transcripts.append(m.convo_video.asr_transcript)
|
||||
transcript = "\n".join(transcripts) if transcripts else None
|
||||
_, embedding = await cls.embedder.embed(post, transcript=transcript)
|
||||
except Exception as e:
|
||||
Metrics.counter("task.multimodal_post_embedding_v5.error").add(1)
|
||||
logger.warning(
|
||||
f"TaskMultimodalPostEmbeddingV5 failed for post {post.id}: {e}"
|
||||
)
|
||||
raise
|
||||
ctx.multimodal_post_embedding_dict["v5_1"] = embedding
|
||||
logger.info(
|
||||
f"TaskMultimodalPostEmbeddingV5 Embedding Added, length: {len(embedding)}, has_transcript={transcript is not None}"
|
||||
)
|
||||
Metrics.counter("task.multimodal_post_embedding_v5.count").add(1)
|
||||
if transcript:
|
||||
Metrics.counter(
|
||||
"task.multimodal_post_embedding_v5.with_transcript.count"
|
||||
).add(1)
|
||||
25
grox/tasks/task_post_safety_screen_deluxe.py
Normal file
25
grox/tasks/task_post_safety_screen_deluxe.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import logging
|
||||
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.classifiers.content.post_safety_screen_deluxe import (
|
||||
PostSafetyDeluxeClassifier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskPostSafetyScreenDeluxe(TaskWithPost):
|
||||
classifier = PostSafetyDeluxeClassifier()
|
||||
|
||||
@classmethod
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
Metrics.counter("task.post_safety_screen_deluxe.total.count").add(1)
|
||||
res = await cls.classifier.classify(post)
|
||||
ctx.content_categories.extend(res)
|
||||
555
grox/tasks/task_pub.py
Normal file
555
grox/tasks/task_pub.py
Normal file
@@ -0,0 +1,555 @@
|
||||
import time
|
||||
import logging
|
||||
import traceback
|
||||
from functools import cache
|
||||
|
||||
import thrifts.gen.twitter.strato.columns.content_understanding.content_understanding.ttypes as t
|
||||
from thrifts.serdes import Serializer
|
||||
from grox.tasks.task import Task
|
||||
from monitor.metrics import Metrics
|
||||
from grox.config.config import KafkaTopicName, grox_config
|
||||
from kafka_cli.producer import KafkaProducer
|
||||
from grox.schedules.types import ReplyScoreResult, TaskContext
|
||||
from grox.tasks.disable_rules import (
|
||||
DisableTaskForDev,
|
||||
DisableTaskForLocal,
|
||||
DisableTaskForNonProd,
|
||||
)
|
||||
from strato_http.queries.unified_post_annotations import (
|
||||
StratoUnifiedPostAnnotations,
|
||||
StratoUpsertTweetBoolMetadataToUnifiedPostAnnotations,
|
||||
)
|
||||
from strato_http.queries.grok_reply_spam_action_with_labels import (
|
||||
StratoGrokReplySpamActionWithLabels,
|
||||
)
|
||||
from grox.data_loaders.data_types import (
|
||||
Post,
|
||||
ContentCategoryType,
|
||||
ContentCategoryResult,
|
||||
ContentCategoryScore,
|
||||
Image,
|
||||
Video,
|
||||
)
|
||||
from strato_http.queries.data_types import (
|
||||
EntityWithMetadata,
|
||||
FoundMetadata,
|
||||
UnifiedPostAnnotations,
|
||||
ReplyRankingScore,
|
||||
ReplyRankingScoreKafka,
|
||||
QualifiedId,
|
||||
)
|
||||
from grox.data_loaders.strato_loader import (
|
||||
ReplyRankingScoreStratoLoader,
|
||||
ReplySpamStratoLoader,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskPublishKafka(Task):
|
||||
DISABLE_RULES = [DisableTaskForLocal, DisableTaskForDev]
|
||||
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
post = ctx.payload.post
|
||||
results = ctx.content_categories
|
||||
if not post:
|
||||
return
|
||||
|
||||
if not results:
|
||||
Metrics.counter("task.publish_kafka.skipped.count").add(
|
||||
1, attributes={"reason": "no_results"}
|
||||
)
|
||||
return
|
||||
Metrics.counter("task.publish_kafka.intaken.count").add(1)
|
||||
try:
|
||||
await cls._publish_to_kafka(
|
||||
post.id,
|
||||
post.user.id if post.user else None,
|
||||
results,
|
||||
summary=ctx.summary,
|
||||
embedding=ctx.multimodal_post_embedding,
|
||||
)
|
||||
for result in results:
|
||||
Metrics.counter("task.publish_kafka.success.count").add(
|
||||
1, attributes={"category": result.category.value}
|
||||
)
|
||||
if post.created_at:
|
||||
latency = time.time() - post.created_at.timestamp()
|
||||
for res in results:
|
||||
Metrics.histogram("task.classification_e2e_latency").record(
|
||||
latency, attributes={"category": res.category.value}
|
||||
)
|
||||
else:
|
||||
Metrics.counter("task.publish_kafka.post_no_created_at.count").add(1)
|
||||
except Exception:
|
||||
Metrics.counter("task.publish_kafka.failed.count").add(1)
|
||||
logger.error(
|
||||
f"Failed to publish classification record: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _publish_to_kafka(
|
||||
cls,
|
||||
post_id: str,
|
||||
user_id: int | None,
|
||||
results: list[ContentCategoryResult],
|
||||
summary: str,
|
||||
embedding: list[float] | None,
|
||||
):
|
||||
category_results = [
|
||||
t.CategoryResult(
|
||||
category=r.category.name,
|
||||
positive=r.positive,
|
||||
score=r.score,
|
||||
summary=r.summary,
|
||||
taxonomyCategories=[
|
||||
t.TaxonomyCategoryScore(id=tc.id, name=tc.name, score=tc.score)
|
||||
for tc in r.taxonomy_categories
|
||||
]
|
||||
if r.taxonomy_categories
|
||||
else None,
|
||||
keywords=None,
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
grox_content_analysis = t.GroxContentAnalysis(
|
||||
postId=int(post_id),
|
||||
userId=user_id,
|
||||
categoryResults=category_results,
|
||||
summary=summary,
|
||||
createdAt=int(time.time()),
|
||||
)
|
||||
serialized_bytes = Serializer.serialize(grox_content_analysis)
|
||||
await cls._get_kafka_producer().send(id=post_id, value=serialized_bytes)
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def _get_kafka_producer(cls):
|
||||
producer_config = grox_config.get_kafka_producer_topic(
|
||||
KafkaTopicName.GROX_CONTENT_ANALYSIS
|
||||
)
|
||||
logger.info(
|
||||
f"Creating kafka producer with config: {producer_config.model_dump()}"
|
||||
)
|
||||
return KafkaProducer(producer_config)
|
||||
|
||||
|
||||
class TaskPublishUnifiedPostAnnotationsManhattan(Task):
|
||||
DISABLE_RULES = [DisableTaskForNonProd]
|
||||
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
Metrics.counter("task.publish_unified_post_annotations.count").add(1)
|
||||
results = ctx.content_categories
|
||||
if not results:
|
||||
logger.info("No unified post annotations to publish")
|
||||
return
|
||||
|
||||
post = ctx.payload.post
|
||||
if not post:
|
||||
return
|
||||
|
||||
grok_response = next(
|
||||
(
|
||||
r
|
||||
for r in results
|
||||
if r.category == ContentCategoryType.BANGER_INITIAL_SCREEN
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not grok_response:
|
||||
return
|
||||
|
||||
if grok_response.slop_score is not None:
|
||||
if grok_response.slop_score == 1:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.slop_score_1.count"
|
||||
).add(1)
|
||||
elif grok_response.slop_score == 2:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.slop_score_2.count"
|
||||
).add(1)
|
||||
elif grok_response.slop_score == 3:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.slop_score_3.count"
|
||||
).add(1)
|
||||
|
||||
if grok_response.tweet_bool_metadata:
|
||||
if grok_response.tweet_bool_metadata.isHighQuality:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.is_high_quality_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isNsfw:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.is_nsfw_true.count"
|
||||
).add(1)
|
||||
record_nsfw_detection(post)
|
||||
if grok_response.tweet_bool_metadata.isGore:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.is_gore_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isViolent:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.is_violent_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isSpam:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.is_spam_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isSoftNsfw:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.is_soft_nsfw_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isAdult:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.is_adult_true.count"
|
||||
).add(1)
|
||||
|
||||
if grok_response.tags and len(grok_response.tags) > 0:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.tags_non_empty.count"
|
||||
).add(1)
|
||||
|
||||
if grok_response.is_image_editable_by_grok:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.is_image_editable_by_grok_true.count"
|
||||
).add(1)
|
||||
|
||||
if post.media:
|
||||
if any(isinstance(m, Video) for m in post.media):
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.has_video_true.count"
|
||||
).add(1)
|
||||
if any(isinstance(m, Image) for m in post.media):
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.has_image_true.count"
|
||||
).add(1)
|
||||
|
||||
resolved_grok_topics = []
|
||||
if grok_response.taxonomy_categories and ctx.available_topics:
|
||||
id_to_name = {}
|
||||
name_to_category_id = {}
|
||||
for category in ctx.available_topics:
|
||||
id_to_name[category.categoryEntityId] = category.categoryName
|
||||
name_to_category_id[category.categoryName] = category.categoryEntityId
|
||||
for sub in category.subtopics:
|
||||
id_to_name[sub.topicEntityId] = sub.topicName
|
||||
name_to_category_id[sub.topicName] = category.categoryEntityId
|
||||
|
||||
topic_id_to_best_score = {}
|
||||
for grok_topic in grok_response.taxonomy_categories:
|
||||
topic_id = grok_topic.id
|
||||
if topic_id in id_to_name:
|
||||
topic_name = id_to_name[topic_id]
|
||||
category_id = name_to_category_id[topic_name]
|
||||
|
||||
resolved_grok_topic = ContentCategoryScore(
|
||||
id=topic_id,
|
||||
name=topic_name,
|
||||
score=grok_topic.score,
|
||||
category_id=category_id,
|
||||
)
|
||||
logger.info(
|
||||
f"Validated grok_topic: ID {topic_id} -> '{topic_name}' (category_id: {category_id})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid topic ID from Grok: {topic_id} not found in available topics"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.invalid_grok_topic.count"
|
||||
).add(1)
|
||||
continue
|
||||
|
||||
if (
|
||||
topic_id not in topic_id_to_best_score
|
||||
or grok_topic.score > topic_id_to_best_score[topic_id].score
|
||||
):
|
||||
topic_id_to_best_score[topic_id] = resolved_grok_topic
|
||||
|
||||
resolved_grok_topics = list(topic_id_to_best_score.values())
|
||||
elif grok_response.taxonomy_categories:
|
||||
logger.warning("No available topics to validate grok_topics")
|
||||
resolved_grok_topics = []
|
||||
|
||||
for topic in resolved_grok_topics:
|
||||
sanitized_topic_name_for_metric = (
|
||||
topic.name.lower().replace(" ", "_").replace("&", "and")
|
||||
)
|
||||
Metrics.counter(
|
||||
f"task.publish_unified_post_annotations.topic_{sanitized_topic_name_for_metric}.count"
|
||||
).add(1)
|
||||
|
||||
entities = []
|
||||
if resolved_grok_topics and len(resolved_grok_topics) > 0:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.with_grok_topics.count"
|
||||
).add(1)
|
||||
entities = [
|
||||
EntityWithMetadata(
|
||||
qualifiedId=QualifiedId(domainId=236, entityId=str(grok_topic.id)),
|
||||
score=grok_topic.score,
|
||||
categoryId=QualifiedId(
|
||||
domainId=236, entityId=str(grok_topic.category_id)
|
||||
)
|
||||
if grok_topic.category_id
|
||||
else None,
|
||||
)
|
||||
for grok_topic in resolved_grok_topics
|
||||
]
|
||||
else:
|
||||
Metrics.counter(
|
||||
"task.publish_unified_post_annotations.with_empty_grok_topics.count"
|
||||
).add(1)
|
||||
|
||||
annotations = UnifiedPostAnnotations(
|
||||
tweetId=post.id,
|
||||
entities=entities,
|
||||
tags=[{"tag": tag, "score": 0.0} for tag in (grok_response.tags or [])],
|
||||
tweetBoolMetadata=grok_response.tweet_bool_metadata.model_dump()
|
||||
if grok_response.tweet_bool_metadata
|
||||
else None,
|
||||
description=grok_response.summary,
|
||||
isImageEditableByGrok=grok_response.is_image_editable_by_grok,
|
||||
slopScore=grok_response.slop_score,
|
||||
originalOcrText="",
|
||||
evergreenScore=None,
|
||||
hasVideo=post.media and any(isinstance(m, Video) for m in post.media),
|
||||
hasImage=post.media and any(isinstance(m, Image) for m in post.media),
|
||||
imageDescription=None,
|
||||
videoDescription=None,
|
||||
qualityScore=grok_response.score,
|
||||
hasMinorScore=grok_response.has_minor_score,
|
||||
hasCard=post.card is not None,
|
||||
foundMetadata=FoundMetadata(
|
||||
imageCount=sum(1 for m in post.media if isinstance(m, Image))
|
||||
if post.media
|
||||
else 0,
|
||||
videoCount=sum(1 for m in post.media if isinstance(m, Video))
|
||||
if post.media
|
||||
else 0,
|
||||
cardCount=1 if post.card else 0,
|
||||
cardV2Count=len(post.cardsV2) if post.cardsV2 else 0,
|
||||
),
|
||||
)
|
||||
|
||||
await StratoUnifiedPostAnnotations().put(int(post.id), annotations)
|
||||
Metrics.counter("task.publish_unified_post_annotations.success.count").add(1)
|
||||
|
||||
|
||||
class TaskUpsertTweetBoolMetadataToUnifiedPostAnnotation(Task):
|
||||
DISABLE_RULES = [DisableTaskForNonProd]
|
||||
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.count"
|
||||
).add(1)
|
||||
results = ctx.content_categories
|
||||
if not results:
|
||||
logger.info("No unified post annotations to publish")
|
||||
return
|
||||
|
||||
post = ctx.payload.post
|
||||
if not post:
|
||||
return
|
||||
|
||||
grok_response = next(
|
||||
(
|
||||
r
|
||||
for r in results
|
||||
if r.category == ContentCategoryType.POST_SAFETY_SCREEN
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not grok_response or not grok_response.tweet_bool_metadata:
|
||||
return
|
||||
|
||||
if grok_response.tweet_bool_metadata.isHighQuality:
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_high_quality_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isNsfw:
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_nsfw_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isGore:
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_gore_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isViolent:
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_violent_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isSpam:
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_spam_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isSoftNsfw:
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_soft_nsfw_true.count"
|
||||
).add(1)
|
||||
if grok_response.tweet_bool_metadata.isAdult:
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.is_adult_true.count"
|
||||
).add(1)
|
||||
|
||||
await StratoUpsertTweetBoolMetadataToUnifiedPostAnnotations().put(
|
||||
int(post.id), grok_response.tweet_bool_metadata.model_dump()
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.upsert_tweet_bool_metadata_to_unified_post_annotations.success.count"
|
||||
).add(1)
|
||||
|
||||
|
||||
class TaskWriteReplyRankingManhattan(Task):
|
||||
DISABLE_RULES = [DisableTaskForNonProd]
|
||||
|
||||
_strato_grok_reply_spam_action_with_labels = StratoGrokReplySpamActionWithLabels()
|
||||
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
post = ctx.payload.post
|
||||
results = ctx.reply_ranking_results
|
||||
if not post:
|
||||
return
|
||||
if not results:
|
||||
Metrics.counter("task.write_reply_ranking_manhattan.skipped.count").add(
|
||||
1, attributes={"reason": "no_results"}
|
||||
)
|
||||
return
|
||||
Metrics.counter("task.write_reply_ranking_manhattan.intaken.count").add(1)
|
||||
try:
|
||||
await cls._publish_to_reply_ranking_manhattan(post, results)
|
||||
logger.info(
|
||||
f"Published reply ranking post to manhattan: {post.id=} {post.user.id=}"
|
||||
)
|
||||
except Exception:
|
||||
Metrics.counter("task.write_reply_ranking_manhattan.failed.count").add(1)
|
||||
logger.error(
|
||||
f"Failed to write reply ranking score to manhattan: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _publish_to_reply_ranking_manhattan(
|
||||
cls, post: Post, results: list[ReplyScoreResult]
|
||||
):
|
||||
logger.info(
|
||||
f"[_publish_to_reply_ranking_manhattan] checking results: {results}"
|
||||
)
|
||||
reasoning = ""
|
||||
|
||||
try:
|
||||
reply_ranking_result = next(r for r in results)
|
||||
except:
|
||||
reply_ranking_result = None
|
||||
|
||||
score = reply_ranking_result.score if reply_ranking_result else 3.0
|
||||
reasoning = reply_ranking_result.reason if reply_ranking_result else ""
|
||||
|
||||
if post.user:
|
||||
logger.info(
|
||||
f"[_publish_to_reply_ranking_manhattan] {reasoning=} {post.id=} {post.user.id=} {score=}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Missing user id [_publish_to_reply_ranking_manhattan] {reasoning=} {post.id=} {score=}"
|
||||
)
|
||||
|
||||
if score == 0.0:
|
||||
action_result = (
|
||||
await cls._strato_grok_reply_spam_action_with_labels.execute(
|
||||
int(post.id)
|
||||
)
|
||||
)
|
||||
if action_result and len(action_result.applied_labels) > 0:
|
||||
logger.info(
|
||||
f"grokReplySpamActionWithLabels applied labels: debugString='{action_result.debug_string}', appliedLabels={action_result.applied_labels} for post {post.id}"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.grok_reply_spam_action_with_labels.applied.count"
|
||||
).add(1)
|
||||
elif action_result:
|
||||
logger.info(
|
||||
f"grokReplySpamActionWithLabels no labels applied: debugString='{action_result.debug_string}' for post {post.id}"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.grok_reply_spam_action_with_labels.empty.count"
|
||||
).add(1)
|
||||
else:
|
||||
logger.info(f"grokReplySpamActionWithLabels failed for post {post.id}")
|
||||
Metrics.counter(
|
||||
"task.grok_reply_spam_action_with_labels.failed.count"
|
||||
).add(1)
|
||||
|
||||
await ReplyRankingScoreStratoLoader.save_reply_ranking_score(
|
||||
post_id=post.id,
|
||||
reply_ranking_score=ReplyRankingScore(
|
||||
score=score, reasoning=reasoning[-500:]
|
||||
),
|
||||
)
|
||||
|
||||
await ReplyRankingScoreStratoLoader.save_reply_ranking_kafka_v2(
|
||||
post_id=post.id,
|
||||
reply_ranking_score_kafka=ReplyRankingScoreKafka(
|
||||
postId=int(post.id), score=score, reasoning=reasoning[-500:]
|
||||
),
|
||||
)
|
||||
|
||||
Metrics.counter("task.write_reply_ranking_manhattan.success.count").add(
|
||||
1, attributes={"column": "reply_ranking"}
|
||||
)
|
||||
|
||||
|
||||
class TaskWriteReplySpamManhattan(Task):
|
||||
DISABLE_RULES = [DisableTaskForNonProd]
|
||||
|
||||
_strato_grok_reply_spam_action_with_labels = StratoGrokReplySpamActionWithLabels()
|
||||
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
post = ctx.payload.post
|
||||
if not post:
|
||||
return
|
||||
|
||||
results = ctx.content_categories
|
||||
for result in results:
|
||||
if result.category == ContentCategoryType.SPAM_COMMENT:
|
||||
if result.positive:
|
||||
action_result = (
|
||||
await cls._strato_grok_reply_spam_action_with_labels.execute(
|
||||
int(post.id)
|
||||
)
|
||||
)
|
||||
if action_result and len(action_result.applied_labels) > 0:
|
||||
logger.info(
|
||||
f"grokReplySpamActionWithLabels applied labels: debugString='{action_result.debug_string}', appliedLabels={action_result.applied_labels} for post {post.id}"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.grok_reply_spam_action_with_labels.applied.count"
|
||||
).add(1)
|
||||
elif action_result:
|
||||
logger.info(
|
||||
f"grokReplySpamActionWithLabels no labels applied: debugString='{action_result.debug_string}' for post {post.id}"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.grok_reply_spam_action_with_labels.empty.count"
|
||||
).add(1)
|
||||
else:
|
||||
logger.info(
|
||||
f"grokReplySpamActionWithLabels failed for post {post.id}"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.grok_reply_spam_action_with_labels.failed.count"
|
||||
).add(1)
|
||||
|
||||
await ReplySpamStratoLoader.save_spam_reply_annotation(
|
||||
post.id, result.score, result.positive, ""
|
||||
)
|
||||
logger.info(
|
||||
f"Published reply spam annotation to manhattan: {post.id=} {post.user.id=}"
|
||||
)
|
||||
31
grox/tasks/task_rank_replies.py
Normal file
31
grox/tasks/task_rank_replies.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import logging
|
||||
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.classifiers.content.reply_ranking import ReplyScorer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskRankReplies(TaskWithPost):
|
||||
scorer = ReplyScorer()
|
||||
|
||||
@classmethod
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
user = post.user
|
||||
logger.info(
|
||||
f"[task_rank_replies] {post.id=} "
|
||||
f"is_pasted={post.is_pasted} "
|
||||
f"user_agent={post.user_agent!r} "
|
||||
f"composition_source={post.composition_source!r} "
|
||||
f"app_attestation_status={post.app_attestation_status!r} "
|
||||
f"has_risky_user_safety_label={user.has_risky_user_safety_label if user else None} "
|
||||
f"num_legit_blocks_received_last_24hrs={user.num_legit_blocks_received_last_24hrs if user else None}"
|
||||
)
|
||||
res = await cls.scorer.score(post)
|
||||
ctx.reply_ranking_results.extend(res)
|
||||
191
grox/tasks/task_rate_limit.py
Normal file
191
grox/tasks/task_rate_limit.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Awaitable, override
|
||||
|
||||
from cachetools import TTLCache
|
||||
from grox.tasks.task import Task, TaskContext, TaskStopExecution
|
||||
from grox.config.config import TaskGeneratorType
|
||||
from monitor.metrics import Metrics
|
||||
from grox.data_loaders.data_types import Post, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskRateLimit(Task):
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
eligible = await cls._eligible(ctx)
|
||||
Metrics.counter("task.rate_limit.count").add(
|
||||
1, attributes={"task_name": cls.get_name(), "passed": eligible}
|
||||
)
|
||||
if not eligible:
|
||||
raise TaskStopExecution()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _eligible(cls, ctx: TaskContext) -> Awaitable[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskRateLimitWithPost(TaskRateLimit):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible(cls, ctx: TaskContext) -> bool:
|
||||
if not ctx.payload.post:
|
||||
return False
|
||||
return await cls._eligible_with_post(ctx.payload.post, ctx)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> Awaitable[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskRateLimitWithUser(TaskRateLimit):
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible(cls, ctx: TaskContext) -> bool:
|
||||
if not ctx.payload.user:
|
||||
return False
|
||||
return await cls._eligible_with_user(ctx.payload.user, ctx)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _eligible_with_user(cls, user: User, ctx: TaskContext) -> Awaitable[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TaskRateLimitEmbeddingWithPostSummary(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_MM_EMB_SUMMARY = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
if post_id not in cls.POST_CACHE_FOR_MM_EMB_SUMMARY:
|
||||
cls.POST_CACHE_FOR_MM_EMB_SUMMARY[post_id] = True
|
||||
return True
|
||||
logger.info(f"Post {post_id} already hit rate limit for mm emb with summary")
|
||||
return False
|
||||
|
||||
|
||||
class TaskRateLimitEmbeddingWithPostSummaryForReply(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_MM_EMB_SUMMARY_REPLY = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
if post_id not in cls.POST_CACHE_FOR_MM_EMB_SUMMARY_REPLY:
|
||||
cls.POST_CACHE_FOR_MM_EMB_SUMMARY_REPLY[post_id] = True
|
||||
return True
|
||||
logger.info(
|
||||
f"Post {post_id} already hit rate limit for mm emb with summary for reply"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class TaskRateLimitEmbeddingV5(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_MM_EMB_V5 = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
if post_id not in cls.POST_CACHE_FOR_MM_EMB_V5:
|
||||
cls.POST_CACHE_FOR_MM_EMB_V5[post_id] = True
|
||||
return True
|
||||
logger.info(f"Post {post_id} already hit rate limit for mm emb v5")
|
||||
return False
|
||||
|
||||
|
||||
class TaskRateLimitEmbeddingV5ForReply(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_MM_EMB_V5_REPLY = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
if post_id not in cls.POST_CACHE_FOR_MM_EMB_V5_REPLY:
|
||||
cls.POST_CACHE_FOR_MM_EMB_V5_REPLY[post_id] = True
|
||||
return True
|
||||
logger.info(f"Post {post_id} already hit rate limit for mm emb v5 for reply")
|
||||
return False
|
||||
|
||||
|
||||
class TaskRateLimitBangerAnnotationWithPost(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_BANGER = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
if post_id not in cls.POST_CACHE_FOR_BANGER:
|
||||
cls.POST_CACHE_FOR_BANGER[post_id] = True
|
||||
return True
|
||||
logger.info(f"Post {post_id} already hit rate limit for banger")
|
||||
return False
|
||||
|
||||
|
||||
class TaskRateLimitReplySpamAnnotationWithPost(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_REPLY_SPAM = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
if post_id not in cls.POST_CACHE_FOR_REPLY_SPAM:
|
||||
cls.POST_CACHE_FOR_REPLY_SPAM[post_id] = True
|
||||
return True
|
||||
logger.info(f"Post {post_id} already hit rate limit for reply spam")
|
||||
return False
|
||||
|
||||
|
||||
class TaskRateLimitReplyRankingAnnotationWithPost(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_REPLY_RANKING = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
if post_id not in cls.POST_CACHE_FOR_REPLY_RANKING:
|
||||
cls.POST_CACHE_FOR_REPLY_RANKING[post_id] = True
|
||||
return True
|
||||
logger.info(f"Post {post_id} already hit rate limit for reply ranking")
|
||||
return False
|
||||
|
||||
|
||||
class TaskRateLimitPostSafetyAnnotationWithPost(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_POST_SAFETY = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
if post_id not in cls.POST_CACHE_FOR_POST_SAFETY:
|
||||
cls.POST_CACHE_FOR_POST_SAFETY[post_id] = True
|
||||
return True
|
||||
logger.info(f"Post {post_id} already hit rate limit for post safety")
|
||||
return False
|
||||
|
||||
|
||||
class TaskRateLimitSafetyPtosAnnotationWithPost(TaskRateLimitWithPost):
|
||||
POST_CACHE_FOR_SAFETY_PTOS = TTLCache(maxsize=10_000, ttl=60)
|
||||
POST_CACHE_FOR_SAFETY_PTOS_DELUXE = TTLCache(maxsize=10_000, ttl=60)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _eligible_with_post(cls, post: Post, ctx: TaskContext) -> bool:
|
||||
post_id = post.id
|
||||
is_deluxe = ctx.payload.task_type == TaskGeneratorType.SAFETY_PTOS_DELUXE
|
||||
cache = (
|
||||
cls.POST_CACHE_FOR_SAFETY_PTOS_DELUXE
|
||||
if is_deluxe
|
||||
else cls.POST_CACHE_FOR_SAFETY_PTOS
|
||||
)
|
||||
label = "safety ptos deluxe" if is_deluxe else "safety ptos"
|
||||
if post_id not in cache:
|
||||
cache[post_id] = True
|
||||
return True
|
||||
logger.info(f"Post {post_id} already hit rate limit for {label}")
|
||||
return False
|
||||
61
grox/tasks/task_safety_ptos_category.py
Normal file
61
grox/tasks/task_safety_ptos_category.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.classifiers.content.safety_ptos import SafetyPtosCategoryClassifier
|
||||
from grox.config.config import ModelName, TaskGeneratorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskSafetyPtosCategoryDetection(TaskWithPost):
|
||||
classifier = SafetyPtosCategoryClassifier(ModelName.VLM_SAFETY)
|
||||
deluxe_classifier = SafetyPtosCategoryClassifier(
|
||||
ModelName.VLM_PRIMARY_CRITICAL, deluxe=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
is_deluxe = ctx.payload.task_type == TaskGeneratorType.SAFETY_PTOS_DELUXE
|
||||
active_classifier = cls.deluxe_classifier if is_deluxe else cls.classifier
|
||||
metric_prefix = (
|
||||
"task.safety_ptos_deluxe_category"
|
||||
if is_deluxe
|
||||
else "task.safety_ptos_category"
|
||||
)
|
||||
|
||||
safety_annotations = await active_classifier.classify_post(post)
|
||||
ctx.safety_annotations = safety_annotations
|
||||
|
||||
safety_categories = safety_annotations.violatedPolicies or []
|
||||
violation_count = len(safety_categories)
|
||||
has_violations = violation_count > 0
|
||||
Metrics.counter(f"{metric_prefix}.classified.count").add(1)
|
||||
|
||||
if has_violations:
|
||||
Metrics.counter(f"{metric_prefix}.has_violations.count").add(1)
|
||||
Metrics.counter(f"{metric_prefix}.violations.count").add(violation_count)
|
||||
|
||||
violation_details = []
|
||||
for violation in safety_categories:
|
||||
Metrics.counter(f"{metric_prefix}.violations_by_category.count").add(
|
||||
1, attributes={"category": violation.category.value}
|
||||
)
|
||||
violation_details.append(
|
||||
f"{violation.category.value}({violation.score})"
|
||||
)
|
||||
|
||||
mode = " (deluxe)" if is_deluxe else ""
|
||||
logger.info(
|
||||
f"Post {post.id}: Found {violation_count} violations{mode} - Details: {', '.join(violation_details)}"
|
||||
)
|
||||
else:
|
||||
Metrics.counter(f"{metric_prefix}.no_violations.count").add(1)
|
||||
mode = " (deluxe)" if is_deluxe else ""
|
||||
logger.info(f"Post {post.id}: No safety violations detected{mode}")
|
||||
|
||||
@classmethod
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
89
grox/tasks/task_safety_ptos_policy.py
Normal file
89
grox/tasks/task_safety_ptos_policy.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import logging
|
||||
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import (
|
||||
Post,
|
||||
SafetyPolicyCategory,
|
||||
SafetyPolicyType,
|
||||
SafetyPtosViolatedPolicy,
|
||||
)
|
||||
from grox.classifiers.content.safety_ptos import SafetyPtosPolicyClassifier
|
||||
from grox.config.config import TaskGeneratorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskSafetyPtosPolicyDetection(TaskWithPost):
|
||||
violated_policy_classifier = SafetyPtosPolicyClassifier()
|
||||
deluxe_violated_policy_classifier = SafetyPtosPolicyClassifier(deluxe=True)
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
if not ctx.safety_annotations:
|
||||
return
|
||||
|
||||
is_deluxe = ctx.payload.task_type == TaskGeneratorType.SAFETY_PTOS_DELUXE
|
||||
active_classifier = (
|
||||
cls.deluxe_violated_policy_classifier
|
||||
if is_deluxe
|
||||
else cls.violated_policy_classifier
|
||||
)
|
||||
metric_prefix = (
|
||||
"task.safety_ptos_deluxe_policy" if is_deluxe else "task.safety_ptos_policy"
|
||||
)
|
||||
|
||||
violations = list(ctx.safety_annotations.violatedPolicies or [])
|
||||
|
||||
injected_recheck = None
|
||||
if is_deluxe:
|
||||
if not any(
|
||||
v.category == SafetyPolicyCategory.AdultContent for v in violations
|
||||
):
|
||||
injected_recheck = SafetyPtosViolatedPolicy(
|
||||
category=SafetyPolicyCategory.AdultContent,
|
||||
reason="adult content recheck",
|
||||
score=50,
|
||||
)
|
||||
violations.append(injected_recheck)
|
||||
|
||||
for violation in violations:
|
||||
violation.safetyPolicy = (
|
||||
await active_classifier.classify_policy_for_violation(post, violation)
|
||||
)
|
||||
if violation.safetyPolicy:
|
||||
cls._record_policy_metrics(metric_prefix, violation)
|
||||
|
||||
if injected_recheck is not None:
|
||||
policy = injected_recheck.safetyPolicy
|
||||
if not policy or policy.policyType == SafetyPolicyType.NoViolation:
|
||||
violations.remove(injected_recheck)
|
||||
|
||||
ctx.safety_annotations.violatedPolicies = violations
|
||||
|
||||
@classmethod
|
||||
def _record_policy_metrics(
|
||||
cls, metric_prefix: str, violation: SafetyPtosViolatedPolicy
|
||||
) -> None:
|
||||
Metrics.counter(f"{metric_prefix}.classified_total.count").add(1)
|
||||
category_key = {
|
||||
SafetyPolicyCategory.ViolentMedia: "violent_media",
|
||||
SafetyPolicyCategory.AdultContent: "adult_content",
|
||||
SafetyPolicyCategory.Spam: "spam",
|
||||
SafetyPolicyCategory.IllegalAndRegulatedBehaviors: "illegal_and_regulated_behaviors",
|
||||
SafetyPolicyCategory.HateOrAbuse: "hate_or_abuse",
|
||||
SafetyPolicyCategory.ViolentSpeech: "violent_speech",
|
||||
SafetyPolicyCategory.SuicideOrSelfHarm: "suicide_or_self_harm",
|
||||
}.get(violation.category)
|
||||
if category_key:
|
||||
Metrics.counter(
|
||||
f"{metric_prefix}.classified_{category_key}_violations.count"
|
||||
).add(1)
|
||||
Metrics.counter(
|
||||
f"{metric_prefix}.classified_{category_key}_policy_types.count"
|
||||
).add(1, attributes={"policy_type": violation.safetyPolicy.policyType.name})
|
||||
|
||||
@classmethod
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
57
grox/tasks/task_spam_detection.py
Normal file
57
grox/tasks/task_spam_detection.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post, ContentCategoryType
|
||||
from grox.classifiers.content.spam import SpamEapiLowFollowerClassifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskSpamDetection(TaskWithPost):
|
||||
eapi_low_follower_classifier = SpamEapiLowFollowerClassifier()
|
||||
|
||||
@classmethod
|
||||
def get_follower_bucket_string(cls, post: Post) -> str:
|
||||
if not post.ancestors:
|
||||
return "invalid"
|
||||
in_reply_user_follower_count = post.ancestors[-1].user.follower_count or 0
|
||||
root_user_follower_count = post.ancestors[0].user.follower_count or 0
|
||||
if in_reply_user_follower_count <= 100 and root_user_follower_count <= 100:
|
||||
return "lte_100"
|
||||
elif in_reply_user_follower_count <= 500 and root_user_follower_count <= 500:
|
||||
return "lte_500"
|
||||
elif in_reply_user_follower_count <= 1000 and root_user_follower_count <= 1000:
|
||||
return "lte_1000"
|
||||
else:
|
||||
return "gt_1000"
|
||||
|
||||
@classmethod
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
res = await cls.eapi_low_follower_classifier.classify(post)
|
||||
ctx.content_categories.extend(res)
|
||||
passed = any(
|
||||
r.positive for r in res if r.category == ContentCategoryType.SPAM_COMMENT
|
||||
)
|
||||
|
||||
follower_bucket_string = cls.get_follower_bucket_string(post)
|
||||
if passed and follower_bucket_string != "gt_1000":
|
||||
logger.info(
|
||||
f"Reply Spam Found for lower than 1000 follower bucket. The post_id is {post.id} and the follower bucket is {follower_bucket_string}"
|
||||
)
|
||||
|
||||
if passed:
|
||||
Metrics.counter("task.spam_comment_detection.positive.count").add(
|
||||
1, attributes={"reason": follower_bucket_string}
|
||||
)
|
||||
else:
|
||||
Metrics.counter("task.spam_comment_detection.negative.count").add(
|
||||
1, attributes={"reason": follower_bucket_string}
|
||||
)
|
||||
25
grox/tasks/task_summarizer_for_post_embedding.py
Normal file
25
grox/tasks/task_summarizer_for_post_embedding.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import logging
|
||||
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
from grox.tasks.task_load_post_with_not_found_retry import TaskLoadPostWithNotFoundRetry
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext, TaskPayload
|
||||
from grox.data_loaders.data_types import Post
|
||||
from grox.summarizer.post_embedding_summarizer import PostEmbeddingSummarizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskPostEmbeddingSummarizer(TaskWithPost):
|
||||
summarizer = PostEmbeddingSummarizer(prompt_file="")
|
||||
|
||||
@classmethod
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
res = await cls.summarizer.summarize(post)
|
||||
assert res is not None
|
||||
post.summary = res
|
||||
Metrics.counter("task.post_embedding_summarizer.count").add(1)
|
||||
193
grox/tasks/task_write_mm_embedding_sink.py
Normal file
193
grox/tasks/task_write_mm_embedding_sink.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from grox.tasks.task import Task, TaskWithPost, TaskResultCategory
|
||||
from grox.tasks.disable_rules import DisableTaskForNonMmEmbProd
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import Post
|
||||
from strato_http.queries.post_multimodal_embedding_sink import (
|
||||
StratoPostMultimodalEmbeddingSink,
|
||||
)
|
||||
from tenacity import retry, wait_chain, wait_fixed, stop_after_attempt
|
||||
from strato_http.queries.post_multimodal_embedding_mh_searchai import (
|
||||
StratoPostMultimodalEmbeddingMhSearchAi,
|
||||
TweetEmbedding,
|
||||
StratoPostMultimodalEmbeddingMhSearchAiNoCache,
|
||||
StratoPostMultimodalEmbeddingGrokSummaryMh,
|
||||
StratoMultiModalEmbeddingTopic,
|
||||
)
|
||||
from grox.tasks.task_embedding_pub import (
|
||||
TaskPublishEmbeddingV4Kafka,
|
||||
TaskPublishEmbeddingV5Kafka,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskWriteMMEmbeddingSinkBase(TaskWithPost):
|
||||
model_version: str
|
||||
|
||||
DISABLE_RULES = [DisableTaskForNonMmEmbProd]
|
||||
|
||||
@classmethod
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_chain(wait_fixed(1), wait_fixed(2)))
|
||||
async def exec(cls, ctx: TaskContext) -> TaskResultCategory:
|
||||
return await Task.exec.__wrapped__(cls, ctx)
|
||||
|
||||
|
||||
class TaskWriteMMEmbeddingSinkExperiment(TaskWriteMMEmbeddingSinkBase):
|
||||
model_version = "v2"
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
start_time = time.perf_counter_ns()
|
||||
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
|
||||
assert embedding is not None
|
||||
query = StratoPostMultimodalEmbeddingMhSearchAi()
|
||||
await query.put(
|
||||
int(post.id),
|
||||
cls.model_version,
|
||||
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
|
||||
)
|
||||
logger.info(
|
||||
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version})"
|
||||
)
|
||||
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
|
||||
Metrics.histogram(
|
||||
"task.write_post_embedding_sink_experiment.duration_ms"
|
||||
).record(duration_ms)
|
||||
Metrics.counter("task.write_post_embedding_sink_experiment.count").add(1)
|
||||
|
||||
|
||||
class TaskWriteMMEmbeddingSinkV3(TaskWriteMMEmbeddingSinkBase):
|
||||
model_version = "v3"
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
start_time = time.perf_counter_ns()
|
||||
|
||||
summary = post.summary
|
||||
assert summary is not None
|
||||
stratoPostMultimodalEmbeddingGrokSummaryMh = (
|
||||
StratoPostMultimodalEmbeddingGrokSummaryMh()
|
||||
)
|
||||
await stratoPostMultimodalEmbeddingGrokSummaryMh.put(
|
||||
int(post.id), cls.model_version, summary
|
||||
)
|
||||
|
||||
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
|
||||
assert embedding is not None
|
||||
query = StratoPostMultimodalEmbeddingMhSearchAi()
|
||||
await query.put(
|
||||
int(post.id),
|
||||
cls.model_version,
|
||||
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
|
||||
)
|
||||
|
||||
stratoMultiModalEmbeddingTopic = StratoMultiModalEmbeddingTopic()
|
||||
await stratoMultiModalEmbeddingTopic.insert(
|
||||
TweetEmbedding(tweetId=int(post.id), embedding1=embedding)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version})"
|
||||
)
|
||||
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
|
||||
Metrics.histogram("task.write_post_embedding_sink_v3.duration_ms").record(
|
||||
duration_ms
|
||||
)
|
||||
Metrics.counter("task.write_post_embedding_sink_v3.count").add(1)
|
||||
|
||||
|
||||
class TaskWriteMMEmbeddingSinkV4(TaskWriteMMEmbeddingSinkBase):
|
||||
model_version = "v4"
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
start_time = time.perf_counter_ns()
|
||||
|
||||
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
|
||||
assert embedding is not None
|
||||
query = StratoPostMultimodalEmbeddingMhSearchAiNoCache()
|
||||
await query.put(
|
||||
int(post.id),
|
||||
cls.model_version,
|
||||
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
|
||||
)
|
||||
|
||||
await TaskPublishEmbeddingV4Kafka._publish_to_kafka(post, embedding)
|
||||
|
||||
logger.info(
|
||||
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version})"
|
||||
)
|
||||
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
|
||||
Metrics.histogram("task.write_post_embedding_sink_v4.duration_ms").record(
|
||||
duration_ms
|
||||
)
|
||||
Metrics.counter("task.write_post_embedding_sink_v4.count").add(1)
|
||||
|
||||
|
||||
class TaskWriteMMEmbeddingSinkV5(TaskWriteMMEmbeddingSinkBase):
|
||||
model_version = "v5_1"
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
start_time = time.perf_counter_ns()
|
||||
|
||||
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
|
||||
assert embedding is not None
|
||||
query = StratoPostMultimodalEmbeddingMhSearchAiNoCache()
|
||||
await query.put(
|
||||
int(post.id),
|
||||
cls.model_version,
|
||||
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
|
||||
)
|
||||
|
||||
await TaskPublishEmbeddingV5Kafka._publish_to_kafka(post, embedding)
|
||||
|
||||
logger.info(
|
||||
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version})"
|
||||
)
|
||||
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
|
||||
Metrics.histogram("task.write_post_embedding_sink_v5.duration_ms").record(
|
||||
duration_ms
|
||||
)
|
||||
Metrics.counter("task.write_post_embedding_sink_v5.count").add(1)
|
||||
|
||||
|
||||
class TaskWriteMMEmbeddingSinkV5SkipKafkaForReplies(TaskWriteMMEmbeddingSinkBase):
|
||||
model_version = "v5_1"
|
||||
|
||||
@classmethod
|
||||
async def _exec_with_post(cls, ctx: TaskContext, post: Post) -> None:
|
||||
start_time = time.perf_counter_ns()
|
||||
|
||||
embedding = ctx.multimodal_post_embedding_dict[cls.model_version]
|
||||
assert embedding is not None
|
||||
query = StratoPostMultimodalEmbeddingMhSearchAiNoCache()
|
||||
await query.put(
|
||||
int(post.id),
|
||||
cls.model_version,
|
||||
TweetEmbedding(tweetId=int(post.id), embedding1=embedding),
|
||||
)
|
||||
|
||||
is_reply = bool(post.ancestors)
|
||||
if not is_reply:
|
||||
await TaskPublishEmbeddingV5Kafka._publish_to_kafka(post, embedding)
|
||||
else:
|
||||
Metrics.counter(
|
||||
"task.write_post_embedding_sink_v5.kafka_skipped_reply.count"
|
||||
).add(1)
|
||||
logger.info(
|
||||
f"Skipping Kafka publish for reply post {post.id} (written to Manhattan only)"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"wrote post embedding to strato sink for post {post.id} (model: {cls.model_version}, kafka={'yes' if not is_reply else 'no'})"
|
||||
)
|
||||
duration_ms = (time.perf_counter_ns() - start_time) / 1_000
|
||||
Metrics.histogram("task.write_post_embedding_sink_v5.duration_ms").record(
|
||||
duration_ms
|
||||
)
|
||||
Metrics.counter("task.write_post_embedding_sink_v5.count").add(1)
|
||||
342
grox/tasks/task_write_safety_post_annotations_result_sink.py
Normal file
342
grox/tasks/task_write_safety_post_annotations_result_sink.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from grox.tasks.task import Task
|
||||
from grox.tasks.disable_rules import DisableTaskForNonPtosProd
|
||||
from monitor.metrics import Metrics
|
||||
from grox.schedules.types import TaskContext
|
||||
from grox.data_loaders.data_types import (
|
||||
Image,
|
||||
Video,
|
||||
SafetyPolicyCategory,
|
||||
SafetyPolicyType,
|
||||
)
|
||||
from strato_http.queries.data_types import (
|
||||
SafetyPostAnnotations,
|
||||
SafetyPostAnnotationsResult,
|
||||
SafetyBoolMetadata,
|
||||
SafetyPtosViolatedPolicy,
|
||||
SafetyPolicy,
|
||||
FoundMetadata,
|
||||
)
|
||||
from strato_http.queries.safety_post_annotations_result import (
|
||||
StratoSafetyPostAnnotationsResultMh,
|
||||
StratoSafetyPostAnnotationsResultDirectMh,
|
||||
StratoSafetyPostAnnotationsResultKafka,
|
||||
)
|
||||
from strato_http.queries.grok_ptos_action_with_labels import (
|
||||
StratoGrokPtosActionWithLabels,
|
||||
)
|
||||
from strato_http.queries.grok_ptos_delete_labels import StratoGrokPtosDeleteLabels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskWriteSafetyPostAnnotationsResultSink(Task):
|
||||
DISABLE_RULES = [DisableTaskForNonPtosProd]
|
||||
|
||||
_strato_mh = StratoSafetyPostAnnotationsResultMh()
|
||||
_strato_direct_mh = StratoSafetyPostAnnotationsResultDirectMh()
|
||||
_strato_kafka = StratoSafetyPostAnnotationsResultKafka()
|
||||
_strato_grok_ptos_action_with_labels = StratoGrokPtosActionWithLabels()
|
||||
_strato_grok_ptos_delete_labels = StratoGrokPtosDeleteLabels()
|
||||
|
||||
@staticmethod
|
||||
def _build_found_metadata(post) -> FoundMetadata:
|
||||
return FoundMetadata(
|
||||
imageCount=sum(1 for m in post.media if isinstance(m, Image))
|
||||
if post.media
|
||||
else 0,
|
||||
videoCount=sum(1 for m in post.media if isinstance(m, Video))
|
||||
if post.media
|
||||
else 0,
|
||||
cardV2Count=len(post.cardsV2) if post.cardsV2 else 0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _compute_bool_metadata_from_violations(
|
||||
cls, safety_annotations
|
||||
) -> SafetyBoolMetadata:
|
||||
is_gore = False
|
||||
is_nsfw = False
|
||||
is_soft_nsfw = False
|
||||
is_spam = False
|
||||
|
||||
if safety_annotations and safety_annotations.violatedPolicies:
|
||||
for violation in safety_annotations.violatedPolicies:
|
||||
if (
|
||||
violation.category == SafetyPolicyCategory.ViolentMedia
|
||||
and violation.safetyPolicy
|
||||
and violation.safetyPolicy.policyType
|
||||
!= SafetyPolicyType.NoViolation
|
||||
):
|
||||
is_gore = True
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.detected_gore.count"
|
||||
).add(1)
|
||||
|
||||
if (
|
||||
violation.category == SafetyPolicyCategory.AdultContent
|
||||
and violation.safetyPolicy
|
||||
and violation.safetyPolicy.policyType
|
||||
== SafetyPolicyType.AdultContentSexualHard
|
||||
):
|
||||
is_nsfw = True
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.detected_nsfw.count"
|
||||
).add(1)
|
||||
|
||||
if (
|
||||
violation.category == SafetyPolicyCategory.AdultContent
|
||||
and violation.safetyPolicy
|
||||
and violation.safetyPolicy.policyType
|
||||
== SafetyPolicyType.AdultContentSexualSoft
|
||||
):
|
||||
is_soft_nsfw = True
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.detected_soft_nsfw.count"
|
||||
).add(1)
|
||||
|
||||
if (
|
||||
violation.category == SafetyPolicyCategory.Spam
|
||||
and violation.safetyPolicy
|
||||
and violation.safetyPolicy.policyType
|
||||
!= SafetyPolicyType.NoViolation
|
||||
):
|
||||
is_spam = True
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.detected_spam.count"
|
||||
).add(1)
|
||||
|
||||
if (
|
||||
violation.category
|
||||
== SafetyPolicyCategory.IllegalAndRegulatedBehaviors
|
||||
and violation.safetyPolicy
|
||||
and violation.safetyPolicy.policyType
|
||||
!= SafetyPolicyType.NoViolation
|
||||
):
|
||||
is_spam = True
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.detected_spam_illegal.count"
|
||||
).add(1)
|
||||
|
||||
return SafetyBoolMetadata(
|
||||
isGore=is_gore, isNsfw=is_nsfw, isSoftNsfw=is_soft_nsfw, isSpam=is_spam
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _merge_bool_metadata(
|
||||
cls, existing: SafetyBoolMetadata | None, new: SafetyBoolMetadata
|
||||
) -> SafetyBoolMetadata:
|
||||
if existing is None:
|
||||
return new
|
||||
return SafetyBoolMetadata(
|
||||
isGore=True if (existing.isGore or new.isGore) else False,
|
||||
isNsfw=True if (existing.isNsfw or new.isNsfw) else False,
|
||||
isSoftNsfw=True if (existing.isSoftNsfw or new.isSoftNsfw) else False,
|
||||
isSpam=True if (existing.isSpam or new.isSpam) else False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _exec(cls, ctx: TaskContext) -> None:
|
||||
Metrics.counter("task.write_safety_post_annotations_result_sink.count").add(1)
|
||||
|
||||
post = ctx.payload.post
|
||||
if not post:
|
||||
return
|
||||
|
||||
safety_annotations = ctx.safety_annotations
|
||||
if not safety_annotations:
|
||||
return
|
||||
|
||||
post_id = int(post.id)
|
||||
|
||||
existing_result = await cls._strato_direct_mh.fetch(post_id)
|
||||
if existing_result:
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.existing_found.count"
|
||||
).add(1)
|
||||
else:
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.existing_not_found.count"
|
||||
).add(1)
|
||||
|
||||
if safety_annotations.violatedPolicies:
|
||||
safety_annotations.violatedPolicies.sort(
|
||||
key=lambda x: x.score or 0, reverse=True
|
||||
)
|
||||
|
||||
task_type_suffix = (
|
||||
ctx.payload.task_type.value if ctx.payload.task_type else "unknown"
|
||||
)
|
||||
identifier = f"{task_type_suffix}"
|
||||
timestamp_ms = int(time.time() * 1000)
|
||||
found_metadata = cls._build_found_metadata(post)
|
||||
|
||||
new_annotation = SafetyPostAnnotations(
|
||||
tweetId=post_id,
|
||||
violatedPolicies=[
|
||||
policy.model_dump()
|
||||
for policy in (safety_annotations.violatedPolicies or [])
|
||||
],
|
||||
foundMetadata=found_metadata,
|
||||
identifier=identifier,
|
||||
timestamp=timestamp_ms,
|
||||
)
|
||||
|
||||
annotations_list = (
|
||||
list(existing_result.safetyPostAnnotations)
|
||||
if existing_result and existing_result.safetyPostAnnotations
|
||||
else []
|
||||
)
|
||||
annotations_list.append(new_annotation)
|
||||
|
||||
new_bool_metadata = cls._compute_bool_metadata_from_violations(
|
||||
safety_annotations
|
||||
)
|
||||
existing_bool_metadata = (
|
||||
existing_result.safetyBoolMetadata if existing_result else None
|
||||
)
|
||||
merged_bool_metadata = cls._merge_bool_metadata(
|
||||
existing_bool_metadata, new_bool_metadata
|
||||
)
|
||||
|
||||
violation_details = []
|
||||
for v in safety_annotations.violatedPolicies:
|
||||
policy_type = v.safetyPolicy.policyType.name if v.safetyPolicy else "none"
|
||||
violation_details.append(f"{v.category.value}:{policy_type}")
|
||||
violations_summary = ", ".join(violation_details)
|
||||
|
||||
action_result = await cls._strato_grok_ptos_action_with_labels.execute(
|
||||
new_annotation
|
||||
)
|
||||
if action_result and len(action_result.applied_labels) > 0:
|
||||
logger.info(
|
||||
f"grokPtosActionWithLabels applied labels: debugString='{action_result.debug_string}', appliedLabels={action_result.applied_labels} for post {post_id} (result_sink), violations=[{violations_summary}]"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.grok_ptos_action_with_labels.count"
|
||||
).add(1)
|
||||
for label in action_result.applied_labels:
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.grok_ptos_action_with_labels.applied_label.count"
|
||||
).add(1, attributes={"label": label})
|
||||
elif action_result:
|
||||
logger.info(
|
||||
f"grokPtosActionWithLabels did not apply any labels: (debugString='{action_result.debug_string}') for post {post_id} (result_sink), violations=[{violations_summary}] "
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.grok_ptos_action_with_labels.empty.count"
|
||||
).add(1)
|
||||
else:
|
||||
logger.info(
|
||||
f"grokPtosActionWithLabels failed for post {post_id} (result_sink), violations=[{violations_summary}] "
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.grok_ptos_action_with_labels.failed.count"
|
||||
).add(1)
|
||||
|
||||
ptos_already_nsfw = new_bool_metadata.isNsfw
|
||||
if ctx.safemodel_sex_nudity.positive and not ptos_already_nsfw:
|
||||
safemodel_confidence_int = round(ctx.safemodel_sex_nudity.confidence * 100)
|
||||
safemodel_annotation = SafetyPostAnnotations(
|
||||
tweetId=post_id,
|
||||
violatedPolicies=[
|
||||
SafetyPtosViolatedPolicy(
|
||||
category=SafetyPolicyCategory.AdultContent,
|
||||
score=safemodel_confidence_int,
|
||||
reason="safemodel sex-and-nudity classifier detected adult content",
|
||||
safetyPolicy=SafetyPolicy(
|
||||
policyType=SafetyPolicyType.AdultContentSexualHard,
|
||||
confidenceScore=safemodel_confidence_int,
|
||||
reason="safemodel sex-and-nudity classifier detected adult content",
|
||||
),
|
||||
).model_dump(),
|
||||
],
|
||||
foundMetadata=found_metadata,
|
||||
identifier="safemodel-sex-nudity",
|
||||
timestamp=timestamp_ms,
|
||||
)
|
||||
annotations_list.append(safemodel_annotation)
|
||||
merged_bool_metadata = cls._merge_bool_metadata(
|
||||
merged_bool_metadata,
|
||||
SafetyBoolMetadata(
|
||||
isGore=False, isNsfw=True, isSoftNsfw=False, isSpam=False
|
||||
),
|
||||
)
|
||||
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.safemodel_enforced.count"
|
||||
).add(1)
|
||||
safemodel_action_result = (
|
||||
await cls._strato_grok_ptos_action_with_labels.execute(
|
||||
safemodel_annotation
|
||||
)
|
||||
)
|
||||
if (
|
||||
safemodel_action_result
|
||||
and len(safemodel_action_result.applied_labels) > 0
|
||||
):
|
||||
logger.info(
|
||||
f"safemodel enforce: grokPtosActionWithLabels applied labels: debugString='{safemodel_action_result.debug_string}', "
|
||||
f"appliedLabels={safemodel_action_result.applied_labels} for post {post_id}"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.safemodel_action.count"
|
||||
).add(1)
|
||||
for label in safemodel_action_result.applied_labels:
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.safemodel_action.applied_label.count"
|
||||
).add(1, attributes={"label": label})
|
||||
elif safemodel_action_result:
|
||||
logger.info(
|
||||
f"safemodel enforce: grokPtosActionWithLabels no labels applied (debugString='{safemodel_action_result.debug_string}') for post {post_id}"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.safemodel_action.empty.count"
|
||||
).add(1)
|
||||
else:
|
||||
logger.info(
|
||||
f"safemodel enforce: grokPtosActionWithLabels returned None for post {post_id}"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.safemodel_action.failed.count"
|
||||
).add(1)
|
||||
|
||||
final_result = SafetyPostAnnotationsResult(
|
||||
tweetId=post_id,
|
||||
safetyPostAnnotations=annotations_list,
|
||||
safetyBoolMetadata=merged_bool_metadata,
|
||||
)
|
||||
|
||||
delete_labels_result = await cls._strato_grok_ptos_delete_labels.execute(
|
||||
final_result
|
||||
)
|
||||
if delete_labels_result:
|
||||
logger.info(
|
||||
f"grokPtosDeleteLabels returned '{delete_labels_result}' for post {post_id} (result_sink)"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.grok_ptos_delete_labels.count"
|
||||
).add(1)
|
||||
else:
|
||||
logger.info(
|
||||
f"grokPtosDeleteLabels returned no result for post {post_id} (result_sink)"
|
||||
)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.grok_ptos_delete_labels.empty.count"
|
||||
).add(1)
|
||||
|
||||
await cls._strato_mh.put(post_id, final_result)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.mh.success.count"
|
||||
).add(1)
|
||||
|
||||
await cls._strato_kafka.insert(post_id, final_result)
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.kafka.success.count"
|
||||
).add(1)
|
||||
|
||||
Metrics.counter(
|
||||
"task.write_safety_post_annotations_result_sink.success.count"
|
||||
).add(1)
|
||||
20
home-mixer/ads/mod.rs
Normal file
20
home-mixer/ads/mod.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
mod partition_organic_blender;
|
||||
mod safe_gap_blender;
|
||||
pub(crate) mod util;
|
||||
|
||||
pub use partition_organic_blender::PartitionOrganicAdsBlender;
|
||||
pub use safe_gap_blender::SafeGapAdsBlender;
|
||||
|
||||
use util::{record_ad_risk_stats, record_post_verdict_stats};
|
||||
use xai_home_mixer_proto::{FeedItem, ScoredPost};
|
||||
use xai_recsys_proto::AdIndexInfo;
|
||||
|
||||
pub trait AdsBlender: Send + Sync {
|
||||
fn blend_inner(&self, scored_posts: Vec<ScoredPost>, ads: Vec<AdIndexInfo>) -> Vec<FeedItem>;
|
||||
|
||||
fn blend(&self, scored_posts: Vec<ScoredPost>, ads: Vec<AdIndexInfo>) -> Vec<FeedItem> {
|
||||
record_post_verdict_stats(&scored_posts);
|
||||
record_ad_risk_stats(&ads);
|
||||
self.blend_inner(scored_posts, ads)
|
||||
}
|
||||
}
|
||||
190
home-mixer/ads/partition_organic_blender.rs
Normal file
190
home-mixer/ads/partition_organic_blender.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use super::AdsBlender;
|
||||
use super::util::*;
|
||||
use crate::params::RESULT_SIZE;
|
||||
use xai_home_mixer_proto::{FeedItem, ScoredPost, feed_item};
|
||||
use xai_recsys_proto::AdIndexInfo;
|
||||
use xai_stats_receiver::global_stats_receiver;
|
||||
|
||||
const ENFORCEMENT_METRIC: &str = "PartitionOrganic.enforcement";
|
||||
|
||||
pub struct PartitionOrganicAdsBlender;
|
||||
|
||||
impl AdsBlender for PartitionOrganicAdsBlender {
|
||||
fn blend_inner(&self, scored_posts: Vec<ScoredPost>, ads: Vec<AdIndexInfo>) -> Vec<FeedItem> {
|
||||
blend_impl(scored_posts, ads, MIN_POSTS_FOR_ADS)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn blend_impl(
|
||||
scored_posts: Vec<ScoredPost>,
|
||||
ads: Vec<AdIndexInfo>,
|
||||
min_posts: usize,
|
||||
) -> Vec<FeedItem> {
|
||||
let n = scored_posts.len();
|
||||
|
||||
if ads.is_empty() || n < min_posts {
|
||||
return posts_to_feed_items(scored_posts);
|
||||
}
|
||||
|
||||
let spacing = compute_spacing(&ads);
|
||||
|
||||
let safe_count = scored_posts.iter().filter(|p| !has_avoid(p)).count();
|
||||
let max_from_safe = safe_count / 2;
|
||||
let expected_from_spacing = if spacing.requested > 0 {
|
||||
n.saturating_sub(1) / spacing.requested
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let actual_ads = ads.len().min(expected_from_spacing).min(max_from_safe);
|
||||
|
||||
if actual_ads == 0 {
|
||||
return posts_to_feed_items(scored_posts);
|
||||
}
|
||||
|
||||
let mut safe: Vec<ScoredPost> = Vec::new();
|
||||
let mut unsafe_posts: Vec<ScoredPost> = Vec::new();
|
||||
for post in scored_posts {
|
||||
if has_avoid(&post) {
|
||||
unsafe_posts.push(post);
|
||||
} else {
|
||||
safe.push(post);
|
||||
}
|
||||
}
|
||||
|
||||
let num_safe = safe.len();
|
||||
let group_size = num_safe / actual_ads;
|
||||
|
||||
let mut safe_opts: Vec<Option<ScoredPost>> = safe.into_iter().map(Some).collect();
|
||||
let mut triples: Vec<(AdIndexInfo, ScoredPost, ScoredPost)> = Vec::new();
|
||||
|
||||
let mut bsr_drop: u64 = 0;
|
||||
let mut bsr_ok: u64 = 0;
|
||||
let mut handle_drop: u64 = 0;
|
||||
let mut keyword_drop: u64 = 0;
|
||||
|
||||
let mut group_idx = 0;
|
||||
|
||||
for ad in ads {
|
||||
if group_idx >= actual_ads {
|
||||
break;
|
||||
}
|
||||
let group_start = group_idx * group_size;
|
||||
let above_ref = safe_opts[group_start].as_ref();
|
||||
let below_ref = safe_opts[group_start + 1].as_ref();
|
||||
|
||||
if should_drop_bsr_low(&ad, above_ref, below_ref) {
|
||||
bsr_drop += 1;
|
||||
continue;
|
||||
}
|
||||
if is_bsr_low_ad(&ad) {
|
||||
bsr_ok += 1;
|
||||
}
|
||||
|
||||
if should_drop_handle(&ad, above_ref, below_ref) {
|
||||
handle_drop += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if should_drop_keyword(&ad, above_ref, below_ref) {
|
||||
keyword_drop += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let above = safe_opts[group_start].take().unwrap();
|
||||
let below = safe_opts[group_start + 1].take().unwrap();
|
||||
triples.push((ad, above, below));
|
||||
group_idx += 1;
|
||||
}
|
||||
|
||||
let placed_ads = triples.len();
|
||||
emit_enforcement_metrics(bsr_drop, bsr_ok, handle_drop, keyword_drop);
|
||||
|
||||
if placed_ads == 0 {
|
||||
let mut all_posts: Vec<ScoredPost> = safe_opts.into_iter().flatten().collect();
|
||||
all_posts.extend(unsafe_posts);
|
||||
all_posts.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
return posts_to_feed_items(all_posts);
|
||||
}
|
||||
|
||||
let mut filler: Vec<ScoredPost> =
|
||||
Vec::with_capacity(num_safe - 2 * placed_ads + unsafe_posts.len());
|
||||
filler.extend(safe_opts.into_iter().flatten());
|
||||
filler.extend(unsafe_posts);
|
||||
filler.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let inter_ad_gaps = placed_ads;
|
||||
let filler_per_gap = filler.len() / inter_ad_gaps;
|
||||
let remainder = filler.len() % inter_ad_gaps;
|
||||
let mut filler_iter = filler.into_iter();
|
||||
|
||||
let mut items: Vec<FeedItem> = Vec::with_capacity(n + placed_ads);
|
||||
|
||||
for (i, (ad, above, below)) in triples.into_iter().enumerate() {
|
||||
items.push(FeedItem {
|
||||
position: 0,
|
||||
item: Some(feed_item::Item::Post(above)),
|
||||
});
|
||||
items.push(FeedItem {
|
||||
position: 0,
|
||||
item: Some(feed_item::Item::Ad(ad)),
|
||||
});
|
||||
items.push(FeedItem {
|
||||
position: 0,
|
||||
item: Some(feed_item::Item::Post(below)),
|
||||
});
|
||||
|
||||
let count = filler_per_gap + if i >= inter_ad_gaps - remainder { 1 } else { 0 };
|
||||
for _ in 0..count {
|
||||
if let Some(post) = filler_iter.next() {
|
||||
items.push(FeedItem {
|
||||
position: 0,
|
||||
item: Some(feed_item::Item::Post(post)),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
items.truncate(RESULT_SIZE);
|
||||
if matches!(items.last(), Some(item) if matches!(item.item, Some(feed_item::Item::Ad(_)))) {
|
||||
items.pop();
|
||||
}
|
||||
for (i, item) in items.iter_mut().enumerate() {
|
||||
item.position = i as i32;
|
||||
}
|
||||
|
||||
items
|
||||
}
|
||||
|
||||
fn emit_enforcement_metrics(bsr_drop: u64, bsr_ok: u64, handle_drop: u64, keyword_drop: u64) {
|
||||
let Some(receiver) = global_stats_receiver() else {
|
||||
return;
|
||||
};
|
||||
if bsr_drop > 0 {
|
||||
receiver.incr(ENFORCEMENT_METRIC, &[("action", "drop")], bsr_drop);
|
||||
}
|
||||
if bsr_ok > 0 {
|
||||
receiver.incr(ENFORCEMENT_METRIC, &[("action", "ok")], bsr_ok);
|
||||
}
|
||||
if handle_drop > 0 {
|
||||
receiver.incr(
|
||||
ENFORCEMENT_METRIC,
|
||||
&[("action", "handle_drop")],
|
||||
handle_drop,
|
||||
);
|
||||
}
|
||||
if keyword_drop > 0 {
|
||||
receiver.incr(
|
||||
ENFORCEMENT_METRIC,
|
||||
&[("action", "keyword_drop")],
|
||||
keyword_drop,
|
||||
);
|
||||
}
|
||||
}
|
||||
95
home-mixer/ads/safe_gap_blender.rs
Normal file
95
home-mixer/ads/safe_gap_blender.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use super::AdsBlender;
|
||||
use super::util::*;
|
||||
use xai_home_mixer_proto::{FeedItem, ScoredPost};
|
||||
use xai_recsys_proto::AdIndexInfo;
|
||||
|
||||
pub struct SafeGapAdsBlender;
|
||||
|
||||
impl AdsBlender for SafeGapAdsBlender {
|
||||
fn blend_inner(&self, scored_posts: Vec<ScoredPost>, ads: Vec<AdIndexInfo>) -> Vec<FeedItem> {
|
||||
blend_impl(scored_posts, ads, MIN_POSTS_FOR_ADS)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn blend_impl(
|
||||
scored_posts: Vec<ScoredPost>,
|
||||
ads: Vec<AdIndexInfo>,
|
||||
min_posts: usize,
|
||||
) -> Vec<FeedItem> {
|
||||
let n = scored_posts.len();
|
||||
|
||||
if ads.is_empty() || n < min_posts {
|
||||
return posts_to_feed_items(scored_posts);
|
||||
}
|
||||
|
||||
let safe_gaps = find_safe_gaps(&scored_posts);
|
||||
let spacing = compute_spacing(&ads);
|
||||
let first_ideal = ads[0].insert_position.max(0) as usize;
|
||||
let placements = assign_ads_to_gaps(&safe_gaps, ads.len(), &spacing, first_ideal);
|
||||
|
||||
interleave_and_finalize(scored_posts, ads, &placements)
|
||||
}
|
||||
|
||||
pub(crate) fn assign_ads_to_gaps(
|
||||
safe_gaps: &[usize],
|
||||
num_ads: usize,
|
||||
spacing: &AdSpacing,
|
||||
first_ideal: usize,
|
||||
) -> Vec<usize> {
|
||||
let mut placements: Vec<usize> = Vec::new();
|
||||
let mut search_from: usize = 0;
|
||||
let mut prev_ideal = first_ideal;
|
||||
|
||||
for _ in 0..num_ads {
|
||||
if search_from >= safe_gaps.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let (ideal, min) = match placements.last() {
|
||||
None => (first_ideal, 1),
|
||||
Some(&last_actual) => {
|
||||
let ideal = prev_ideal + spacing.requested;
|
||||
let min = (prev_ideal + spacing.min).max(last_actual + DEFAULT_SPACING.min);
|
||||
(ideal, min)
|
||||
}
|
||||
};
|
||||
|
||||
let gap = find_best_gap(&safe_gaps[search_from..], ideal, min);
|
||||
|
||||
match gap {
|
||||
Some((offset, g)) => {
|
||||
placements.push(g);
|
||||
search_from += offset + 1;
|
||||
prev_ideal = ideal;
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
placements
|
||||
}
|
||||
|
||||
pub(crate) fn find_best_gap(gaps: &[usize], ideal: usize, min: usize) -> Option<(usize, usize)> {
|
||||
let min_offset = gaps.partition_point(|&g| g < min);
|
||||
if min_offset >= gaps.len() {
|
||||
return None;
|
||||
}
|
||||
let candidates = &gaps[min_offset..];
|
||||
let ideal_pos = candidates.partition_point(|&g| g < ideal);
|
||||
|
||||
let chosen = if ideal_pos >= candidates.len() {
|
||||
candidates.len() - 1
|
||||
} else if ideal_pos == 0 {
|
||||
0
|
||||
} else {
|
||||
let below = candidates[ideal_pos - 1];
|
||||
let above = candidates[ideal_pos];
|
||||
if ideal - below <= above - ideal {
|
||||
ideal_pos - 1
|
||||
} else {
|
||||
ideal_pos
|
||||
}
|
||||
};
|
||||
|
||||
Some((min_offset + chosen, candidates[chosen]))
|
||||
}
|
||||
228
home-mixer/ads/util.rs
Normal file
228
home-mixer/ads/util.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
use crate::params::RESULT_SIZE;
|
||||
use std::sync::LazyLock;
|
||||
use xai_home_mixer_proto::{BrandSafetyVerdict, FeedItem, ScoredPost, feed_item};
|
||||
use xai_post_text::TweetTokenizer;
|
||||
use xai_recsys_proto::{AdIndexInfo, BrandSafetyRiskLevel};
|
||||
use xai_stats_receiver::global_stats_receiver;
|
||||
|
||||
static TWEET_TOKENIZER: LazyLock<TweetTokenizer> = LazyLock::new(TweetTokenizer::new);
|
||||
|
||||
pub(crate) const MIN_POSTS_FOR_ADS: usize = 5;
|
||||
|
||||
pub(crate) const MIN_REQUESTED_GAP: usize = 3;
|
||||
|
||||
pub(crate) const DEFAULT_SPACING: AdSpacing = AdSpacing {
|
||||
requested: 3,
|
||||
min: 2,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct AdSpacing {
|
||||
pub(crate) requested: usize,
|
||||
pub(crate) min: usize,
|
||||
}
|
||||
|
||||
pub(crate) fn has_avoid(post: &ScoredPost) -> bool {
|
||||
post.brand_safety_verdict() == BrandSafetyVerdict::MediumRisk
|
||||
}
|
||||
|
||||
pub(crate) fn find_safe_gaps(scored_posts: &[ScoredPost]) -> Vec<usize> {
|
||||
let n = scored_posts.len();
|
||||
let mut safe = Vec::new();
|
||||
for g in 1..n {
|
||||
if has_avoid(&scored_posts[g - 1]) {
|
||||
continue;
|
||||
}
|
||||
if g < n && has_avoid(&scored_posts[g]) {
|
||||
continue;
|
||||
}
|
||||
safe.push(g);
|
||||
}
|
||||
safe
|
||||
}
|
||||
|
||||
pub(crate) fn compute_spacing(ads: &[AdIndexInfo]) -> AdSpacing {
|
||||
if ads.len() < 2 {
|
||||
return DEFAULT_SPACING;
|
||||
}
|
||||
|
||||
let mut positions: Vec<i32> = ads.iter().take(4).map(|a| a.insert_position).collect();
|
||||
positions.sort_unstable();
|
||||
|
||||
let min_diff = positions
|
||||
.windows(2)
|
||||
.map(|w| (w[1] - w[0]) as usize)
|
||||
.filter(|&d| d > 0)
|
||||
.min();
|
||||
|
||||
match min_diff {
|
||||
Some(requested) if requested >= MIN_REQUESTED_GAP => AdSpacing {
|
||||
requested,
|
||||
min: requested.div_ceil(2),
|
||||
},
|
||||
_ => DEFAULT_SPACING,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_bsr_low_ad(ad: &AdIndexInfo) -> bool {
|
||||
let risk = ad
|
||||
.ad_adjacency_control
|
||||
.as_ref()
|
||||
.map(|c| c.brand_safety_risk())
|
||||
.unwrap_or(BrandSafetyRiskLevel::BsrUnknown);
|
||||
matches!(
|
||||
risk,
|
||||
BrandSafetyRiskLevel::BsrLow | BrandSafetyRiskLevel::BsrIas
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn should_drop_bsr_low(
|
||||
ad: &AdIndexInfo,
|
||||
above: Option<&ScoredPost>,
|
||||
below: Option<&ScoredPost>,
|
||||
) -> bool {
|
||||
let risk = ad
|
||||
.ad_adjacency_control
|
||||
.as_ref()
|
||||
.map(|c| c.brand_safety_risk())
|
||||
.unwrap_or(BrandSafetyRiskLevel::BsrUnknown);
|
||||
if !matches!(
|
||||
risk,
|
||||
BrandSafetyRiskLevel::BsrLow | BrandSafetyRiskLevel::BsrIas
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
let is_lr = |p: &ScoredPost| p.brand_safety_verdict() == BrandSafetyVerdict::LowRisk;
|
||||
above.map(is_lr).unwrap_or(false) || below.map(is_lr).unwrap_or(false)
|
||||
}
|
||||
|
||||
pub(crate) fn should_drop_handle(
|
||||
ad: &AdIndexInfo,
|
||||
above: Option<&ScoredPost>,
|
||||
below: Option<&ScoredPost>,
|
||||
) -> bool {
|
||||
let handles = match ad.ad_adjacency_control.as_ref() {
|
||||
Some(ctrl) if !ctrl.handles.is_empty() => &ctrl.handles,
|
||||
_ => return false,
|
||||
};
|
||||
above
|
||||
.map(|p| handles.contains(&(p.author_id as i64)))
|
||||
.unwrap_or(false)
|
||||
|| below
|
||||
.map(|p| handles.contains(&(p.author_id as i64)))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub(crate) fn should_drop_keyword(
|
||||
ad: &AdIndexInfo,
|
||||
above: Option<&ScoredPost>,
|
||||
below: Option<&ScoredPost>,
|
||||
) -> bool {
|
||||
let keywords = match ad.ad_adjacency_control.as_ref() {
|
||||
Some(ctrl) if !ctrl.keywords.is_empty() => &ctrl.keywords,
|
||||
_ => return false,
|
||||
};
|
||||
|
||||
let tokenizer = &*TWEET_TOKENIZER;
|
||||
|
||||
let tokenized_keywords: Vec<_> = keywords
|
||||
.iter()
|
||||
.map(|kw| tokenizer.tokenize(kw))
|
||||
.filter(|seq| !seq.is_empty())
|
||||
.collect();
|
||||
|
||||
if tokenized_keywords.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let text_matches = |p: &ScoredPost| {
|
||||
if p.tweet_text.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let tweet_tokens = tokenizer.tokenize(&p.tweet_text);
|
||||
if tweet_tokens.is_empty() {
|
||||
return false;
|
||||
}
|
||||
tokenized_keywords
|
||||
.iter()
|
||||
.any(|kw_tokens| tweet_tokens.contains_keyword_sequence(kw_tokens))
|
||||
};
|
||||
above.map(text_matches).unwrap_or(false) || below.map(text_matches).unwrap_or(false)
|
||||
}
|
||||
|
||||
pub(crate) fn posts_to_feed_items(scored_posts: Vec<ScoredPost>) -> Vec<FeedItem> {
|
||||
scored_posts
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, post)| FeedItem {
|
||||
position: i as i32,
|
||||
item: Some(feed_item::Item::Post(post)),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn interleave_and_finalize(
|
||||
scored_posts: Vec<ScoredPost>,
|
||||
ads: Vec<AdIndexInfo>,
|
||||
placements: &[usize],
|
||||
) -> Vec<FeedItem> {
|
||||
let n = scored_posts.len();
|
||||
let mut items: Vec<FeedItem> = Vec::with_capacity(n + placements.len());
|
||||
let mut ads_iter = ads.into_iter();
|
||||
let mut pi = 0;
|
||||
|
||||
for (i, post) in scored_posts.into_iter().enumerate() {
|
||||
if pi < placements.len() && placements[pi] == i {
|
||||
items.push(FeedItem {
|
||||
position: 0,
|
||||
item: Some(feed_item::Item::Ad(ads_iter.next().unwrap())),
|
||||
});
|
||||
pi += 1;
|
||||
}
|
||||
items.push(FeedItem {
|
||||
position: 0,
|
||||
item: Some(feed_item::Item::Post(post)),
|
||||
});
|
||||
}
|
||||
|
||||
items.truncate(RESULT_SIZE);
|
||||
if matches!(items.last(), Some(item) if matches!(item.item, Some(feed_item::Item::Ad(_)))) {
|
||||
items.pop();
|
||||
}
|
||||
|
||||
for (i, item) in items.iter_mut().enumerate() {
|
||||
item.position = i as i32;
|
||||
}
|
||||
|
||||
items
|
||||
}
|
||||
|
||||
const VERDICT_METRIC: &str = "AdsBlender.post_brand_safety_verdict";
|
||||
const RISK_METRIC: &str = "AdsBlender.ad_brand_safety_risk";
|
||||
|
||||
pub(crate) fn record_post_verdict_stats(posts: &[ScoredPost]) {
|
||||
let Some(receiver) = global_stats_receiver() else {
|
||||
return;
|
||||
};
|
||||
|
||||
for post in posts {
|
||||
let label = post.brand_safety_verdict().as_str_name();
|
||||
receiver.incr(VERDICT_METRIC, &[("verdict", label)], 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn record_ad_risk_stats(ads: &[AdIndexInfo]) {
|
||||
let Some(receiver) = global_stats_receiver() else {
|
||||
return;
|
||||
};
|
||||
|
||||
for ad in ads {
|
||||
let risk_level = ad
|
||||
.ad_adjacency_control
|
||||
.as_ref()
|
||||
.map(|c| c.brand_safety_risk())
|
||||
.unwrap_or(BrandSafetyRiskLevel::BsrUnknown);
|
||||
|
||||
receiver.incr(RISK_METRIC, &[("risk", risk_level.as_str_name())], 1);
|
||||
}
|
||||
}
|
||||
176
home-mixer/candidate_hydrators/ads_brand_safety_hydrator.rs
Normal file
176
home-mixer/candidate_hydrators/ads_brand_safety_hydrator.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use crate::models::brand_safety::{
|
||||
BrandSafetyVerdict, botmaker_rule_category, botmaker_rule_id_from, compute_verdict,
|
||||
truncate_description, worst_verdict,
|
||||
};
|
||||
use crate::models::candidate::{PostCandidate, SafetyLabelInfo};
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params::*;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::utils::{
|
||||
MokaCache, TweetAgeExpiry, build_moka_cache_tweet_age,
|
||||
};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
|
||||
use xai_safety_label_store::SafetyLabelStoreClient;
|
||||
|
||||
const CACHE_SIZE: u64 = 1_000_000;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct CachedBrandSafety {
|
||||
verdict: BrandSafetyVerdict,
|
||||
safety_labels: Vec<SafetyLabelInfo>,
|
||||
}
|
||||
pub struct AdsBrandSafetyHydrator {
|
||||
pub client: Arc<dyn SafetyLabelStoreClient>,
|
||||
pub cache: MokaCache<u64, CachedBrandSafety>,
|
||||
}
|
||||
|
||||
impl AdsBrandSafetyHydrator {
|
||||
pub fn new(client: Arc<dyn SafetyLabelStoreClient>) -> Self {
|
||||
let cache = build_moka_cache_tweet_age(
|
||||
CACHE_SIZE,
|
||||
TweetAgeExpiry {
|
||||
age_threshold: Duration::from_secs(5 * 60),
|
||||
new_tweet_ttl: Duration::from_secs(60),
|
||||
old_tweet_ttl: Duration::from_secs(60 * 60),
|
||||
},
|
||||
);
|
||||
Self { client, cache }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for AdsBrandSafetyHydrator {
|
||||
type CacheKey = u64;
|
||||
type CacheValue = CachedBrandSafety;
|
||||
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
query.params.get(EnableAdsBrandSafetyHydrator)
|
||||
&& !query
|
||||
.decider
|
||||
.as_ref()
|
||||
.is_some_and(|d| d.enabled("vf_brand_safety_dark_traffic"))
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
|
||||
&self.cache
|
||||
}
|
||||
|
||||
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
|
||||
candidate.tweet_id
|
||||
}
|
||||
|
||||
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
|
||||
CachedBrandSafety {
|
||||
verdict: hydrated.brand_safety_verdict.unwrap_or_default(),
|
||||
safety_labels: hydrated.safety_labels.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
|
||||
PostCandidate {
|
||||
brand_safety_verdict: Some(value.verdict),
|
||||
safety_labels: value.safety_labels,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn hydrate_from_client(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let mut all_ids: HashSet<u64> = HashSet::new();
|
||||
for c in candidates {
|
||||
all_ids.insert(c.retweeted_tweet_id.unwrap_or(c.tweet_id));
|
||||
if let Some(qt_id) = c.quoted_tweet_id {
|
||||
all_ids.insert(qt_id);
|
||||
}
|
||||
}
|
||||
|
||||
let all_ids_vec: Vec<u64> = all_ids.into_iter().collect();
|
||||
let all_ids_i64: Vec<i64> = all_ids_vec.iter().map(|&id| id as i64).collect();
|
||||
|
||||
let mut label_map: HashMap<u64, xai_safety_label_store::types::SafetyLabelMap> =
|
||||
HashMap::new();
|
||||
let mut error_map: HashMap<u64, String> = HashMap::new();
|
||||
|
||||
match self.client.batch_get_all_labels(&all_ids_i64).await {
|
||||
Ok(per_key_results) => {
|
||||
for (&id, result) in all_ids_vec.iter().zip(per_key_results) {
|
||||
match result {
|
||||
Ok(labels) => {
|
||||
label_map.insert(id, labels);
|
||||
}
|
||||
Err(e) => {
|
||||
error_map.insert(id, e.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
for &id in &all_ids_vec {
|
||||
error_map.insert(id, err_str.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let primary_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id);
|
||||
|
||||
if let Some(err) = error_map.get(&primary_id) {
|
||||
return Err(format!("safety label lookup error for tweet {primary_id}: {err}"));
|
||||
}
|
||||
|
||||
let empty = HashMap::new();
|
||||
let primary_labels = label_map.get(&primary_id).unwrap_or(&empty);
|
||||
let mut verdict = compute_verdict(primary_labels, primary_id);
|
||||
let mut safety_labels: Vec<SafetyLabelInfo> = primary_labels
|
||||
.iter()
|
||||
.map(|(k, v)| SafetyLabelInfo {
|
||||
label_type: *k,
|
||||
description: v.source.as_deref().map(truncate_description),
|
||||
source: botmaker_rule_id_from(v)
|
||||
.map(|id| botmaker_rule_category(id).to_string()),
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some(qt_id) = c.quoted_tweet_id {
|
||||
if error_map.contains_key(&qt_id) {
|
||||
verdict = worst_verdict(&verdict, &BrandSafetyVerdict::MediumRisk);
|
||||
} else {
|
||||
let qt_labels = label_map.get(&qt_id).unwrap_or(&empty);
|
||||
verdict = worst_verdict(&verdict, &compute_verdict(qt_labels, qt_id));
|
||||
safety_labels.extend(qt_labels.iter().map(|(k, v)| {
|
||||
SafetyLabelInfo {
|
||||
label_type: *k,
|
||||
description: v.source.as_deref().map(truncate_description),
|
||||
source: botmaker_rule_id_from(v)
|
||||
.map(|id| botmaker_rule_category(id).to_string()),
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
safety_labels.sort_unstable_by_key(|l| i32::from(l.label_type));
|
||||
safety_labels.dedup_by(|a, b| a.label_type == b.label_type);
|
||||
|
||||
Ok(PostCandidate {
|
||||
brand_safety_verdict: Some(verdict),
|
||||
safety_labels,
|
||||
..Default::default()
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.brand_safety_verdict = hydrated.brand_safety_verdict;
|
||||
candidate.safety_labels = hydrated.safety_labels;
|
||||
}
|
||||
}
|
||||
108
home-mixer/candidate_hydrators/ads_brand_safety_vf_hydrator.rs
Normal file
108
home-mixer/candidate_hydrators/ads_brand_safety_vf_hydrator.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use crate::models::brand_safety::{
|
||||
BrandSafetyVerdict, botmaker_rule_category, botmaker_rule_id_from, compute_verdict,
|
||||
truncate_description, worst_verdict,
|
||||
};
|
||||
use crate::models::candidate::{PostCandidate, SafetyLabelInfo};
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params::*;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_visibility_filtering::vf_safety_labels_client::VfClient;
|
||||
|
||||
pub struct AdsBrandSafetyVfHydrator {
|
||||
pub client: Arc<dyn VfClient>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for AdsBrandSafetyVfHydrator {
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
query.params.get(EnableAdsBrandSafetyHydrator)
|
||||
&& query
|
||||
.decider
|
||||
.as_ref()
|
||||
.is_some_and(|d| d.enabled("vf_brand_safety_dark_traffic"))
|
||||
}
|
||||
|
||||
async fn hydrate(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let mut all_ids: HashSet<u64> = HashSet::new();
|
||||
for c in candidates {
|
||||
all_ids.insert(c.retweeted_tweet_id.unwrap_or(c.tweet_id));
|
||||
if let Some(qt_id) = c.quoted_tweet_id {
|
||||
all_ids.insert(qt_id);
|
||||
}
|
||||
}
|
||||
|
||||
let tweet_ids: Vec<u64> = all_ids.into_iter().collect();
|
||||
let batch = match self.client.get_safety_labels(tweet_ids).await {
|
||||
Ok(batch) => batch,
|
||||
Err(e) => {
|
||||
let err = format!("VF get_safety_labels failed: {e}");
|
||||
return candidates.iter().map(|_| Err(err.clone())).collect();
|
||||
}
|
||||
};
|
||||
|
||||
let failed_ids: HashSet<u64> = batch.failures.keys().copied().collect();
|
||||
let label_map = batch.labels;
|
||||
|
||||
candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let primary_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id);
|
||||
|
||||
if failed_ids.contains(&primary_id) {
|
||||
return Err(format!("VF lookup failed for tweet {primary_id}"));
|
||||
}
|
||||
|
||||
let empty = HashMap::new();
|
||||
let primary_labels = label_map.get(&primary_id).unwrap_or(&empty);
|
||||
let mut verdict = compute_verdict(primary_labels, primary_id);
|
||||
let mut safety_labels: Vec<SafetyLabelInfo> = primary_labels
|
||||
.iter()
|
||||
.map(|(k, v)| SafetyLabelInfo {
|
||||
label_type: *k,
|
||||
description: v.source.as_deref().map(truncate_description),
|
||||
source: botmaker_rule_id_from(v)
|
||||
.map(|id| botmaker_rule_category(id).to_string()),
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some(qt_id) = c.quoted_tweet_id {
|
||||
if failed_ids.contains(&qt_id) {
|
||||
verdict = worst_verdict(&verdict, &BrandSafetyVerdict::MediumRisk);
|
||||
} else {
|
||||
let qt_labels = label_map.get(&qt_id).unwrap_or(&empty);
|
||||
verdict = worst_verdict(&verdict, &compute_verdict(qt_labels, qt_id));
|
||||
safety_labels.extend(qt_labels.iter().map(|(k, v)| {
|
||||
SafetyLabelInfo {
|
||||
label_type: *k,
|
||||
description: v.source.as_deref().map(truncate_description),
|
||||
source: botmaker_rule_id_from(v)
|
||||
.map(|id| botmaker_rule_category(id).to_string()),
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
safety_labels.sort_unstable_by_key(|l| i32::from(l.label_type));
|
||||
safety_labels.dedup_by(|a, b| a.label_type == b.label_type);
|
||||
|
||||
Ok(PostCandidate {
|
||||
brand_safety_verdict: Some(verdict),
|
||||
safety_labels,
|
||||
..Default::default()
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.brand_safety_verdict = hydrated.brand_safety_verdict;
|
||||
candidate.safety_labels = hydrated.safety_labels;
|
||||
}
|
||||
}
|
||||
57
home-mixer/candidate_hydrators/blocked_by_hydrator.rs
Normal file
57
home-mixer/candidate_hydrators/blocked_by_hydrator.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::clients::SocialGraphClientOps;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
|
||||
pub struct BlockedByHydrator {
|
||||
pub socialgraph_client: Arc<dyn SocialGraphClientOps>,
|
||||
}
|
||||
|
||||
impl BlockedByHydrator {
|
||||
pub async fn new(socialgraph_client: Arc<dyn SocialGraphClientOps>) -> Self {
|
||||
Self { socialgraph_client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for BlockedByHydrator {
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
!query.has_cached_posts
|
||||
}
|
||||
|
||||
async fn hydrate(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let author_ids: Vec<u64> = candidates.iter().map(|x| x.author_id).collect();
|
||||
|
||||
let blocked_by_user_ids = match self
|
||||
.socialgraph_client
|
||||
.check_blocked_by(query.user_id, &author_ids)
|
||||
.await
|
||||
{
|
||||
Ok(ids) => ids,
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
return candidates.iter().map(|_| Err(err_msg.clone())).collect();
|
||||
}
|
||||
};
|
||||
candidates
|
||||
.iter()
|
||||
.map(|candidate| {
|
||||
let author_blocks_viewer = blocked_by_user_ids.contains(&candidate.author_id);
|
||||
Ok(PostCandidate {
|
||||
author_blocks_viewer: Some(author_blocks_viewer),
|
||||
..Default::default()
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.author_blocks_viewer = hydrated.author_blocks_viewer;
|
||||
}
|
||||
}
|
||||
@@ -1,52 +1,108 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::clients::tweet_entity_service_client::TESClient;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
|
||||
use xai_stats_receiver::global_stats_receiver;
|
||||
|
||||
const FOUND_SCOPE: [(&str, &str); 1] = [("hydration", "found")];
|
||||
const MISSING_SCOPE: [(&str, &str); 1] = [("hydration", "missing")];
|
||||
|
||||
pub struct CoreDataCandidateHydrator {
|
||||
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
pub cache: MokaCache<u64, CoreDataCacheValue>,
|
||||
}
|
||||
|
||||
impl CoreDataCandidateHydrator {
|
||||
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||
Self { tes_client }
|
||||
let cache = default_moka_cache();
|
||||
Self { tes_client, cache }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
|
||||
#[xai_stats_macro::receive_stats]
|
||||
async fn hydrate(
|
||||
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
|
||||
type CacheKey = u64;
|
||||
|
||||
type CacheValue = CoreDataCacheValue;
|
||||
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
!query.has_cached_posts
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
|
||||
&self.cache
|
||||
}
|
||||
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
|
||||
candidate.tweet_id
|
||||
}
|
||||
|
||||
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
|
||||
CoreDataCacheValue {
|
||||
author_id: hydrated.author_id,
|
||||
retweeted_user_id: hydrated.retweeted_user_id,
|
||||
retweeted_tweet_id: hydrated.retweeted_tweet_id,
|
||||
in_reply_to_tweet_id: hydrated.in_reply_to_tweet_id,
|
||||
tweet_text: hydrated.tweet_text.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
|
||||
PostCandidate {
|
||||
author_id: value.author_id,
|
||||
retweeted_user_id: value.retweeted_user_id,
|
||||
retweeted_tweet_id: value.retweeted_tweet_id,
|
||||
in_reply_to_tweet_id: value.in_reply_to_tweet_id,
|
||||
tweet_text: value.tweet_text,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn hydrate_from_client(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Result<Vec<PostCandidate>, String> {
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let client = &self.tes_client;
|
||||
|
||||
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
|
||||
let tweet_ids: Vec<u64> = candidates.iter().map(|c| c.tweet_id).collect();
|
||||
|
||||
let post_features = client.get_tweet_core_datas(tweet_ids.clone()).await;
|
||||
let post_features = post_features.map_err(|e| e.to_string())?;
|
||||
|
||||
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||
let mut hydrated_count = 0usize;
|
||||
let mut missing_count = 0usize;
|
||||
for tweet_id in tweet_ids {
|
||||
let post_features = post_features.get(&tweet_id);
|
||||
let core_data = post_features.and_then(|x| x.as_ref());
|
||||
let text = core_data.map(|x| x.text.clone());
|
||||
let hydrated = PostCandidate {
|
||||
author_id: core_data.map(|x| x.author_id).unwrap_or_default(),
|
||||
retweeted_user_id: core_data.and_then(|x| x.source_user_id),
|
||||
retweeted_tweet_id: core_data.and_then(|x| x.source_tweet_id),
|
||||
in_reply_to_tweet_id: core_data.and_then(|x| x.in_reply_to_tweet_id),
|
||||
tweet_text: text.unwrap_or_default(),
|
||||
..Default::default()
|
||||
};
|
||||
hydrated_candidates.push(hydrated);
|
||||
match post_features {
|
||||
Some(Ok(Some(core_data))) => {
|
||||
hydrated_count += 1;
|
||||
let text = core_data.text.clone();
|
||||
let hydrated = PostCandidate {
|
||||
author_id: core_data.author_id,
|
||||
retweeted_user_id: core_data.source_user_id,
|
||||
retweeted_tweet_id: core_data.source_tweet_id,
|
||||
in_reply_to_tweet_id: core_data.in_reply_to_tweet_id,
|
||||
tweet_text: text,
|
||||
..Default::default()
|
||||
};
|
||||
hydrated_candidates.push(Ok(hydrated));
|
||||
}
|
||||
Some(Ok(None)) | None => {
|
||||
missing_count += 1;
|
||||
hydrated_candidates.push(Ok(PostCandidate::default()));
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
hydrated_candidates.push(Err(err.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hydrated_candidates)
|
||||
self.record_hydration_stats(hydrated_count, missing_count);
|
||||
|
||||
hydrated_candidates
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
@@ -56,3 +112,22 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
|
||||
candidate.tweet_text = hydrated.tweet_text;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CoreDataCacheValue {
|
||||
pub author_id: u64,
|
||||
pub retweeted_user_id: Option<u64>,
|
||||
pub retweeted_tweet_id: Option<u64>,
|
||||
pub in_reply_to_tweet_id: Option<u64>,
|
||||
pub tweet_text: String,
|
||||
}
|
||||
|
||||
impl CoreDataCandidateHydrator {
|
||||
fn record_hydration_stats(&self, hydrated_count: usize, missing_count: usize) {
|
||||
if let Some(receiver) = global_stats_receiver() {
|
||||
let metric_name = format!("{}.hydrate", self.name());
|
||||
receiver.incr(metric_name.as_str(), &FOUND_SCOPE, hydrated_count as u64);
|
||||
receiver.incr(metric_name.as_str(), &MISSING_SCOPE, missing_count as u64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
113
home-mixer/candidate_hydrators/engagement_counts_hydrator.rs
Normal file
113
home-mixer/candidate_hydrators/engagement_counts_hydrator.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use crate::clients::tweet_entity_service_client::TESClient;
|
||||
use crate::models::candidate::{CandidateHelpers, PostCandidate};
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params::EnableContextFeatures;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::utils::{
|
||||
MokaCache, TweetAgeExpiry, build_moka_cache_tweet_age,
|
||||
};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedCounts {
|
||||
fav_count: Option<i64>,
|
||||
reply_count: Option<i64>,
|
||||
repost_count: Option<i64>,
|
||||
quote_count: Option<i64>,
|
||||
}
|
||||
|
||||
pub struct EngagementCountsHydrator {
|
||||
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
cache: MokaCache<u64, CachedCounts>,
|
||||
}
|
||||
|
||||
impl EngagementCountsHydrator {
|
||||
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||
let cache = build_moka_cache_tweet_age(
|
||||
1_000_000,
|
||||
TweetAgeExpiry {
|
||||
age_threshold: Duration::from_secs(30 * 60),
|
||||
new_tweet_ttl: Duration::from_secs(5 * 60),
|
||||
old_tweet_ttl: Duration::from_secs(10 * 60),
|
||||
},
|
||||
);
|
||||
Self { tes_client, cache }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for EngagementCountsHydrator {
|
||||
type CacheKey = u64;
|
||||
type CacheValue = CachedCounts;
|
||||
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
(query.params.get(EnableContextFeatures) || query.is_shadow_traffic)
|
||||
&& !query.has_cached_posts
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
|
||||
&self.cache
|
||||
}
|
||||
|
||||
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
|
||||
candidate.get_original_tweet_id()
|
||||
}
|
||||
|
||||
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
|
||||
CachedCounts {
|
||||
fav_count: hydrated.fav_count,
|
||||
reply_count: hydrated.reply_count,
|
||||
repost_count: hydrated.repost_count,
|
||||
quote_count: hydrated.quote_count,
|
||||
}
|
||||
}
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
|
||||
PostCandidate {
|
||||
fav_count: value.fav_count,
|
||||
reply_count: value.reply_count,
|
||||
repost_count: value.repost_count,
|
||||
quote_count: value.quote_count,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn hydrate_from_client(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let tweet_ids: Vec<u64> = candidates
|
||||
.iter()
|
||||
.map(|c| c.get_original_tweet_id())
|
||||
.collect();
|
||||
|
||||
let counts_results = self.tes_client.get_api_counts(tweet_ids.clone()).await;
|
||||
|
||||
tweet_ids
|
||||
.iter()
|
||||
.map(|tweet_id| {
|
||||
let counts = counts_results
|
||||
.get(tweet_id)
|
||||
.and_then(|r| r.as_ref().ok())
|
||||
.and_then(|opt| opt.as_ref());
|
||||
Ok(PostCandidate {
|
||||
fav_count: counts.and_then(|c| c.favorite_count),
|
||||
reply_count: counts.and_then(|c| c.reply_count),
|
||||
repost_count: counts.and_then(|c| c.retweet_count),
|
||||
quote_count: counts.and_then(|c| c.quote_count),
|
||||
..Default::default()
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.fav_count = hydrated.fav_count;
|
||||
candidate.reply_count = hydrated.reply_count;
|
||||
candidate.repost_count = hydrated.repost_count;
|
||||
candidate.quote_count = hydrated.quote_count;
|
||||
}
|
||||
}
|
||||
134
home-mixer/candidate_hydrators/filtered_topics_hydrator.rs
Normal file
134
home-mixer/candidate_hydrators/filtered_topics_hydrator.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use crate::filters::topic_ids_filter::TopicFilteringOverrideMap;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::candidate_features::{FilteredTopicsByExperiment, TopicFilteringExperiment};
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params::{EnableNewUserTopicFiltering, TopicFilteringId, TopicFilteringOverrides};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use tracing::warn;
|
||||
use xai_candidate_pipeline::component_library::clients::StratoClient;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_strato::{StratoResult, StratoValue, decode};
|
||||
|
||||
fn decode_topics_pair(
|
||||
result: &Result<Vec<u8>, Box<dyn std::error::Error>>,
|
||||
experiment: TopicFilteringExperiment,
|
||||
need_unfiltered: bool,
|
||||
) -> (Option<Vec<i64>>, Option<Vec<i64>>) {
|
||||
match result {
|
||||
Ok(bytes) if !bytes.is_empty() => {
|
||||
let decoded: StratoResult<StratoValue<FilteredTopicsByExperiment>> = decode(bytes);
|
||||
match decoded {
|
||||
StratoResult::Ok(v) => {
|
||||
let ft = v.v;
|
||||
let exp_topics = ft
|
||||
.as_ref()
|
||||
.and_then(|ft| ft.topic_ids_for_experiment(experiment).cloned());
|
||||
let unf_topics = if need_unfiltered {
|
||||
ft.as_ref().and_then(|ft| {
|
||||
ft.topic_ids_for_experiment(TopicFilteringExperiment::Unfiltered)
|
||||
.cloned()
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(exp_topics, unf_topics)
|
||||
}
|
||||
StratoResult::Err(_) => (None, None),
|
||||
}
|
||||
}
|
||||
Ok(_) => (None, None),
|
||||
Err(e) => {
|
||||
warn!("FilteredTopicsHydrator: strato fetch error: {}", e);
|
||||
(None, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FilteredTopicsHydrator {
|
||||
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for FilteredTopicsHydrator {
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
query.is_topic_request()
|
||||
|| query.has_excluded_topics()
|
||||
|| (query.params.get(EnableNewUserTopicFiltering) && query.has_new_user_topic_ids())
|
||||
}
|
||||
|
||||
async fn hydrate(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let experiment = if query.is_bulk_topic_request() || query.has_excluded_topics() {
|
||||
TopicFilteringExperiment::Unfiltered
|
||||
} else {
|
||||
let default_experiment =
|
||||
TopicFilteringExperiment::parse(&query.params.get(TopicFilteringId));
|
||||
let override_map =
|
||||
TopicFilteringOverrideMap::parse(&query.params.get(TopicFilteringOverrides));
|
||||
override_map.resolve(&query.topic_ids, default_experiment)
|
||||
};
|
||||
|
||||
let client = &self.strato_client;
|
||||
let need_unfiltered = experiment != TopicFilteringExperiment::Unfiltered;
|
||||
|
||||
let mut all_ids: Vec<u64> = candidates.iter().map(|c| c.tweet_id).collect();
|
||||
let retweet_offset = all_ids.len();
|
||||
for c in candidates {
|
||||
if let Some(rt_id) = c.retweeted_tweet_id {
|
||||
all_ids.push(rt_id);
|
||||
}
|
||||
}
|
||||
|
||||
let all_results = client
|
||||
.batch_get_filtered_topics_by_experiment(&all_ids)
|
||||
.await;
|
||||
|
||||
let mut retweet_topics: HashMap<u64, Vec<i64>> = HashMap::new();
|
||||
let mut retweet_unfiltered: HashMap<u64, Vec<i64>> = HashMap::new();
|
||||
let mut rt_idx = retweet_offset;
|
||||
for c in candidates {
|
||||
if let Some(rt_id) = c.retweeted_tweet_id {
|
||||
let (exp, unf) =
|
||||
decode_topics_pair(&all_results[rt_idx], experiment, need_unfiltered);
|
||||
if let Some(topics) = exp {
|
||||
retweet_topics.insert(rt_id, topics);
|
||||
}
|
||||
if let Some(topics) = unf {
|
||||
retweet_unfiltered.insert(rt_id, topics);
|
||||
}
|
||||
rt_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
candidates
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| {
|
||||
let (topics, unf_topics) = if let Some(rt_id) = c.retweeted_tweet_id {
|
||||
(
|
||||
retweet_topics.get(&rt_id).cloned(),
|
||||
retweet_unfiltered.get(&rt_id).cloned(),
|
||||
)
|
||||
} else {
|
||||
decode_topics_pair(&all_results[i], experiment, need_unfiltered)
|
||||
};
|
||||
|
||||
Ok(PostCandidate {
|
||||
filtered_topic_ids: topics,
|
||||
unfiltered_topic_ids: unf_topics,
|
||||
..Default::default()
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.filtered_topic_ids = hydrated.filtered_topic_ids;
|
||||
candidate.unfiltered_topic_ids = hydrated.unfiltered_topic_ids;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::in_network_reply::InNetworkReply;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params::{
|
||||
EnableFollowingRepliedUsersFacepile, FollowingRepliedUsersFacepileMaxPosts,
|
||||
FollowingRepliedUsersFacepileMinUsers,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
|
||||
const VIEWER_FOLLOWERS_THRESHOLD: i64 = 1000;
|
||||
|
||||
pub struct FollowingRepliedUsersHydrator;
|
||||
|
||||
impl FollowingRepliedUsersHydrator {
|
||||
fn build_reply_author_map(replies: &[InNetworkReply]) -> HashMap<u64, Vec<u64>> {
|
||||
let mut map: HashMap<u64, Vec<u64>> = HashMap::new();
|
||||
for reply in replies {
|
||||
map.entry(reply.in_reply_to_tweet_id)
|
||||
.or_default()
|
||||
.push(reply.author_id);
|
||||
}
|
||||
for authors in map.values_mut() {
|
||||
authors.sort_unstable();
|
||||
authors.dedup();
|
||||
}
|
||||
map
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for FollowingRepliedUsersHydrator {
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
let has_enough_followers = query
|
||||
.user_features
|
||||
.follower_count
|
||||
.is_some_and(|c| c >= VIEWER_FOLLOWERS_THRESHOLD);
|
||||
|
||||
has_enough_followers && query.params.get(EnableFollowingRepliedUsersFacepile)
|
||||
}
|
||||
|
||||
async fn hydrate(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let min_users = query
|
||||
.params
|
||||
.get(FollowingRepliedUsersFacepileMinUsers)
|
||||
.max(0) as usize;
|
||||
let max_posts = query
|
||||
.params
|
||||
.get(FollowingRepliedUsersFacepileMaxPosts)
|
||||
.max(0) as usize;
|
||||
|
||||
let empty = Vec::new();
|
||||
let replies = query.in_network_replies.get().unwrap_or(&empty);
|
||||
|
||||
let reply_author_map = Self::build_reply_author_map(replies);
|
||||
|
||||
let mut results = Vec::with_capacity(candidates.len());
|
||||
let mut selected_count: usize = 0;
|
||||
|
||||
for candidate in candidates {
|
||||
let is_root_tweet = candidate.in_reply_to_tweet_id.is_none();
|
||||
|
||||
let authors: Vec<u64> = reply_author_map
|
||||
.get(&candidate.tweet_id)
|
||||
.map(|ids| {
|
||||
ids.iter()
|
||||
.copied()
|
||||
.filter(|&aid| aid != candidate.author_id)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let eligible = is_root_tweet && authors.len() >= min_users;
|
||||
let user_ids = if eligible && selected_count < max_posts {
|
||||
selected_count += 1;
|
||||
authors
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
results.push(Ok(PostCandidate {
|
||||
following_replied_user_ids: user_ids,
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.following_replied_user_ids = hydrated.following_replied_user_ids;
|
||||
}
|
||||
}
|
||||
@@ -1,28 +1,67 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::clients::gizmoduck_client::GizmoduckClient;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
|
||||
|
||||
pub struct GizmoduckCandidateHydrator {
|
||||
pub gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
|
||||
pub cache: MokaCache<GizmoduckCacheKey, GizmoduckCacheValue>,
|
||||
}
|
||||
|
||||
impl GizmoduckCandidateHydrator {
|
||||
pub async fn new(gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>) -> Self {
|
||||
Self { gizmoduck_client }
|
||||
let cache = default_moka_cache();
|
||||
Self {
|
||||
gizmoduck_client,
|
||||
cache,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
|
||||
#[xai_stats_macro::receive_stats]
|
||||
async fn hydrate(
|
||||
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
|
||||
type CacheKey = GizmoduckCacheKey;
|
||||
type CacheValue = GizmoduckCacheValue;
|
||||
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
!query.has_cached_posts
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
|
||||
&self.cache
|
||||
}
|
||||
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
|
||||
GizmoduckCacheKey {
|
||||
author_id: candidate.author_id,
|
||||
retweeted_user_id: candidate.retweeted_user_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
|
||||
GizmoduckCacheValue {
|
||||
author_followers_count: hydrated.author_followers_count,
|
||||
author_screen_name: hydrated.author_screen_name.clone(),
|
||||
retweeted_screen_name: hydrated.retweeted_screen_name.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
|
||||
PostCandidate {
|
||||
author_followers_count: value.author_followers_count,
|
||||
author_screen_name: value.author_screen_name,
|
||||
retweeted_screen_name: value.retweeted_screen_name,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn hydrate_from_client(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Result<Vec<PostCandidate>, String> {
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let client = &self.gizmoduck_client;
|
||||
|
||||
let author_ids: Vec<_> = candidates.iter().map(|c| c.author_id).collect();
|
||||
@@ -37,40 +76,54 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
|
||||
user_ids_to_fetch.dedup();
|
||||
|
||||
let users = client.get_users(user_ids_to_fetch).await;
|
||||
let users = users.map_err(|e| e.to_string())?;
|
||||
|
||||
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||
|
||||
for candidate in candidates {
|
||||
let user = users
|
||||
.get(&(candidate.author_id as i64))
|
||||
.and_then(|user| user.as_ref());
|
||||
let user_counts = user.and_then(|user| user.user.as_ref().map(|u| &u.counts));
|
||||
let user_profile = user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
|
||||
|
||||
let author_followers_count: Option<i32> =
|
||||
user_counts.map(|x| x.followers_count).map(|x| x as i32);
|
||||
let author_screen_name: Option<String> = user_profile.map(|x| x.screen_name.clone());
|
||||
let user = users.get(&(candidate.author_id as i64));
|
||||
let user = match user {
|
||||
Some(Ok(Some(user))) => Ok(Some(user)),
|
||||
Some(Ok(None)) | None => Ok(None),
|
||||
Some(Err(err)) => Err(err.to_string()),
|
||||
};
|
||||
|
||||
let retweet_user = candidate
|
||||
.retweeted_user_id
|
||||
.and_then(|retweeted_user_id| users.get(&(retweeted_user_id as i64)))
|
||||
.and_then(|user| user.as_ref());
|
||||
let retweet_profile =
|
||||
retweet_user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
|
||||
let retweeted_screen_name: Option<String> =
|
||||
retweet_profile.map(|x| x.screen_name.clone());
|
||||
.and_then(|retweeted_user_id| users.get(&(retweeted_user_id as i64)));
|
||||
let retweet_user = match retweet_user {
|
||||
Some(Ok(Some(user))) => Ok(Some(user)),
|
||||
Some(Ok(None)) | None => Ok(None),
|
||||
Some(Err(err)) => Err(err.to_string()),
|
||||
};
|
||||
|
||||
let hydrated = PostCandidate {
|
||||
author_followers_count,
|
||||
author_screen_name,
|
||||
retweeted_screen_name,
|
||||
..Default::default()
|
||||
let hydrated = match (user, retweet_user) {
|
||||
(Ok(user), Ok(retweet_user)) => {
|
||||
let user_counts = user.and_then(|user| user.user.as_ref().map(|u| &u.counts));
|
||||
let user_profile = user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
|
||||
|
||||
let author_followers_count: Option<i32> =
|
||||
user_counts.map(|x| x.followers_count).map(|x| x as i32);
|
||||
let author_screen_name: Option<String> =
|
||||
user_profile.map(|x| x.screen_name.clone());
|
||||
|
||||
let retweet_profile =
|
||||
retweet_user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
|
||||
let retweeted_screen_name: Option<String> =
|
||||
retweet_profile.map(|x| x.screen_name.clone());
|
||||
|
||||
Ok(PostCandidate {
|
||||
author_followers_count,
|
||||
author_screen_name,
|
||||
retweeted_screen_name,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
(Err(err), _) | (_, Err(err)) => Err(err),
|
||||
};
|
||||
hydrated_candidates.push(hydrated);
|
||||
}
|
||||
|
||||
Ok(hydrated_candidates)
|
||||
hydrated_candidates
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
@@ -79,3 +132,16 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
|
||||
candidate.retweeted_screen_name = hydrated.retweeted_screen_name;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub struct GizmoduckCacheKey {
|
||||
pub author_id: u64,
|
||||
pub retweeted_user_id: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GizmoduckCacheValue {
|
||||
pub author_followers_count: Option<i32>,
|
||||
pub author_screen_name: Option<String>,
|
||||
pub retweeted_screen_name: Option<String>,
|
||||
}
|
||||
|
||||
85
home-mixer/candidate_hydrators/has_media_hydrator.rs
Normal file
85
home-mixer/candidate_hydrators/has_media_hydrator.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use crate::clients::tweet_entity_service_client::TESClient;
|
||||
use crate::models::candidate::{CandidateHelpers, PostCandidate};
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params::EnableHasMediaHydration;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
|
||||
|
||||
pub struct HasMediaHydrator {
|
||||
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
pub cache: MokaCache<u64, Option<bool>>,
|
||||
}
|
||||
|
||||
impl HasMediaHydrator {
|
||||
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||
let cache = default_moka_cache();
|
||||
Self { tes_client, cache }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for HasMediaHydrator {
|
||||
type CacheKey = u64;
|
||||
|
||||
type CacheValue = Option<bool>;
|
||||
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
(query.params.get(EnableHasMediaHydration) || query.is_shadow_traffic)
|
||||
&& !query.has_cached_posts
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
|
||||
&self.cache
|
||||
}
|
||||
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
|
||||
candidate.get_original_tweet_id()
|
||||
}
|
||||
|
||||
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
|
||||
hydrated.has_media
|
||||
}
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
|
||||
PostCandidate {
|
||||
has_media: value,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn hydrate_from_client(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let client = &self.tes_client;
|
||||
|
||||
let tweet_ids: Vec<u64> = candidates
|
||||
.iter()
|
||||
.map(|c| c.get_original_tweet_id())
|
||||
.collect();
|
||||
|
||||
let has_media_results = client.get_has_media(tweet_ids.clone()).await;
|
||||
|
||||
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||
for tweet_id in tweet_ids {
|
||||
let result = has_media_results.get(&tweet_id);
|
||||
let hydrated = match result {
|
||||
Some(Ok(value)) => Ok(PostCandidate {
|
||||
has_media: *value,
|
||||
..Default::default()
|
||||
}),
|
||||
None => Err(format!("Missing has_media for tweet_id={}", tweet_id)),
|
||||
Some(Err(err)) => Err(err.to_string()),
|
||||
};
|
||||
hydrated_candidates.push(hydrated);
|
||||
}
|
||||
|
||||
hydrated_candidates
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.has_media = hydrated.has_media;
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::collections::HashSet;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
@@ -8,13 +8,16 @@ pub struct InNetworkCandidateHydrator;
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for InNetworkCandidateHydrator {
|
||||
#[xai_stats_macro::receive_stats]
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
!query.has_cached_posts
|
||||
}
|
||||
|
||||
async fn hydrate(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Result<Vec<PostCandidate>, String> {
|
||||
let viewer_id = query.user_id as u64;
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let viewer_id = query.user_id;
|
||||
let followed_ids: HashSet<u64> = query
|
||||
.user_features
|
||||
.followed_user_ids
|
||||
@@ -23,19 +26,17 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for InNetworkCandidateHydrator {
|
||||
.map(|id| id as u64)
|
||||
.collect();
|
||||
|
||||
let hydrated_candidates = candidates
|
||||
candidates
|
||||
.iter()
|
||||
.map(|candidate| {
|
||||
let is_self = candidate.author_id == viewer_id;
|
||||
let is_in_network = is_self || followed_ids.contains(&candidate.author_id);
|
||||
PostCandidate {
|
||||
Ok(PostCandidate {
|
||||
in_network: Some(is_in_network),
|
||||
..Default::default()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(hydrated_candidates)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
|
||||
84
home-mixer/candidate_hydrators/language_code_hydrator.rs
Normal file
84
home-mixer/candidate_hydrators/language_code_hydrator.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use crate::clients::tweet_entity_service_client::TESClient;
|
||||
use crate::models::candidate::{CandidateHelpers, PostCandidate};
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
|
||||
|
||||
pub struct LanguageCodeHydrator {
|
||||
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
pub cache: MokaCache<u64, Option<String>>,
|
||||
}
|
||||
|
||||
impl LanguageCodeHydrator {
|
||||
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||
let cache = default_moka_cache();
|
||||
Self { tes_client, cache }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for LanguageCodeHydrator {
|
||||
type CacheKey = u64;
|
||||
|
||||
type CacheValue = Option<String>;
|
||||
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
!query.has_cached_posts
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
|
||||
&self.cache
|
||||
}
|
||||
|
||||
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
|
||||
candidate.get_original_tweet_id()
|
||||
}
|
||||
|
||||
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
|
||||
hydrated.language_code.clone()
|
||||
}
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
|
||||
PostCandidate {
|
||||
language_code: value,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn hydrate_from_client(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let client = &self.tes_client;
|
||||
|
||||
let tweet_ids: Vec<u64> = candidates
|
||||
.iter()
|
||||
.map(|c| c.get_original_tweet_id())
|
||||
.collect();
|
||||
|
||||
let language_results = client.get_language_code(tweet_ids.clone()).await;
|
||||
|
||||
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||
for tweet_id in tweet_ids {
|
||||
let result = language_results.get(&tweet_id);
|
||||
let hydrated = match result {
|
||||
Some(Ok(value)) => Ok(PostCandidate {
|
||||
language_code: value.clone(),
|
||||
..Default::default()
|
||||
}),
|
||||
None => Err(format!("Missing language_code for tweet_id={}", tweet_id)),
|
||||
Some(Err(err)) => Err(err.to_string()),
|
||||
};
|
||||
hydrated_candidates.push(hydrated);
|
||||
}
|
||||
|
||||
hydrated_candidates
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.language_code = hydrated.language_code;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,16 @@
|
||||
pub mod ads_brand_safety_hydrator;
|
||||
pub mod ads_brand_safety_vf_hydrator;
|
||||
pub mod blocked_by_hydrator;
|
||||
pub mod core_data_candidate_hydrator;
|
||||
pub mod filtered_topics_hydrator;
|
||||
pub mod following_replied_users_hydrator;
|
||||
pub mod gizmoduck_hydrator;
|
||||
pub mod has_media_hydrator;
|
||||
pub mod in_network_candidate_hydrator;
|
||||
pub mod language_code_hydrator;
|
||||
pub mod mutual_follow_jaccard_hydrator;
|
||||
pub mod quote_hydrator;
|
||||
pub mod subscription_hydrator;
|
||||
pub mod tweet_type_metrics_hydrator;
|
||||
pub mod vf_candidate_hydrator;
|
||||
pub mod video_duration_candidate_hydrator;
|
||||
|
||||
118
home-mixer/candidate_hydrators/mutual_follow_jaccard_hydrator.rs
Normal file
118
home-mixer/candidate_hydrators/mutual_follow_jaccard_hydrator.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params::EnableMutualFollowJaccardHydration;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::clients::StratoClient;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
|
||||
const MIN_HASHES: usize = 256;
|
||||
|
||||
pub struct MutualFollowJaccardHydrator {
|
||||
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
|
||||
}
|
||||
|
||||
fn jaccard_from_minhash(a: &[i64], b: &[i64]) -> f64 {
|
||||
let len = a.len().min(b.len());
|
||||
if len == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
let matching = a.iter().zip(b.iter()).filter(|(x, y)| x == y).count();
|
||||
matching as f64 / len as f64
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for MutualFollowJaccardHydrator {
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
query.params.get(EnableMutualFollowJaccardHydration) && query.viewer_minhash.is_some()
|
||||
}
|
||||
|
||||
async fn hydrate(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let viewer_minhash = match &query.viewer_minhash {
|
||||
Some(mh) if mh.len() >= MIN_HASHES => mh,
|
||||
_ => {
|
||||
return candidates
|
||||
.iter()
|
||||
.map(|_| {
|
||||
Ok(PostCandidate {
|
||||
mutual_follow_jaccard: None,
|
||||
..Default::default()
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
};
|
||||
|
||||
let unique_author_ids: Vec<i64> = candidates
|
||||
.iter()
|
||||
.map(|c| c.author_id as i64)
|
||||
.collect::<HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let results = self
|
||||
.strato_client
|
||||
.batch_get_minhash_with_count(&unique_author_ids)
|
||||
.await;
|
||||
|
||||
let mut author_result: HashMap<i64, Result<Option<Vec<i64>>, String>> = HashMap::new();
|
||||
for (uid, result) in unique_author_ids.iter().zip(results) {
|
||||
match result {
|
||||
Ok(Some((minhash, _count))) if minhash.len() >= MIN_HASHES => {
|
||||
author_result.insert(*uid, Ok(Some(minhash)));
|
||||
}
|
||||
Ok(Some((minhash, _))) => {
|
||||
author_result.insert(
|
||||
*uid,
|
||||
Err(format!(
|
||||
"Invalid minhash length {} (need >= {}) for author_id={}",
|
||||
minhash.len(),
|
||||
MIN_HASHES,
|
||||
uid,
|
||||
)),
|
||||
);
|
||||
}
|
||||
Ok(None) => {
|
||||
author_result.insert(*uid, Ok(None));
|
||||
}
|
||||
Err(e) => {
|
||||
author_result.insert(*uid, Err(e.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let author_id = c.author_id as i64;
|
||||
match author_result.get(&author_id) {
|
||||
Some(Ok(Some(author_mh))) => Ok(PostCandidate {
|
||||
mutual_follow_jaccard: Some(jaccard_from_minhash(
|
||||
viewer_minhash,
|
||||
author_mh,
|
||||
)),
|
||||
..Default::default()
|
||||
}),
|
||||
Some(Ok(None)) => Ok(PostCandidate {
|
||||
mutual_follow_jaccard: None,
|
||||
..Default::default()
|
||||
}),
|
||||
Some(Err(err)) => Err(err.clone()),
|
||||
None => Err(format!(
|
||||
"Missing minhash fetch result for author_id={}",
|
||||
author_id,
|
||||
)),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.mutual_follow_jaccard = hydrated.mutual_follow_jaccard;
|
||||
}
|
||||
}
|
||||
176
home-mixer/candidate_hydrators/quote_hydrator.rs
Normal file
176
home-mixer/candidate_hydrators/quote_hydrator.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use crate::clients::tweet_entity_service_client::TESClient;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params::EnableQuotedVqvDurationCheck;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::clients::SocialGraphClientOps;
|
||||
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, Hydrator};
|
||||
|
||||
pub struct QuoteHydrator {
|
||||
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
pub socialgraph_client: Arc<dyn SocialGraphClientOps>,
|
||||
pub cache: MokaCache<u64, QuoteCacheValue>,
|
||||
}
|
||||
|
||||
impl QuoteHydrator {
|
||||
pub async fn new(
|
||||
tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
socialgraph_client: Arc<dyn SocialGraphClientOps>,
|
||||
) -> Self {
|
||||
let cache = default_moka_cache();
|
||||
Self {
|
||||
tes_client,
|
||||
socialgraph_client,
|
||||
cache,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_quoted_video_durations(
|
||||
&self,
|
||||
quoted_tweet_ids: Vec<u64>,
|
||||
) -> HashMap<u64, Option<i32>> {
|
||||
if quoted_tweet_ids.is_empty() {
|
||||
return HashMap::new();
|
||||
}
|
||||
let result = tokio::time::timeout(
|
||||
std::time::Duration::from_millis(200),
|
||||
self.tes_client.get_min_video_durations(quoted_tweet_ids),
|
||||
)
|
||||
.await;
|
||||
match result {
|
||||
Ok(durations) => durations
|
||||
.into_iter()
|
||||
.filter_map(|(id, result)| result.ok().map(|d| (id, d.map(|v| v as i32))))
|
||||
.collect(),
|
||||
Err(_) => HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_blocked_by(&self, viewer_id: u64, quoted_user_ids: Vec<u64>) -> HashSet<u64> {
|
||||
if quoted_user_ids.is_empty() {
|
||||
return HashSet::new();
|
||||
}
|
||||
self.socialgraph_client
|
||||
.check_blocked_by(viewer_id, "ed_user_ids)
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct QuoteCacheValue {
|
||||
pub quoted_tweet_id: Option<u64>,
|
||||
pub quoted_user_id: Option<u64>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for QuoteHydrator {
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
!query.has_cached_posts
|
||||
}
|
||||
|
||||
async fn hydrate(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let tweet_ids: Vec<u64> = candidates.iter().map(|c| c.tweet_id).collect();
|
||||
|
||||
let mut cache_misses: Vec<u64> = Vec::new();
|
||||
let mut resolved: Vec<(u64, Option<u64>, Option<u64>)> =
|
||||
Vec::with_capacity(tweet_ids.len());
|
||||
|
||||
for &tweet_id in &tweet_ids {
|
||||
if let Some(cached) = self.cache.get(&tweet_id).await {
|
||||
resolved.push((tweet_id, cached.quoted_tweet_id, cached.quoted_user_id));
|
||||
} else {
|
||||
cache_misses.push(tweet_id);
|
||||
resolved.push((tweet_id, None, None));
|
||||
}
|
||||
}
|
||||
|
||||
if !cache_misses.is_empty() {
|
||||
let quoted_tweets = self
|
||||
.tes_client
|
||||
.get_quoted_tweets(cache_misses.clone())
|
||||
.await;
|
||||
|
||||
for entry in resolved.iter_mut() {
|
||||
let tweet_id = entry.0;
|
||||
if !cache_misses.contains(&tweet_id) {
|
||||
continue;
|
||||
}
|
||||
let (qt_tweet_id, qt_user_id) = match quoted_tweets.get(&tweet_id) {
|
||||
Some(Ok(Some(qt))) => (Some(qt.tweet_id), Some(qt.user_id)),
|
||||
_ => (None, None),
|
||||
};
|
||||
entry.1 = qt_tweet_id;
|
||||
entry.2 = qt_user_id;
|
||||
|
||||
self.cache
|
||||
.insert(
|
||||
tweet_id,
|
||||
QuoteCacheValue {
|
||||
quoted_tweet_id: qt_tweet_id,
|
||||
quoted_user_id: qt_user_id,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let quoted_user_ids: Vec<u64> = resolved
|
||||
.iter()
|
||||
.filter_map(|(_, _, uid)| *uid)
|
||||
.collect::<HashSet<u64>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let fetch_quoted_duration = query.params.get(EnableQuotedVqvDurationCheck);
|
||||
let quoted_tweet_ids: Vec<u64> = if fetch_quoted_duration {
|
||||
resolved
|
||||
.iter()
|
||||
.filter_map(|(_, qt_id, _)| *qt_id)
|
||||
.collect::<HashSet<u64>>()
|
||||
.into_iter()
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let (blocked_by, quoted_durations) = tokio::join!(
|
||||
self.get_blocked_by(query.user_id, quoted_user_ids),
|
||||
self.get_quoted_video_durations(quoted_tweet_ids),
|
||||
);
|
||||
|
||||
resolved
|
||||
.iter()
|
||||
.map(|(_, qt_tweet_id, qt_user_id)| {
|
||||
let quoted_author_blocks_viewer = qt_user_id
|
||||
.map(|uid| blocked_by.contains(&uid))
|
||||
.unwrap_or(false);
|
||||
let quoted_video_duration_ms = qt_tweet_id
|
||||
.and_then(|id| quoted_durations.get(&id).copied())
|
||||
.flatten();
|
||||
Ok(PostCandidate {
|
||||
quoted_tweet_id: *qt_tweet_id,
|
||||
quoted_user_id: *qt_user_id,
|
||||
quoted_author_blocks_viewer: Some(quoted_author_blocks_viewer),
|
||||
quoted_video_duration_ms,
|
||||
..Default::default()
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.quoted_tweet_id = hydrated.quoted_tweet_id;
|
||||
candidate.quoted_user_id = hydrated.quoted_user_id;
|
||||
candidate.quoted_author_blocks_viewer = hydrated.quoted_author_blocks_viewer;
|
||||
candidate.quoted_video_duration_ms = hydrated.quoted_video_duration_ms;
|
||||
}
|
||||
}
|
||||
@@ -1,47 +1,79 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::clients::tweet_entity_service_client::TESClient;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
|
||||
|
||||
pub struct SubscriptionHydrator {
|
||||
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
pub cache: MokaCache<u64, Option<u64>>,
|
||||
}
|
||||
|
||||
impl SubscriptionHydrator {
|
||||
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||
Self { tes_client }
|
||||
let cache = default_moka_cache();
|
||||
Self { tes_client, cache }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for SubscriptionHydrator {
|
||||
#[xai_stats_macro::receive_stats]
|
||||
async fn hydrate(
|
||||
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for SubscriptionHydrator {
|
||||
type CacheKey = u64;
|
||||
type CacheValue = Option<u64>;
|
||||
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
!query.has_cached_posts
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
|
||||
&self.cache
|
||||
}
|
||||
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
|
||||
candidate.tweet_id
|
||||
}
|
||||
|
||||
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
|
||||
hydrated.subscription_author_id
|
||||
}
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
|
||||
PostCandidate {
|
||||
subscription_author_id: value,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn hydrate_from_client(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Result<Vec<PostCandidate>, String> {
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let client = &self.tes_client;
|
||||
|
||||
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
|
||||
let tweet_ids: Vec<u64> = candidates.iter().map(|c| c.tweet_id).collect();
|
||||
|
||||
let post_features = client.get_subscription_author_ids(tweet_ids.clone()).await;
|
||||
let post_features = post_features.map_err(|e| e.to_string())?;
|
||||
|
||||
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||
for tweet_id in tweet_ids {
|
||||
let post_features = post_features.get(&tweet_id);
|
||||
let subscription_author_id = post_features.and_then(|x| *x);
|
||||
let hydrated = PostCandidate {
|
||||
subscription_author_id,
|
||||
..Default::default()
|
||||
let hydrated = match post_features {
|
||||
Some(Ok(value)) => Ok(PostCandidate {
|
||||
subscription_author_id: *value,
|
||||
..Default::default()
|
||||
}),
|
||||
None => Err(format!(
|
||||
"Missing subscription author id for tweet_id={}",
|
||||
tweet_id
|
||||
)),
|
||||
Some(Err(err)) => Err(err.to_string()),
|
||||
};
|
||||
hydrated_candidates.push(hydrated);
|
||||
}
|
||||
|
||||
Ok(hydrated_candidates)
|
||||
hydrated_candidates
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
|
||||
180
home-mixer/candidate_hydrators/tweet_type_metrics_hydrator.rs
Normal file
180
home-mixer/candidate_hydrators/tweet_type_metrics_hydrator.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::util::tweet_type_metrics::*;
|
||||
use std::collections::HashSet;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::utils::duration_since_creation_opt;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
|
||||
const THIRTY_MINUTES_MS: u64 = 30 * 60 * 1000;
|
||||
const ONE_HOUR_MS: u64 = 60 * 60 * 1000;
|
||||
const SIX_HOURS_MS: u64 = 6 * 60 * 60 * 1000;
|
||||
const TWELVE_HOURS_MS: u64 = 12 * 60 * 60 * 1000;
|
||||
const TWENTY_FOUR_HOURS_MS: u64 = 24 * 60 * 60 * 1000;
|
||||
|
||||
pub struct TweetTypeMetricsHydrator;
|
||||
|
||||
impl TweetTypeMetricsHydrator {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn create_tweet_type_bitset(
|
||||
candidate: &PostCandidate,
|
||||
query: &ScoredPostsQuery,
|
||||
) -> HashSet<usize> {
|
||||
let mut true_tweet_types = HashSet::new();
|
||||
|
||||
true_tweet_types.insert(ANY_CANDIDATE);
|
||||
|
||||
if candidate.retweeted_tweet_id.is_some() {
|
||||
true_tweet_types.insert(RETWEET);
|
||||
}
|
||||
|
||||
if candidate.in_reply_to_tweet_id.is_some() {
|
||||
true_tweet_types.insert(REPLY);
|
||||
}
|
||||
|
||||
if candidate.subscription_author_id.is_some() {
|
||||
true_tweet_types.insert(SUBSCRIPTION_POST);
|
||||
}
|
||||
|
||||
if let Some(score) = candidate.score
|
||||
&& score != 0.0
|
||||
{
|
||||
true_tweet_types.insert(FULL_SCORING_SUCCEEDED);
|
||||
}
|
||||
|
||||
if !candidate.ancestors.is_empty() {
|
||||
true_tweet_types.insert(HAS_ANCESTORS);
|
||||
}
|
||||
|
||||
if candidate.in_network.unwrap_or(true) {
|
||||
true_tweet_types.insert(IN_NETWORK);
|
||||
}
|
||||
|
||||
if let Some(followers) = candidate.author_followers_count {
|
||||
let followers_u32 = followers as u32;
|
||||
if followers_u32 < 100 {
|
||||
true_tweet_types.insert(AUTHOR_FOLLOWERS_0_100);
|
||||
}
|
||||
if (100..1000).contains(&followers_u32) {
|
||||
true_tweet_types.insert(AUTHOR_FOLLOWERS_100_1K);
|
||||
}
|
||||
if (1000..10000).contains(&followers_u32) {
|
||||
true_tweet_types.insert(AUTHOR_FOLLOWERS_1K_10K);
|
||||
}
|
||||
if (10000..100000).contains(&followers_u32) {
|
||||
true_tweet_types.insert(AUTHOR_FOLLOWERS_10K_100K);
|
||||
}
|
||||
if (100000..1000000).contains(&followers_u32) {
|
||||
true_tweet_types.insert(AUTHOR_FOLLOWERS_100K_1M);
|
||||
}
|
||||
if followers_u32 >= 1000000 {
|
||||
true_tweet_types.insert(AUTHOR_FOLLOWERS_1M_PLUS);
|
||||
}
|
||||
}
|
||||
|
||||
if candidate.min_video_duration_ms.is_some() {
|
||||
true_tweet_types.insert(VIDEO);
|
||||
}
|
||||
|
||||
if let Some(duration_ms) = candidate.min_video_duration_ms {
|
||||
let duration_ms_u32 = duration_ms as u32;
|
||||
if duration_ms_u32 <= 10000 {
|
||||
true_tweet_types.insert(VIDEO_LTE_10_SEC);
|
||||
}
|
||||
if duration_ms_u32 > 10000 && duration_ms_u32 <= 60000 {
|
||||
true_tweet_types.insert(VIDEO_BT_10_60_SEC);
|
||||
}
|
||||
if duration_ms_u32 > 60000 {
|
||||
true_tweet_types.insert(VIDEO_GT_60_SEC);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(age) = duration_since_creation_opt(candidate.tweet_id) {
|
||||
let age_ms = age.as_millis() as u64;
|
||||
|
||||
if age_ms <= THIRTY_MINUTES_MS {
|
||||
true_tweet_types.insert(TWEET_AGE_LTE_30_MINUTES);
|
||||
}
|
||||
if age_ms <= ONE_HOUR_MS {
|
||||
true_tweet_types.insert(TWEET_AGE_LTE_1_HOUR);
|
||||
}
|
||||
if age_ms <= SIX_HOURS_MS {
|
||||
true_tweet_types.insert(TWEET_AGE_LTE_6_HOURS);
|
||||
}
|
||||
if age_ms <= TWELVE_HOURS_MS {
|
||||
true_tweet_types.insert(TWEET_AGE_LTE_12_HOURS);
|
||||
}
|
||||
if age_ms >= TWENTY_FOUR_HOURS_MS {
|
||||
true_tweet_types.insert(TWEET_AGE_GTE_24_HOURS);
|
||||
}
|
||||
}
|
||||
|
||||
let served_size = query.served_ids.len();
|
||||
if served_size == 0 {
|
||||
true_tweet_types.insert(EMPTY_REQUEST);
|
||||
}
|
||||
if served_size < 3 {
|
||||
true_tweet_types.insert(NEAR_EMPTY);
|
||||
}
|
||||
if served_size < 20 {
|
||||
true_tweet_types.insert(SERVED_SIZE_LESS_THAN_20);
|
||||
}
|
||||
if served_size < 10 {
|
||||
true_tweet_types.insert(SERVED_SIZE_LESS_THAN_10);
|
||||
}
|
||||
if served_size < 5 {
|
||||
true_tweet_types.insert(SERVED_SIZE_LESS_THAN_5);
|
||||
}
|
||||
|
||||
true_tweet_types
|
||||
}
|
||||
|
||||
pub fn bitset_to_bytes(bits: &HashSet<usize>) -> Vec<u8> {
|
||||
if bits.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let max_bit = bits.iter().max().copied().unwrap_or(0);
|
||||
let num_bytes = (max_bit / 8) + 1;
|
||||
let mut bytes = vec![0u8; num_bytes];
|
||||
|
||||
for &bit_index in bits {
|
||||
let byte_index = bit_index / 8;
|
||||
let bit_offset = bit_index % 8;
|
||||
bytes[byte_index] |= 1u8 << bit_offset;
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for TweetTypeMetricsHydrator {
|
||||
async fn hydrate(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||
for candidate in candidates {
|
||||
let true_tweet_types = Self::create_tweet_type_bitset(candidate, query);
|
||||
|
||||
let tweet_type_metrics = Some(Self::bitset_to_bytes(&true_tweet_types));
|
||||
|
||||
let hydrated = PostCandidate {
|
||||
tweet_type_metrics,
|
||||
..Default::default()
|
||||
};
|
||||
hydrated_candidates.push(Ok(hydrated));
|
||||
}
|
||||
|
||||
hydrated_candidates
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.tweet_type_metrics = hydrated.tweet_type_metrics;
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use anyhow::Result;
|
||||
use futures::future::join;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
@@ -7,7 +8,7 @@ use tonic::async_trait;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_twittercontext_proto::GetTwitterContextViewer;
|
||||
use xai_twittercontext_proto::TwitterContextViewer;
|
||||
use xai_visibility_filtering::models::FilteredReason;
|
||||
use xai_visibility_filtering::models::{Action, FilteredReason};
|
||||
use xai_visibility_filtering::vf_client::SafetyLevel;
|
||||
use xai_visibility_filtering::vf_client::SafetyLevel::{TimelineHome, TimelineHomeRecommendations};
|
||||
use xai_visibility_filtering::vf_client::VisibilityFilteringClient;
|
||||
@@ -23,44 +24,55 @@ impl VFCandidateHydrator {
|
||||
|
||||
async fn fetch_vf_results(
|
||||
client: &Arc<dyn VisibilityFilteringClient + Send + Sync>,
|
||||
tweet_ids: Vec<i64>,
|
||||
tweet_ids: Vec<u64>,
|
||||
safety_level: SafetyLevel,
|
||||
for_user_id: i64,
|
||||
for_user_id: u64,
|
||||
context: Option<TwitterContextViewer>,
|
||||
) -> Result<HashMap<i64, Option<FilteredReason>>, String> {
|
||||
) -> HashMap<u64, Result<Option<FilteredReason>>> {
|
||||
if tweet_ids.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
return HashMap::new();
|
||||
}
|
||||
|
||||
client
|
||||
.get_result(tweet_ids, safety_level, for_user_id, context)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for VFCandidateHydrator {
|
||||
#[xai_stats_macro::receive_stats]
|
||||
async fn hydrate(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Result<Vec<PostCandidate>, String> {
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let context = query.get_viewer();
|
||||
let user_id = query.user_id;
|
||||
let client = &self.vf_client;
|
||||
|
||||
let mut in_network_ids = Vec::new();
|
||||
let mut oon_ids = Vec::new();
|
||||
let mut in_network_ids: Vec<u64> = Vec::new();
|
||||
let mut oon_ids: Vec<u64> = Vec::new();
|
||||
|
||||
for candidate in candidates.iter() {
|
||||
if candidate.in_network.unwrap_or(false) {
|
||||
in_network_ids.push(candidate.tweet_id);
|
||||
} else {
|
||||
oon_ids.push(candidate.tweet_id);
|
||||
}
|
||||
for &ancestor_id in &candidate.ancestors {
|
||||
oon_ids.push(ancestor_id);
|
||||
}
|
||||
if let Some(quoted_id) = candidate.quoted_tweet_id {
|
||||
oon_ids.push(quoted_id);
|
||||
}
|
||||
if let Some(retweeted_id) = candidate.retweeted_tweet_id {
|
||||
in_network_ids.push(retweeted_id);
|
||||
}
|
||||
}
|
||||
|
||||
oon_ids.sort_unstable();
|
||||
oon_ids.dedup();
|
||||
|
||||
let in_network_future = Self::fetch_vf_results(
|
||||
client,
|
||||
in_network_ids,
|
||||
@@ -78,24 +90,73 @@ impl Hydrator<ScoredPostsQuery, PostCandidate> for VFCandidateHydrator {
|
||||
);
|
||||
|
||||
let (in_network_result, oon_result) = join(in_network_future, oon_future).await;
|
||||
let mut result: HashMap<i64, Option<FilteredReason>> = HashMap::new();
|
||||
result.extend(in_network_result?);
|
||||
result.extend(oon_result?);
|
||||
let mut all_results: HashMap<u64, Result<Option<FilteredReason>>> = HashMap::new();
|
||||
all_results.extend(in_network_result);
|
||||
all_results.extend(oon_result);
|
||||
|
||||
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||
for candidate in candidates {
|
||||
let visibility_reason = result.get(&candidate.tweet_id);
|
||||
let visibility_reason = visibility_reason.unwrap_or(&None);
|
||||
let hydrated = PostCandidate {
|
||||
visibility_reason: visibility_reason.clone(),
|
||||
..Default::default()
|
||||
let primary_result = all_results.get(&candidate.tweet_id);
|
||||
let visibility_reason = match primary_result {
|
||||
Some(Ok(Some(reason))) => Some(reason.clone()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let drop_ancillary = should_drop_ancillary(candidate, &all_results);
|
||||
|
||||
let hydrated = match primary_result {
|
||||
Some(Err(err)) => Err(err.to_string()),
|
||||
_ => Ok(PostCandidate {
|
||||
visibility_reason,
|
||||
drop_ancillary_posts: Some(drop_ancillary),
|
||||
..Default::default()
|
||||
}),
|
||||
};
|
||||
hydrated_candidates.push(hydrated);
|
||||
}
|
||||
Ok(hydrated_candidates)
|
||||
hydrated_candidates
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.visibility_reason = hydrated.visibility_reason;
|
||||
candidate.drop_ancillary_posts = hydrated.drop_ancillary_posts;
|
||||
}
|
||||
}
|
||||
|
||||
fn should_drop_ancillary(
|
||||
candidate: &PostCandidate,
|
||||
vf_results: &HashMap<u64, Result<Option<FilteredReason>>>,
|
||||
) -> bool {
|
||||
for &ancestor_id in &candidate.ancestors {
|
||||
if let Some(Ok(Some(reason))) = vf_results.get(&ancestor_id)
|
||||
&& should_drop_reason(reason)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(quoted_id) = candidate.quoted_tweet_id
|
||||
&& let Some(Ok(Some(reason))) = vf_results.get("ed_id)
|
||||
&& should_drop_reason(reason)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(retweeted_id) = candidate.retweeted_tweet_id
|
||||
&& let Some(Ok(Some(reason))) = vf_results.get(&retweeted_id)
|
||||
&& should_drop_reason(reason)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn should_drop_reason(reason: &FilteredReason) -> bool {
|
||||
match reason {
|
||||
FilteredReason::SafetyResult(safety_result) => {
|
||||
matches!(safety_result.action, Action::Drop(_))
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,62 +1,85 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::candidate_features::MediaInfo;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::clients::tweet_entity_service_client::TESClient;
|
||||
use crate::models::candidate::{CandidateHelpers, PostCandidate};
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_candidate_pipeline::component_library::utils::{MokaCache, default_moka_cache};
|
||||
use xai_candidate_pipeline::hydrator::{CacheStore, CachedHydrator};
|
||||
|
||||
pub struct VideoDurationCandidateHydrator {
|
||||
pub tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
pub cache: MokaCache<u64, Option<i32>>,
|
||||
}
|
||||
|
||||
impl VideoDurationCandidateHydrator {
|
||||
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
|
||||
Self { tes_client }
|
||||
let cache = default_moka_cache();
|
||||
Self { tes_client, cache }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hydrator<ScoredPostsQuery, PostCandidate> for VideoDurationCandidateHydrator {
|
||||
#[xai_stats_macro::receive_stats]
|
||||
async fn hydrate(
|
||||
impl CachedHydrator<ScoredPostsQuery, PostCandidate> for VideoDurationCandidateHydrator {
|
||||
type CacheKey = u64;
|
||||
|
||||
type CacheValue = Option<i32>;
|
||||
|
||||
fn enable(&self, query: &ScoredPostsQuery) -> bool {
|
||||
!query.has_cached_posts
|
||||
}
|
||||
|
||||
fn cache_store(&self) -> &dyn CacheStore<Self::CacheKey, Self::CacheValue> {
|
||||
&self.cache
|
||||
}
|
||||
fn cache_key(&self, candidate: &PostCandidate) -> Self::CacheKey {
|
||||
candidate.get_original_tweet_id()
|
||||
}
|
||||
|
||||
fn cache_value(&self, hydrated: &PostCandidate) -> Self::CacheValue {
|
||||
hydrated.min_video_duration_ms
|
||||
}
|
||||
|
||||
fn hydrate_from_cache(&self, value: Self::CacheValue) -> PostCandidate {
|
||||
PostCandidate {
|
||||
min_video_duration_ms: value,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn hydrate_from_client(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: &[PostCandidate],
|
||||
) -> Result<Vec<PostCandidate>, String> {
|
||||
) -> Vec<Result<PostCandidate, String>> {
|
||||
let client = &self.tes_client;
|
||||
|
||||
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
|
||||
let tweet_ids: Vec<u64> = candidates
|
||||
.iter()
|
||||
.map(|c| c.get_original_tweet_id())
|
||||
.collect();
|
||||
|
||||
let post_features = client.get_tweet_media_entities(tweet_ids.clone()).await;
|
||||
let post_features = post_features.map_err(|e| e.to_string())?;
|
||||
let durations = client.get_min_video_durations(tweet_ids.clone()).await;
|
||||
|
||||
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
|
||||
for tweet_id in tweet_ids {
|
||||
let post_features = post_features.get(&tweet_id);
|
||||
let media_entities = post_features.and_then(|x| x.as_ref());
|
||||
|
||||
let video_duration_ms = media_entities.and_then(|entities| {
|
||||
entities.iter().find_map(|entity| {
|
||||
if let Some(MediaInfo::VideoInfo(video_info)) = &entity.media_info {
|
||||
Some(video_info.duration_millis)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let hydrated = PostCandidate {
|
||||
video_duration_ms,
|
||||
..Default::default()
|
||||
let hydrated = match durations.get(&tweet_id) {
|
||||
Some(Ok(min_video_duration_ms)) => Ok(PostCandidate {
|
||||
min_video_duration_ms: min_video_duration_ms.map(|v| v as i32),
|
||||
..Default::default()
|
||||
}),
|
||||
None => Err(format!(
|
||||
"Missing min video duration for tweet_id={}",
|
||||
tweet_id
|
||||
)),
|
||||
Some(Err(err)) => Err(err.to_string()),
|
||||
};
|
||||
hydrated_candidates.push(hydrated);
|
||||
}
|
||||
|
||||
Ok(hydrated_candidates)
|
||||
hydrated_candidates
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
|
||||
candidate.video_duration_ms = hydrated.video_duration_ms;
|
||||
candidate.min_video_duration_ms = hydrated.min_video_duration_ms;
|
||||
}
|
||||
}
|
||||
|
||||
278
home-mixer/candidate_pipeline/for_you_candidate_pipeline.rs
Normal file
278
home-mixer/candidate_pipeline/for_you_candidate_pipeline.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use crate::clients::ad_index_client::{AdIndexClient, MockAdIndexClient, ProdAdIndexClient};
|
||||
use crate::clients::kafka_publisher_client::{KafkaPublisherClient, MockKafkaPublisherClient};
|
||||
use crate::clients::past_request_timestamps_client::{
|
||||
MockPastRequestTimestampsClient, PastRequestTimestampsClient, ProdPastRequestTimestampsClient,
|
||||
};
|
||||
use crate::clients::prompts_client::{MockPromptsClient, ProdPromptsClient, PromptsClient};
|
||||
use crate::clients::served_history_client::{
|
||||
MockServedHistoryClient, ProdServedHistoryClient, ServedHistoryClient,
|
||||
};
|
||||
use crate::clients::tweet_entity_service_client::{MockTESClient, ProdTESClient, TESClient};
|
||||
use crate::clients::who_to_follow_client::{
|
||||
MockWhoToFollowClient, ProdWhoToFollowClient, WhoToFollowClient,
|
||||
};
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params;
|
||||
use crate::query_hydrators::past_request_timestamps_query_hydrator::PastRequestTimestampsQueryHydrator;
|
||||
use crate::query_hydrators::served_history_query_hydrator::ServedHistoryQueryHydrator;
|
||||
use crate::scored_posts_server::ScoredPostsServer;
|
||||
use crate::selectors::BlenderSelector;
|
||||
use crate::side_effects::ads_injection_logging_side_effect::AdsInjectionLoggingSideEffect;
|
||||
use crate::side_effects::client_events_kafka_side_effect::ClientEventsKafkaSideEffect;
|
||||
use crate::side_effects::for_you_response_stats_side_effect::ForYouResponseStatsSideEffect;
|
||||
use crate::side_effects::publish_seen_ids_to_kafka_side_effect::PublishSeenIdsToKafkaSideEffect;
|
||||
use crate::side_effects::served_candidates_kafka_side_effect::ServedCandidatesKafkaSideEffect;
|
||||
use crate::side_effects::truncate_served_history_side_effect::TruncateServedHistorySideEffect;
|
||||
use crate::side_effects::update_past_request_timestamps_side_effect::UpdatePastRequestTimestampsSideEffect;
|
||||
use crate::side_effects::update_served_history_side_effect::UpdateServedHistorySideEffect;
|
||||
use crate::sources::ads_source::AdsSource;
|
||||
use crate::sources::prompts_source::PromptsSource;
|
||||
use crate::sources::push_to_home_source::PushToHomeSource;
|
||||
use crate::sources::scored_posts_source::ScoredPostsSource;
|
||||
use crate::sources::who_to_follow_source::WhoToFollowSource;
|
||||
use std::sync::Arc;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
|
||||
use xai_candidate_pipeline::component_library::clients::{
|
||||
MockReplyMixerClient, ProdReplyMixerClient, ReplyMixerClient,
|
||||
};
|
||||
use xai_candidate_pipeline::filter::Filter;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
|
||||
use xai_candidate_pipeline::scorer::Scorer;
|
||||
use xai_candidate_pipeline::selector::Selector;
|
||||
use xai_candidate_pipeline::side_effect::SideEffect;
|
||||
use xai_candidate_pipeline::source::Source;
|
||||
use xai_home_mixer_proto::FeedItem;
|
||||
|
||||
pub struct ForYouCandidatePipeline {
|
||||
query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>>,
|
||||
sources: Vec<Box<dyn Source<ScoredPostsQuery, FeedItem>>>,
|
||||
selector: BlenderSelector,
|
||||
side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, FeedItem>>>>,
|
||||
}
|
||||
|
||||
impl ForYouCandidatePipeline {
|
||||
pub async fn new(scored_posts_server: Arc<ScoredPostsServer>, datacenter: &str) -> Self {
|
||||
let (
|
||||
ad_index_client,
|
||||
ads_injection_logging,
|
||||
publish_seen_ids,
|
||||
served_candidates,
|
||||
client_events,
|
||||
served_history_client,
|
||||
who_to_follow_client,
|
||||
prompts_client,
|
||||
past_request_timestamps_client,
|
||||
tes_client,
|
||||
reply_mixer_client,
|
||||
) = tokio::join!(
|
||||
async {
|
||||
Arc::new(
|
||||
ProdAdIndexClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create AdIndex client"),
|
||||
) as Arc<dyn AdIndexClient + Send + Sync>
|
||||
},
|
||||
AdsInjectionLoggingSideEffect::prod(),
|
||||
PublishSeenIdsToKafkaSideEffect::prod(),
|
||||
ServedCandidatesKafkaSideEffect::prod(),
|
||||
ClientEventsKafkaSideEffect::prod(),
|
||||
async {
|
||||
Arc::new(
|
||||
ProdServedHistoryClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create ServedHistoryClient"),
|
||||
) as Arc<dyn ServedHistoryClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdWhoToFollowClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create WhoToFollowClient"),
|
||||
) as Arc<dyn WhoToFollowClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdPromptsClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create PromptsClient"),
|
||||
) as Arc<dyn PromptsClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdPastRequestTimestampsClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create PastRequestTimestampsClient"),
|
||||
) as Arc<dyn PastRequestTimestampsClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdTESClient::new(None, datacenter)
|
||||
.await
|
||||
.expect("Failed to create TES client"),
|
||||
) as Arc<dyn TESClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdReplyMixerClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create ReplyMixer client"),
|
||||
) as Arc<dyn ReplyMixerClient>
|
||||
},
|
||||
);
|
||||
|
||||
Self::build(
|
||||
scored_posts_server,
|
||||
ad_index_client,
|
||||
ads_injection_logging,
|
||||
publish_seen_ids,
|
||||
served_candidates,
|
||||
client_events,
|
||||
served_history_client,
|
||||
who_to_follow_client,
|
||||
prompts_client,
|
||||
past_request_timestamps_client,
|
||||
tes_client,
|
||||
reply_mixer_client,
|
||||
)
|
||||
}
|
||||
|
||||
fn build(
|
||||
scored_posts_server: Arc<ScoredPostsServer>,
|
||||
ad_index_client: Arc<dyn AdIndexClient + Send + Sync>,
|
||||
ads_injection_logging: AdsInjectionLoggingSideEffect,
|
||||
publish_seen_ids: PublishSeenIdsToKafkaSideEffect,
|
||||
served_candidates: ServedCandidatesKafkaSideEffect,
|
||||
client_events: ClientEventsKafkaSideEffect,
|
||||
served_history_client: Arc<dyn ServedHistoryClient>,
|
||||
who_to_follow_client: Arc<dyn WhoToFollowClient + Send + Sync>,
|
||||
prompts_client: Arc<dyn PromptsClient + Send + Sync>,
|
||||
past_request_timestamps_client: Arc<dyn PastRequestTimestampsClient>,
|
||||
tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
reply_mixer_client: Arc<dyn ReplyMixerClient>,
|
||||
) -> Self {
|
||||
let query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>> = vec![
|
||||
Box::new(ServedHistoryQueryHydrator::from_client(Arc::clone(
|
||||
&served_history_client,
|
||||
))),
|
||||
Box::new(PastRequestTimestampsQueryHydrator::new(Arc::clone(
|
||||
&past_request_timestamps_client,
|
||||
))),
|
||||
];
|
||||
|
||||
let sources: Vec<Box<dyn Source<ScoredPostsQuery, FeedItem>>> = vec![
|
||||
Box::new(ScoredPostsSource {
|
||||
scored_posts_server,
|
||||
}),
|
||||
Box::new(AdsSource { ad_index_client }),
|
||||
Box::new(WhoToFollowSource {
|
||||
who_to_follow_client,
|
||||
}),
|
||||
Box::new(PromptsSource { prompts_client }),
|
||||
Box::new(PushToHomeSource {
|
||||
tes_client,
|
||||
reply_mixer_client,
|
||||
}),
|
||||
];
|
||||
|
||||
let selector = BlenderSelector::new();
|
||||
|
||||
let side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, FeedItem>>>> =
|
||||
Arc::new(vec![
|
||||
Box::new(ads_injection_logging),
|
||||
Box::new(publish_seen_ids),
|
||||
Box::new(served_candidates),
|
||||
Box::new(client_events),
|
||||
Box::new(ForYouResponseStatsSideEffect),
|
||||
Box::new(UpdatePastRequestTimestampsSideEffect::new(
|
||||
past_request_timestamps_client,
|
||||
)),
|
||||
Box::new(UpdateServedHistorySideEffect::new(Arc::clone(
|
||||
&served_history_client,
|
||||
))),
|
||||
Box::new(TruncateServedHistorySideEffect::new(served_history_client)),
|
||||
]);
|
||||
|
||||
Self {
|
||||
query_hydrators,
|
||||
sources,
|
||||
selector,
|
||||
side_effects,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mock(scored_posts_server: Arc<ScoredPostsServer>) -> Self {
|
||||
let ad_index_client: Arc<dyn AdIndexClient + Send + Sync> = Arc::new(MockAdIndexClient);
|
||||
let mock_kafka = Arc::new(MockKafkaPublisherClient) as Arc<dyn KafkaPublisherClient>;
|
||||
let ads_injection = AdsInjectionLoggingSideEffect::new(Arc::clone(&mock_kafka));
|
||||
let publish_seen_ids = PublishSeenIdsToKafkaSideEffect::new(Arc::clone(&mock_kafka));
|
||||
let served_candidates = ServedCandidatesKafkaSideEffect::new(Arc::clone(&mock_kafka));
|
||||
let client_events = ClientEventsKafkaSideEffect::new(mock_kafka);
|
||||
let served_history_client: Arc<dyn ServedHistoryClient> = Arc::new(MockServedHistoryClient);
|
||||
let who_to_follow_client: Arc<dyn WhoToFollowClient + Send + Sync> =
|
||||
Arc::new(MockWhoToFollowClient);
|
||||
let prompts_client: Arc<dyn PromptsClient + Send + Sync> = Arc::new(MockPromptsClient);
|
||||
let past_request_timestamps_client: Arc<dyn PastRequestTimestampsClient> =
|
||||
Arc::new(MockPastRequestTimestampsClient);
|
||||
let tes_client: Arc<dyn TESClient + Send + Sync> = Arc::new(MockTESClient::default());
|
||||
let reply_mixer_client: Arc<dyn ReplyMixerClient> = Arc::new(MockReplyMixerClient);
|
||||
Self::build(
|
||||
scored_posts_server,
|
||||
ad_index_client,
|
||||
ads_injection,
|
||||
publish_seen_ids,
|
||||
served_candidates,
|
||||
client_events,
|
||||
served_history_client,
|
||||
who_to_follow_client,
|
||||
prompts_client,
|
||||
past_request_timestamps_client,
|
||||
tes_client,
|
||||
reply_mixer_client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CandidatePipeline<ScoredPostsQuery, FeedItem> for ForYouCandidatePipeline {
|
||||
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<ScoredPostsQuery>>] {
|
||||
&self.query_hydrators
|
||||
}
|
||||
|
||||
fn sources(&self) -> &[Box<dyn Source<ScoredPostsQuery, FeedItem>>] {
|
||||
&self.sources
|
||||
}
|
||||
|
||||
fn hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, FeedItem>>] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, FeedItem>>] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn scorers(&self) -> &[Box<dyn Scorer<ScoredPostsQuery, FeedItem>>] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn selector(&self) -> &dyn Selector<ScoredPostsQuery, FeedItem> {
|
||||
&self.selector
|
||||
}
|
||||
|
||||
fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, FeedItem>>] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn post_selection_filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, FeedItem>>] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, FeedItem>>>> {
|
||||
Arc::clone(&self.side_effects)
|
||||
}
|
||||
|
||||
fn result_size(&self) -> usize {
|
||||
params::FOR_YOU_MAX_RESULT_SIZE
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,2 @@
|
||||
pub mod candidate;
|
||||
pub mod candidate_features;
|
||||
pub mod for_you_candidate_pipeline;
|
||||
pub mod phoenix_candidate_pipeline;
|
||||
pub mod query;
|
||||
pub mod query_features;
|
||||
|
||||
@@ -1,51 +1,130 @@
|
||||
use crate::candidate_hydrators::ads_brand_safety_hydrator::AdsBrandSafetyHydrator;
|
||||
use crate::candidate_hydrators::ads_brand_safety_vf_hydrator::AdsBrandSafetyVfHydrator;
|
||||
use crate::candidate_hydrators::blocked_by_hydrator::BlockedByHydrator;
|
||||
use crate::candidate_hydrators::core_data_candidate_hydrator::CoreDataCandidateHydrator;
|
||||
use crate::candidate_hydrators::filtered_topics_hydrator::FilteredTopicsHydrator;
|
||||
use crate::candidate_hydrators::following_replied_users_hydrator::FollowingRepliedUsersHydrator;
|
||||
use crate::candidate_hydrators::gizmoduck_hydrator::GizmoduckCandidateHydrator;
|
||||
use crate::candidate_hydrators::has_media_hydrator::HasMediaHydrator;
|
||||
use crate::candidate_hydrators::in_network_candidate_hydrator::InNetworkCandidateHydrator;
|
||||
use crate::candidate_hydrators::language_code_hydrator::LanguageCodeHydrator;
|
||||
use crate::candidate_hydrators::mutual_follow_jaccard_hydrator::MutualFollowJaccardHydrator;
|
||||
use crate::candidate_hydrators::quote_hydrator::QuoteHydrator;
|
||||
use crate::candidate_hydrators::subscription_hydrator::SubscriptionHydrator;
|
||||
use crate::candidate_hydrators::tweet_type_metrics_hydrator::TweetTypeMetricsHydrator;
|
||||
use crate::candidate_hydrators::vf_candidate_hydrator::VFCandidateHydrator;
|
||||
use crate::candidate_hydrators::video_duration_candidate_hydrator::VideoDurationCandidateHydrator;
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::clients::gizmoduck_client::{GizmoduckClient, ProdGizmoduckClient};
|
||||
use crate::clients::phoenix_prediction_client::{
|
||||
PhoenixPredictionClient, ProdPhoenixPredictionClient,
|
||||
use crate::clients::followed_grok_topics_store_client::{
|
||||
FollowedGrokTopicsStoreClient, MockFollowedGrokTopicsStoreClient,
|
||||
ProdFollowedGrokTopicsStoreClient,
|
||||
};
|
||||
use crate::clients::phoenix_retrieval_client::{
|
||||
PhoenixRetrievalClient, ProdPhoenixRetrievalClient,
|
||||
use crate::clients::followed_starter_packs_store_client::{
|
||||
FollowedStarterPacksStoreClient, MockFollowedStarterPacksStoreClient,
|
||||
ProdFollowedStarterPacksStoreClient,
|
||||
};
|
||||
use crate::clients::gender_prediction_client::{
|
||||
GenderPredictionGrpcClient, MockGenderPredictionGrpcClient, ProdGenderPredictionGrpcClient,
|
||||
};
|
||||
use crate::clients::gizmoduck_client::{GizmoduckClient, MockGizmoduckClient, ProdGizmoduckClient};
|
||||
use crate::clients::impressed_posts_client::ImpressedPostsClient;
|
||||
use crate::clients::kafka_publisher_client::{
|
||||
KafkaCluster, KafkaPublisherClient, MockKafkaPublisherClient, PHOENIX_SCORES_TOPIC,
|
||||
ProdKafkaPublisherClient, RERANKING_TOPIC,
|
||||
};
|
||||
use crate::clients::s2s::{S2S_CHAIN_PATH, S2S_CRT_PATH, S2S_KEY_PATH};
|
||||
use crate::clients::socialgraph_client::SocialGraphClient;
|
||||
use crate::clients::strato_client::{ProdStratoClient, StratoClient};
|
||||
use crate::clients::thunder_client::ThunderClient;
|
||||
use crate::clients::tweet_entity_service_client::{ProdTESClient, TESClient};
|
||||
use crate::clients::uas_fetcher::UserActionSequenceFetcher;
|
||||
use crate::clients::tweet_entity_service_client::{MockTESClient, ProdTESClient, TESClient};
|
||||
use crate::clients::user_action_aggregation_client::{
|
||||
MockUserActionAggregationClient, ProdUserActionAggregationClient, UserActionAggregationClient,
|
||||
};
|
||||
use crate::clients::user_demographics_client::{
|
||||
MockUserDemographicsClient, ProdUserDemographicsClient, UserDemographicsClient,
|
||||
};
|
||||
use crate::clients::user_inferred_gender_store_client::{
|
||||
MockUserInferredGenderStoreClient, ProdUserInferredGenderStoreClient,
|
||||
UserInferredGenderStoreClient,
|
||||
};
|
||||
use crate::clients::vm_ranker_client::{MockVMRankerClient, ProdVMRankerClient, VMRankerClient};
|
||||
use crate::filters::age_filter::AgeFilter;
|
||||
use crate::filters::ancillary_vf_filter::AncillaryVFFilter;
|
||||
use crate::filters::author_socialgraph_filter::AuthorSocialgraphFilter;
|
||||
use crate::filters::core_data_hydration_filter::CoreDataHydrationFilter;
|
||||
use crate::filters::dedup_conversation_filter::DedupConversationFilter;
|
||||
use crate::filters::drop_duplicates_filter::DropDuplicatesFilter;
|
||||
use crate::filters::ineligible_subscription_filter::IneligibleSubscriptionFilter;
|
||||
use crate::filters::muted_keyword_filter::MutedKeywordFilter;
|
||||
use crate::filters::new_user_topic_ids_filter::NewUserTopicIdsFilter;
|
||||
use crate::filters::previously_seen_posts_backup_filter::PreviouslySeenPostsBackupFilter;
|
||||
use crate::filters::previously_seen_posts_filter::PreviouslySeenPostsFilter;
|
||||
use crate::filters::previously_served_posts_filter::PreviouslyServedPostsFilter;
|
||||
use crate::filters::retweet_deduplication_filter::RetweetDeduplicationFilter;
|
||||
use crate::filters::self_tweet_filter::SelfTweetFilter;
|
||||
use crate::filters::topic_ids_filter::TopicIdsFilter;
|
||||
use crate::filters::vf_filter::VFFilter;
|
||||
use crate::filters::video_filter::VideoFilter;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use crate::params;
|
||||
use crate::query_hydrators::user_action_seq_query_hydrator::UserActionSeqQueryHydrator;
|
||||
use crate::query_hydrators::user_features_query_hydrator::UserFeaturesQueryHydrator;
|
||||
use crate::scorers::author_diversity_scorer::AuthorDiversityScorer;
|
||||
use crate::scorers::oon_scorer::OONScorer;
|
||||
use crate::query_hydrators::blocked_user_ids_query_hydrator::BlockedUserIdsQueryHydrator;
|
||||
use crate::query_hydrators::cached_posts_query_hydrator::CachedPostsQueryHydrator;
|
||||
use crate::query_hydrators::followed_grok_topics_query_hydrator::FollowedGrokTopicsQueryHydrator;
|
||||
use crate::query_hydrators::followed_starter_packs_query_hydrator::FollowedStarterPacksQueryHydrator;
|
||||
use crate::query_hydrators::followed_user_ids_query_hydrator::FollowedUserIdsQueryHydrator;
|
||||
use crate::query_hydrators::ip_query_hydrator::IpQueryHydrator;
|
||||
use crate::query_hydrators::impressed_posts_query_hydrator::ImpressedPostsQueryHydrator;
|
||||
use crate::query_hydrators::impression_bloom_filter_query_hydrator::ImpressionBloomFilterQueryHydrator;
|
||||
use crate::query_hydrators::inferred_grok_topics_query_hydrator::InferredGrokTopicsQueryHydrator;
|
||||
use crate::query_hydrators::muted_user_ids_query_hydrator::MutedUserIdsQueryHydrator;
|
||||
use crate::query_hydrators::mutual_follow_query_hydrator::MutualFollowQueryHydrator;
|
||||
use crate::query_hydrators::retrieval_sequence_query_hydrator::RetrievalSequenceQueryHydrator;
|
||||
use crate::query_hydrators::scoring_sequence_query_hydrator::ScoringSequenceQueryHydrator;
|
||||
use crate::query_hydrators::subscribed_user_ids_query_hydrator::SubscribedUserIdsQueryHydrator;
|
||||
use crate::query_hydrators::user_demographics_query_hydrator::UserDemographicsQueryHydrator;
|
||||
use crate::query_hydrators::user_inferred_gender_query_hydrator::UserInferredGenderQueryHydrator;
|
||||
use crate::scorers::phoenix_scorer::PhoenixScorer;
|
||||
use crate::scorers::weighted_scorer::WeightedScorer;
|
||||
use crate::scorers::ranking_scorer::RankingScorer;
|
||||
use crate::scorers::vm_ranker::VMRanker;
|
||||
use crate::selectors::TopKScoreSelector;
|
||||
use crate::side_effects::cache_request_info_side_effect::CacheRequestInfoSideEffect;
|
||||
use crate::side_effects::mutual_follow_stats_side_effect::MutualFollowStatsSideEffect;
|
||||
use crate::side_effects::phoenix_experiments_side_effect::PhoenixExperimentsSideEffect;
|
||||
use crate::side_effects::phoenix_request_cache_side_effect::PhoenixRequestCacheSideEffect;
|
||||
use crate::side_effects::redis_post_candidate_cache_side_effect::RedisPostCandidateCacheSideEffect;
|
||||
use crate::side_effects::reranking_kafka_side_effect::RerankingKafkaSideEffect;
|
||||
use crate::side_effects::scored_stats_side_effect::ScoredStatsSideEffect;
|
||||
use crate::sources::cached_posts_source::CachedPostsSource;
|
||||
use crate::sources::phoenix_moe_source::PhoenixMOESource;
|
||||
use crate::sources::phoenix_source::PhoenixSource;
|
||||
use crate::sources::phoenix_topics_source::PhoenixTopicsSource;
|
||||
use crate::sources::thunder_source::ThunderSource;
|
||||
use crate::sources::tweet_mixer_source::TweetMixerSource;
|
||||
use xai_candidate_pipeline::component_library::clients::{
|
||||
MockTweetMixerClient, ProdTweetMixerClient, TweetMixerClient,
|
||||
};
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
|
||||
use xai_candidate_pipeline::component_library::clients::ThunderClient;
|
||||
use xai_candidate_pipeline::component_library::clients::egress_prediction_client::EgressPhoenixPredictionClient;
|
||||
use xai_candidate_pipeline::component_library::clients::phoenix_prediction_client::{
|
||||
MockPredictClient, PhoenixPredictionClient, ProdPhoenixPredictionClient,
|
||||
};
|
||||
use xai_candidate_pipeline::component_library::clients::phoenix_retrieval_client::{
|
||||
MockRetrievalClient, PhoenixRetrievalClient, PhoenixRetrievalCluster,
|
||||
ProdPhoenixRetrievalClient,
|
||||
};
|
||||
use xai_candidate_pipeline::component_library::clients::redis_client::{
|
||||
MockRedisClient, RedisClient,
|
||||
};
|
||||
use xai_candidate_pipeline::component_library::clients::{
|
||||
ImpressionBloomFilterClient, MockImpressionBloomFilterClient, ProdImpressionBloomFilterClient,
|
||||
};
|
||||
use xai_candidate_pipeline::component_library::clients::{
|
||||
MockSocialGraphClient, SocialGraphClient, SocialGraphClientOps,
|
||||
};
|
||||
use xai_candidate_pipeline::component_library::clients::{
|
||||
MockStratoClient, ProdStratoClient, StratoClient,
|
||||
};
|
||||
use xai_candidate_pipeline::filter::Filter;
|
||||
use xai_candidate_pipeline::hydrator::Hydrator;
|
||||
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
|
||||
@@ -53,9 +132,13 @@ use xai_candidate_pipeline::scorer::Scorer;
|
||||
use xai_candidate_pipeline::selector::Selector;
|
||||
use xai_candidate_pipeline::side_effect::SideEffect;
|
||||
use xai_candidate_pipeline::source::Source;
|
||||
use xai_geo_ip::GeoIpLocationClient;
|
||||
use xai_redis_client::{XdsRedisClient, XdsRedisConfig};
|
||||
use xai_visibility_filtering::vf_client::{
|
||||
ProdVisibilityFilteringClient, VisibilityFilteringClient,
|
||||
MockVisibilityFilteringClient, ProdVisibilityFilteringClient, VisibilityFilteringClient,
|
||||
};
|
||||
use xai_visibility_filtering::vf_safety_labels_client::{MockVfClient, ProdVfClient, VfClient};
|
||||
use xai_x_rpc::wily_lookup_service::ShardCoordinate;
|
||||
|
||||
pub struct PhoenixCandidatePipeline {
|
||||
query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>>,
|
||||
@@ -70,42 +153,124 @@ pub struct PhoenixCandidatePipeline {
|
||||
}
|
||||
|
||||
impl PhoenixCandidatePipeline {
|
||||
async fn build_with_clients(
|
||||
uas_fetcher: Arc<UserActionSequenceFetcher>,
|
||||
pub(crate) async fn build_with_clients(
|
||||
user_action_aggregation_client: Arc<dyn UserActionAggregationClient + Send + Sync>,
|
||||
phoenix_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
|
||||
egress_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
|
||||
phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
|
||||
thunder_client: Arc<ThunderClient>,
|
||||
strato_client: Arc<dyn StratoClient + Send + Sync>,
|
||||
tweet_mixer_client: Arc<dyn TweetMixerClient>,
|
||||
tes_client: Arc<dyn TESClient + Send + Sync>,
|
||||
gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
|
||||
vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>,
|
||||
redis_client: Arc<dyn RedisClient + Send + Sync>,
|
||||
phoenix_kafka_client: Arc<dyn KafkaPublisherClient>,
|
||||
reranking_kafka_client: Arc<dyn KafkaPublisherClient>,
|
||||
socialgraph_client: Arc<dyn SocialGraphClientOps>,
|
||||
vm_ranker_client: Arc<dyn VMRankerClient>,
|
||||
safety_label_client: Arc<dyn xai_safety_label_store::SafetyLabelStoreClient>,
|
||||
vf_safety_labels_client: Arc<dyn VfClient>,
|
||||
phoenix_request_cache_redis_atla_client: Arc<dyn RedisClient + Send + Sync>,
|
||||
phoenix_request_cache_redis_pdxa_client: Arc<dyn RedisClient + Send + Sync>,
|
||||
impression_bloom_filter_client: Arc<dyn ImpressionBloomFilterClient>,
|
||||
ip_client: Arc<GeoIpLocationClient>,
|
||||
user_demographics_client: Arc<dyn UserDemographicsClient>,
|
||||
user_inferred_gender_store_client: Arc<dyn UserInferredGenderStoreClient>,
|
||||
user_inferred_gender_grpc_client: Arc<dyn GenderPredictionGrpcClient>,
|
||||
impressed_posts_client: Arc<dyn ImpressedPostsClient>,
|
||||
followed_grok_topics_client: Arc<dyn FollowedGrokTopicsStoreClient>,
|
||||
followed_starter_packs_client: Arc<dyn FollowedStarterPacksStoreClient>,
|
||||
) -> PhoenixCandidatePipeline {
|
||||
// Query Hydrators
|
||||
let query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>> = vec![
|
||||
Box::new(UserActionSeqQueryHydrator::new(uas_fetcher)),
|
||||
Box::new(UserFeaturesQueryHydrator {
|
||||
Box::new(ScoringSequenceQueryHydrator::new(
|
||||
user_action_aggregation_client.clone(),
|
||||
)),
|
||||
Box::new(RetrievalSequenceQueryHydrator::new(
|
||||
user_action_aggregation_client,
|
||||
)),
|
||||
Box::new(BlockedUserIdsQueryHydrator {
|
||||
socialgraph_client: socialgraph_client.clone(),
|
||||
}),
|
||||
Box::new(MutedUserIdsQueryHydrator {
|
||||
socialgraph_client: socialgraph_client.clone(),
|
||||
}),
|
||||
Box::new(FollowedUserIdsQueryHydrator {
|
||||
socialgraph_client: socialgraph_client.clone(),
|
||||
}),
|
||||
Box::new(SubscribedUserIdsQueryHydrator {
|
||||
socialgraph_client: socialgraph_client.clone(),
|
||||
}),
|
||||
Box::new(CachedPostsQueryHydrator {
|
||||
redis_client: redis_client.clone(),
|
||||
}),
|
||||
Box::new(MutualFollowQueryHydrator {
|
||||
strato_client: strato_client.clone(),
|
||||
}),
|
||||
Box::new(UserDemographicsQueryHydrator {
|
||||
client: user_demographics_client,
|
||||
}),
|
||||
Box::new(FollowedGrokTopicsQueryHydrator::new(
|
||||
followed_grok_topics_client,
|
||||
)),
|
||||
Box::new(FollowedStarterPacksQueryHydrator::new(
|
||||
followed_starter_packs_client,
|
||||
)),
|
||||
Box::new(InferredGrokTopicsQueryHydrator {
|
||||
strato_client: strato_client.clone(),
|
||||
}),
|
||||
Box::new(ImpressionBloomFilterQueryHydrator {
|
||||
client: impression_bloom_filter_client,
|
||||
}),
|
||||
Box::new(IpQueryHydrator {
|
||||
client: ip_client,
|
||||
}),
|
||||
Box::new(UserInferredGenderQueryHydrator::new(
|
||||
user_inferred_gender_store_client,
|
||||
user_inferred_gender_grpc_client,
|
||||
)),
|
||||
];
|
||||
|
||||
// Sources
|
||||
let _impressed_posts_hydrator = ImpressedPostsQueryHydrator {
|
||||
client: impressed_posts_client,
|
||||
};
|
||||
|
||||
let phoenix_source = Box::new(PhoenixSource {
|
||||
phoenix_retrieval_client: phoenix_retrieval_client.clone(),
|
||||
});
|
||||
let phoenix_topics_source = Box::new(PhoenixTopicsSource {
|
||||
phoenix_retrieval_client: phoenix_retrieval_client.clone(),
|
||||
});
|
||||
let phoenix_moe_source = Box::new(PhoenixMOESource {
|
||||
phoenix_retrieval_client,
|
||||
});
|
||||
let thunder_source = Box::new(ThunderSource { thunder_client });
|
||||
let sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>> =
|
||||
vec![phoenix_source, thunder_source];
|
||||
let tweet_mixer_source = Box::new(TweetMixerSource { tweet_mixer_client });
|
||||
let cached_posts_source = Box::new(CachedPostsSource);
|
||||
let sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||
thunder_source,
|
||||
tweet_mixer_source,
|
||||
phoenix_source,
|
||||
phoenix_topics_source,
|
||||
phoenix_moe_source,
|
||||
cached_posts_source,
|
||||
];
|
||||
|
||||
// Hydrators
|
||||
let hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||
Box::new(InNetworkCandidateHydrator),
|
||||
Box::new(CoreDataCandidateHydrator::new(tes_client.clone()).await),
|
||||
Box::new(QuoteHydrator::new(tes_client.clone(), socialgraph_client.clone()).await),
|
||||
Box::new(VideoDurationCandidateHydrator::new(tes_client.clone()).await),
|
||||
Box::new(HasMediaHydrator::new(tes_client.clone()).await),
|
||||
Box::new(SubscriptionHydrator::new(tes_client.clone()).await),
|
||||
Box::new(GizmoduckCandidateHydrator::new(gizmoduck_client).await),
|
||||
Box::new(BlockedByHydrator::new(socialgraph_client).await),
|
||||
Box::new(FilteredTopicsHydrator {
|
||||
strato_client: strato_client.clone(),
|
||||
}),
|
||||
Box::new(LanguageCodeHydrator::new(tes_client.clone()).await),
|
||||
];
|
||||
|
||||
// Filters
|
||||
let filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||
Box::new(DropDuplicatesFilter),
|
||||
Box::new(CoreDataHydrationFilter),
|
||||
@@ -114,37 +279,63 @@ impl PhoenixCandidatePipeline {
|
||||
Box::new(RetweetDeduplicationFilter),
|
||||
Box::new(IneligibleSubscriptionFilter),
|
||||
Box::new(PreviouslySeenPostsFilter),
|
||||
Box::new(PreviouslySeenPostsBackupFilter),
|
||||
Box::new(PreviouslyServedPostsFilter),
|
||||
Box::new(MutedKeywordFilter::new()),
|
||||
Box::new(AuthorSocialgraphFilter),
|
||||
Box::new(VideoFilter),
|
||||
Box::new(TopicIdsFilter),
|
||||
Box::new(NewUserTopicIdsFilter),
|
||||
];
|
||||
|
||||
// Scorers
|
||||
let phoenix_scorer = Box::new(PhoenixScorer { phoenix_client });
|
||||
let weighted_scorer = Box::new(WeightedScorer);
|
||||
let author_diversity_scorer = Box::new(AuthorDiversityScorer::default());
|
||||
let oon_scorer = Box::new(OONScorer);
|
||||
let scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||
phoenix_scorer,
|
||||
weighted_scorer,
|
||||
author_diversity_scorer,
|
||||
oon_scorer,
|
||||
];
|
||||
let phoenix_scorer = Box::new(PhoenixScorer {
|
||||
phoenix_client: phoenix_client.clone(),
|
||||
egress_client: Arc::clone(&egress_client),
|
||||
});
|
||||
let ranking_scorer = Box::new(RankingScorer);
|
||||
let vm_ranker = Box::new(VMRanker {
|
||||
client: vm_ranker_client,
|
||||
});
|
||||
let scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>> =
|
||||
vec![phoenix_scorer, ranking_scorer, vm_ranker];
|
||||
|
||||
// Selector
|
||||
let selector = TopKScoreSelector;
|
||||
|
||||
// Post-selection hydrators
|
||||
let post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> =
|
||||
vec![Box::new(VFCandidateHydrator::new(vf_client.clone()).await)];
|
||||
let post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||
Box::new(VFCandidateHydrator::new(vf_client.clone()).await),
|
||||
Box::new(AdsBrandSafetyHydrator::new(safety_label_client)),
|
||||
Box::new(AdsBrandSafetyVfHydrator {
|
||||
client: vf_safety_labels_client,
|
||||
}),
|
||||
Box::new(TweetTypeMetricsHydrator::new()),
|
||||
Box::new(FollowingRepliedUsersHydrator),
|
||||
Box::new(MutualFollowJaccardHydrator {
|
||||
strato_client: strato_client.clone(),
|
||||
}),
|
||||
];
|
||||
|
||||
// Post-selection filters
|
||||
let post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> =
|
||||
vec![Box::new(VFFilter), Box::new(DedupConversationFilter)];
|
||||
let post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> = vec![
|
||||
Box::new(VFFilter),
|
||||
Box::new(AncillaryVFFilter),
|
||||
Box::new(DedupConversationFilter),
|
||||
];
|
||||
|
||||
// Side Effects
|
||||
let side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>> =
|
||||
Arc::new(vec![Box::new(CacheRequestInfoSideEffect { strato_client })]);
|
||||
Arc::new(vec![
|
||||
Box::new(PhoenixExperimentsSideEffect::new(
|
||||
phoenix_client,
|
||||
egress_client,
|
||||
phoenix_kafka_client,
|
||||
)),
|
||||
Box::new(RerankingKafkaSideEffect::new(reranking_kafka_client)),
|
||||
Box::new(RedisPostCandidateCacheSideEffect::new(redis_client)),
|
||||
Box::new(ScoredStatsSideEffect),
|
||||
Box::new(MutualFollowStatsSideEffect),
|
||||
Box::new(PhoenixRequestCacheSideEffect::new(
|
||||
phoenix_request_cache_redis_atla_client,
|
||||
phoenix_request_cache_redis_pdxa_client,
|
||||
)),
|
||||
]);
|
||||
|
||||
PhoenixCandidatePipeline {
|
||||
query_hydrators,
|
||||
@@ -159,54 +350,379 @@ impl PhoenixCandidatePipeline {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prod() -> PhoenixCandidatePipeline {
|
||||
let uas_fetcher =
|
||||
Arc::new(UserActionSequenceFetcher::new().expect("Failed to create UAS fetcher"));
|
||||
let _sgs_client = Arc::new(SocialGraphClient::new());
|
||||
let phoenix_client = Arc::new(
|
||||
ProdPhoenixPredictionClient::new()
|
||||
.await
|
||||
.expect("Failed to create Phoenix prediction client"),
|
||||
);
|
||||
let phoenix_retrieval_client = Arc::new(
|
||||
ProdPhoenixRetrievalClient::new()
|
||||
.await
|
||||
.expect("Failed to create Phoenix retrieval client"),
|
||||
);
|
||||
let thunder_client = Arc::new(ThunderClient::new().await);
|
||||
let strato_client = Arc::new(
|
||||
ProdStratoClient::new()
|
||||
.await
|
||||
.expect("Failed to create Strato client"),
|
||||
);
|
||||
let tes_client = Arc::new(
|
||||
ProdTESClient::new()
|
||||
.await
|
||||
.expect("Failed to create TES client"),
|
||||
);
|
||||
let gizmoduck_client = Arc::new(
|
||||
ProdGizmoduckClient::new()
|
||||
.await
|
||||
.expect("Failed to create Gizmoduck client"),
|
||||
);
|
||||
let vf_client = Arc::new(
|
||||
ProdVisibilityFilteringClient::new(
|
||||
S2S_CHAIN_PATH.clone(),
|
||||
S2S_CRT_PATH.clone(),
|
||||
S2S_KEY_PATH.clone()
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create VF client"),
|
||||
pub async fn prod(
|
||||
shard_coordinate: Option<ShardCoordinate>,
|
||||
datacenter: &str,
|
||||
) -> PhoenixCandidatePipeline {
|
||||
let local_cache_eds = String::new();
|
||||
let atla_phoenix_cache_eds = "";
|
||||
let pdxa_phoenix_cache_eds = "";
|
||||
|
||||
let (
|
||||
flock_socialgraph_client,
|
||||
user_action_aggregation_client,
|
||||
phoenix_client,
|
||||
egress_client,
|
||||
phoenix_retrieval_client,
|
||||
thunder_client,
|
||||
strato_client,
|
||||
tweet_mixer_client,
|
||||
tes_client,
|
||||
gizmoduck_client,
|
||||
vf_client,
|
||||
redis_client,
|
||||
phoenix_request_cache_redis_atla_client,
|
||||
phoenix_request_cache_redis_pdxa_client,
|
||||
phoenix_kafka_client,
|
||||
reranking_kafka_client,
|
||||
vm_ranker_client,
|
||||
safety_label_client,
|
||||
vf_safety_labels_client,
|
||||
impression_bloom_filter_client,
|
||||
ip_client,
|
||||
user_demographics_client,
|
||||
user_inferred_gender_store_client,
|
||||
user_inferred_gender_grpc_client,
|
||||
impressed_posts_client,
|
||||
followed_grok_topics_client,
|
||||
followed_starter_packs_client,
|
||||
) = tokio::join!(
|
||||
async {
|
||||
Arc::new(
|
||||
SocialGraphClient::new(
|
||||
datacenter,
|
||||
&S2S_CHAIN_PATH,
|
||||
&S2S_CRT_PATH,
|
||||
&S2S_KEY_PATH,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create flock SocialGraphClient"),
|
||||
) as Arc<dyn SocialGraphClientOps>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdUserActionAggregationClient::new()
|
||||
.await
|
||||
.expect("Failed to create User Action Aggregation client"),
|
||||
) as Arc<dyn UserActionAggregationClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdPhoenixPredictionClient::new()
|
||||
.await
|
||||
.expect("Failed to create Phoenix prediction client"),
|
||||
) as Arc<dyn PhoenixPredictionClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
EgressPhoenixPredictionClient::connect()
|
||||
.await
|
||||
.expect("Failed to connect to egress sidecar"),
|
||||
) as Arc<dyn PhoenixPredictionClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdPhoenixRetrievalClient::new(Some((
|
||||
PhoenixRetrievalCluster::Experiment1Fou,
|
||||
PhoenixRetrievalCluster::Experiment1Lap7,
|
||||
)))
|
||||
.await
|
||||
.expect("Failed to create Phoenix retrieval client"),
|
||||
) as Arc<dyn PhoenixRetrievalClient + Send + Sync>
|
||||
},
|
||||
async { Arc::new(ThunderClient::new().await) },
|
||||
async {
|
||||
Arc::new(
|
||||
ProdStratoClient::new(shard_coordinate, datacenter)
|
||||
.await
|
||||
.expect("Failed to create Strato client"),
|
||||
) as Arc<dyn StratoClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdTweetMixerClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create TweetMixer client"),
|
||||
) as Arc<dyn TweetMixerClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdTESClient::new(shard_coordinate, datacenter)
|
||||
.await
|
||||
.expect("Failed to create TES client"),
|
||||
) as Arc<dyn TESClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdGizmoduckClient::new(
|
||||
shard_coordinate,
|
||||
datacenter,
|
||||
Some("home-mixer.prod".to_string()),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create Gizmoduck client"),
|
||||
) as Arc<dyn GizmoduckClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdVisibilityFilteringClient::new(
|
||||
S2S_CHAIN_PATH.clone(),
|
||||
S2S_CRT_PATH.clone(),
|
||||
S2S_KEY_PATH.clone(),
|
||||
"home-mixer.prod".to_string(),
|
||||
datacenter.to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create VF client"),
|
||||
) as Arc<dyn VisibilityFilteringClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
XdsRedisClient::new(XdsRedisConfig {
|
||||
eds_resource_name: local_cache_eds.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create xDS Redis client for local cache"),
|
||||
) as Arc<dyn RedisClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
XdsRedisClient::new(XdsRedisConfig {
|
||||
eds_resource_name: atla_phoenix_cache_eds.into(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create xDS Redis client for atla phoenix cache"),
|
||||
) as Arc<dyn RedisClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
XdsRedisClient::new(XdsRedisConfig {
|
||||
eds_resource_name: pdxa_phoenix_cache_eds.into(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create xDS Redis client for pdxa phoenix cache"),
|
||||
) as Arc<dyn RedisClient + Send + Sync>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdKafkaPublisherClient::new(PHOENIX_SCORES_TOPIC, KafkaCluster::Aiml).await,
|
||||
) as Arc<dyn KafkaPublisherClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdKafkaPublisherClient::new(RERANKING_TOPIC, KafkaCluster::Phoenix).await,
|
||||
) as Arc<dyn KafkaPublisherClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdVMRankerClient::new()
|
||||
.await
|
||||
.expect("Failed to create VMRanker client"),
|
||||
) as Arc<dyn VMRankerClient>
|
||||
},
|
||||
async {
|
||||
let s2s = xai_manhattan::s2s::S2sConfig {
|
||||
client_cert_path: S2S_CRT_PATH.clone(),
|
||||
client_key_path: S2S_KEY_PATH.clone(),
|
||||
ca_cert_path: S2S_CHAIN_PATH.clone(),
|
||||
};
|
||||
Arc::new(
|
||||
xai_safety_label_store::ProdSafetyLabelStoreClient::new(datacenter, s2s)
|
||||
.await
|
||||
.expect("Failed to create SafetyLabelStore client"),
|
||||
) as Arc<dyn xai_safety_label_store::SafetyLabelStoreClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdVfClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create VF SafetyLabels client")
|
||||
.with_timeout_ms(500)
|
||||
.with_max_batch_size(150),
|
||||
) as Arc<dyn VfClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdImpressionBloomFilterClient::new(datacenter)
|
||||
.await
|
||||
.expect("Failed to create ImpressionBloomFilter client"),
|
||||
) as Arc<dyn ImpressionBloomFilterClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
GeoIpLocationClient::new(
|
||||
&S2S_CHAIN_PATH,
|
||||
&S2S_CRT_PATH,
|
||||
&S2S_KEY_PATH,
|
||||
datacenter,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create GeoIpLocationClient"),
|
||||
)
|
||||
},
|
||||
async {
|
||||
let s2s = xai_manhattan::s2s::S2sConfig {
|
||||
client_cert_path: S2S_CRT_PATH.clone(),
|
||||
client_key_path: S2S_KEY_PATH.clone(),
|
||||
ca_cert_path: S2S_CHAIN_PATH.clone(),
|
||||
};
|
||||
Arc::new(
|
||||
ProdUserDemographicsClient::new(datacenter, s2s)
|
||||
.await
|
||||
.expect("Failed to create UserDemographics client"),
|
||||
) as Arc<dyn UserDemographicsClient>
|
||||
},
|
||||
async {
|
||||
let s2s = xai_manhattan::s2s::S2sConfig {
|
||||
client_cert_path: S2S_CRT_PATH.clone(),
|
||||
client_key_path: S2S_KEY_PATH.clone(),
|
||||
ca_cert_path: S2S_CHAIN_PATH.clone(),
|
||||
};
|
||||
Arc::new(
|
||||
ProdUserInferredGenderStoreClient::new(datacenter, s2s)
|
||||
.await
|
||||
.expect("Failed to create UserInferredGenderStore client"),
|
||||
) as Arc<dyn UserInferredGenderStoreClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
ProdGenderPredictionGrpcClient::new()
|
||||
.await
|
||||
.expect("Failed to create GenderPredictionGrpcClient"),
|
||||
) as Arc<dyn GenderPredictionGrpcClient>
|
||||
},
|
||||
async {
|
||||
Arc::new(
|
||||
crate::clients::impressed_posts_client::ProdImpressedPostsClient::new(
|
||||
datacenter,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create ImpressedPosts client"),
|
||||
) as Arc<dyn ImpressedPostsClient>
|
||||
},
|
||||
async {
|
||||
let s2s = xai_manhattan::s2s::S2sConfig {
|
||||
client_cert_path: S2S_CRT_PATH.clone(),
|
||||
client_key_path: S2S_KEY_PATH.clone(),
|
||||
ca_cert_path: S2S_CHAIN_PATH.clone(),
|
||||
};
|
||||
Arc::new(
|
||||
ProdFollowedGrokTopicsStoreClient::new(datacenter, s2s)
|
||||
.await
|
||||
.expect("Failed to create FollowedGrokTopicsStore client"),
|
||||
) as Arc<dyn FollowedGrokTopicsStoreClient>
|
||||
},
|
||||
async {
|
||||
let s2s = xai_manhattan::s2s::S2sConfig {
|
||||
client_cert_path: S2S_CRT_PATH.clone(),
|
||||
client_key_path: S2S_KEY_PATH.clone(),
|
||||
ca_cert_path: S2S_CHAIN_PATH.clone(),
|
||||
};
|
||||
Arc::new(
|
||||
ProdFollowedStarterPacksStoreClient::new(datacenter, s2s)
|
||||
.await
|
||||
.expect("Failed to create FollowedStarterPacksStore client"),
|
||||
) as Arc<dyn FollowedStarterPacksStoreClient>
|
||||
},
|
||||
);
|
||||
|
||||
PhoenixCandidatePipeline::build_with_clients(
|
||||
uas_fetcher,
|
||||
user_action_aggregation_client,
|
||||
phoenix_client,
|
||||
egress_client,
|
||||
phoenix_retrieval_client,
|
||||
thunder_client,
|
||||
strato_client,
|
||||
tweet_mixer_client,
|
||||
tes_client,
|
||||
gizmoduck_client,
|
||||
vf_client,
|
||||
redis_client,
|
||||
phoenix_kafka_client,
|
||||
reranking_kafka_client,
|
||||
flock_socialgraph_client,
|
||||
vm_ranker_client,
|
||||
safety_label_client,
|
||||
vf_safety_labels_client,
|
||||
phoenix_request_cache_redis_atla_client,
|
||||
phoenix_request_cache_redis_pdxa_client,
|
||||
impression_bloom_filter_client,
|
||||
ip_client,
|
||||
user_demographics_client,
|
||||
user_inferred_gender_store_client,
|
||||
user_inferred_gender_grpc_client,
|
||||
impressed_posts_client,
|
||||
followed_grok_topics_client,
|
||||
followed_starter_packs_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn mock() -> PhoenixCandidatePipeline {
|
||||
let user_action_aggregation_client = Arc::new(MockUserActionAggregationClient);
|
||||
let phoenix_client = Arc::new(MockPredictClient);
|
||||
let phoenix_retrieval_client = Arc::new(MockRetrievalClient);
|
||||
let thunder_client = Arc::new(ThunderClient::mock());
|
||||
let strato_client = Arc::new(MockStratoClient::default());
|
||||
let tweet_mixer_client: Arc<dyn TweetMixerClient> = Arc::new(MockTweetMixerClient);
|
||||
let tes_client = Arc::new(MockTESClient::default());
|
||||
let gizmoduck_client = Arc::new(MockGizmoduckClient::default());
|
||||
let vf_client = Arc::new(MockVisibilityFilteringClient);
|
||||
let redis_client = Arc::new(MockRedisClient::default());
|
||||
let kafka_client: Arc<dyn KafkaPublisherClient> = Arc::new(MockKafkaPublisherClient);
|
||||
let reranking_kafka_client: Arc<dyn KafkaPublisherClient> =
|
||||
Arc::new(MockKafkaPublisherClient);
|
||||
let mock_socialgraph: Arc<dyn SocialGraphClientOps> = Arc::new(MockSocialGraphClient);
|
||||
let vm_ranker_client: Arc<dyn VMRankerClient> = Arc::new(MockVMRankerClient);
|
||||
let safety_label_client: Arc<dyn xai_safety_label_store::SafetyLabelStoreClient> =
|
||||
Arc::new(xai_safety_label_store::MockSafetyLabelStoreClient);
|
||||
let vf_safety_labels_client: Arc<dyn VfClient> = Arc::new(MockVfClient);
|
||||
let phoenix_request_cache_redis_atla_client = Arc::new(MockRedisClient::default());
|
||||
let phoenix_request_cache_redis_pdxa_client: Arc<dyn RedisClient + Send + Sync> =
|
||||
Arc::new(MockRedisClient::default());
|
||||
let impression_bloom_filter_client: Arc<dyn ImpressionBloomFilterClient> =
|
||||
Arc::new(MockImpressionBloomFilterClient::default());
|
||||
let ip_client = Arc::new(GeoIpLocationClient::mock());
|
||||
let user_demographics_client: Arc<dyn UserDemographicsClient> =
|
||||
Arc::new(MockUserDemographicsClient);
|
||||
let user_inferred_gender_store_client: Arc<dyn UserInferredGenderStoreClient> =
|
||||
Arc::new(MockUserInferredGenderStoreClient);
|
||||
let user_inferred_gender_grpc_client: Arc<dyn GenderPredictionGrpcClient> =
|
||||
Arc::new(MockGenderPredictionGrpcClient);
|
||||
let impressed_posts_client: Arc<dyn ImpressedPostsClient> =
|
||||
Arc::new(crate::clients::impressed_posts_client::MockImpressedPostsClient::default());
|
||||
let followed_grok_topics_client: Arc<dyn FollowedGrokTopicsStoreClient> =
|
||||
Arc::new(MockFollowedGrokTopicsStoreClient);
|
||||
let followed_starter_packs_client: Arc<dyn FollowedStarterPacksStoreClient> =
|
||||
Arc::new(MockFollowedStarterPacksStoreClient);
|
||||
PhoenixCandidatePipeline::build_with_clients(
|
||||
user_action_aggregation_client,
|
||||
phoenix_client.clone(),
|
||||
phoenix_client,
|
||||
phoenix_retrieval_client,
|
||||
thunder_client,
|
||||
strato_client,
|
||||
tweet_mixer_client,
|
||||
tes_client,
|
||||
gizmoduck_client,
|
||||
vf_client,
|
||||
redis_client,
|
||||
kafka_client,
|
||||
reranking_kafka_client,
|
||||
mock_socialgraph,
|
||||
vm_ranker_client,
|
||||
safety_label_client,
|
||||
vf_safety_labels_client,
|
||||
phoenix_request_cache_redis_atla_client,
|
||||
phoenix_request_cache_redis_pdxa_client,
|
||||
impression_bloom_filter_client,
|
||||
ip_client,
|
||||
user_demographics_client,
|
||||
user_inferred_gender_store_client,
|
||||
user_inferred_gender_grpc_client,
|
||||
impressed_posts_client,
|
||||
followed_grok_topics_client,
|
||||
followed_starter_packs_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -253,3 +769,4 @@ impl CandidatePipeline<ScoredPostsQuery, PostCandidate> for PhoenixCandidatePipe
|
||||
params::RESULT_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::util::snowflake;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::time::Duration;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::component_library::utils::duration_since_creation_opt;
|
||||
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||
|
||||
/// Filter that removes tweets older than a specified duration.
|
||||
@@ -15,24 +14,23 @@ impl AgeFilter {
|
||||
Self { max_age }
|
||||
}
|
||||
|
||||
fn is_within_age(&self, tweet_id: i64) -> bool {
|
||||
snowflake::duration_since_creation_opt(tweet_id)
|
||||
fn is_within_age(&self, tweet_id: u64) -> bool {
|
||||
duration_since_creation_opt(tweet_id)
|
||||
.map(|age| age <= self.max_age)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Filter<ScoredPostsQuery, PostCandidate> for AgeFilter {
|
||||
async fn filter(
|
||||
fn filter(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: Vec<PostCandidate>,
|
||||
) -> Result<FilterResult<PostCandidate>, String> {
|
||||
) -> FilterResult<PostCandidate> {
|
||||
let (kept, removed): (Vec<_>, Vec<_>) = candidates
|
||||
.into_iter()
|
||||
.partition(|c| self.is_within_age(c.tweet_id));
|
||||
|
||||
Ok(FilterResult { kept, removed })
|
||||
FilterResult { kept, removed }
|
||||
}
|
||||
}
|
||||
|
||||
19
home-mixer/filters/ancillary_vf_filter.rs
Normal file
19
home-mixer/filters/ancillary_vf_filter.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||
|
||||
pub struct AncillaryVFFilter;
|
||||
|
||||
impl Filter<ScoredPostsQuery, PostCandidate> for AncillaryVFFilter {
|
||||
fn filter(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: Vec<PostCandidate>,
|
||||
) -> FilterResult<PostCandidate> {
|
||||
let (removed, kept): (Vec<_>, Vec<_>) = candidates
|
||||
.into_iter()
|
||||
.partition(|c| c.drop_ancillary_posts.unwrap_or(false));
|
||||
|
||||
FilterResult { kept, removed }
|
||||
}
|
||||
}
|
||||
@@ -1,27 +1,26 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use tonic::async_trait;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||
|
||||
// Remove candidates that are blocked or muted by the viewer
|
||||
pub struct AuthorSocialgraphFilter;
|
||||
|
||||
#[async_trait]
|
||||
impl Filter<ScoredPostsQuery, PostCandidate> for AuthorSocialgraphFilter {
|
||||
async fn filter(
|
||||
fn filter(
|
||||
&self,
|
||||
query: &ScoredPostsQuery,
|
||||
candidates: Vec<PostCandidate>,
|
||||
) -> Result<FilterResult<PostCandidate>, String> {
|
||||
let viewer_blocked_user_ids = query.user_features.blocked_user_ids.clone();
|
||||
let viewer_muted_user_ids = query.user_features.muted_user_ids.clone();
|
||||
|
||||
if viewer_blocked_user_ids.is_empty() && viewer_muted_user_ids.is_empty() {
|
||||
return Ok(FilterResult {
|
||||
kept: candidates,
|
||||
removed: Vec::new(),
|
||||
});
|
||||
}
|
||||
) -> FilterResult<PostCandidate> {
|
||||
let viewer_blocked_user_ids: HashSet<i64> = query
|
||||
.user_features
|
||||
.blocked_user_ids
|
||||
.iter()
|
||||
.copied()
|
||||
.collect();
|
||||
let viewer_muted_user_ids: HashSet<i64> =
|
||||
query.user_features.muted_user_ids.iter().copied().collect();
|
||||
|
||||
let mut kept: Vec<PostCandidate> = Vec::new();
|
||||
let mut removed: Vec<PostCandidate> = Vec::new();
|
||||
@@ -30,13 +29,33 @@ impl Filter<ScoredPostsQuery, PostCandidate> for AuthorSocialgraphFilter {
|
||||
let author_id = candidate.author_id as i64;
|
||||
let muted = viewer_muted_user_ids.contains(&author_id);
|
||||
let blocked = viewer_blocked_user_ids.contains(&author_id);
|
||||
if muted || blocked {
|
||||
let author_blocks_viewer = candidate.author_blocks_viewer.unwrap_or(false);
|
||||
|
||||
let quoted_author_blocks_viewer =
|
||||
candidate.quoted_author_blocks_viewer.unwrap_or(false);
|
||||
let viewer_blocks_quoted_author = candidate
|
||||
.quoted_user_id
|
||||
.map(|uid| viewer_blocked_user_ids.contains(&(uid as i64)))
|
||||
.unwrap_or(false);
|
||||
|
||||
let viewer_blocks_retweeted_user = candidate
|
||||
.retweeted_user_id
|
||||
.map(|uid| viewer_blocked_user_ids.contains(&(uid as i64)))
|
||||
.unwrap_or(false);
|
||||
|
||||
if muted
|
||||
|| blocked
|
||||
|| author_blocks_viewer
|
||||
|| quoted_author_blocks_viewer
|
||||
|| viewer_blocks_quoted_author
|
||||
|| viewer_blocks_retweeted_user
|
||||
{
|
||||
removed.push(candidate);
|
||||
} else {
|
||||
kept.push(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(FilterResult { kept, removed })
|
||||
FilterResult { kept, removed }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use tonic::async_trait;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||
|
||||
pub struct CoreDataHydrationFilter;
|
||||
|
||||
#[async_trait]
|
||||
impl Filter<ScoredPostsQuery, PostCandidate> for CoreDataHydrationFilter {
|
||||
async fn filter(
|
||||
fn filter(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: Vec<PostCandidate>,
|
||||
) -> Result<FilterResult<PostCandidate>, String> {
|
||||
let (kept, removed) = candidates
|
||||
.into_iter()
|
||||
.partition(|c| c.author_id != 0 && !c.tweet_text.trim().is_empty());
|
||||
Ok(FilterResult { kept, removed })
|
||||
) -> FilterResult<PostCandidate> {
|
||||
let (kept, removed) = candidates.into_iter().partition(|c| c.author_id != 0);
|
||||
FilterResult { kept, removed }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
use crate::candidate_pipeline::candidate::PostCandidate;
|
||||
use crate::candidate_pipeline::query::ScoredPostsQuery;
|
||||
use crate::models::candidate::PostCandidate;
|
||||
use crate::models::query::ScoredPostsQuery;
|
||||
use std::collections::HashMap;
|
||||
use tonic::async_trait;
|
||||
use xai_candidate_pipeline::filter::{Filter, FilterResult};
|
||||
|
||||
/// Keeps only the highest-scored candidate per branch of a conversation tree
|
||||
pub struct DedupConversationFilter;
|
||||
|
||||
#[async_trait]
|
||||
impl Filter<ScoredPostsQuery, PostCandidate> for DedupConversationFilter {
|
||||
async fn filter(
|
||||
fn filter(
|
||||
&self,
|
||||
_query: &ScoredPostsQuery,
|
||||
candidates: Vec<PostCandidate>,
|
||||
) -> Result<FilterResult<PostCandidate>, String> {
|
||||
) -> FilterResult<PostCandidate> {
|
||||
let mut kept: Vec<PostCandidate> = Vec::new();
|
||||
let mut removed: Vec<PostCandidate> = Vec::new();
|
||||
let mut best_per_convo: HashMap<u64, (usize, f64)> = HashMap::new();
|
||||
@@ -37,7 +34,7 @@ impl Filter<ScoredPostsQuery, PostCandidate> for DedupConversationFilter {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(FilterResult { kept, removed })
|
||||
FilterResult { kept, removed }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,5 +44,5 @@ fn get_conversation_id(candidate: &PostCandidate) -> u64 {
|
||||
.iter()
|
||||
.copied()
|
||||
.min()
|
||||
.unwrap_or(candidate.tweet_id as u64)
|
||||
.unwrap_or(candidate.tweet_id)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user