Skip to content

JSON config support for cmd_conf utility #3227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,20 @@ def set_embedding_config(
# pyre-ignore [24]
def cmd_conf(func: Callable) -> Callable:

def _load_config_file(config_path: str, is_json: bool = False) -> Dict[str, Any]:
if not config_path:
return {}

try:
with open(config_path, "r") as f:
if is_json:
return json.load(f) or {}
else:
return yaml.safe_load(f) or {}
except Exception as e:
logger.warning(f"Failed to load config because {e}. Proceeding without it.")
return {}

# pyre-ignore [3]
def wrapper() -> Any:
sig = inspect.signature(func)
Expand All @@ -548,6 +562,13 @@ def wrapper() -> Any:
help="YAML config file for benchmarking",
)

parser.add_argument(
"--json_config",
type=str,
default=None,
help="JSON config file for benchmarking",
)

# Add loglevel argument with current logger level as default
parser.add_argument(
"--loglevel",
Expand All @@ -558,18 +579,18 @@ def wrapper() -> Any:

pre_args, _ = parser.parse_known_args()

yaml_defaults: Dict[str, Any] = {}
if pre_args.yaml_config:
try:
with open(pre_args.yaml_config, "r") as f:
yaml_defaults = yaml.safe_load(f) or {}
logger.info(
f"Loaded YAML config from {pre_args.yaml_config}: {yaml_defaults}"
)
except Exception as e:
logger.warning(
f"Failed to load YAML config because {e}. Proceeding without it."
)
yaml_defaults: Dict[str, Any] = (
_load_config_file(pre_args.yaml_config, is_json=False)
if pre_args.yaml_config
else {}
)
json_defaults: Dict[str, Any] = (
_load_config_file(pre_args.json_config, is_json=True)
if pre_args.json_config
else {}
)
# Merge the two dictionaries, JSON overrides YAML
merged_defaults = {**yaml_defaults, **json_defaults}

seen_args = set() # track all --<name> we've added

Expand All @@ -595,10 +616,10 @@ def wrapper() -> Any:
ftype = non_none[0]
origin = get_origin(ftype)

# Handle default_factory value and allow YAML config to override it
default_value = yaml_defaults.get(
# Handle default_factory value and allow config to override
default_value = merged_defaults.get(
arg_name, # flat lookup
yaml_defaults.get(cls.__name__, {}).get( # hierarchy lookup
merged_defaults.get(cls.__name__, {}).get( # hierarchy lookup
arg_name,
(
f.default_factory() # pyre-ignore [29]
Expand Down
Loading