diff --git a/flagcx/core/cost_model.cc b/flagcx/core/cost_model.cc index 8fdbd54c..fc18f91f 100644 --- a/flagcx/core/cost_model.cc +++ b/flagcx/core/cost_model.cc @@ -10,7 +10,8 @@ flagcxResult_t flagcxAlgoTimeEstimator::getAlgoTime(float *time) { const char *interServerTopoFile = flagcxGetEnv("FLAGCX_INTERSERVER_ROUTE_FILE"); if (enableTopoDetect && interServerTopoFile && - strcmp(enableTopoDetect, "TRUE") == 0) { + (strcmp(enableTopoDetect, "TRUE") == 0 || + strcmp(enableTopoDetect, "True") == 0)) { // algo time estimator depends on cluster level topology detection float preHomoTime, heteroTime, postHomoTime; INFO(FLAGCX_GRAPH, "COST_MODEL: getting time for prehomo funcs"); diff --git a/flagcx/core/init.cc b/flagcx/core/init.cc index ae5ae51e..3e23b2ea 100644 --- a/flagcx/core/init.cc +++ b/flagcx/core/init.cc @@ -216,7 +216,7 @@ static flagcxResult_t flagcxCommInitRankFunc(struct flagcxAsyncJob *job_) { } FLAGCXCHECK(flagcxNetInit(comm)); INFO(FLAGCX_INIT, "Using network %s", comm->netAdaptor->name); - if (env && strcmp(env, "TRUE") == 0) { + if (env && (strcmp(env, "TRUE") == 0 || strcmp(env, "True") == 0)) { INFO(FLAGCX_INIT, "getting busId for cudaDev %d", comm->cudaDev); FLAGCXCHECK(getBusId(comm->cudaDev, &comm->busId)); INFO(FLAGCX_INIT, "getting commHash for rank %d", comm->rank); diff --git a/flagcx/core/topo.cc b/flagcx/core/topo.cc index 5477d8ab..c301dad2 100644 --- a/flagcx/core/topo.cc +++ b/flagcx/core/topo.cc @@ -537,7 +537,8 @@ flagcxResult_t flagcxGetLocalNetFromGpu(int apu, int *dev, } if (strlen(name) == 0 && enable_topo_detect && - strcmp(enable_topo_detect, "TRUE") == 0) { + (strcmp(enable_topo_detect, "TRUE") == 0 || + strcmp(enable_topo_detect, "True") == 0)) { FLAGCXCHECK(flagcxTopoGetLocalNet(comm->topoServer, comm->rank, dev)); } diff --git a/flagcx/flagcx.cc b/flagcx/flagcx.cc index bbf72b9d..6ad21e77 100644 --- a/flagcx/flagcx.cc +++ b/flagcx/flagcx.cc @@ -484,9 +484,10 @@ flagcxResult_t flagcxCommInitRank(flagcxComm_t *comm, int nranks, struct flagcxNicDistance *nicDistanceData; FLAGCXCHECK(flagcxCalloc(&nicDistanceData, nranks)); const char *enableTopoDetect = flagcxGetEnv("FLAGCX_ENABLE_TOPO_DETECT"); - if (enableTopoDetect && strcmp(enableTopoDetect, "TRUE") == - 0) { // safety check nic distance is only - // available after topo detection + if (enableTopoDetect && (strcmp(enableTopoDetect, "TRUE") == 0 || + strcmp(enableTopoDetect, "True") == + 0)) { // safety check nic distance is only + // available after topo detection FLAGCXCHECK(flagcxGetNicDistance((*comm)->hetero_comm->topoServer, rank, nicDistanceData + rank)); } else { @@ -1762,4 +1763,4 @@ flagcxResult_t flagcxGroupEnd(flagcxComm_t comm) { } } return flagcxSuccess; -} \ No newline at end of file +}