Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions server/routes/browser/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import flask
from flask import request
from flask import Response
from markupsafe import escape

from server.lib import fetch
from server.lib.cache import cache
Expand Down Expand Up @@ -75,18 +76,23 @@ def get_observation_id():
return Response(json.dumps("error: must provide a place field"),
400,
mimetype='application/json')
place_id = str(escape(place_id))
stat_var_id = request.args.get("statVar")
if not stat_var_id:
return Response(json.dumps("error: must provide a statVar field"),
400,
mimetype='application/json')
stat_var_id = str(escape(stat_var_id))
date = request.args.get("date", "")
if not date:
return Response(json.dumps("error: must provide a date field"),
400,
mimetype='application/json')
request_mmethod = request.args.get("measurementMethod", NO_MMETHOD_KEY)
request_obsPeriod = request.args.get("obsPeriod", NO_OBSPERIOD_KEY)
date = str(escape(date))
request_mmethod = str(
escape(request.args.get("measurementMethod", NO_MMETHOD_KEY)))
request_obsPeriod = str(
escape(request.args.get("obsPeriod", NO_OBSPERIOD_KEY)))
sparql_query = get_sparql_query(place_id, stat_var_id, date)
result = ""
(_, rows) = dc.query(sparql_query)
Expand Down
2 changes: 2 additions & 0 deletions server/routes/browser/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flask import current_app
from flask import g
from flask import render_template
from markupsafe import escape

import server.lib.render as lib_render
import server.lib.shared as shared_api
Expand All @@ -40,6 +41,7 @@ def bio_browser_main():

@bp.route('/<path:dcid>')
def browser_node(dcid):
dcid = str(escape(dcid))
node_name = dcid
try:
api_name = shared_api.names([dcid]).get(dcid)
Expand Down
3 changes: 3 additions & 0 deletions server/routes/dev_datagemma/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flask import current_app
from flask import request
from flask import Response
from markupsafe import escape

from server.lib.nl.detection.llm_api import detect_model_name

Expand Down Expand Up @@ -80,6 +81,8 @@ def datagemma_query():
return 'error: must provide a query field', 400
if not mode or mode not in [_RIG_MODE, _RAG_MODE]:
return f'error: must provide a mode field with values {_RIG_MODE} or {_RAG_MODE}', 400
query = str(escape(query))
mode = str(escape(mode))
dg_result = _get_datagemma_result(query, mode)
result = {'answer': '', 'debug': ''}
if dg_result:
Expand Down
48 changes: 28 additions & 20 deletions server/routes/disaster/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from flask import current_app
from flask import request
from flask import Response
from markupsafe import escape

from server.lib.cache import cache
import server.lib.fetch as fetch
Expand All @@ -36,6 +37,12 @@
DATA_RETRIEVAL_DATE_LENGTH = 7


def _escape_value(value):
if value is None:
return None
return str(escape(value))


@bp.route('/event-date-range')
def event_date_range():
"""Gets the date range of event data for a specific event type
Expand All @@ -46,13 +53,13 @@ def event_date_range():
maxDate: string
}
"""
event_type = request.args.get('eventType', '')
event_type = _escape_value(request.args.get('eventType', ''))
if not event_type:
return "error: must provide a eventType field", 400
place = request.args.get('place', '')
place = _escape_value(request.args.get('place', ''))
if not place:
return "error: must provide a place field", 400
use_cache = request.args.get('useCache', '')
use_cache = _escape_value(request.args.get('useCache', ''))
result = {'minDate': "", 'maxDate': ""}
date_list = []
if use_cache == '1':
Expand Down Expand Up @@ -151,23 +158,24 @@ def json_event_data():
}
}
"""
event_type = request.args.get('eventType', '')
event_type = _escape_value(request.args.get('eventType', ''))
if not event_type:
return "error: must provide a eventType field", 400
min_date = request.args.get('minDate', '')
min_date = _escape_value(request.args.get('minDate', ''))
if not min_date:
return "error: must provide a minDate field", 400
max_date = request.args.get('maxDate', '')
max_date = _escape_value(request.args.get('maxDate', ''))
if not max_date:
return "error: must provide a maxDate field", 400
place = request.args.get('place', '')
place = _escape_value(request.args.get('place', ''))
if not place:
return "error: must provide a place field", 400
filter_prop = request.args.get('filterProp', '')
filter_unit = request.args.get('filterUnit', '')
filter_upper_limit = float(request.args.get('filterUpperLimit', float("inf")))
filter_lower_limit = float(request.args.get('filterLowerLimit',
-float("inf")))
filter_prop = _escape_value(request.args.get('filterProp', ''))
filter_unit = _escape_value(request.args.get('filterUnit', ''))
filter_upper_limit = float(
_escape_value(request.args.get('filterUpperLimit', float("inf"))))
filter_lower_limit = float(
_escape_value(request.args.get('filterLowerLimit', -float("inf"))))
event_points = []
disaster_data = current_app.config['DISASTER_DASHBOARD_DATA']
date_list = get_date_list(min_date, max_date)
Expand Down Expand Up @@ -233,23 +241,23 @@ def event_data():
}
}
"""
event_type = request.args.get('eventType', '')
event_type = _escape_value(request.args.get('eventType', ''))
if not event_type:
return "error: must provide a eventType field", 400
min_date = request.args.get('minDate', '')
min_date = _escape_value(request.args.get('minDate', ''))
if not min_date:
return "error: must provide a minDate field", 400
max_date = request.args.get('maxDate', '')
max_date = _escape_value(request.args.get('maxDate', ''))
if not max_date:
return "error: must provide a maxDate field", 400
place = request.args.get('place', '')
place = _escape_value(request.args.get('place', ''))
if not place:
return "error: must provide a place field", 400
filter_prop = request.args.get('filterProp', '')
filter_unit = request.args.get('filterUnit', '')
req_upper = request.args.get('filterUpperLimit', None)
filter_prop = _escape_value(request.args.get('filterProp', ''))
filter_unit = _escape_value(request.args.get('filterUnit', ''))
req_upper = _escape_value(request.args.get('filterUpperLimit', None))
filter_upper_limit = float(req_upper) if req_upper else None
req_lower = request.args.get('filterLowerLimit', None)
req_lower = _escape_value(request.args.get('filterLowerLimit', None))
filter_lower_limit = float(req_lower) if req_lower else None
date_list = get_date_list(min_date, max_date)
event_points = []
Expand Down
2 changes: 2 additions & 0 deletions server/routes/disaster/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flask import redirect
from flask import url_for
from google.protobuf.json_format import MessageToJson
from markupsafe import escape

