Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OnlineNewsMediaCloudESProvider.random_sample method & tests #47

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 150 additions & 53 deletions mc_providers/onlinenews.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import random
from collections import Counter
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, TypeAlias, TypedDict
from typing import Any, Dict, Iterable, List, Mapping, NamedTuple, Optional, Sequence, TypeAlias, TypedDict

# PyPI
import ciso8601
Expand Down Expand Up @@ -298,12 +298,15 @@ def _selector_query_clauses(cls, kwargs: dict) -> list[str]:
domain_strings = " OR ".join(domains)
selector_clauses.append(f"{cls.domain_search_string()}:({domain_strings})")

# put all filters in single query string (NOTE: additive)
# put all filters in single query string
# (NOTE: filters are additive, not subtractive!)
filters = kwargs.get('filters', [])
if len(filters) > 0:
for filter in filters:
if "AND" in filter:
# parenthesize if any chance it has a grabby AND.
# (Phil: did I get this in reverse? and would need to parenthesize
# things containing OR if ANDing subtractive clauses together?)
selector_clauses.append(f"({filter})")
else:
selector_clauses.append(filter)
Expand Down Expand Up @@ -661,8 +664,12 @@ def _assemble_and_chunk_query_str(cls, base_query: str, chunk: bool = True, **kw

from .exceptions import MysteryProviderException, ProviderParseException, PermanentProviderException, ProviderException, TemporaryProviderException

Field: TypeAlias = str | InstrumentedField # quiet mypy complaints
FilterTuple: TypeAlias = tuple[int, Query | None]
ES_Fieldname: TypeAlias = str | InstrumentedField # quiet mypy complaints
ES_Fieldnames: TypeAlias = list[ES_Fieldname]

class FilterTuple(NamedTuple):
weight: int | float
query: Query | None

_ES_MAXPAGE = 1000 # define globally (ie; in .providers)???

Expand All @@ -672,10 +679,8 @@ def _assemble_and_chunk_query_str(cls, base_query: str, chunk: bool = True, **kw
_DEF_SORT_FIELD = "indexed_date"
_DEF_SORT_ORDER = "desc"

# Secondary sort key to break ties if sort_field is non-empty.
# Used as sole sort key if "sort_field" argument or
# _DEF_SORT_FIELD (above) is empty string.
# _doc is documented as most efficient sort key at:
# Secondary sort key to break ties
# (see above about identical indexed_date values)
# https://www.elastic.co/guide/en/elasticsearch/reference/current/sort-search-results.html
#
# But at
Expand All @@ -686,44 +691,89 @@ def _assemble_and_chunk_query_str(cls, base_query: str, chunk: bool = True, **kw
# documents with the same sort values are not ordered consistently."
#
# HOWEVER: use of session_id/preference should route all requests
# from the same session to the same shards for each successive query.
# from the same session to the same shards for each successive query,
# so (to quote HHGttG) "mostly harmless"?
_SECONDARY_SORT_ARGS = {"_doc": "asc"}

def _sanitize(s: str) -> str:
"""
quote slashes to avoid interpretation as /regexp/
as done by _sanitize_es_query in mc_providers/mediacloud.py client library

should only be used by SanitizedQueryString!
"""
return s.replace("/", r"\/")

class SanitizedQueryString(QueryString):
"""
query string (expression) with quoting
"""
def __init__(self, query: str, **kwargs: Any):
super().__init__(query=_sanitize(query), **kwargs)
# XXX always pass allow_leading_wildcard=False
# (take as an argument, defaulting to False)??

# quote slashes to avoid interpretation as /regexp/
# (which not only appear in URLs but are expensive as well)
# as done by _sanitize_es_query in mc_providers/mediacloud.py client library
sanitized = query.replace("/", r"\/")
super().__init__(query=sanitized, **kwargs)

def _format_match(hit: Hit, expanded: bool = False) -> dict:
"""
from news-search-api/client.py EsClientWrapper.format_match
from news-search-api/client.py EsClientWrapper.format_match;
Unparsed (JSON safe), so can be returned by overview for caching.
Result passed to _match_to_row called for any data returned to user.
"""
res = {
"article_title": getattr(hit, "article_title", None),
"publication_date": getattr(hit, "publication_date", "")[:10] or None,
"indexed_date": getattr(hit, "indexed_date", None),
"language": getattr(hit, "language", None),
"full_langauge": getattr(hit, "full_language", None),
"full_langauge": getattr(hit, "full_language", None), # never returned!
"url": getattr(hit, "url", None),
"original_url": getattr(hit, "original_url", None),
"original_url": getattr(hit, "original_url", None), # never returned!
"canonical_domain": getattr(hit, "canonical_domain", None),
"id": hit.meta.id # PB: was re-hash of url!
}
if expanded:
res["text_content"] = getattr(hit, "text_content", None)
return res

# Added for format_match_fields, which was added for random_sample
# NOTE! full_language and original_url are NOT included,
# since they're never returned in a "row".
class _ES_Field:
"""ordinary field"""
METADATA = False

def __init__(self, field_name: str):
self.es_field_name = field_name

def get(self, hit: Hit) -> Any:
return getattr(hit, self.es_field_name)

class _ES_DateTime(_ES_Field):
def get(self, hit: Hit) -> Any:
return ciso8601.parse_datetime(super().get(hit) + "Z")

class _ES_Date(_ES_Field):
def get(self, hit: Hit) -> Any:
return dt.date.fromisoformat(super().get(hit)[:10])

class _ES_MetaData(_ES_Field):
"""
metadata field (incl 'id', 'index', 'score')
"""
METADATA = True # does not need to be requested

def get(self, hit: Hit) -> Any:
return getattr(hit.meta, self.es_field_name)

# map external ("row") field name to _ES_Field instance
# (with "get" method to fetch/parse field from Hit)
_FIELDS: dict[str, _ES_Field] = {
"id": _ES_MetaData("id"),
"indexed_date": _ES_DateTime("indexed_date"),
"language": _ES_Field("language"),
"media_name": _ES_Field("canonical_domain"),
"media_url": _ES_Field("canonical_domain"),
"publish_date": _ES_Date("publication_date"),
"text": _ES_Field("text_content"),
"title": _ES_Field("article_title"),
"url": _ES_Field("url"),
}

def _format_day_counts(bucket: list) -> Counts:
"""
from news-search-api/client.py EsClientWrapper.format_count
Expand Down Expand Up @@ -839,7 +889,7 @@ def get_client(self):
"""
return None

def _fields(self, expanded: bool) -> list[Field]:
def _fields(self, expanded: bool) -> ES_Fieldnames:
"""
from news-search-api/client.py QueryBuilder constructor:
return list of fields for item, paged_items, all_items to return
Expand All @@ -848,40 +898,46 @@ def _fields(self, expanded: bool) -> list[Field]:
with only millisecond resolution, while the stored string usually
has microsecond resolution.
"""
fields: list[Field] = [
fields: ES_Fieldnames = [
"article_title", "publication_date", "indexed_date",
"language", "full_language", "canonical_domain", "url", "original_url"
"language", "canonical_domain", "url",
# never returned to user: consider removing?
"full_language", "original_url"
]
if expanded:
fields.append("text_content")
return fields

# Allow weighting in order to (experimentally) apply filters in
# most efficient order (most selective filter first). An average
# day is about 500K stories for all sources. 1K stories/day would
# be a large source. BUT adding a day means only expanding a
# range, not adding another term...
# Multipliers to allow weighting in order to (experimentally)
# apply filters in most efficient order (cheapest/most selective
# filter first). If all sources and all days were equal they
# would be equally selective. BUT adding a day means only
# expanding a range. _LOWER_ values mean filter applied first.
# So test increasing SELECTOR_WEIGHT?
SELECTOR_WEIGHT = 1 # domains, filters, url_search_strings
DAY_WEIGHT = 1

@classmethod
def _selector_filter_tuple(cls, kwargs: dict) -> FilterTuple:
"""
function to allow construction of DSL
rather than restorting to formatting/quoting query-string
only to have ES have to parse it??
For canonical_domain: "Match" query defaults to OR for space separated words
For url: use "Wildcard"??
Should initially take temp kwarg bool to allow A/B testing!!!
"""

# rather than restorting to formatting/quoting query-string
# only to have ES have to parse it??
# For canonical_domain: "Match" query defaults to OR for space separated words
# For url: use "Wildcard"??
# Should initially take (another) temp kwarg bool to allow A/B testing!!!
# elasticsearch_dsl allows "Query | Query"

selector_clauses = cls._selector_query_clauses(kwargs)
if selector_clauses:
sqs = cls._selector_query_string_from_clauses(selector_clauses)
return (cls._selector_count(kwargs) * cls.SELECTOR_WEIGHT,
SanitizedQueryString(query=sqs))
return FilterTuple(cls._selector_count(kwargs) * cls.SELECTOR_WEIGHT,
SanitizedQueryString(query=sqs))
else:
return (0, None)
# return dummy record, will be weeded out
return FilterTuple(0, None)

def _basic_search(self, user_query: str, start_date: dt.datetime, end_date: dt.datetime,
expanded: bool = False, source: bool = True, **kwargs: Any) -> Search:
Expand Down Expand Up @@ -925,17 +981,18 @@ def _basic_search(self, user_query: str, start_date: dt.datetime, end_date: dt.d

days = (end_date - start_date).days + 1
filters : list[FilterTuple] = [
(days * self.DAY_WEIGHT, Range(publication_date={'gte': start, "lte": end})),
FilterTuple(days * self.DAY_WEIGHT, Range(publication_date={'gte': start, "lte": end})),
self._selector_filter_tuple(kwargs)
# could include languages (etc) here
]

# try applying more selective queries (fewer results) first
filters.sort()
for weight, query in filters:
if query:
# key function avoids attempts to compare Query objects when tied!
filters.sort(key=lambda ft : ft.weight)
for ft in filters:
if ft.query:
# ends up as list under bool.filter:
s = s.filter(query)
s = s.filter(ft.query)

if source: # return source (fields)?
return s.source(self._fields(expanded))
Expand Down Expand Up @@ -1285,12 +1342,54 @@ def all_items(self, query: str,
if not next_page_token:
break

def _sample_titles(self, query: str, start_date: dt.datetime, end_date: dt.datetime, sample_size: int,
**kwargs: Any) -> Iterable[list[dict[str,str]]]:
@staticmethod
def _hit_to_row(hit: Hit, fields: list[str]) -> dict[str, Any]:
"""
format a Hit returned by ES into an external "row".
fields is a list of external/row field names to be returned
(from _format_match (above) AND _matches to rows)
"""
# need to iterate over _external_ names rather than just returned
# fields to be able to return metadata fields
res = {
field: _FIELDS[field].get(hit)
for field in fields
}
return res

# max controlled by index-level index.max_result_window, default is 10K.
MAX_RANDOM_SAMPLE = 10000 # need pagination for more!
def random_sample(self, query: str, start_date: dt.datetime, end_date: dt.datetime,
limit: int, fields: list[str], **kwargs: Any) -> AllItems:
"""
helper for generic Provider._sampled_title_words;
returns a single page with entire sample_size of dicts
returns generator to allow pagination, but actual pagination may
require more work.

If pagination issues perfected, "fields" and "randomize" and "limit"
_COULD_ be kwargs processed by _basic_query?!!!

_PERHAPS_ just be up-front, and have randomize require
passing seed and field arguments for RandomScore?

To discourage indescriminate use/impact allow
MAX_RANDOM_SAMPLE rows for single field query, half that for
two fields and so on... Some fields are (obviously) more
expensive than others (language vs full_text!) Could have
_ES_Field take a per-field weight as optional argument!
"""
if not fields:
# _COULD_ default to everything, but make user think
# about what they need!
raise ValueError("ES.random_sample requires fields list")

if limit < 1 or limit > self.MAX_RANDOM_SAMPLE/len(fields):
raise ValueError(f"ES.random_sample limit must be between 1 and {self.MAX_RANDOM_SAMPLE}/nfields")

# convert requested field names to ES field names
es_fields: ES_Fieldnames = [
_FIELDS[f].es_field_name for f in fields if not _FIELDS[f].METADATA
]

search = self._basic_search(query, start_date, end_date, **kwargs)\
.query(
FunctionScore(
Expand All @@ -1305,17 +1404,15 @@ def _sample_titles(self, query: str, start_date: dt.datetime, end_date: dt.datet
]
)
)\
.source(["article_title", "language"])\
.extra(size=sample_size) # everything in one query
.source(es_fields)\
.extra(size=limit) # everything in one query (for now)

hits = self._search_hits(search)
ret = [{"title": hit.article_title, "language": hit.language} for hit in hits]
return [ret] # single page
yield [self._hit_to_row(hit, fields) for hit in hits] # just one page

def words(self, query: str, start_date: dt.datetime, end_date: dt.datetime, limit: int = 100,
**kwargs: Any) -> list[Term]:
"""
uses generic Provider._sampled_title_words
with data from local helper above
uses generic Provider._sampled_title_words and Provider._sample_titles!
"""
return self._sampled_title_words(query, start_date, end_date, limit=limit, **kwargs)
17 changes: 12 additions & 5 deletions mc_providers/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def set_default_timeout(timeout: int) -> None:

Item: TypeAlias = dict[str, Any] # differs between providers?
Items: TypeAlias = list[Item] # page of items
AllItems: TypeAlias = Generator[Items, None, None]
AllItems: TypeAlias = Iterable[Items] # iterable of pages

class Date(TypedDict):
"""
Expand Down Expand Up @@ -135,6 +135,7 @@ def __init__(self,
def everything_query(self) -> str:
raise QueryingEverythingUnsupportedQuery()

# historically not a random semaple, see random_sample below!
def sample(self, query: str, start_date: dt.datetime, end_date: dt.datetime, limit: int = 20,
**kwargs: Any) -> list[dict]:
raise NotImplementedError("Doesn't support sample content.")
Expand Down Expand Up @@ -171,6 +172,11 @@ def paged_items(self, query: str, start_date: dt.datetime, end_date: dt.datetime
# should read in token, offset, or whatever else they need from `kwargs` to determine which page to return
raise NotImplementedError("Doesn't support fetching all matching content.")

def random_sample(self, query: str, start_date: dt.datetime, end_date: dt.datetime,
limit: int, fields: list[str], **kwargs: Any) -> AllItems:
# NOTE! could be subsumed by passing keyword arguments (fields, randomize) to all_items?!
raise NotImplementedError("Doesn't support fetching random sample.")

def normalized_count_over_time(self, query: str, start_date: dt.datetime, end_date: dt.datetime,
**kwargs: Any) -> dict:
"""
Expand Down Expand Up @@ -226,7 +232,7 @@ def _sampled_languages(self, query: str, start_date: dt.datetime, end_date: dt.d
counts: collections.Counter = collections.Counter()
for page in self.all_items(query, start_date, end_date, limit=sample_size):
sampled_count += len(page)
[counts.update(t['language'] for t in page)]
[counts.update(t.get('language', "UNK") for t in page)]
# clean up results
results = [Language(language=w, value=c, ratio=c/sampled_count) for w, c in counts.most_common(limit)]
return results
Expand All @@ -236,22 +242,23 @@ def _sample_titles(self, query: str, start_date: dt.datetime, end_date: dt.datet
"""
default helper for _sampled_title_words: return a sampling of stories for top words
"""
# XXX force sort on something non-chronological???
return self.all_items(query, start_date, end_date, limit=sample_size)
return self.random_sample(query, start_date, end_date, sample_size,
fields=["title", "language"], **kwargs)

# use this if you need to sample some content for top words
def _sampled_title_words(self, query: str, start_date: dt.datetime, end_date: dt.datetime, limit: int = 100,
**kwargs: Any) -> list[Term]:
# support sample_size kwarg
sample_size = kwargs.pop('sample_size', self.WORDS_SAMPLE)
# NOTE! english stopwords contain contractions!!!
remove_punctuation = bool(kwargs.pop("remove_punctuation", True)) # XXX TEMP?

# grab a sample and count terms as we page through it
sampled_count = 0
counts: collections.Counter = collections.Counter()
for page in self._sample_titles(query, start_date, end_date, sample_size, **kwargs):
sampled_count += len(page)
[counts.update(terms_without_stopwords(t['language'], t['title'], remove_punctuation)) for t in page]
[counts.update(terms_without_stopwords(t.get('language', 'UNK'), t['title'], remove_punctuation)) for t in page if 'title' in t]
# clean up results
results = [Term(term=w, count=c, ratio=c/sampled_count) for w, c in counts.most_common(limit)]
self.trace(Trace.RESULTS, "_sampled_title_words %r", results)
Expand Down
Loading