Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions server/routes/nl/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flask import render_template
from flask import request
from flask import url_for
from markupsafe import escape

import server.services.datacommons as dc

Expand All @@ -49,7 +50,7 @@ def eval_rig():
def eval_retrieval_generation():
if not current_app.config.get('ENABLE_DATAGEMMA_EVAL_TOOLS', False):
flask.abort(404)
sheet_id = request.args.get('sheet_id')
sheet_id = str(escape(request.args.get('sheet_id', '')))
if not sheet_id:
return redirect(url_for('nl.eval_retrieval_generation',
sheet_id=_TEST_SHEET_ID),
Expand All @@ -61,9 +62,9 @@ def eval_retrieval_generation():
def eval_retrieval_generation_sxs():
if not current_app.config.get('ENABLE_DATAGEMMA_EVAL_TOOLS', False):
flask.abort(404)
sheet_id_a = request.args.get('sheetIdA', '')
sheet_id_b = request.args.get('sheetIdB', '')
session_id = request.args.get('sessionId', '')
sheet_id_a = str(escape(request.args.get('sheetIdA', '')))
sheet_id_b = str(escape(request.args.get('sheetIdB', '')))
session_id = str(escape(request.args.get('sessionId', '')))
return render_template('/eval_retrieval_generation_sxs.html',
sheet_id_a=sheet_id_a,
sheet_id_b=sheet_id_b,
Expand Down
8 changes: 7 additions & 1 deletion server/routes/place/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flask import g
from flask import jsonify
from flask import request
from markupsafe import escape

from server.lib.cache import cache
from server.lib.cache import cache_and_log_mixer_usage
Expand Down Expand Up @@ -78,8 +79,10 @@ async def fetch_place_types(
child_place_type_to_highlight_task,
place_type_task)

place_dcid = str(escape(place_dcid))
# Validate the category parameter.
place_category = request.args.get("category", place_utils.OVERVIEW_CATEGORY)
place_category = str(
escape(request.args.get("category", place_utils.OVERVIEW_CATEGORY)))
if place_category not in place_utils.ALLOWED_CATEGORIES:
return error_response(
f"Argument 'category' {place_category} must be one of: {', '.join(place_utils.ALLOWED_CATEGORIES)}"
Expand Down Expand Up @@ -145,6 +148,7 @@ async def related_places(place_dcid: str):
- Place details (name, type, etc.)
- Lists of nearby, similar, and child places
"""
place_dcid = str(escape(place_dcid))
# Fetch the current place.
place = place_utils.fetch_place(place_dcid, g.locale)

Expand Down Expand Up @@ -223,6 +227,7 @@ def overview_table(place_dcid: str):
"""
Fetches and returns overview table data for the specified place.
"""
place_dcid = str(escape(place_dcid))
data_rows, mixer_response_ids = place_utils.fetch_overview_table_data(
place_dcid)
response_data = PlaceOverviewTableApiResponse(
Expand All @@ -237,5 +242,6 @@ async def place_summary(place_dcid: str):
"""
Fetches and returns place summary data for the specified place.
"""
place_dcid = str(escape(place_dcid))
summary = await place_utils.generate_place_summary(place_dcid, g.locale)
return jsonify(PlaceSummaryApiResponse(summary=summary))
6 changes: 5 additions & 1 deletion server/routes/place/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import flask
from flask import current_app
from flask import g
from markupsafe import escape
from werkzeug.datastructures import MultiDict

from server.lib.cache import cache
Expand Down Expand Up @@ -154,6 +155,7 @@ def place_explorer():
request includes a dcid.
"""
dcid = flask.request.args.get('dcid', None)
dcid = str(escape(dcid)) if dcid else None

# If the request contains a dcid, redirect to the place page.
# This handles redirects from Google Search "Explore More" link.
Expand Down Expand Up @@ -189,7 +191,9 @@ def place(place_dcid):
Args:
place_dcid: DCID of the place to redirect to
"""
place_dcid = str(escape(place_dcid))
redirect_args = dict(flask.request.args)
redirect_args = {k: str(escape(v)) for k, v in redirect_args.items()}
# Strip trailing slashes from place dcids
should_redirect = False
if place_dcid and place_dcid.endswith('/'):
Expand All @@ -198,7 +202,7 @@ def place(place_dcid):

# Rename legacy "topic" request argument to "category"
if 'topic' in flask.request.args:
redirect_args['category'] = flask.request.args.get('topic', '')
redirect_args['category'] = str(escape(flask.request.args.get('topic', '')))
del redirect_args['topic']
should_redirect = True

Expand Down
3 changes: 2 additions & 1 deletion server/routes/redirects.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flask import redirect
from flask import request
from flask import url_for
from markupsafe import escape

from server.lib.config import GLOBAL_CONFIG_BUCKET
from shared.lib import gcs
Expand All @@ -34,7 +35,7 @@

@bp.route('/kg')
def kg():
dcid = request.args.get('dcid', '')
dcid = str(escape(request.args.get('dcid', '')))
if dcid:
url = url_for('browser.browser_node', dcid=dcid)
else:
Expand Down
5 changes: 3 additions & 2 deletions server/routes/search/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flask import Blueprint
from flask import current_app
from flask import request
from markupsafe import escape

import server.services.datacommons as dc

Expand All @@ -28,7 +29,7 @@
@bp.route('/search')
def search():
"""Custom search page"""
query = request.args.get('q', '')
query = str(escape(request.args.get('q', '')))
return flask.render_template('search.html',
maps_api_key=current_app.config['MAPS_API_KEY'],
query=query)
Expand All @@ -37,7 +38,7 @@ def search():
@bp.route('/search_dc')
def search_dc():
"""Add DC API powered search for non-place searches temporarily"""
query_text = request.args.get('q', '')
query_text = str(escape(request.args.get('q', '')))
max_results = int(request.args.get('l', _MAX_SEARCH_RESULTS))
if query_text:
search_response = dc.search(query_text, max_results)
Expand Down
26 changes: 20 additions & 6 deletions server/routes/shared_api/facets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,24 @@

from flask import Blueprint
from flask import request
from markupsafe import escape

from server.lib import fetch

bp = Blueprint("facets", __name__, url_prefix='/api/facets')


def _get_escaped_arg(name: str, default=None):
value = request.args.get(name, default)
if value is None:
return None
return str(escape(value))


def _get_escaped_arg_list(name: str) -> list[str]:
return [str(escape(v)) for v in request.args.getlist(name)]


def is_valid_date(date):
"""
Returns whether or not the date string is valid. Valid date strings are:
Expand All @@ -46,16 +58,17 @@ def get_facets_within():
date: If empty, fetch for all date; Otherwise could be "LATEST" or
specific date.
"""
parent_entity = request.args.get('parentEntity')
parent_entity = _get_escaped_arg('parentEntity')
if not parent_entity:
return 'error: must provide a parentEntity field', 400
child_type = request.args.get('childType')
child_type = _get_escaped_arg('childType')
if not child_type:
return 'error: must provide a childType field', 400
variables = list(filter(lambda x: x != "", request.args.getlist('variables')))
variables = list(
filter(lambda x: x != "", _get_escaped_arg_list('variables')))
if not variables:
return 'error: must provide a variables field', 400
date = request.args.get('date')
date = _get_escaped_arg('date', '')
if not is_valid_date(date):
return 'error: date must be LATEST or YYYY or YYYY-MM or YYYY-MM-DD', 400
return fetch.point_within_facet(parent_entity, child_type, variables, date,
Expand All @@ -66,8 +79,9 @@ def get_facets_within():
def get_facets():
"""Gets the available facets for a list of stat vars for a list of places.
"""
entities = list(filter(lambda x: x != "", request.args.getlist('entities')))
variables = list(filter(lambda x: x != "", request.args.getlist('variables')))
entities = list(filter(lambda x: x != "", _get_escaped_arg_list('entities')))
variables = list(
filter(lambda x: x != "", _get_escaped_arg_list('variables')))
if not entities:
return 'error: must provide a `entities` field', 400
if not variables:
Expand Down
37 changes: 29 additions & 8 deletions server/routes/shared_api/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import flask
from flask import request
from flask import Response
from markupsafe import escape

from server.lib import fetch

Expand All @@ -31,6 +32,23 @@
_OUT_ARROW = '->'


def _escaped_arg(name: str, default=None):
value = request.args.get(name, default)
if value is None:
return None
return str(escape(value))


def _escaped_arg_list(name: str) -> list[str]:
return [str(escape(v)) for v in request.args.getlist(name)]


def _escaped_list(values) -> list[str]:
if not values:
return []
return [str(escape(v)) for v in values]


@dataclass
class PropertySpec:
# name of the property to get values for
Expand All @@ -44,6 +62,8 @@ class PropertySpec:
@bp.route('/triples/<path:direction>/<path:dcid>')
def triples(direction, dcid):
"""Returns all the triples given a node dcid."""
direction = str(escape(direction))
dcid = str(escape(dcid))
if direction != 'in' and direction != 'out':
return "Invalid direction provided, please use 'in' or 'out'", 400
return fetch.triples([dcid], direction == 'out').get(dcid, {})
Expand All @@ -52,14 +72,15 @@ def triples(direction, dcid):
@bp.route('/propvals/<path:direction>', methods=['GET', 'POST'])
def get_property_value(direction):
"""Returns the property values for given node dcids and property label."""
direction = str(escape(direction))
if direction != "in" and direction != "out":
return "Invalid direction provided, please use 'in' or 'out'", 400
dcids = request.args.getlist('dcids')
dcids = _escaped_arg_list('dcids')
if not dcids:
dcids = request.json['dcids']
prop = request.args.get('prop')
dcids = _escaped_list(request.json['dcids'])
prop = _escaped_arg('prop')
if not prop:
prop = request.json['prop']
prop = str(escape(request.json['prop']))
response = fetch.raw_property_values(dcids, prop, direction == 'out')
return Response(json.dumps(response), 200, mimetype='application/json')

Expand Down Expand Up @@ -140,12 +161,12 @@ def expression_property_value():
Returns the property values for given node dcids and a relation expression for
a property (https://docs.datacommons.org/api/rest/v2#relation-expressions)
"""
dcids = request.args.getlist('dcids')
dcids = _escaped_arg_list('dcids')
if not dcids and request.json:
dcids = request.json['dcids']
prop_expression = request.args.get('propExpr')
dcids = _escaped_list(request.json['dcids'])
prop_expression = _escaped_arg('propExpr')
if not prop_expression and request.json:
prop_expression = request.json['propExpr']
prop_expression = str(escape(request.json['propExpr']))
if not dcids:
return 'error: must provide a `dcids` field', 400
if not prop_expression:
Expand Down
22 changes: 17 additions & 5 deletions server/routes/shared_api/observation/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from flask import Blueprint
from flask import request
from markupsafe import escape

from server.lib import util
import server.services.datacommons as dc
Expand All @@ -22,18 +23,29 @@
bp = Blueprint("observation_dates", __name__)


def _get_escaped_arg(name: str, default=None):
value = request.args.get(name, default)
if value is None:
return None
return str(escape(value))


def _get_escaped_arg_list(name: str) -> list[str]:
return [str(escape(v)) for v in request.args.getlist(name)]


@bp.route('/api/observation-dates')
def observation_dates():
"""Given ancestor place, child place type and stat vars, return the dates that
have data for each stat var across all child places.
"""
parent_entity = request.args.get('parentEntity')
parent_entity = _get_escaped_arg('parentEntity')
if not parent_entity:
return 'error: must provide a parentEntity field', 400
child_type = request.args.get('childType')
child_type = _get_escaped_arg('childType')
if not child_type:
return 'error: must provide a childType field', 400
variable = request.args.get('variable')
variable = _get_escaped_arg('variable')
if not variable:
return 'error: must provide a variable field', 400
return dc.get_series_dates(parent_entity, child_type, [variable])
Expand Down Expand Up @@ -116,10 +128,10 @@ def observation_dates_entities():
}
```
"""
entities = request.args.getlist('entities')
entities = _get_escaped_arg_list('entities')
if len(entities) == 0:
return 'error: must provide entities field', 400
variables = request.args.getlist('variables')
variables = _get_escaped_arg_list('variables')
if len(variables) == 0:
return 'error: must provide a variables field', 400
return util.get_series_dates_from_entities(entities, variables)
Loading