import server.lib.subject_page_config as lib_subject_page_config
import server.lib.util
Expand All @@ -42,6 +43,7 @@ def disaster_dashboard(place_dcid=None):
place_dcid=lib_subject_page_config.DEFAULT_PLACE_DCID),
code=302)

place_dcid = str(escape(place_dcid))
raw_dashboard_config = current_app.config['DISASTER_DASHBOARD_CONFIG']
if current_app.config['LOCAL']:
# Reload configs for faster local iteration.
Expand Down
2 changes: 2 additions & 0 deletions server/routes/experiments/biomed_nl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flask import request
from flask import Response
from google import genai
from markupsafe import escape
from pydantic import BaseModel
from pydantic import Field

Expand Down Expand Up @@ -201,6 +202,7 @@ def llm_search():
query = request.args.get('q')
if not query:
return 'error: q param is required', 400
query = str(escape(query))
result = _fulfill_traversal_query(query).model_dump(by_alias=True,
mode="json")
return Response(json.dumps(result), 200, mimetype='application/json')
26 changes: 18 additions & 8 deletions server/routes/explore/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flask import current_app
from flask import request
from flask import Response
from markupsafe import escape

from server.lib.cache import cache
from server.lib.nl.common import serialize
Expand All @@ -49,13 +50,20 @@
bp = Blueprint('explore_api', __name__, url_prefix='/api/explore')


def _escape_value(value):
if value is None:
return None
return str(escape(value))


#
# The detection endpoint.
#
@bp.route('/detect', methods=['POST'])
def detect():
debug_logs = {}
client = request.args.get(Params.CLIENT.value, Clients.DEFAULT.value)
client = _escape_value(request.args.get(Params.CLIENT.value,
Clients.DEFAULT.value))

utterance, error_json = helpers.parse_query_and_detect(
request, 'explore', client, debug_logs)
Expand Down Expand Up @@ -105,8 +113,9 @@ def fulfill():
def detect_and_fulfill():
debug_logs = {}

test = request.args.get(Params.TEST.value, '')
client = request.args.get(Params.CLIENT.value, Clients.DEFAULT.value)
test = _escape_value(request.args.get(Params.TEST.value, ''))
client = _escape_value(
request.args.get(Params.CLIENT.value, Clients.DEFAULT.value))

# First sanity DC name, if any.
dc_name = request.get_json().get(Params.DC.value)
Expand All @@ -127,8 +136,8 @@ def detect_and_fulfill():
utterance.insight_ctx[
Params.EXP_MORE_DISABLED.value] = request.get_json().get(
Params.EXP_MORE_DISABLED, "")
utterance.insight_ctx[Params.SKIP_RELATED_THINGS] = request.args.get(
Params.SKIP_RELATED_THINGS.value, '') == 'true'
utterance.insight_ctx[Params.SKIP_RELATED_THINGS] = _escape_value(
request.args.get(Params.SKIP_RELATED_THINGS.value, '')) == 'true'
helpers.update_insight_ctx_for_chart_fulfill(request, utterance, dc_name)

