-
-
Notifications
You must be signed in to change notification settings - Fork 225
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor create messages and stats validation into class
- Loading branch information
Showing
7 changed files
with
60 additions
and
246 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
136 changes: 18 additions & 118 deletions
136
listenbrainz_spark/stats/incremental/listener/entity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,132 +1,32 @@ | ||
import abc | ||
import logging | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from pyspark.errors import AnalysisException | ||
from pyspark.sql import DataFrame | ||
from pyspark.sql.types import StructType, StructField, TimestampType | ||
|
||
import listenbrainz_spark | ||
from listenbrainz_spark import hdfs_connection | ||
from listenbrainz_spark.config import HDFS_CLUSTER_URI | ||
from listenbrainz_spark.path import INCREMENTAL_DUMPS_SAVE_PATH, \ | ||
LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY, LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY | ||
from listenbrainz_spark.stats import run_query | ||
from listenbrainz_spark.utils import read_files_from_HDFS, get_listens_from_dump | ||
from datetime import date | ||
from typing import Optional | ||
|
||
from listenbrainz_spark.path import LISTENBRAINZ_LISTENER_STATS_DIRECTORY | ||
from listenbrainz_spark.stats.incremental.user.entity import UserEntity | ||
|
||
logger = logging.getLogger(__name__) | ||
BOOKKEEPING_SCHEMA = StructType([ | ||
StructField('from_date', TimestampType(), nullable=False), | ||
StructField('to_date', TimestampType(), nullable=False), | ||
StructField('created', TimestampType(), nullable=False), | ||
]) | ||
|
||
|
||
class EntityListener(abc.ABC): | ||
|
||
def __init__(self, entity): | ||
self.entity = entity | ||
|
||
def get_existing_aggregate_path(self, stats_range) -> str: | ||
return f"{LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY}/{self.entity}/{stats_range}" | ||
|
||
def get_bookkeeping_path(self, stats_range) -> str: | ||
return f"{LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY}/{self.entity}/{stats_range}" | ||
class EntityListener(UserEntity, abc.ABC): | ||
|
||
def get_partial_aggregate_schema(self) -> StructType: | ||
raise NotImplementedError() | ||
def __init__(self, entity: str, stats_range: str, database: Optional[str], message_type: Optional[str]): | ||
if not database: | ||
database = f"{self.entity}_listeners_{self.stats_range}_{date.today().strftime('%Y%m%d')}" | ||
super().__init__(entity, stats_range, database, message_type) | ||
|
||
def aggregate(self, table, cache_tables) -> DataFrame: | ||
raise NotImplementedError() | ||
def get_table_prefix(self) -> str: | ||
return f"{self.entity}_listener_{self.stats_range}" | ||
|
||
def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate): | ||
raise NotImplementedError() | ||
def get_base_path(self) -> str: | ||
return LISTENBRAINZ_LISTENER_STATS_DIRECTORY | ||
|
||
def combine_aggregates(self, existing_aggregate, incremental_aggregate) -> DataFrame: | ||
def get_entity_id(self): | ||
raise NotImplementedError() | ||
|
||
def get_top_n(self, final_aggregate, N) -> DataFrame: | ||
raise NotImplementedError() | ||
|
||
def get_cache_tables(self) -> List[str]: | ||
raise NotImplementedError() | ||
|
||
def generate_stats(self, stats_range: str, from_date: datetime, | ||
to_date: datetime, top_entity_limit: int): | ||
cache_tables = [] | ||
for idx, df_path in enumerate(self.get_cache_tables()): | ||
df_name = f"entity_data_cache_{idx}" | ||
cache_tables.append(df_name) | ||
read_files_from_HDFS(df_path).createOrReplaceTempView(df_name) | ||
|
||
metadata_path = self.get_bookkeeping_path(stats_range) | ||
try: | ||
metadata = listenbrainz_spark \ | ||
.session \ | ||
.read \ | ||
.schema(BOOKKEEPING_SCHEMA) \ | ||
.json(f"{HDFS_CLUSTER_URI}{metadata_path}") \ | ||
.collect()[0] | ||
existing_from_date, existing_to_date = metadata["from_date"], metadata["to_date"] | ||
existing_aggregate_usable = existing_from_date.date() == from_date.date() | ||
except AnalysisException: | ||
existing_aggregate_usable = False | ||
logger.info("Existing partial aggregate not found!") | ||
|
||
prefix = f"entity_listener_{self.entity}_{stats_range}" | ||
existing_aggregate_path = self.get_existing_aggregate_path(stats_range) | ||
|
||
only_inc_entities = True | ||
|
||
if not hdfs_connection.client.status(existing_aggregate_path, strict=False) or not existing_aggregate_usable: | ||
table = f"{prefix}_full_listens" | ||
get_listens_from_dump(from_date, to_date, include_incremental=False).createOrReplaceTempView(table) | ||
|
||
logger.info("Creating partial aggregate from full dump listens") | ||
hdfs_connection.client.makedirs(Path(existing_aggregate_path).parent) | ||
full_df = self.aggregate(table, cache_tables) | ||
full_df.write.mode("overwrite").parquet(existing_aggregate_path) | ||
|
||
hdfs_connection.client.makedirs(Path(metadata_path).parent) | ||
metadata_df = listenbrainz_spark.session.createDataFrame( | ||
[(from_date, to_date, datetime.now())], | ||
schema=BOOKKEEPING_SCHEMA | ||
) | ||
metadata_df.write.mode("overwrite").json(metadata_path) | ||
only_inc_entities = False | ||
|
||
full_df = read_files_from_HDFS(existing_aggregate_path) | ||
|
||
if hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False): | ||
table = f"{prefix}_incremental_listens" | ||
read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH) \ | ||
.createOrReplaceTempView(table) | ||
inc_df = self.aggregate(table, cache_tables) | ||
else: | ||
inc_df = listenbrainz_spark.session.createDataFrame([], schema=self.get_partial_aggregate_schema()) | ||
only_inc_entities = False | ||
|
||
full_table = f"{prefix}_existing_aggregate" | ||
full_df.createOrReplaceTempView(full_table) | ||
|
||
inc_table = f"{prefix}_incremental_aggregate" | ||
inc_df.createOrReplaceTempView(inc_table) | ||
|
||
if only_inc_entities: | ||
existing_table = f"{prefix}_filtered_aggregate" | ||
filtered_aggregate_df = self.filter_existing_aggregate(full_table, inc_table) | ||
filtered_aggregate_df.createOrReplaceTempView(existing_table) | ||
else: | ||
existing_table = full_table | ||
|
||
combined_df = self.combine_aggregates(existing_table, inc_table) | ||
|
||
combined_table = f"{prefix}_combined_aggregate" | ||
combined_df.createOrReplaceTempView(combined_table) | ||
results_df = self.get_top_n(combined_table, top_entity_limit) | ||
def items_per_message(self): | ||
return 10000 | ||
|
||
return only_inc_entities, results_df.toLocalIterator() | ||
|
||
def parse_one_user_stats(self, entry: dict): | ||
raise entry |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.