diff --git a/server/routes/browser/api.py b/server/routes/browser/api.py index 9640b1db7a..0ab03264ee 100644 --- a/server/routes/browser/api.py +++ b/server/routes/browser/api.py @@ -18,6 +18,7 @@ import flask from flask import request from flask import Response +from markupsafe import escape from server.lib import fetch from server.lib.cache import cache @@ -75,18 +76,23 @@ def get_observation_id(): return Response(json.dumps("error: must provide a place field"), 400, mimetype='application/json') + place_id = str(escape(place_id)) stat_var_id = request.args.get("statVar") if not stat_var_id: return Response(json.dumps("error: must provide a statVar field"), 400, mimetype='application/json') + stat_var_id = str(escape(stat_var_id)) date = request.args.get("date", "") if not date: return Response(json.dumps("error: must provide a date field"), 400, mimetype='application/json') - request_mmethod = request.args.get("measurementMethod", NO_MMETHOD_KEY) - request_obsPeriod = request.args.get("obsPeriod", NO_OBSPERIOD_KEY) + date = str(escape(date)) + request_mmethod = str( + escape(request.args.get("measurementMethod", NO_MMETHOD_KEY))) + request_obsPeriod = str( + escape(request.args.get("obsPeriod", NO_OBSPERIOD_KEY))) sparql_query = get_sparql_query(place_id, stat_var_id, date) result = "" (_, rows) = dc.query(sparql_query) diff --git a/server/routes/browser/html.py b/server/routes/browser/html.py index 8dcabd3fc9..8bc8c5d915 100644 --- a/server/routes/browser/html.py +++ b/server/routes/browser/html.py @@ -21,6 +21,7 @@ from flask import current_app from flask import g from flask import render_template +from markupsafe import escape import server.lib.render as lib_render import server.lib.shared as shared_api @@ -40,6 +41,7 @@ def bio_browser_main(): @bp.route('/') def browser_node(dcid): + dcid = str(escape(dcid)) node_name = dcid try: api_name = shared_api.names([dcid]).get(dcid) diff --git a/server/routes/dev_datagemma/api.py b/server/routes/dev_datagemma/api.py index 45a22cee62..78b3b21c79 100644 --- a/server/routes/dev_datagemma/api.py +++ b/server/routes/dev_datagemma/api.py @@ -25,6 +25,7 @@ from flask import current_app from flask import request from flask import Response +from markupsafe import escape from server.lib.nl.detection.llm_api import detect_model_name @@ -80,6 +81,8 @@ def datagemma_query(): return 'error: must provide a query field', 400 if not mode or mode not in [_RIG_MODE, _RAG_MODE]: return f'error: must provide a mode field with values {_RIG_MODE} or {_RAG_MODE}', 400 + query = str(escape(query)) + mode = str(escape(mode)) dg_result = _get_datagemma_result(query, mode) result = {'answer': '', 'debug': ''} if dg_result: diff --git a/server/routes/disaster/api.py b/server/routes/disaster/api.py index ce3672dbb9..4da39f1de1 100644 --- a/server/routes/disaster/api.py +++ b/server/routes/disaster/api.py @@ -20,6 +20,7 @@ from flask import current_app from flask import request from flask import Response +from markupsafe import escape from server.lib.cache import cache import server.lib.fetch as fetch @@ -36,6 +37,12 @@ DATA_RETRIEVAL_DATE_LENGTH = 7 +def _escape_value(value): + if value is None: + return None + return str(escape(value)) + + @bp.route('/event-date-range') def event_date_range(): """Gets the date range of event data for a specific event type @@ -46,13 +53,13 @@ def event_date_range(): maxDate: string } """ - event_type = request.args.get('eventType', '') + event_type = _escape_value(request.args.get('eventType', '')) if not event_type: return "error: must provide a eventType field", 400 - place = request.args.get('place', '') + place = _escape_value(request.args.get('place', '')) if not place: return "error: must provide a place field", 400 - use_cache = request.args.get('useCache', '') + use_cache = _escape_value(request.args.get('useCache', '')) result = {'minDate': "", 'maxDate': ""} date_list = [] if use_cache == '1': @@ -151,23 +158,24 @@ def json_event_data(): } } """ - event_type = request.args.get('eventType', '') + event_type = _escape_value(request.args.get('eventType', '')) if not event_type: return "error: must provide a eventType field", 400 - min_date = request.args.get('minDate', '') + min_date = _escape_value(request.args.get('minDate', '')) if not min_date: return "error: must provide a minDate field", 400 - max_date = request.args.get('maxDate', '') + max_date = _escape_value(request.args.get('maxDate', '')) if not max_date: return "error: must provide a maxDate field", 400 - place = request.args.get('place', '') + place = _escape_value(request.args.get('place', '')) if not place: return "error: must provide a place field", 400 - filter_prop = request.args.get('filterProp', '') - filter_unit = request.args.get('filterUnit', '') - filter_upper_limit = float(request.args.get('filterUpperLimit', float("inf"))) - filter_lower_limit = float(request.args.get('filterLowerLimit', - -float("inf"))) + filter_prop = _escape_value(request.args.get('filterProp', '')) + filter_unit = _escape_value(request.args.get('filterUnit', '')) + filter_upper_limit = float( + _escape_value(request.args.get('filterUpperLimit', float("inf")))) + filter_lower_limit = float( + _escape_value(request.args.get('filterLowerLimit', -float("inf")))) event_points = [] disaster_data = current_app.config['DISASTER_DASHBOARD_DATA'] date_list = get_date_list(min_date, max_date) @@ -233,23 +241,23 @@ def event_data(): } } """ - event_type = request.args.get('eventType', '') + event_type = _escape_value(request.args.get('eventType', '')) if not event_type: return "error: must provide a eventType field", 400 - min_date = request.args.get('minDate', '') + min_date = _escape_value(request.args.get('minDate', '')) if not min_date: return "error: must provide a minDate field", 400 - max_date = request.args.get('maxDate', '') + max_date = _escape_value(request.args.get('maxDate', '')) if not max_date: return "error: must provide a maxDate field", 400 - place = request.args.get('place', '') + place = _escape_value(request.args.get('place', '')) if not place: return "error: must provide a place field", 400 - filter_prop = request.args.get('filterProp', '') - filter_unit = request.args.get('filterUnit', '') - req_upper = request.args.get('filterUpperLimit', None) + filter_prop = _escape_value(request.args.get('filterProp', '')) + filter_unit = _escape_value(request.args.get('filterUnit', '')) + req_upper = _escape_value(request.args.get('filterUpperLimit', None)) filter_upper_limit = float(req_upper) if req_upper else None - req_lower = request.args.get('filterLowerLimit', None) + req_lower = _escape_value(request.args.get('filterLowerLimit', None)) filter_lower_limit = float(req_lower) if req_lower else None date_list = get_date_list(min_date, max_date) event_points = [] diff --git a/server/routes/disaster/html.py b/server/routes/disaster/html.py index c52e398cfb..6961eecf18 100644 --- a/server/routes/disaster/html.py +++ b/server/routes/disaster/html.py @@ -22,6 +22,7 @@ from flask import redirect from flask import url_for from google.protobuf.json_format import MessageToJson +from markupsafe import escape import server.lib.subject_page_config as lib_subject_page_config import server.lib.util @@ -42,6 +43,7 @@ def disaster_dashboard(place_dcid=None): place_dcid=lib_subject_page_config.DEFAULT_PLACE_DCID), code=302) + place_dcid = str(escape(place_dcid)) raw_dashboard_config = current_app.config['DISASTER_DASHBOARD_CONFIG'] if current_app.config['LOCAL']: # Reload configs for faster local iteration. diff --git a/server/routes/experiments/biomed_nl/api.py b/server/routes/experiments/biomed_nl/api.py index 47c983c140..9c1a61afde 100644 --- a/server/routes/experiments/biomed_nl/api.py +++ b/server/routes/experiments/biomed_nl/api.py @@ -21,6 +21,7 @@ from flask import request from flask import Response from google import genai +from markupsafe import escape from pydantic import BaseModel from pydantic import Field @@ -201,6 +202,7 @@ def llm_search(): query = request.args.get('q') if not query: return 'error: q param is required', 400 + query = str(escape(query)) result = _fulfill_traversal_query(query).model_dump(by_alias=True, mode="json") return Response(json.dumps(result), 200, mimetype='application/json') diff --git a/server/routes/explore/api.py b/server/routes/explore/api.py index 22cad5c05f..eb77f90f43 100644 --- a/server/routes/explore/api.py +++ b/server/routes/explore/api.py @@ -23,6 +23,7 @@ from flask import current_app from flask import request from flask import Response +from markupsafe import escape from server.lib.cache import cache from server.lib.nl.common import serialize @@ -49,13 +50,20 @@ bp = Blueprint('explore_api', __name__, url_prefix='/api/explore') +def _escape_value(value): + if value is None: + return None + return str(escape(value)) + + # # The detection endpoint. # @bp.route('/detect', methods=['POST']) def detect(): debug_logs = {} - client = request.args.get(Params.CLIENT.value, Clients.DEFAULT.value) + client = _escape_value(request.args.get(Params.CLIENT.value, + Clients.DEFAULT.value)) utterance, error_json = helpers.parse_query_and_detect( request, 'explore', client, debug_logs) @@ -105,8 +113,9 @@ def fulfill(): def detect_and_fulfill(): debug_logs = {} - test = request.args.get(Params.TEST.value, '') - client = request.args.get(Params.CLIENT.value, Clients.DEFAULT.value) + test = _escape_value(request.args.get(Params.TEST.value, '')) + client = _escape_value( + request.args.get(Params.CLIENT.value, Clients.DEFAULT.value)) # First sanity DC name, if any. dc_name = request.get_json().get(Params.DC.value) @@ -127,8 +136,8 @@ def detect_and_fulfill(): utterance.insight_ctx[ Params.EXP_MORE_DISABLED.value] = request.get_json().get( Params.EXP_MORE_DISABLED, "") - utterance.insight_ctx[Params.SKIP_RELATED_THINGS] = request.args.get( - Params.SKIP_RELATED_THINGS.value, '') == 'true' + utterance.insight_ctx[Params.SKIP_RELATED_THINGS] = _escape_value( + request.args.get(Params.SKIP_RELATED_THINGS.value, '')) == 'true' helpers.update_insight_ctx_for_chart_fulfill(request, utterance, dc_name) # Important to setup utterance for explore flow (this is really the only difference @@ -252,9 +261,10 @@ def _fulfill_with_chart_config(utterance: nl_utterance.Utterance, def _fulfill_with_insight_ctx(request: Dict, debug_logs: Dict, counters: ctr.Counters) -> Dict: insight_ctx = request.get_json() - test = request.args.get(Params.TEST.value, '') - client = request.args.get(Params.CLIENT.value, Clients.DEFAULT.value) - mode = request.args.get(Params.MODE.value, '') + test = _escape_value(request.args.get(Params.TEST.value, '')) + client = _escape_value( + request.args.get(Params.CLIENT.value, Clients.DEFAULT.value)) + mode = _escape_value(request.args.get(Params.MODE.value, '')) if not insight_ctx: return helpers.abort('Sorry, could not answer your query.', '', [], diff --git a/server/routes/explore/helpers.py b/server/routes/explore/helpers.py index 0e38a70ef0..55cadfe8c7 100644 --- a/server/routes/explore/helpers.py +++ b/server/routes/explore/helpers.py @@ -58,12 +58,17 @@ _SANITY_TEST = 'sanity' +def _escape_value(value): + if value is None: + return None + return str(escape(value)) + + # Get the default place to be used for fulfillment. If there is a place in the # request, use that. Otherwise, use pre-chosen places. def _get_default_place(request: Dict, is_special_dc: bool, debug_logs: Dict): - default_place_dcid = request.args.get(params.Params.DEFAULT_PLACE, - default='', - type=str) + default_place_dcid = _escape_value( + request.args.get(params.Params.DEFAULT_PLACE, default='', type=str)) # If default place from request is earth, use the Earth place object if default_place_dcid == constants.EARTH.dcid: return constants.EARTH @@ -92,15 +97,15 @@ def parse_query_and_detect(request: Dict, backend: str, client: str, flask.abort(404) nl_bad_words = current_app.config['NL_BAD_WORDS'] - test = request.args.get(params.Params.TEST.value, '') + test = _escape_value(request.args.get(params.Params.TEST.value, '')) # i18n param - i18n_str = request.args.get(params.Params.I18N.value, '') + i18n_str = _escape_value(request.args.get(params.Params.I18N.value, '')) i18n = i18n_str and i18n_str.lower() == 'true' # Index-type default is in nl_server. - idx_param_str = request.args.get(params.Params.INDEX.value, '') + idx_param_str = _escape_value(request.args.get(params.Params.INDEX.value, '')) embeddings_index_types = [x.strip() for x in idx_param_str.split(',')] - original_query = request.args.get('q') + original_query = _escape_value(request.args.get('q')) if not original_query: err_json = helpers.abort( 'Received an empty query, please type a few words :)', @@ -115,13 +120,14 @@ def parse_query_and_detect(request: Dict, backend: str, client: str, embeddings_index_types = params.dc_to_embedding_types(dc, embeddings_index_types) - detector_type = request.args.get(params.Params.DETECTOR.value, - default=RequestedDetectorType.Hybrid.value, - type=str) + detector_type = _escape_value( + request.args.get(params.Params.DETECTOR.value, + default=RequestedDetectorType.Hybrid.value, + type=str)) # mode param use_default_place = True - mode = request.args.get(params.Params.MODE.value, '') + mode = _escape_value(request.args.get(params.Params.MODE.value, '')) if mode == QueryMode.STRICT: # Strict mode is compatible only with Heuristic Detector! detector_type = RequestedDetectorType.Heuristic.value @@ -180,10 +186,10 @@ def parse_query_and_detect(request: Dict, backend: str, client: str, use_default_place = False # See if we have a variable reranker model specified. - reranker = request.args.get(params.Params.RERANKER.value) + reranker = _escape_value(request.args.get(params.Params.RERANKER.value)) # Get sv threshold as a float if it was passed in the request - var_threshold = request.args.get(params.Params.VAR_THRESHOLD.value) + var_threshold = _escape_value(request.args.get(params.Params.VAR_THRESHOLD.value)) if var_threshold: # if sv_threshold is not a float, don't set sv_threshold try: @@ -192,8 +198,8 @@ def parse_query_and_detect(request: Dict, backend: str, client: str, var_threshold = None # StopWords handling - include_stop_words_str = request.args.get( - params.Params.INCLUDE_STOP_WORDS.value, '') + include_stop_words_str = _escape_value( + request.args.get(params.Params.INCLUDE_STOP_WORDS.value, '')) detection_args = DetectionArgs( embeddings_index_types=embeddings_index_types, @@ -280,7 +286,7 @@ def update_insight_ctx_for_chart_fulfill(request: Dict, params.Params.MAX_CHARTS, params.Params.CHART_TYPE, ]: - param_val = request.args.get(p, None) + param_val = _escape_value(request.args.get(p, None)) if param_val != None: if param_val.isnumeric(): param_val = int(param_val) diff --git a/server/routes/explore/html.py b/server/routes/explore/html.py index 573ea4721b..a2d455d6cf 100644 --- a/server/routes/explore/html.py +++ b/server/routes/explore/html.py @@ -18,6 +18,7 @@ from flask import Blueprint from flask import current_app from flask import render_template +from markupsafe import escape bp = Blueprint('explore', __name__, url_prefix='/explore') @@ -32,6 +33,7 @@ def page(): @bp.route('/') def landing(topic): + topic = str(escape(topic)) return render_template('/explore_landing.html', topic=topic, website_hash=os.environ.get("WEBSITE_HASH")) diff --git a/server/routes/nl/api.py b/server/routes/nl/api.py index c1f7354436..25b3aea8d3 100644 --- a/server/routes/nl/api.py +++ b/server/routes/nl/api.py @@ -19,6 +19,7 @@ from flask import Blueprint from flask import request from flask import Response +from markupsafe import escape from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field @@ -39,6 +40,16 @@ Dcid = str +def _escape_value(value): + if value is None: + return None + return str(escape(value)) + + +def _escape_list(values: list[str]) -> list[str]: + return [str(escape(v)) for v in values] + + class IndicatorScore(BaseModel): dcid: str score: float @@ -106,7 +117,7 @@ class SearchVariablesResponse(ApiBaseModel): @bp.route('/encode-vector', methods=['POST']) def encode_vector(): """Retrieves the embedding vector for a given query and model.""" - model = request.args.get('model') + model = _escape_value(request.args.get('model')) queries = request.json.get('queries', []) return json.dumps(dc.nl_encode(model, queries)) @@ -114,7 +125,7 @@ def encode_vector(): @bp.route('/search-vector', methods=['POST']) def search_vector(): """Performs vector search for a given query and embedding index.""" - idx = request.args.get('idx') + idx = _escape_value(request.args.get('idx')) if not idx: flask.abort(400, 'Must provide an `idx`') queries = request.json.get('queries') @@ -123,7 +134,8 @@ def search_vector(): return dc.nl_search_vars(queries, idx.split(','), - skip_topics=request.args.get('skip_topics', '')) + skip_topics=_escape_value( + request.args.get('skip_topics', ''))) def _get_property_value(sv_data: dict, @@ -270,14 +282,14 @@ def _build_final_response(queries: list[str], indices: list[str], def _parse_request_args() -> SearchIndicatorsRequest: """Parses and validates all query string arguments for the search.""" - queries = request.args.getlist(PARAM_QUERIES) + queries = _escape_list(request.args.getlist(PARAM_QUERIES)) if not queries: flask.abort(400, f"`{PARAM_QUERIES}` is a required parameter") threshold_override = None if PARAM_THRESHOLD in request.args: try: - threshold_override = float(request.args[PARAM_THRESHOLD]) + threshold_override = float(_escape_value(request.args[PARAM_THRESHOLD])) except ValueError: flask.abort(400, f"The `{PARAM_THRESHOLD}` parameter must be a valid float.") @@ -285,18 +297,18 @@ def _parse_request_args() -> SearchIndicatorsRequest: limit_per_index = None if PARAM_LIMIT_PER_INDEX in request.args: try: - limit_per_index = int(request.args[PARAM_LIMIT_PER_INDEX]) + limit_per_index = int(_escape_value(request.args[PARAM_LIMIT_PER_INDEX])) except ValueError: flask.abort( 400, f"The `{PARAM_LIMIT_PER_INDEX}` parameter must be a valid integer.") - indices = request.args.getlist(PARAM_INDEX) + indices = _escape_list(request.args.getlist(PARAM_INDEX)) if not indices: server_config = dc.nl_server_config() indices = server_config.get("default_indexes", []) - include_types = request.args.getlist(PARAM_INCLUDE_TYPES) + include_types = _escape_list(request.args.getlist(PARAM_INCLUDE_TYPES)) skip_topics = False if include_types and TYPE_TOPIC not in include_types: skip_topics = True diff --git a/server/routes/nl/html.py b/server/routes/nl/html.py index 1a05c18e0a..718b0ca54b 100644 --- a/server/routes/nl/html.py +++ b/server/routes/nl/html.py @@ -23,6 +23,7 @@ from flask import render_template from flask import request from flask import url_for +from markupsafe import escape import server.services.datacommons as dc @@ -49,7 +50,7 @@ def eval_rig(): def eval_retrieval_generation(): if not current_app.config.get('ENABLE_DATAGEMMA_EVAL_TOOLS', False): flask.abort(404) - sheet_id = request.args.get('sheet_id') + sheet_id = str(escape(request.args.get('sheet_id', ''))) if not sheet_id: return redirect(url_for('nl.eval_retrieval_generation', sheet_id=_TEST_SHEET_ID), @@ -61,9 +62,9 @@ def eval_retrieval_generation(): def eval_retrieval_generation_sxs(): if not current_app.config.get('ENABLE_DATAGEMMA_EVAL_TOOLS', False): flask.abort(404) - sheet_id_a = request.args.get('sheetIdA', '') - sheet_id_b = request.args.get('sheetIdB', '') - session_id = request.args.get('sessionId', '') + sheet_id_a = str(escape(request.args.get('sheetIdA', ''))) + sheet_id_b = str(escape(request.args.get('sheetIdB', ''))) + session_id = str(escape(request.args.get('sessionId', ''))) return render_template('/eval_retrieval_generation_sxs.html', sheet_id_a=sheet_id_a, sheet_id_b=sheet_id_b, diff --git a/server/routes/oembed/chart.py b/server/routes/oembed/chart.py index af2be3f1da..fef639498e 100644 --- a/server/routes/oembed/chart.py +++ b/server/routes/oembed/chart.py @@ -17,6 +17,7 @@ from flask import current_app from flask import render_template from flask import request +from markupsafe import escape from server.lib.cache import cache from server.routes import TIMEOUT @@ -42,12 +43,14 @@ def render_chart(): chart_type = request.args.get("chartType", None) if not chart_type or not chart_type in ALLOWED_CHART_TYPES: return "error: must provide a valid chart type", 400 + chart_type = str(escape(chart_type)) attributes = "" for key in request.args.keys(): if key != "chartType": - values = request.args.getlist(key) - attributes += f'{key}="{" ".join(values)}" ' + sanitized_key = str(escape(key)) + values = [str(escape(v)) for v in request.args.getlist(key)] + attributes += f'{sanitized_key}="{" ".join(values)}" ' component = ( f"") diff --git a/server/routes/oembed/oembed.py b/server/routes/oembed/oembed.py index 606f7030ca..f7b7211a41 100644 --- a/server/routes/oembed/oembed.py +++ b/server/routes/oembed/oembed.py @@ -21,6 +21,7 @@ from flask import current_app from flask import request from flask import Response +from markupsafe import escape from server.lib.cache import cache from server.routes import TIMEOUT @@ -50,6 +51,7 @@ def render_chart(): if not url or not re.match(url_regex, url): # reject request if url not matching allowed pattern or not provided return "error: must provide a valid url", 400 + url = str(escape(url)) max_width = request.args.get("maxwidth", type=int, default=500) max_height = request.args.get("maxheight", type=int, default=400) diff --git a/server/routes/place/api.py b/server/routes/place/api.py index 79bcd06669..8ff3d67499 100644 --- a/server/routes/place/api.py +++ b/server/routes/place/api.py @@ -21,6 +21,7 @@ from flask import g from flask import jsonify from flask import request +from markupsafe import escape from server.lib.cache import cache from server.lib.cache import cache_and_log_mixer_usage @@ -78,8 +79,10 @@ async def fetch_place_types( child_place_type_to_highlight_task, place_type_task) + place_dcid = str(escape(place_dcid)) # Validate the category parameter. - place_category = request.args.get("category", place_utils.OVERVIEW_CATEGORY) + place_category = str( + escape(request.args.get("category", place_utils.OVERVIEW_CATEGORY))) if place_category not in place_utils.ALLOWED_CATEGORIES: return error_response( f"Argument 'category' {place_category} must be one of: {', '.join(place_utils.ALLOWED_CATEGORIES)}" @@ -145,6 +148,7 @@ async def related_places(place_dcid: str): - Place details (name, type, etc.) - Lists of nearby, similar, and child places """ + place_dcid = str(escape(place_dcid)) # Fetch the current place. place = place_utils.fetch_place(place_dcid, g.locale) @@ -223,6 +227,7 @@ def overview_table(place_dcid: str): """ Fetches and returns overview table data for the specified place. """ + place_dcid = str(escape(place_dcid)) data_rows, mixer_response_ids = place_utils.fetch_overview_table_data( place_dcid) response_data = PlaceOverviewTableApiResponse( @@ -237,5 +242,6 @@ async def place_summary(place_dcid: str): """ Fetches and returns place summary data for the specified place. """ + place_dcid = str(escape(place_dcid)) summary = await place_utils.generate_place_summary(place_dcid, g.locale) return jsonify(PlaceSummaryApiResponse(summary=summary)) diff --git a/server/routes/place/html.py b/server/routes/place/html.py index 06120187a3..49c2379e83 100644 --- a/server/routes/place/html.py +++ b/server/routes/place/html.py @@ -23,6 +23,7 @@ import flask from flask import current_app from flask import g +from markupsafe import escape from werkzeug.datastructures import MultiDict from server.lib.cache import cache @@ -154,6 +155,7 @@ def place_explorer(): request includes a dcid. """ dcid = flask.request.args.get('dcid', None) + dcid = str(escape(dcid)) if dcid else None # If the request contains a dcid, redirect to the place page. # This handles redirects from Google Search "Explore More" link. @@ -189,7 +191,9 @@ def place(place_dcid): Args: place_dcid: DCID of the place to redirect to """ + place_dcid = str(escape(place_dcid)) redirect_args = dict(flask.request.args) + redirect_args = {k: str(escape(v)) for k, v in redirect_args.items()} # Strip trailing slashes from place dcids should_redirect = False if place_dcid and place_dcid.endswith('/'): @@ -198,7 +202,7 @@ def place(place_dcid): # Rename legacy "topic" request argument to "category" if 'topic' in flask.request.args: - redirect_args['category'] = flask.request.args.get('topic', '') + redirect_args['category'] = str(escape(flask.request.args.get('topic', ''))) del redirect_args['topic'] should_redirect = True diff --git a/server/routes/place_list/html.py b/server/routes/place_list/html.py index 67d28082cd..6483385861 100644 --- a/server/routes/place_list/html.py +++ b/server/routes/place_list/html.py @@ -16,6 +16,7 @@ from flask import Blueprint from flask import render_template +from markupsafe import escape from server.lib.cache import cache from server.lib.fetch import raw_property_values @@ -41,6 +42,7 @@ def index(): @bp.route('/place-list/') @cache.memoize(timeout=TIMEOUT) def node(dcid): + dcid = str(escape(dcid)) child_places = child_fetch(dcid) place_by_type = collections.defaultdict(list) for place_type, childs in child_places.items(): diff --git a/server/routes/ranking/api.py b/server/routes/ranking/api.py index 43d28cbb1c..a94066bdf1 100644 --- a/server/routes/ranking/api.py +++ b/server/routes/ranking/api.py @@ -17,6 +17,7 @@ import logging import flask +from markupsafe import escape from server.lib.util import error_response import server.routes.shared_api.place as place_api @@ -44,6 +45,9 @@ def ranking_api(stat_var, place_type, place=None): pc (per capita - the presence of the key enables it) bottom (show bottom ranking instead - the presence of the key enables it) """ + stat_var = str(escape(stat_var)) + place_type = str(escape(place_type)) + place = str(escape(place)) if place else place is_per_capita = flask.request.args.get('pc', False) != False is_show_bottom = flask.request.args.get('bottom', False) != False rank_keys = BOTTOM_KEYS_KEEP if is_show_bottom else TOP_KEYS_KEEP diff --git a/server/routes/ranking/html.py b/server/routes/ranking/html.py index ea9ee01943..36d85508bb 100644 --- a/server/routes/ranking/html.py +++ b/server/routes/ranking/html.py @@ -17,6 +17,7 @@ import flask from flask import current_app +from markupsafe import escape import server.routes.shared_api.place as place_api @@ -26,6 +27,9 @@ @bp.route('//', strict_slashes=False) @bp.route('///') def ranking(stat_var, place_type, place_dcid=''): + stat_var = str(escape(stat_var)) + place_type = str(escape(place_type)) + place_dcid = str(escape(place_dcid)) place_name = '' if place_dcid: place_names = place_api.get_i18n_name([place_dcid]) diff --git a/server/routes/redirects.py b/server/routes/redirects.py index bd948370ea..89161f8f37 100644 --- a/server/routes/redirects.py +++ b/server/routes/redirects.py @@ -22,6 +22,7 @@ from flask import redirect from flask import request from flask import url_for +from markupsafe import escape from server.lib.config import GLOBAL_CONFIG_BUCKET from shared.lib import gcs @@ -34,7 +35,7 @@ @bp.route('/kg') def kg(): - dcid = request.args.get('dcid', '') + dcid = str(escape(request.args.get('dcid', ''))) if dcid: url = url_for('browser.browser_node', dcid=dcid) else: diff --git a/server/routes/search/html.py b/server/routes/search/html.py index f6810ad4ad..b2de3e909d 100644 --- a/server/routes/search/html.py +++ b/server/routes/search/html.py @@ -17,6 +17,7 @@ from flask import Blueprint from flask import current_app from flask import request +from markupsafe import escape import server.services.datacommons as dc @@ -28,7 +29,7 @@ @bp.route('/search') def search(): """Custom search page""" - query = request.args.get('q', '') + query = str(escape(request.args.get('q', ''))) return flask.render_template('search.html', maps_api_key=current_app.config['MAPS_API_KEY'], query=query) @@ -37,7 +38,7 @@ def search(): @bp.route('/search_dc') def search_dc(): """Add DC API powered search for non-place searches temporarily""" - query_text = request.args.get('q', '') + query_text = str(escape(request.args.get('q', ''))) max_results = int(request.args.get('l', _MAX_SEARCH_RESULTS)) if query_text: search_response = dc.search(query_text, max_results) diff --git a/server/routes/shared_api/autocomplete/autocomplete.py b/server/routes/shared_api/autocomplete/autocomplete.py index f179ff5319..b889957236 100644 --- a/server/routes/shared_api/autocomplete/autocomplete.py +++ b/server/routes/shared_api/autocomplete/autocomplete.py @@ -22,6 +22,7 @@ from flask import Blueprint from flask import jsonify from flask import request +from markupsafe import escape from server.lib.feature_flags import ENABLE_STAT_VAR_AUTOCOMPLETE from server.lib.feature_flags import is_feature_enabled @@ -41,9 +42,9 @@ async def autocomplete(): """Predicts the user query for location and stat vars.""" start_time = time.time() - lang = request.args.get('hl', 'en') - original_query = request.args.get('query', '') - has_location = request.args.get('has_location', 'false') == 'true' + lang = str(escape(request.args.get('hl', 'en'))) + original_query = str(escape(request.args.get('query', ''))) + has_location = str(escape(request.args.get('has_location', 'false'))) == 'true' # Don't trigger autocomplete on short queries or if the last word is a stop word. words = original_query.split() diff --git a/server/routes/shared_api/choropleth.py b/server/routes/shared_api/choropleth.py index b107337d75..b1cd2f525a 100644 --- a/server/routes/shared_api/choropleth.py +++ b/server/routes/shared_api/choropleth.py @@ -26,6 +26,7 @@ from flask import send_file from flask import url_for from geojson_rewind import rewind +from markupsafe import escape from server.lib.cache import cache import server.lib.fetch as fetch @@ -245,14 +246,18 @@ def geojson(): return Response(json.dumps("error: must provide a placeDcid field"), 400, mimetype='application/json') + place_dcid = str(escape(place_dcid)) place_type = request.args.get("placeType") if not place_type: place_dcid, place_type = get_choropleth_display_level(place_dcid) + else: + place_type = str(escape(place_type)) place_name_prop = request.args.get("placeNameProp") + place_name_prop = str(escape(place_name_prop)) if place_name_prop else None # If the request has a geoJsonProp, use that. Otherwise, use the default # property specified in the app config. - geojson_prop = request.args.get("geoJsonProp", - current_app.config["GEO_JSON_PROP"]) + geojson_prop = str( + escape(request.args.get("geoJsonProp", current_app.config["GEO_JSON_PROP"]))) cached_geojson = current_app.config['CACHED_GEOJSONS'].get( place_dcid, {}).get(place_type, {}).get(geojson_prop, {}) if cached_geojson: @@ -407,6 +412,7 @@ def choropleth_data(dcid): sources: [], } """ + dcid = str(escape(dcid)) cc = request.json.get('spec', None) if not cc: return Response(json.dumps({}), 200, mimetype='application/json') @@ -493,11 +499,13 @@ def get_map_points(): return Response(json.dumps("error: must provide a placeDcid field"), 400, mimetype='application/json') + place_dcid = str(escape(place_dcid)) place_type = request.args.get("placeType") if not place_type: return Response(json.dumps("error: must provide a placeType field"), 400, mimetype='application/json') + place_type = str(escape(place_type)) geos = [] geos = fetch.descendent_places([place_dcid], place_type).get(place_dcid, []) if not geos: diff --git a/server/routes/shared_api/facets.py b/server/routes/shared_api/facets.py index fcb03ae7e8..9aeeaf1b0a 100644 --- a/server/routes/shared_api/facets.py +++ b/server/routes/shared_api/facets.py @@ -16,12 +16,24 @@ from flask import Blueprint from flask import request +from markupsafe import escape from server.lib import fetch bp = Blueprint("facets", __name__, url_prefix='/api/facets') +def _get_escaped_arg(name: str, default=None): + value = request.args.get(name, default) + if value is None: + return None + return str(escape(value)) + + +def _get_escaped_arg_list(name: str) -> list[str]: + return [str(escape(v)) for v in request.args.getlist(name)] + + def is_valid_date(date): """ Returns whether or not the date string is valid. Valid date strings are: @@ -46,16 +58,17 @@ def get_facets_within(): date: If empty, fetch for all date; Otherwise could be "LATEST" or specific date. """ - parent_entity = request.args.get('parentEntity') + parent_entity = _get_escaped_arg('parentEntity') if not parent_entity: return 'error: must provide a parentEntity field', 400 - child_type = request.args.get('childType') + child_type = _get_escaped_arg('childType') if not child_type: return 'error: must provide a childType field', 400 - variables = list(filter(lambda x: x != "", request.args.getlist('variables'))) + variables = list( + filter(lambda x: x != "", _get_escaped_arg_list('variables'))) if not variables: return 'error: must provide a variables field', 400 - date = request.args.get('date') + date = _get_escaped_arg('date', '') if not is_valid_date(date): return 'error: date must be LATEST or YYYY or YYYY-MM or YYYY-MM-DD', 400 return fetch.point_within_facet(parent_entity, child_type, variables, date, @@ -66,8 +79,9 @@ def get_facets_within(): def get_facets(): """Gets the available facets for a list of stat vars for a list of places. """ - entities = list(filter(lambda x: x != "", request.args.getlist('entities'))) - variables = list(filter(lambda x: x != "", request.args.getlist('variables'))) + entities = list(filter(lambda x: x != "", _get_escaped_arg_list('entities'))) + variables = list( + filter(lambda x: x != "", _get_escaped_arg_list('variables'))) if not entities: return 'error: must provide a `entities` field', 400 if not variables: diff --git a/server/routes/shared_api/node.py b/server/routes/shared_api/node.py index 50b5876779..9a7e999ebb 100644 --- a/server/routes/shared_api/node.py +++ b/server/routes/shared_api/node.py @@ -20,6 +20,7 @@ import flask from flask import request from flask import Response +from markupsafe import escape from server.lib import fetch @@ -31,6 +32,23 @@ _OUT_ARROW = '->' +def _escaped_arg(name: str, default=None): + value = request.args.get(name, default) + if value is None: + return None + return str(escape(value)) + + +def _escaped_arg_list(name: str) -> list[str]: + return [str(escape(v)) for v in request.args.getlist(name)] + + +def _escaped_list(values) -> list[str]: + if not values: + return [] + return [str(escape(v)) for v in values] + + @dataclass class PropertySpec: # name of the property to get values for @@ -44,6 +62,8 @@ class PropertySpec: @bp.route('/triples//') def triples(direction, dcid): """Returns all the triples given a node dcid.""" + direction = str(escape(direction)) + dcid = str(escape(dcid)) if direction != 'in' and direction != 'out': return "Invalid direction provided, please use 'in' or 'out'", 400 return fetch.triples([dcid], direction == 'out').get(dcid, {}) @@ -52,14 +72,15 @@ def triples(direction, dcid): @bp.route('/propvals/', methods=['GET', 'POST']) def get_property_value(direction): """Returns the property values for given node dcids and property label.""" + direction = str(escape(direction)) if direction != "in" and direction != "out": return "Invalid direction provided, please use 'in' or 'out'", 400 - dcids = request.args.getlist('dcids') + dcids = _escaped_arg_list('dcids') if not dcids: - dcids = request.json['dcids'] - prop = request.args.get('prop') + dcids = _escaped_list(request.json['dcids']) + prop = _escaped_arg('prop') if not prop: - prop = request.json['prop'] + prop = str(escape(request.json['prop'])) response = fetch.raw_property_values(dcids, prop, direction == 'out') return Response(json.dumps(response), 200, mimetype='application/json') @@ -140,12 +161,12 @@ def expression_property_value(): Returns the property values for given node dcids and a relation expression for a property (https://docs.datacommons.org/api/rest/v2#relation-expressions) """ - dcids = request.args.getlist('dcids') + dcids = _escaped_arg_list('dcids') if not dcids and request.json: - dcids = request.json['dcids'] - prop_expression = request.args.get('propExpr') + dcids = _escaped_list(request.json['dcids']) + prop_expression = _escaped_arg('propExpr') if not prop_expression and request.json: - prop_expression = request.json['propExpr'] + prop_expression = str(escape(request.json['propExpr'])) if not dcids: return 'error: must provide a `dcids` field', 400 if not prop_expression: diff --git a/server/routes/shared_api/observation/date.py b/server/routes/shared_api/observation/date.py index d44a97df72..2c6fdd773d 100644 --- a/server/routes/shared_api/observation/date.py +++ b/server/routes/shared_api/observation/date.py @@ -14,6 +14,7 @@ from flask import Blueprint from flask import request +from markupsafe import escape from server.lib import util import server.services.datacommons as dc @@ -22,18 +23,29 @@ bp = Blueprint("observation_dates", __name__) +def _get_escaped_arg(name: str, default=None): + value = request.args.get(name, default) + if value is None: + return None + return str(escape(value)) + + +def _get_escaped_arg_list(name: str) -> list[str]: + return [str(escape(v)) for v in request.args.getlist(name)] + + @bp.route('/api/observation-dates') def observation_dates(): """Given ancestor place, child place type and stat vars, return the dates that have data for each stat var across all child places. """ - parent_entity = request.args.get('parentEntity') + parent_entity = _get_escaped_arg('parentEntity') if not parent_entity: return 'error: must provide a parentEntity field', 400 - child_type = request.args.get('childType') + child_type = _get_escaped_arg('childType') if not child_type: return 'error: must provide a childType field', 400 - variable = request.args.get('variable') + variable = _get_escaped_arg('variable') if not variable: return 'error: must provide a variable field', 400 return dc.get_series_dates(parent_entity, child_type, [variable]) @@ -116,10 +128,10 @@ def observation_dates_entities(): } ``` """ - entities = request.args.getlist('entities') + entities = _get_escaped_arg_list('entities') if len(entities) == 0: return 'error: must provide entities field', 400 - variables = request.args.getlist('variables') + variables = _get_escaped_arg_list('variables') if len(variables) == 0: return 'error: must provide a variables field', 400 return util.get_series_dates_from_entities(entities, variables) diff --git a/server/routes/shared_api/observation/point.py b/server/routes/shared_api/observation/point.py index 65d5a62677..b9d313dced 100644 --- a/server/routes/shared_api/observation/point.py +++ b/server/routes/shared_api/observation/point.py @@ -14,6 +14,7 @@ from flask import Blueprint from flask import request +from markupsafe import escape from server.lib import fetch from server.lib.cache import cache @@ -28,6 +29,17 @@ bp = Blueprint('point', __name__, url_prefix='/api/observations/point') +def _get_escaped_arg(name: str, default=None): + value = request.args.get(name, default) + if value is None: + return None + return str(escape(value)) + + +def _get_escaped_arg_list(name: str) -> list[str]: + return [str(escape(v)) for v in request.args.getlist(name)] + + def _filter_point_for_facets(point_data, facet_ids: list[str]): """Filter the point data to only include the specified facets. Args: @@ -88,15 +100,16 @@ def _filter_point_for_facets(point_data, facet_ids: list[str]): @cache_and_log_mixer_usage(timeout=TIMEOUT, query_string=True) def point(): """Handler to get the observation point given multiple stat vars and places.""" - entities = list(filter(lambda x: x != "", request.args.getlist('entities'))) - variables = list(filter(lambda x: x != "", request.args.getlist('variables'))) + entities = list(filter(lambda x: x != "", _get_escaped_arg_list('entities'))) + variables = list( + filter(lambda x: x != "", _get_escaped_arg_list('variables'))) facet_id = list(filter(lambda x: x != "", - request.args.getlist('facetId'))) or None + _get_escaped_arg_list('facetId'))) or None if not entities: return 'error: must provide a `entities` field', 400 if not variables: return 'error: must provide a `variables` field', 400 - date = request.args.get('date') or DATE_LATEST + date = _get_escaped_arg('date') or DATE_LATEST # Fetch recent observations with the highest entity coverage if date == DATE_HIGHEST_COVERAGE: return fetch_highest_coverage(entities=entities, @@ -123,13 +136,14 @@ def point(): @cache_and_log_mixer_usage(timeout=TIMEOUT, query_string=True) def point_all(): """Handler to get all the observation points given multiple stat vars and entities.""" - entities = list(filter(lambda x: x != "", request.args.getlist('entities'))) - variables = list(filter(lambda x: x != "", request.args.getlist('variables'))) + entities = list(filter(lambda x: x != "", _get_escaped_arg_list('entities'))) + variables = list( + filter(lambda x: x != "", _get_escaped_arg_list('variables'))) if not entities: return 'error: must provide a `entities` field', 400 if not variables: return 'error: must provide a `variables` field', 400 - date = request.args.get('date') or DATE_LATEST + date = _get_escaped_arg('date') or DATE_LATEST # Fetch recent observations with the highest entity coverage if date == DATE_HIGHEST_COVERAGE: return fetch_highest_coverage(entities=entities, @@ -151,17 +165,18 @@ def point_within(): This returns the observation for the preferred facet. """ - parent_entity = request.args.get('parentEntity') + parent_entity = _get_escaped_arg('parentEntity') if not parent_entity: return 'error: must provide a `parentEntity` field', 400 - child_type = request.args.get('childType') + child_type = _get_escaped_arg('childType') if not child_type: return 'error: must provide a `childType` field', 400 - variables = list(filter(lambda x: x != "", request.args.getlist('variables'))) + variables = list( + filter(lambda x: x != "", _get_escaped_arg_list('variables'))) if not variables: return 'error: must provide a `variables` field', 400 - date = request.args.get('date') or DATE_LATEST - facet_ids = list(filter(lambda x: x != "", request.args.getlist('facetIds'))) + date = _get_escaped_arg('date') or DATE_LATEST + facet_ids = list(filter(lambda x: x != "", _get_escaped_arg_list('facetIds'))) # Fetch recent observations with the highest entity coverage if date == DATE_HIGHEST_COVERAGE: return fetch_highest_coverage(parent_entity=parent_entity, @@ -186,16 +201,17 @@ def point_within_all(): This returns the observation for all facets. """ - parent_entity = request.args.get('parentEntity') + parent_entity = _get_escaped_arg('parentEntity') if not parent_entity: return 'error: must provide a `parentEntity` field', 400 - child_type = request.args.get('childType') + child_type = _get_escaped_arg('childType') if not child_type: return 'error: must provide a `childType` field', 400 - variables = list(filter(lambda x: x != "", request.args.getlist('variables'))) + variables = list( + filter(lambda x: x != "", _get_escaped_arg_list('variables'))) if not variables: return 'error: must provide a `variables` field', 400 - date = request.args.get('date') or DATE_LATEST + date = _get_escaped_arg('date') or DATE_LATEST # Fetch recent observations with the highest entity coverage if date == DATE_HIGHEST_COVERAGE: return fetch_highest_coverage(parent_entity=parent_entity, diff --git a/server/routes/shared_api/observation/series.py b/server/routes/shared_api/observation/series.py index a80302bacb..d406b44e63 100644 --- a/server/routes/shared_api/observation/series.py +++ b/server/routes/shared_api/observation/series.py @@ -14,9 +14,11 @@ import logging from typing import List +from typing import Optional from flask import Blueprint from flask import request +from markupsafe import escape from server.lib import fetch from server.lib import shared @@ -45,6 +47,18 @@ bp = Blueprint("series", __name__, url_prefix='/api/observations/series') +def _escape_value(value): + if value is None: + return None + return str(escape(value)) + + +def _escape_list(values: Optional[List[str]]) -> List[str]: + if not values: + return [] + return [str(escape(v)) for v in values] + + # Filters a list for non empty values # TODO: use request directly in this function and pass in arg name def _get_filtered_arg_list(arg_list: List[str]) -> List[str]: @@ -61,13 +75,16 @@ def _get_filtered_arg_list(arg_list: List[str]) -> List[str]: def series(): """Handler to get preferred time series given multiple stat vars and entities.""" if request.method == 'POST': - entities = request.json.get('entities') - variables = request.json.get('variables') - facet_ids = request.json.get('facetIds') + entities = _escape_list(request.json.get('entities')) + variables = _escape_list(request.json.get('variables')) + facet_ids = _escape_list(request.json.get('facetIds')) else: - entities = _get_filtered_arg_list(request.args.getlist('entities')) - variables = _get_filtered_arg_list(request.args.getlist('variables')) - facet_ids = _get_filtered_arg_list(request.args.getlist('facetIds')) + entities = _get_filtered_arg_list(_escape_list(request.args.getlist( + 'entities'))) + variables = _get_filtered_arg_list(_escape_list(request.args.getlist( + 'variables'))) + facet_ids = _get_filtered_arg_list(_escape_list(request.args.getlist( + 'facetIds'))) if not entities: return 'error: must provide a `entities` field', 400 if not variables: @@ -82,8 +99,10 @@ def series(): @cache_and_log_mixer_usage(timeout=TIMEOUT, query_string=True) def series_all(): """Handler to get all the time series given multiple stat vars and places.""" - entities = _get_filtered_arg_list(request.args.getlist('entities')) - variables = _get_filtered_arg_list(request.args.getlist('variables')) + entities = _get_filtered_arg_list(_escape_list(request.args.getlist( + 'entities'))) + variables = _get_filtered_arg_list(_escape_list(request.args.getlist( + 'variables'))) if not entities: return 'error: must provide a `entities` field', 400 if not variables: @@ -102,19 +121,21 @@ def series_within(): Note: the preferred facet is returned. """ - parent_entity = request.args.get('parentEntity') + parent_entity = _escape_value(request.args.get('parentEntity')) if not parent_entity: return 'error: must provide a `parentEntity` field', 400 - child_type = request.args.get('childType') + child_type = _escape_value(request.args.get('childType')) if not child_type: return 'error: must provide a `childType` field', 400 - variables = _get_filtered_arg_list(request.args.getlist('variables')) + variables = _get_filtered_arg_list(_escape_list(request.args.getlist( + 'variables'))) if not variables: return 'error: must provide a `variables` field', 400 - facet_ids = _get_filtered_arg_list(request.args.getlist('facetIds')) + facet_ids = _get_filtered_arg_list(_escape_list(request.args.getlist( + 'facetIds'))) # Make batched calls there are too many child places for server to handle # Mixer checks num_places * num_variables and stop processing if the number is @@ -148,15 +169,16 @@ def series_within_all(): Note: all the facets are returned. """ - parent_entity = request.args.get('parentEntity') + parent_entity = _escape_value(request.args.get('parentEntity')) if not parent_entity: return 'error: must provide a `parentEntity` field', 400 - child_type = request.args.get('childType') + child_type = _escape_value(request.args.get('childType')) if not child_type: return 'error: must provide a `childType` field', 400 - variables = _get_filtered_arg_list(request.args.getlist('variables')) + variables = _get_filtered_arg_list(_escape_list(request.args.getlist( + 'variables'))) if not variables: return 'error: must provide a `variables` field', 400 diff --git a/server/routes/shared_api/place.py b/server/routes/shared_api/place.py index d04c9bfda2..34f7c9a236 100644 --- a/server/routes/shared_api/place.py +++ b/server/routes/shared_api/place.py @@ -121,6 +121,23 @@ bp = Blueprint("api_place", __name__, url_prefix='/api/place') +def _escaped_arg(name: str, default=None): + value = request.args.get(name, default) + if value is None: + return None + return str(escape(value)) + + +def _escaped_arg_list(name: str) -> list[str]: + return [str(escape(v)) for v in request.args.getlist(name)] + + +def _escaped_list(values) -> list[str]: + if not values: + return [] + return [str(escape(v)) for v in values] + + def get_place_type(place_dcids): place_types = fetch.property_values(place_dcids, 'typeOf') ret = {} @@ -151,19 +168,21 @@ def get_place_type_i18n_name(place_type: str, plural: bool = False) -> str: @bp.route('/type/') @cache.memoize(timeout=TIMEOUT) def api_place_type(place_dcid): + place_dcid = str(escape(place_dcid)) return get_place_type([place_dcid]).get(place_dcid, '') @bp.route('/name', methods=['GET', 'POST']) def api_name(): """Get place names.""" - dcids = request.args.getlist('dcids') + dcids = _escaped_arg_list('dcids') if not dcids: - dcids = request.json['dcids'] + dcids = _escaped_list(request.json['dcids']) dcids = list(filter(lambda d: d != '', dcids)) - prop = request.args.get('prop') + prop = _escaped_arg('prop') if request.is_json: - prop = request.json.get('prop') + if request.json.get('prop') is not None: + prop = str(escape(request.json.get('prop'))) try: return names(dcids, prop) except Exception as e: @@ -228,7 +247,7 @@ def extract_locale_name(entry, locale): @bp.route('/name/i18n') def api_i18n_name(): """Get place i18n names.""" - dcids = request.args.getlist('dcid') + dcids = _escaped_arg_list('dcid') result = get_i18n_name(dcids) return Response(json.dumps(result), 200, mimetype='application/json') @@ -236,7 +255,7 @@ def api_i18n_name(): @bp.route('/named_typed') def get_named_typed_place(): """Returns data for NamedTypedPlace, a dictionary of key -> NamedTypedPlace.""" - dcids = request.args.getlist('dcids') + dcids = _escaped_arg_list('dcids') place2type = get_place_type(dcids) place_names = names(dcids) ret = {} @@ -257,9 +276,9 @@ def get_place_variable(): Returns: List of unique statistical variable dcids each as a string. """ - dcids = request.args.getlist('dcids') + dcids = _escaped_arg_list('dcids') if not dcids: - dcids = request.json['dcids'] + dcids = _escaped_list(request.json['dcids']) resp = fetch.entity_variables(dcids) # All the keys (stat var dcid) in resp are variables for at lease one of the # places. @@ -275,7 +294,7 @@ def get_place_variable_count(): Returns: A map from place dcid to the stat var count. """ - dcids = request.args.getlist('dcids') + dcids = _escaped_arg_list('dcids') if not dcids: return 'error: must provide `dcids` field', 400 result = {} @@ -292,6 +311,7 @@ def get_place_variable_count(): @cache.memoize(timeout=TIMEOUT) def child(dcid): """Get top child places for a place.""" + dcid = str(escape(dcid)) child_places = child_fetch(dcid) for place_type in child_places: child_places[place_type].sort(key=lambda x: x['pop'], reverse=True) @@ -363,7 +383,7 @@ def child_fetch(parent_dcid): @bp.route('/parent') @cache.cached(timeout=TIMEOUT, query_string=True) def api_parent_places(): - dcid = request.args.get("dcid") + dcid = _escaped_arg("dcid") result = parent_places([dcid])[dcid] return Response(json.dumps(result), 200, mimetype='application/json') @@ -406,6 +426,7 @@ def api_mapinfo(dcid): function for places with those complicated situations, need to adjust this function accordingly. """ + dcid = str(escape(dcid)) left = 180 right = -180 up = -90 @@ -468,10 +489,11 @@ def get_ranking_url(containing_dcid, @cache.cached(timeout=TIMEOUT, query_string=True) def api_ranking(dcid): """Get the ranking information for a given place.""" + dcid = str(escape(dcid)) current_place_type = api_place_type(dcid) parents = parent_places([dcid])[dcid] parent_i18n_names = get_i18n_name([x['dcid'] for x in parents], False) - should_return_all = request.args.get('all', '') == "1" + should_return_all = _escaped_arg('all', '') == "1" selected_parents = [] parent_names = {} @@ -617,9 +639,9 @@ def get_display_name(dcids): @bp.route('/displayname', methods=['GET', 'POST']) def api_display_name(): """Get display names for a list of places.""" - dcids = request.args.getlist('dcids') + dcids = _escaped_arg_list('dcids') if not dcids: - dcids = request.json.get('dcids', []) + dcids = _escaped_list(request.json.get('dcids', [])) result = get_display_name(dcids) return Response(json.dumps(result), 200, mimetype='application/json') @@ -631,8 +653,8 @@ def descendent(): Returns: Dict keyed by ancestor DCIDs with lists of descendent place DCIDs as values. """ - dcids = request.args.getlist("dcids") - descendent_type = request.args.get("descendentType") + dcids = _escaped_arg_list("dcids") + descendent_type = _escaped_arg("descendentType") return fetch.descendent_places(dcids, descendent_type) @@ -644,9 +666,9 @@ def descendent_names(): Returns: Dicts keyed by desccendent place DCIDs with their names as values. """ - dcid = request.args.get("dcid") - descendent_type = request.args.get("descendentType") - prop = request.args.get("prop", None) + dcid = _escaped_arg("dcid") + descendent_type = _escaped_arg("descendentType") + prop = _escaped_arg("prop", None) child_places = fetch.descendent_places([dcid], descendent_type).get(dcid, []) result = {} if prop: @@ -676,7 +698,7 @@ def placeid2dcid(): This is to use together with the Google Maps Autocomplete API: https://developers.google.com/places/web-service/autocomplete. """ - place_ids = request.args.getlist("placeIds") + place_ids = _escaped_arg_list("placeIds") return findplacedcid(place_ids) @@ -689,9 +711,9 @@ def coords2places(): Returns a list of { latitude: number, longitude: number, placeDcid: str, placeName: str } objects. """ - latitudes = request.args.getlist("latitudes") - longitudes = request.args.getlist("longitudes") - place_type = request.args.get("placeType", "") + latitudes = _escaped_arg_list("latitudes") + longitudes = _escaped_arg_list("longitudes") + place_type = _escaped_arg("placeType", "") # Get resolved place coordinate information for each coordinate of interest coordinates = [] for idx in range(0, min(len(latitudes), len(longitudes))): diff --git a/server/routes/shared_api/stats.py b/server/routes/shared_api/stats.py index e64875ea0c..9ad0fc941f 100644 --- a/server/routes/shared_api/stats.py +++ b/server/routes/shared_api/stats.py @@ -20,6 +20,7 @@ from flask import request from flask import Response from google.cloud import discoveryengine_v1 as discoveryengine +from markupsafe import escape from server.lib import fetch from server.lib import shared @@ -38,6 +39,23 @@ logger = logging.getLogger(__name__) +def _escaped_arg(name: str, default=None): + value = request.args.get(name, default) + if value is None: + return None + return str(escape(value)) + + +def _escaped_arg_list(name: str) -> list[str]: + return [str(escape(v)) for v in request.args.getlist(name)] + + +def _escaped_list(values) -> list[str]: + if not values: + return [] + return [str(escape(v)) for v in values] + + # Constants for Vertex AI Search Application # TODO: Move the VAI app to a different GCP project and figure out a better way to authenticate (ex. use API keys) VAI_PROJECT_ID = "datcom-nl" @@ -54,7 +72,7 @@ def stat_var_property(): A dictionary keyed by stats var dcid with value being a dictionary of all the properties of each stats var. """ - dcids = request.args.getlist('dcids') + dcids = _escaped_arg_list('dcids') ranked_statvars = current_app.config['RANKED_STAT_VARS'] result = {} resp = fetch.triples(dcids) @@ -122,13 +140,14 @@ def search_statvar(): is_vai_enabled = is_feature_enabled(VAI_FOR_STATVAR_SEARCH_FEATURE_FLAG, request=request) if request.method == 'GET': - query = request.args.get("query") - entities = request.args.getlist("entities") - sv_only = request.args.get("svOnly", False) + query = _escaped_arg("query") + entities = _escaped_arg_list("entities") + sv_only = request.args.get("svOnly", "false").lower() == 'true' limit = int(request.args.get("limit", 100)) else: # Method is POST - query = request.json.get("query") - entities = request.json.get("entities", []) + query = str(escape(request.json.get("query"))) if request.json.get( + "query") else None + entities = _escaped_list(request.json.get("entities", [])) sv_only = request.json.get("svOnly") limit = int(request.json.get("limit", 100)) diff --git a/server/routes/shared_api/variable.py b/server/routes/shared_api/variable.py index 5a09674e74..901667989b 100644 --- a/server/routes/shared_api/variable.py +++ b/server/routes/shared_api/variable.py @@ -26,14 +26,17 @@ @bp.route('/path') def get_variable_path(): """Gets the path of a stat var to the root of the stat var hierarchy.""" - dcid = escape(request.args.get("dcid")) + dcid = request.args.get("dcid") + if not dcid: + return "error: must provide a dcid field", 400 + dcid = str(escape(dcid)) return json.dumps([dcid] + dc.get_variable_ancestors(dcid)), 200 @bp.route('/info') def variable_info(): """Gets the info of a list of stat var.""" - dcids = request.args.getlist("dcids") + dcids = [str(escape(dcid)) for dcid in request.args.getlist("dcids")] data = dc.variable_info(dcids).get("data", []) result = {} for item in data: diff --git a/server/routes/sustainability/html.py b/server/routes/sustainability/html.py index b6d9f2550a..74ac3a2f9e 100644 --- a/server/routes/sustainability/html.py +++ b/server/routes/sustainability/html.py @@ -22,6 +22,7 @@ from flask import redirect from flask import url_for from google.protobuf.json_format import MessageToJson +from markupsafe import escape from server.lib.cache import cache import server.lib.subject_page_config as lib_subject_page_config @@ -49,6 +50,7 @@ def sustainability_explorer(place_dcid=None): place_dcid=lib_subject_page_config.DEFAULT_PLACE_DCID), code=302) + place_dcid = str(escape(place_dcid)) raw_subject_config = current_app.config['DISASTER_SUSTAINABILITY_CONFIG'] if current_app.config['LOCAL']: # Reload configs for faster local iteration. diff --git a/server/routes/topic_page/html.py b/server/routes/topic_page/html.py index 26fc626201..716566b563 100644 --- a/server/routes/topic_page/html.py +++ b/server/routes/topic_page/html.py @@ -21,6 +21,7 @@ from flask import g from flask import request from google.protobuf.json_format import MessageToJson +from markupsafe import escape import server.lib.subject_page_config as lib_subject_page_config import server.lib.util as libutil @@ -65,6 +66,8 @@ def get_sdg_config(place_dcid, more_places, topic_config): @bp.route('/', strict_slashes=False) @bp.route('//', strict_slashes=False) def topic_page(topic_id=None, place_dcid=None): + topic_id = str(escape(topic_id)) if topic_id else topic_id + place_dcid = str(escape(place_dcid)) if place_dcid else place_dcid topics_summary = json.dumps(current_app.config['TOPIC_PAGE_SUMMARY']) # Redirect to the landing page. if not place_dcid and not topic_id: @@ -99,7 +102,7 @@ def topic_page(topic_id=None, place_dcid=None): sample_questions=json.dumps( current_app.config.get('HOMEPAGE_SAMPLE_QUESTIONS', []))) - more_places = request.args.getlist('places') + more_places = [str(escape(place)) for place in request.args.getlist('places')] place_name = '' place_type = ''