"""Tests for checking consistency between yaml files and their corresponding training scripts.
Authors
* Mirco Ravanelli 2022
* Andreas Nautsch 2022
"""
import os
import re
[docs]def get_yaml_var(hparam_file):
"""Extracts from the input yaml file (hparams_file) the list of variables that
should be used in the script file.
Arguments
---------
hparam_file : path
Path of the yaml file containing the hyperparameters.
Returns
-------
var_list: list
list of the variables declared in the yaml file (sub-variables are not
included).
"""
var_lst = []
with open(hparam_file) as f:
for line in f:
# Avoid empty lists or comments - skip pretrainer definitions
if (
len(line) <= 1
or line.startswith("#")
or "speechbrain.utils.parameter_transfer.Pretrainer" in line
):
continue
# Remove trailing characters
line = line.rstrip()
# Check for variables (e.g., 'key:' or '- !ref')
if line.find(":") != -1 or line.find("- !ref") != -1:
var_name = line[: line.find(":")]
# The variables to check are like "key:" (we do not need to check
# subvariavles as " key:")
if not (
var_name[0] == " "
or var_name[0] == "\t"
or "!!" in line
or "!apply" in line
):
var_lst.append(var_name)
# Check for the reference pattern
# note: output_folder: !ref results/<experiment_name>/<seed> pattern
if line.find("!ref") != -1:
# Check for 'annotation_list_to_check: [!ref <train_csv>, !ref <valid_csv>]' pattern
for subline in line.split("<"):
sub_var = subline[: subline.find(">")]
# Check for 'models[generator]' pattern (dictionary reference)
dict_pos = sub_var.find("[")
if dict_pos != -1:
sub_var = sub_var[:dict_pos]
# Remove variables already used in yaml
if sub_var in var_lst:
var_lst.remove(sub_var)
return var_lst
[docs]def detect_script_vars(script_file, var_lst):
"""Detects from the input script file (script_file) which of given variables (var_lst) are demanded.
Arguments
---------
script_file : path
Path of the script file needing the hyperparameters.
var_lst : list
list of the variables declared in the yaml file.
Returns
-------
detected_var: list
list of the variables detected in the script file.
"""
var_types = [
"hparams.",
"modules.",
'attr(self.hparams, "',
'hparams.get("',
]
detected_var = []
with open(script_file) as f:
for line in f:
for var in var_lst:
# The pattern can be ["key"] or ".key"
if '["' + var + '"]' in line:
if var not in detected_var:
detected_var.append(var)
continue # no need to go through the other cases for this var
# case: hparams[f"{dataset}_annotation"] - only that structure at the moment
re_match = re.search(r"\[f.\{.*\}(.*).\]", line)
if re_match is not None:
if re_match.group(1) in var:
print(
"\t\tWARNING: potential inconsistency %s maybe used in %s (or not)."
% (var, re_match.group(0))
)
if var not in detected_var:
detected_var.append(var)
continue
# Chek var types
for var_type in var_types:
if var_type + var in line:
if var not in detected_var:
detected_var.append(var)
continue
# case: tea_enc_list.append(hparams['tea{}_enc'])
re_var = re.search(r"\[.(.*){}(.*).\]", line)
if re_var is not None:
re_var_pattern = re_var.group(1) + ".*" + re_var.group(2)
re_pattern = re.search(re_var_pattern, var)
if re_pattern is not None:
if re_pattern.group() == var:
detected_var.append(var)
continue
return detected_var
[docs]def check_yaml_vs_script(hparam_file, script_file):
"""Checks consistency between the given yaml file (hparams_file) and the
script file. The function detects if there are variables declared in the yaml
file, but not used in the script file.
Arguments
---------
hparam_file : path
Path of the yaml file containing the hyperparameters.
script_file : path
Path of the script file (.py) containing the training recipe.
Returns
-------
Bool
This function returns False if some mismatch is detected and True otherwise.
An error is raised to inform about which variable has been declared but
not used.
"""
print("Checking %s..." % (hparam_file))
# Check if files exist
if not (os.path.exists(hparam_file)):
print("File %s not found!" % (hparam_file,))
return False
if not (os.path.exists(script_file)):
print("File %s not found!" % (script_file,))
return False
# Getting list of variables declared in yaml
var_lst = get_yaml_var(hparam_file)
# Detect which of these variables are used in the script file
detected_vars_train = detect_script_vars(script_file, var_lst)
# Check which variables are declared but not used
default_run_opt_keys = [
"debug",
"debug_batches",
"debug_epochs",
"device",
"cpu",
"data_parallel_backend",
"distributed_launch",
"distributed_backend",
"find_unused_parameters",
"jit_module_keys",
"auto_mix_prec",
"max_grad_norm",
"nonfinite_patience",
"noprogressbar",
"ckpt_interval_minutes",
"grad_accumulation_factor",
"optimizer_step_limit",
]
unused_vars = list(
set(var_lst) - set(detected_vars_train) - set(default_run_opt_keys)
)
for unused_var in unused_vars:
print(
'\tERROR: variable "%s" not used in %s!' % (unused_var, script_file)
)
return len(unused_vars) == 0
[docs]def check_module_vars(
hparam_file, script_file, module_key="modules:", module_var="self.modules."
):
"""Checks if the variables self.moduled.var are properly declared in the
hparam file.
Arguments
---------
hparam_file : path
Path of the yaml file containing the hyperparameters.
script_file : path
Path of the script file (.py) containing the training recipe.
module_key: string
String that denoted the start of the module in the hparam file.
module_var: string
String that denoted the start of the module in the script file.
Returns
-------
Bool
This function returns False if some mismatch is detected and True otherwise.
An error is raised to inform about which variable has been used but
not declared.
"""
stop_char = [" ", ",", "(", ")", "[", "]", "{", "}", ".", ":", "\n"]
module_block = False
end_block = [" ", "\t"]
avoid_lst = ["parameters", "keys", "eval", "train"]
# Extract Modules variables from the hparam file
module_vars_hparams = []
with open(hparam_file) as f:
for line in f:
if module_key in line:
module_block = True
continue
if line[0] not in end_block:
module_block = False
if module_block:
var = line.strip().split(":")[0]
module_vars_hparams.append(var)
# Extract Modules variables from the script file
with open(script_file) as file:
lines = file.readlines()
lines = [line.rstrip() for line in lines]
module_var_script = extract_patterns(lines, module_var, stop_char)
module_var_script = set(module_var_script)
for avoid in avoid_lst:
if avoid in module_var_script:
module_var_script.remove(avoid)
# Remove optional variables "if hasattr(self.modules, "env_corrupt"):"
stop_char.append('"')
opt_vars = extract_patterns(lines, 'if hasattr(self.modules, "', stop_char)
opt_vars.extend(
extract_patterns(lines, 'if hasattr(self.hparams, "', stop_char)
)
# Remove optional
for avoid in set(opt_vars):
if avoid in module_var_script:
module_var_script.remove(avoid)
# Check Module variavles
unused_vars = list(set(module_var_script) - set(module_vars_hparams))
for unused_var in unused_vars:
print(
'\tERROR: variable "self.modules.%s" used in %s, but not listed in the "modules:" section of %s'
% (unused_var, script_file, hparam_file)
)
return len(unused_vars) == 0