# Important to setup utterance for explore flow (this is really the only difference
Expand Down Expand Up @@ -252,9 +261,10 @@ def _fulfill_with_chart_config(utterance: nl_utterance.Utterance,
def _fulfill_with_insight_ctx(request: Dict, debug_logs: Dict,
counters: ctr.Counters) -> Dict:
insight_ctx = request.get_json()
test = request.args.get(Params.TEST.value, '')
client = request.args.get(Params.CLIENT.value, Clients.DEFAULT.value)
mode = request.args.get(Params.MODE.value, '')
test = _escape_value(request.args.get(Params.TEST.value, ''))
client = _escape_value(
request.args.get(Params.CLIENT.value, Clients.DEFAULT.value))
mode = _escape_value(request.args.get(Params.MODE.value, ''))
if not insight_ctx:
return helpers.abort('Sorry, could not answer your query.',
'', [],
Expand Down
38 changes: 22 additions & 16 deletions server/routes/explore/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,17 @@
_SANITY_TEST = 'sanity'


def _escape_value(value):
if value is None:
return None
return str(escape(value))


# Get the default place to be used for fulfillment. If there is a place in the
# request, use that. Otherwise, use pre-chosen places.
def _get_default_place(request: Dict, is_special_dc: bool, debug_logs: Dict):
default_place_dcid = request.args.get(params.Params.DEFAULT_PLACE,
default='',
type=str)
default_place_dcid = _escape_value(
request.args.get(params.Params.DEFAULT_PLACE, default='', type=str))
# If default place from request is earth, use the Earth place object
if default_place_dcid == constants.EARTH.dcid:
return constants.EARTH
Expand Down Expand Up @@ -92,15 +97,15 @@ def parse_query_and_detect(request: Dict, backend: str, client: str,
flask.abort(404)
nl_bad_words = current_app.config['NL_BAD_WORDS']

test = request.args.get(params.Params.TEST.value, '')
test = _escape_value(request.args.get(params.Params.TEST.value, ''))
# i18n param
i18n_str = request.args.get(params.Params.I18N.value, '')
i18n_str = _escape_value(request.args.get(params.Params.I18N.value, ''))
i18n = i18n_str and i18n_str.lower() == 'true'

# Index-type default is in nl_server.
idx_param_str = request.args.get(params.Params.INDEX.value, '')
idx_param_str = _escape_value(request.args.get(params.Params.INDEX.value, ''))
embeddings_index_types = [x.strip() for x in idx_param_str.split(',')]
original_query = request.args.get('q')
original_query = _escape_value(request.args.get('q'))
if not original_query:
err_json = helpers.abort(
'Received an empty query, please type a few words :)',
Expand All @@ -115,13 +120,14 @@ def parse_query_and_detect(request: Dict, backend: str, client: str,
embeddings_index_types = params.dc_to_embedding_types(dc,
embeddings_index_types)

detector_type = request.args.get(params.Params.DETECTOR.value,
default=RequestedDetectorType.Hybrid.value,
type=str)
detector_type = _escape_value(
request.args.get(params.Params.DETECTOR.value,
default=RequestedDetectorType.Hybrid.value,
type=str))

# mode param
use_default_place = True
mode = request.args.get(params.Params.MODE.value, '')
mode = _escape_value(request.args.get(params.Params.MODE.value, ''))
if mode == QueryMode.STRICT:
# Strict mode is compatible only with Heuristic Detector!
detector_type = RequestedDetectorType.Heuristic.value
Expand Down Expand Up @@ -180,10 +186,10 @@ def parse_query_and_detect(request: Dict, backend: str, client: str,
use_default_place = False

# See if we have a variable reranker model specified.
reranker = request.args.get(params.Params.RERANKER.value)
reranker = _escape_value(request.args.get(params.Params.RERANKER.value))

# Get sv threshold as a float if it was passed in the request
var_threshold = request.args.get(params.Params.VAR_THRESHOLD.value)
var_threshold = _escape_value(request.args.get(params.Params.VAR_THRESHOLD.value))
if var_threshold:
# if sv_threshold is not a float, don't set sv_threshold
try:
Expand All @@ -192,8 +198,8 @@ def parse_query_and_detect(request: Dict, backend: str, client: str,
var_threshold = None

# StopWords handling
include_stop_words_str = request.args.get(
params.Params.INCLUDE_STOP_WORDS.value, '')
include_stop_words_str = _escape_value(
request.args.get(params.Params.INCLUDE_STOP_WORDS.value, ''))

detection_args = DetectionArgs(
embeddings_index_types=embeddings_index_types,
Expand Down Expand Up @@ -280,7 +286,7 @@ def update_insight_ctx_for_chart_fulfill(request: Dict,
params.Params.MAX_CHARTS,
params.Params.CHART_TYPE,
]:
param_val = request.args.get(p, None)
param_val = _escape_value(request.args.get(p, None))
if param_val != None:
if param_val.isnumeric():
param_val = int(param_val)
Expand Down
2 changes: 2 additions & 0 deletions server/routes/explore/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flask import Blueprint
from flask import current_app
from flask import render_template
from markupsafe import escape

bp = Blueprint('explore', __name__, url_prefix='/explore')

Expand All @@ -32,6 +33,7 @@ def page():

@bp.route('/<string:topic>')
def landing(topic):
topic = str(escape(topic))
return render_template('/explore_landing.html',
topic=topic,
website_hash=os.environ.get("WEBSITE_HASH"))
Loading