diff --git a/RELEASE.md b/RELEASE.md index 51c61e66a..d860f21cf 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,14 @@ +# Release 1.11.0 +## Major Features and Improvements +* Add data table preview query interface + +## Bug Fixes +* Fix the performance problems of upload and reader in processing large amounts of data +* Fix online inference cannot be done after model migration bug +* Fix the model cannot be saved to the specified database bug +* Fix reader data preview display bug + + # Release 1.10.1 ## Major Features and Improvements * Optimize table info API diff --git a/doc/swagger/swagger.yaml b/doc/swagger/swagger.yaml index e84b5955e..0855feb68 100644 --- a/doc/swagger/swagger.yaml +++ b/doc/swagger/swagger.yaml @@ -732,6 +732,57 @@ paths: type: string example: no find table + '/table/preview': + post: + summary: table data preview + tags: + - table + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - name + - namespace + properties: + name: + type: string + example: "guest" + namespace: + type: string + example: "data" + responses: + '200': + description: get preview table success + content: + application/json: + schema: + type: object + properties: + retcode: + type: integer + example: 0 + retmsg: + type: string + example: success + data: + type: object + '404': + description: no found table + content: + application/json: + schema: + type: object + properties: + retcode: + type: integer + example: 210 + retmsg: + type: string + example: no find table + '/job/submit': post: summary: submit job diff --git a/python/fate_flow/apps/job_app.py b/python/fate_flow/apps/job_app.py index 7beac12bd..ee78cea84 100644 --- a/python/fate_flow/apps/job_app.py +++ b/python/fate_flow/apps/job_app.py @@ -153,10 +153,11 @@ def update_job(): @manager.route('/report', methods=['POST']) def job_report(): + jobs = JobSaver.query_job(**request.json) tasks = JobSaver.query_task(**request.json) - if not tasks: + if not tasks or not jobs: return get_json_result(retcode=101, retmsg='find task failed') - return get_json_result(retcode=0, retmsg='success', data=job_utils.task_report(tasks)) + return get_json_result(retcode=0, retmsg='success', data=job_utils.task_report(jobs, tasks)) @manager.route('/parameter/update', methods=['POST']) diff --git a/python/fate_flow/apps/model_app.py b/python/fate_flow/apps/model_app.py index f0a261856..f191f5c26 100644 --- a/python/fate_flow/apps/model_app.py +++ b/python/fate_flow/apps/model_app.py @@ -643,6 +643,12 @@ def query_model(): return get_json_result(retcode=retcode, retmsg=retmsg, data=data) +@manager.route('/query/detail', methods=['POST']) +def query_model_detail(): + retcode, retmsg, data = model_utils.query_model_detail(**request.json) + return get_json_result(retcode=retcode, retmsg=retmsg, data=data) + + @manager.route('/deploy', methods=['POST']) @validate_request('model_id', 'model_version') def deploy(): diff --git a/python/fate_flow/apps/table_app.py b/python/fate_flow/apps/table_app.py index 5c617a65a..54f2248fb 100644 --- a/python/fate_flow/apps/table_app.py +++ b/python/fate_flow/apps/table_app.py @@ -158,6 +158,19 @@ def table_download(): ) +@manager.route('/preview', methods=['post']) +def table_data_preview(): + request_data = request.json + from fate_flow.component_env_utils.env_utils import import_component_output_depend + import_component_output_depend() + data_table_meta = storage.StorageTableMeta(name=request_data.get("name"), namespace=request_data.get("namespace")) + if not data_table_meta: + return error_response(response_code=210, retmsg=f'no found table:{request_data.get("namespace")}, {request_data.get("name")}') + + data = TableStorage.read_table_data(data_table_meta, limit=request_data.get("limit")) + return get_json_result(retcode=0, retmsg='success', data=data) + + @manager.route('/delete', methods=['post']) def table_delete(): request_data = request.json diff --git a/python/fate_flow/components/api_reader.py b/python/fate_flow/components/api_reader.py index 82be52557..5a55950ec 100644 --- a/python/fate_flow/components/api_reader.py +++ b/python/fate_flow/components/api_reader.py @@ -192,9 +192,14 @@ def upload_data(self): ) upload_registry_info = self.service_info.get("upload") logger.info(f"upload info:{upload_registry_info.to_dict()}") + params = self.parameters.get("parameters", {}) + params.update({"job_id": self.tracker.job_id, }) + en_content = self.encrypt_content() + if en_content: + params.update({"sign": en_content}) response = getattr(requests, upload_registry_info.f_method.lower(), None)( url=upload_registry_info.f_url, - params={"requestBody": json.dumps(self.parameters.get("parameters", {}))}, + params={"requestBody": json.dumps(params)}, data=data, headers={'Content-Type': data.content_type} ) @@ -206,3 +211,11 @@ def set_service_registry_info(self): if key == info.f_service_name: self.service_info[key] = info logger.info(f"set service registry info:{self.service_info}") + + def encrypt_content(self, job_id=None): + if not job_id: + job_id = self.tracker.job_id + import hashlib + md5 = hashlib.md5() + md5.update(job_id.encode()) + return md5.hexdigest() diff --git a/python/fate_flow/components/reader.py b/python/fate_flow/components/reader.py index 6f162223c..44f57ffad 100644 --- a/python/fate_flow/components/reader.py +++ b/python/fate_flow/components/reader.py @@ -35,6 +35,7 @@ from fate_flow.manager.data_manager import DataTableTracker, TableStorage, AnonymousGenerator from fate_flow.operation.job_tracker import Tracker from fate_flow.utils import data_utils +from federatedml.feature.instance import Instance LOGGER = log.getLogger() MAX_NUM = 10000 @@ -305,8 +306,22 @@ def data_info_display(output_table_meta): data_list[0].extend(headers) LOGGER.info(f"data info header: {data_list[0]}") for data in output_table_meta.get_part_of_data(): - delimiter = schema.get("meta", {}).get("delimiter") or output_table_meta.id_delimiter - data_list.append(data[1].split(delimiter)) + if isinstance(data[1], str): + delimiter = schema.get("meta", {}).get( + "delimiter") or output_table_meta.id_delimiter + data_list.append(data[1].split(delimiter)) + elif isinstance(data[1], Instance): + table_data = [] + if data[1].inst_id: + table_data = table_data.append(data[1].inst_id) + if not data[1].label is None: + table_data.append(data[1].label) + + table_data.extend(data[1].features) + data_list.append([str(v) for v in table_data]) + else: + data_list.append(data[1]) + data = np.array(data_list) Tdata = data.transpose() for data in Tdata: @@ -317,7 +332,7 @@ def data_info_display(output_table_meta): if schema.get("label_name"): anonymous_info[schema.get("label_name")] = schema.get("anonymous_label") attribute_info[schema.get("label_name")] = "label" - if schema.get("meta").get("id_list"): + if schema.get("meta", {}).get("id_list"): for id_name in schema.get("meta").get("id_list"): if id_name in attribute_info: attribute_info[id_name] = "match_id" diff --git a/python/fate_flow/components/upload.py b/python/fate_flow/components/upload.py index 309e6a0c7..b0bc17700 100644 --- a/python/fate_flow/components/upload.py +++ b/python/fate_flow/components/upload.py @@ -287,54 +287,59 @@ def get_count(input_file): count += 1 return count + def kv_generator(self, input_feature_count, fp, job_id, part_of_data): + fate_uuid = uuid.uuid1().hex + get_line = self.get_line() + line_index = 0 + LOGGER.info(input_feature_count) + while True: + lines = fp.readlines(JobDefaultConfig.upload_block_max_bytes) + LOGGER.info(JobDefaultConfig.upload_block_max_bytes) + if lines: + for line in lines: + values = line.rstrip().split(self.parameters["id_delimiter"]) + k, v = get_line( + values=values, + line_index=line_index, + extend_sid=self.parameters["extend_sid"], + auto_increasing_sid=self.parameters["auto_increasing_sid"], + id_delimiter=self.parameters["id_delimiter"], + fate_uuid=fate_uuid, + ) + yield k, v + line_index += 1 + if line_index <= 100: + part_of_data.append((k, v)) + save_progress = line_index / input_feature_count * 100 // 1 + job_info = { + "progress": save_progress, + "job_id": job_id, + "role": self.parameters["local"]["role"], + "party_id": self.parameters["local"]["party_id"], + } + ControllerClient.update_job(job_info=job_info) + else: + return + + def update_schema(self, head, fp): + read_status = False + if head is True: + data_head = fp.readline() + self.update_table_schema(data_head) + read_status = True + else: + self.update_table_schema() + return read_status + def upload_file(self, input_file, head, job_id=None, input_feature_count=None, table=None): if not table: table = self.table - with open(input_file, "r") as fin: - lines_count = 0 - if head is True: - data_head = fin.readline() + part_of_data = [] + with open(input_file, "r") as fp: + if self.update_schema(head, fp): input_feature_count -= 1 - self.update_table_schema(data_head) - else: - self.update_table_schema() - n = 0 - fate_uuid = uuid.uuid1().hex - get_line = self.get_line() - line_index = 0 - while True: - data = list() - lines = fin.readlines(JobDefaultConfig.upload_block_max_bytes) - LOGGER.info(JobDefaultConfig.upload_block_max_bytes) - if lines: - # self.append_data_line(lines, data, n) - for line in lines: - values = line.rstrip().split(self.parameters["id_delimiter"]) - k, v = get_line( - values=values, - line_index=line_index, - extend_sid=self.parameters["extend_sid"], - auto_increasing_sid=self.parameters["auto_increasing_sid"], - id_delimiter=self.parameters["id_delimiter"], - fate_uuid=fate_uuid, - ) - data.append((k, v)) - line_index += 1 - lines_count += len(data) - save_progress = lines_count / input_feature_count * 100 // 1 - job_info = { - "progress": save_progress, - "job_id": job_id, - "role": self.parameters["local"]["role"], - "party_id": self.parameters["local"]["party_id"], - } - ControllerClient.update_job(job_info=job_info) - table.put_all(data) - if n == 0: - table.meta.update_metas(part_of_data=data) - else: - return - n += 1 + self.table.put_all(self.kv_generator(input_feature_count, fp, job_id, part_of_data)) + table.meta.update_metas(part_of_data=part_of_data) def get_computing_table(self, name, namespace, schema=None): storage_table_meta = storage.StorageTableMeta(name=name, namespace=namespace) diff --git a/python/fate_flow/db/service_registry.py b/python/fate_flow/db/service_registry.py index c68a87620..183b7644c 100644 --- a/python/fate_flow/db/service_registry.py +++ b/python/fate_flow/db/service_registry.py @@ -102,7 +102,14 @@ def save(cls, service_config): cls.parameter_check(server_info) api_info = server_info.pop("api", {}) for service_name, info in api_info.items(): - ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info) + ServiceRegistry.save_service_info( + server_name, service_name, uri=info.get('uri'), + method=info.get('method', 'POST'), + server_info=server_info, + data=info.get("data", {}), + headers=info.get("headers", {}), + params=info.get("params", {}) + ) cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http") setattr(cls, server_name.upper(), server_info) return update_server diff --git a/python/fate_flow/external/storage/mysql.py b/python/fate_flow/external/storage/mysql.py index b73162530..fff456af9 100644 --- a/python/fate_flow/external/storage/mysql.py +++ b/python/fate_flow/external/storage/mysql.py @@ -59,7 +59,7 @@ def save(self): self._con.commit() sql = None LOGGER.info(f"save data count:{count}") - if count > 0: + if count > 0 and sql: sql = ",".join(sql.split(",")[:-1]) + ";" self._cur.execute(sql) self._con.commit() diff --git a/python/fate_flow/hook/api/site_authentication.py b/python/fate_flow/hook/api/site_authentication.py index 881b3846e..7e36fbe50 100644 --- a/python/fate_flow/hook/api/site_authentication.py +++ b/python/fate_flow/hook/api/site_authentication.py @@ -14,9 +14,11 @@ def signature(parm: SignatureParameters) -> SignatureReturn: if not service_list: raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature") service = service_list[0] + data = service.f_data if service.f_data else {} + data.update(parm.to_dict()) response = getattr(requests, service.f_method.lower(), None)( url=service.f_url, - json=parm.to_dict() + json=data ) if response.status_code == 200: if response.json().get("code") == 0: @@ -37,9 +39,11 @@ def authentication(parm: AuthenticationParameters) -> AuthenticationReturn: raise Exception( f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication") service = service_list[0] + data = service.f_data if service.f_data else {} + data.update(parm.to_dict()) response = getattr(requests, service.f_method.lower(), None)( url=service.f_url, - json=parm.to_dict() + json=data ) if response.status_code != 200: raise Exception( diff --git a/python/fate_flow/hook/flow/site_authentication.py b/python/fate_flow/hook/flow/site_authentication.py index 439627585..c87795303 100644 --- a/python/fate_flow/hook/flow/site_authentication.py +++ b/python/fate_flow/hook/flow/site_authentication.py @@ -3,7 +3,7 @@ from Crypto.PublicKey import RSA from Crypto.Signature import PKCS1_v1_5 -from Crypto.Hash import MD5 +from Crypto.Hash import SHA256 from fate_flow.db.key_manager import RsaKeyManager from fate_flow.entity import RetCode @@ -19,7 +19,7 @@ def signature(parm: SignatureParameters) -> SignatureReturn: private_key = RsaKeyManager.get_key(parm.party_id, key_name=SiteKeyName.PRIVATE.value) if not private_key: raise Exception(f"signature error: no found party id {parm.party_id} private key") - sign= PKCS1_v1_5.new(RSA.importKey(private_key)).sign(MD5.new(json.dumps(parm.body).encode())) + sign = PKCS1_v1_5.new(RSA.importKey(private_key)).sign(SHA256.new(json.dumps(parm.body).encode())) return SignatureReturn(site_signature=base64.b64encode(sign).decode()) @@ -30,7 +30,7 @@ def authentication(parm: AuthenticationParameters) -> AuthenticationReturn: if not public_key: raise Exception(f"signature error: no found party id {party_id} public key") verifier = PKCS1_v1_5.new(RSA.importKey(public_key)) - if verifier.verify(MD5.new(json.dumps(parm.body).encode()), base64.b64decode(parm.site_signature)) is True: + if verifier.verify(SHA256.new(json.dumps(parm.body).encode()), base64.b64decode(parm.site_signature)) is True: return AuthenticationReturn() else: return AuthenticationReturn(code=RetCode.AUTHENTICATION_ERROR, message="authentication failed") diff --git a/python/fate_flow/manager/data_manager.py b/python/fate_flow/manager/data_manager.py index fba8239ee..437bbcbed 100644 --- a/python/fate_flow/manager/data_manager.py +++ b/python/fate_flow/manager/data_manager.py @@ -23,6 +23,7 @@ from flask import send_file +from fate_arch import storage from fate_arch.abc import StorageTableABC from fate_arch.common.base_utils import fate_uuid from fate_arch.session import Session @@ -184,86 +185,81 @@ def track_job(cls, table_name, table_namespace, display=False): class TableStorage: @staticmethod - def copy_table(src_table: StorageTableABC, dest_table: StorageTableABC, deserialize_value=False): + def collect(src_table, part_of_data): + line_index = 0 + count = 0 + fate_uuid = uuid.uuid1().hex + for k, v in src_table.collect(): + if src_table.meta.get_extend_sid(): + v = src_table.meta.get_id_delimiter().join([k, v]) + k = line_extend_uuid(fate_uuid, line_index) + line_index += 1 + yield k, v + if count <= 100: + part_of_data.append((k, v)) + count += 1 + + @staticmethod + def read(src_table, schema, part_of_data): + line_index = 0 count = 0 - data_temp = [] + src_table_meta = src_table.meta + fate_uuid = uuid.uuid1().hex + if src_table_meta.get_have_head(): + get_head = False + else: + get_head = True + if not src_table.meta.get_extend_sid(): + get_line = data_utils.get_data_line + elif not src_table_meta.get_auto_increasing_sid(): + get_line = data_utils.get_sid_data_line + else: + get_line = data_utils.get_auto_increasing_sid_data_line + for line in src_table.read(): + if not get_head: + schema.update(data_utils.get_header_schema( + header_line=line, + id_delimiter=src_table_meta.get_id_delimiter(), + extend_sid=src_table_meta.get_extend_sid(), + )) + get_head = True + continue + values = line.rstrip().split(src_table.meta.get_id_delimiter()) + k, v = get_line( + values=values, + line_index=line_index, + extend_sid=src_table.meta.get_extend_sid(), + auto_increasing_sid=src_table.meta.get_auto_increasing_sid(), + id_delimiter=src_table.meta.get_id_delimiter(), + fate_uuid=fate_uuid, + ) + line_index += 1 + yield k, v + if count <= 100: + part_of_data.append((k, v)) + count += 1 + + @staticmethod + def copy_table(src_table: StorageTableABC, dest_table: StorageTableABC): part_of_data = [] src_table_meta = src_table.meta schema = {} update_schema = False - line_index = 0 - fate_uuid = uuid.uuid1().hex if not src_table_meta.get_in_serialized(): - if src_table_meta.get_have_head(): - get_head = False - else: - get_head = True - if not src_table.meta.get_extend_sid(): - get_line = data_utils.get_data_line - elif not src_table_meta.get_auto_increasing_sid(): - get_line = data_utils.get_sid_data_line - else: - get_line = data_utils.get_auto_increasing_sid_data_line - for line in src_table.read(): - if not get_head: - schema = data_utils.get_header_schema( - header_line=line, - id_delimiter=src_table_meta.get_id_delimiter(), - extend_sid=src_table_meta.get_extend_sid(), - ) - get_head = True - continue - values = line.rstrip().split(src_table.meta.get_id_delimiter()) - k, v = get_line( - values=values, - line_index=line_index, - extend_sid=src_table.meta.get_extend_sid(), - auto_increasing_sid=src_table.meta.get_auto_increasing_sid(), - id_delimiter=src_table.meta.get_id_delimiter(), - fate_uuid=fate_uuid, - ) - line_index += 1 - count = TableStorage.put_in_table( - table=dest_table, - k=k, - v=v, - temp=data_temp, - count=count, - part_of_data=part_of_data, - ) + dest_table.put_all(TableStorage.read(src_table, schema, part_of_data)) else: source_header = copy.deepcopy(src_table_meta.get_schema().get("header")) TableStorage.update_full_header(src_table_meta) - for k, v in src_table.collect(): - if src_table.meta.get_extend_sid(): - # extend id - v = src_table.meta.get_id_delimiter().join([k, v]) - k = line_extend_uuid(fate_uuid, line_index) - line_index += 1 - if deserialize_value: - # writer component: deserialize value - v, extend_header = feature_utils.get_deserialize_value(v, dest_table.meta.get_id_delimiter()) - if not update_schema: - header_list = get_component_output_data_schema(src_table.meta, extend_header) - schema = get_header_schema(dest_table.meta.get_id_delimiter().join(header_list), - dest_table.meta.get_id_delimiter()) - _, dest_table.meta = dest_table.meta.update_metas(schema=schema) - update_schema = True - count = TableStorage.put_in_table( - table=dest_table, - k=k, - v=v, - temp=data_temp, - count=count, - part_of_data=part_of_data, - ) + dest_table.put_all(TableStorage.collect(src_table, part_of_data)) schema = src_table.meta.get_schema() schema["header"] = source_header - if data_temp: - dest_table.put_all(data_temp) if schema.get("extend_tag"): schema.update({"extend_tag": False}) - _, dest_table.meta = dest_table.meta.update_metas(schema=schema if not update_schema else None, part_of_data=part_of_data) + _, dest_table.meta = dest_table.meta.update_metas( + schema=schema if not update_schema else None, + part_of_data=part_of_data, + id_delimiter=src_table_meta.get_id_delimiter() + ) return dest_table.count() @staticmethod @@ -275,14 +271,39 @@ def update_full_header(table_meta): table_meta.set_metas(schema=schema) @staticmethod - def put_in_table(table: StorageTableABC, k, v, temp, count, part_of_data, max_num=10000): - temp.append((k, v)) - if count < 100: - part_of_data.append((k, v)) - if len(temp) == max_num: - table.put_all(temp) - temp.clear() - return count + 1 + def read_table_data(data_table_meta, limit=100): + if not limit or limit > 100: + limit = 100 + data_table = storage.StorageTableMeta( + name=data_table_meta.get_name(), + namespace=data_table_meta.get_namespace() + ) + if data_table: + table_schema = data_table_meta.get_schema() + out_header = None + data_list = [] + all_extend_header = {} + for k, v in data_table_meta.get_part_of_data(): + data_line, is_str, all_extend_header = feature_utils.get_component_output_data_line( + src_key=k, + src_value=v, + schema=table_schema, + all_extend_header=all_extend_header + ) + data_list.append(data_line) + if len(data_list) == limit: + break + if data_list: + extend_header = feature_utils.generate_header(all_extend_header, schema=table_schema) + out_header = get_component_output_data_schema( + output_table_meta=data_table_meta, + is_str=is_str, + extend_header=extend_header + ) + + return {'header': out_header, 'data': data_list} + + return {'header': [], 'data': []} @staticmethod def send_table(output_tables_meta, tar_file_name="", limit=-1, need_head=True, local_download=False, output_data_file_path=None): diff --git a/python/fate_flow/model/mysql_model_storage.py b/python/fate_flow/model/mysql_model_storage.py index d7008f6cf..80faabcbb 100644 --- a/python/fate_flow/model/mysql_model_storage.py +++ b/python/fate_flow/model/mysql_model_storage.py @@ -309,6 +309,7 @@ class MachineLearningModel(DataBaseModel): f_slice_index = IntegerField(default=0, index=True) class Meta: + database = DB db_table = 't_machine_learning_model' primary_key = CompositeKey('f_model_id', 'f_model_version', 'f_slice_index') @@ -322,6 +323,7 @@ class MachineLearningComponent(DataBaseModel): f_slice_index = IntegerField(default=0, index=True) class Meta: + database = DB db_table = 't_machine_learning_component' indexes = ( (('f_party_model_id', 'f_model_version', 'f_component_name', 'f_slice_index'), True), diff --git a/python/fate_flow/pipelined_model/deploy_model.py b/python/fate_flow/pipelined_model/deploy_model.py index 57c70c573..287437fe1 100644 --- a/python/fate_flow/pipelined_model/deploy_model.py +++ b/python/fate_flow/pipelined_model/deploy_model.py @@ -129,7 +129,9 @@ def deploy(config_data): job_id=model_version, role=local_role, party_id=local_party_id, dsl_parser=parser, origin_inference_dsl=inference_dsl, ) - pipeline_model.inference_dsl = json_dumps(inference_dsl, byte=True) + # migrate model miss CodePath + module_object_dict = get_module_object_dict(json_loads(pipeline_model.inference_dsl)) + pipeline_model.inference_dsl = json_dumps(parser.get_predict_dsl(inference_dsl, module_object_dict), byte=True) train_runtime_conf = JobRuntimeConfigAdapter( train_runtime_conf, @@ -214,3 +216,12 @@ def deploy(config_data): f'deploy model of role {local_role} {local_party_id} success' + ('' if not warning_msg else f', warning: {warning_msg}') ) + + +def get_module_object_dict(inference_dsl): + module_object_dict = {} + for _, components in inference_dsl.items(): + for name, module in components.items(): + module_object_dict[name] = module.get("CodePath") + return module_object_dict + diff --git a/python/fate_flow/scheduling_apps/client/control_client.py b/python/fate_flow/scheduling_apps/client/control_client.py index f0ebabf56..e6743af74 100644 --- a/python/fate_flow/scheduling_apps/client/control_client.py +++ b/python/fate_flow/scheduling_apps/client/control_client.py @@ -22,7 +22,7 @@ class ControllerClient(object): @classmethod def update_job(cls, job_info): - LOGGER.info("request update job {} on {} {}".format(job_info["job_id"], job_info["role"], job_info["party_id"])) + LOGGER.info(f"request update job {job_info['job_id']} on {job_info['role']} {job_info['party_id']}: {job_info}") response = api_utils.local_api( job_id=job_info["job_id"], method='POST', diff --git a/python/fate_flow/utils/job_utils.py b/python/fate_flow/utils/job_utils.py index 3f28222e2..cc8eda27f 100644 --- a/python/fate_flow/utils/job_utils.py +++ b/python/fate_flow/utils/job_utils.py @@ -32,7 +32,7 @@ from fate_flow.entity.run_status import JobStatus, TaskStatus from fate_flow.entity.types import InputSearchType from fate_flow.settings import FATE_BOARD_DASHBOARD_ENDPOINT -from fate_flow.utils import data_utils, detect_utils, process_utils, session_utils +from fate_flow.utils import data_utils, detect_utils, process_utils, session_utils, schedule_utils from fate_flow.utils.base_utils import get_fate_flow_directory from fate_flow.utils.log_utils import schedule_logger from fate_flow.utils.schedule_utils import get_dsl_parser_by_version @@ -478,12 +478,24 @@ def _wrapper(*args, **kwargs): return _wrapper -def task_report(tasks): - now_time = current_timestamp() - report_list = [{"component_name": task.f_component_name, "start_time": task.f_start_time, - "end_time": task.f_end_time, "elapsed": task.f_elapsed, "status": task.f_status} - for task in tasks] - report_list.sort(key=lambda x: (x["start_time"] if x["start_time"] else now_time, x["status"])) +def task_report(jobs, tasks): + job = jobs[0] + report_list = [] + dsl_parser = schedule_utils.get_job_dsl_parser( + dsl=job.f_dsl, + runtime_conf=job.f_runtime_conf, + train_runtime_conf=job.f_train_runtime_conf + ) + name_component_maps, hierarchical_structure = dsl_parser.get_dsl_hierarchical_structure() + for index, cpn_list in enumerate(hierarchical_structure): + for name in cpn_list: + for task in tasks: + if task.f_component_name == name: + report_list.append({ + "component_name": task.f_component_name, "start_time": task.f_start_time, + "end_time": task.f_end_time, "elapsed": task.f_elapsed, "status": task.f_status, + "index": index + }) return report_list diff --git a/python/fate_flow/utils/model_utils.py b/python/fate_flow/utils/model_utils.py index 1bccbe04a..d746eae0e 100644 --- a/python/fate_flow/utils/model_utils.py +++ b/python/fate_flow/utils/model_utils.py @@ -22,6 +22,7 @@ from fate_flow.pipelined_model.pipelined_model import PipelinedModel from fate_flow.scheduler.cluster_scheduler import ClusterScheduler from fate_flow.settings import ENABLE_MODEL_STORE, stat_logger +from fate_flow.utils import schedule_utils from fate_flow.utils.base_utils import compare_version, get_fate_flow_directory @@ -138,6 +139,42 @@ def gather_model_info_data(model: PipelinedModel): return model_info +def query_model_detail(model_id, model_version, **kwargs): + model_detail = {} + retcode, retmsg, model_infos = query_model_info(model_id=model_id, model_version=model_version) + if not model_infos: + return retcode, retmsg, model_detail + model_info = model_infos[0] + model_detail["runtime_conf"] = model_info.get("f_runtime_conf") or model_info.get("f_train_runtime_conf") + model_detail["dsl"] = model_info.get("f_train_dsl") + model_detail["inference_dsl"] = model_info.get("f_inference_dsl", {}) + is_parent = model_info.get("f_parent") + model_detail["component_info"] = get_component_list(model_detail["runtime_conf"], model_detail["dsl"], is_parent) + model_detail["inference_component_info"] = get_component_list(model_detail["runtime_conf"], model_detail["inference_dsl"], is_parent) + + return retcode, retmsg, model_detail + + +def get_component_list(conf, dsl, is_train): + job_type = "train" + if not is_train: + job_type = "predict" + dsl_parser = schedule_utils.get_job_dsl_parser(dsl=dsl, + runtime_conf=conf, + train_runtime_conf=conf, + job_type=job_type + ) + name_component_maps, hierarchical_structure = dsl_parser.get_dsl_hierarchical_structure() + return [{"component_name": k, "module": v["module"], "index": get_component_index(k, hierarchical_structure)} + for k, v in dsl.get("components", {}).items()] + + +def get_component_index(component_name, hierarchical_structure): + for index, cpn_list in enumerate(hierarchical_structure): + if component_name in cpn_list: + return index + + def query_model_info(**kwargs): file_only = kwargs.pop('file_only', False) kwargs['query_filters'] = set(kwargs['query_filters']) if kwargs.get('query_filters') else set() diff --git a/python/fate_flow/utils/schedule_utils.py b/python/fate_flow/utils/schedule_utils.py index 24efa5b8e..3bef373ec 100644 --- a/python/fate_flow/utils/schedule_utils.py +++ b/python/fate_flow/utils/schedule_utils.py @@ -96,7 +96,7 @@ def get_conf_version(conf: dict): return int(conf.get("dsl_version", "1")) -def get_job_dsl_parser(dsl=None, runtime_conf=None, pipeline_dsl=None, train_runtime_conf=None): +def get_job_dsl_parser(dsl=None, runtime_conf=None, pipeline_dsl=None, train_runtime_conf=None, job_type=None): parser_version = get_conf_version(runtime_conf) if parser_version == 1: @@ -106,7 +106,8 @@ def get_job_dsl_parser(dsl=None, runtime_conf=None, pipeline_dsl=None, train_run parser_version = 2 dsl_parser = get_dsl_parser_by_version(parser_version) - job_type = JobRuntimeConfigAdapter(runtime_conf).get_job_type() + if not job_type: + job_type = JobRuntimeConfigAdapter(runtime_conf).get_job_type() dsl_parser.run(dsl=dsl, runtime_conf=runtime_conf, pipeline_dsl=pipeline_dsl,