#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018-2021 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
import hashlib
import os
import re
import sys
import numpy as np
from . import ui
_z_str_to_z_dict = {
'H': 1,
'He': 2,
'Li': 3,
'Be': 4,
'B': 5,
'C': 6,
'N': 7,
'O': 8,
'F': 9,
'Ne': 10,
'Na': 11,
'Mg': 12,
'Al': 13,
'Si': 14,
'P': 15,
'S': 16,
'Cl': 17,
'Ar': 18,
'K': 19,
'Ca': 20,
'Sc': 21,
'Ti': 22,
'V': 23,
'Cr': 24,
'Mn': 25,
'Fe': 26,
'Co': 27,
'Ni': 28,
'Cu': 29,
'Zn': 30,
'Ga': 31,
'Ge': 32,
'As': 33,
'Se': 34,
'Br': 35,
'Kr': 36,
'Rb': 37,
'Sr': 38,
'Y': 39,
'Zr': 40,
'Nb': 41,
'Mo': 42,
'Tc': 43,
'Ru': 44,
'Rh': 45,
'Pd': 46,
'Ag': 47,
'Cd': 48,
'In': 49,
'Sn': 50,
'Sb': 51,
'Te': 52,
'I': 53,
'Xe': 54,
'Cs': 55,
'Ba': 56,
'La': 57,
'Ce': 58,
'Pr': 59,
'Nd': 60,
'Pm': 61,
'Sm': 62,
'Eu': 63,
'Gd': 64,
'Tb': 65,
'Dy': 66,
'Ho': 67,
'Er': 68,
'Tm': 69,
'Yb': 70,
'Lu': 71,
'Hf': 72,
'Ta': 73,
'W': 74,
'Re': 75,
'Os': 76,
'Ir': 77,
'Pt': 78,
'Au': 79,
'Hg': 80,
'Tl': 81,
'Pb': 82,
'Bi': 83,
'Po': 84,
'At': 85,
'Rn': 86,
'Fr': 87,
'Ra': 88,
'Ac': 89,
'Th': 90,
'Pa': 91,
'U': 92,
'Np': 93,
'Pu': 94,
'Am': 95,
'Cm': 96,
'Bk': 97,
'Cf': 98,
'Es': 99,
'Fm': 100,
'Md': 101,
'No': 102,
'Lr': 103,
'Rf': 104,
'Db': 105,
'Sg': 106,
'Bh': 107,
'Hs': 108,
'Mt': 109,
'Ds': 110,
'Rg': 111,
'Cn': 112,
'Uuq': 114,
'Uuh': 116,
}
_z_to_z_str_dict = {v: k for k, v in _z_str_to_z_dict.items()}
[docs]def z_str_to_z(z_str):
return np.array([_z_str_to_z_dict[x] for x in z_str])
[docs]def z_to_z_str(z):
return [_z_to_z_str_dict[int(x)] for x in z]
[docs]def train_dir_name(dataset, n_train, use_sym, use_E, use_E_cstr):
theory_level_str = re.sub(r'[^\w\-_\.]', '.', str(dataset['theory']))
theory_level_str = re.sub(r'\.\.', '.', theory_level_str)
sym_str = '-sym' if use_sym else ''
# cprsn_str = '-cprsn' if use_cprsn else ''
noE_str = '-noE' if not use_E else ''
Ecstr_str = '-Ecstr' if use_E_cstr else ''
return 'sgdml_cv_%s-%s-train%d%s%s%s' % (
dataset['name'].astype(str),
theory_level_str,
n_train,
sym_str,
# cprsn_str,
noE_str,
Ecstr_str,
)
[docs]def task_file_name(task):
n_train = task['idxs_train'].shape[0]
n_perms = task['perms'].shape[0]
sig = np.squeeze(task['sig'])
return 'task-train%d-sym%d-sig%04d.npz' % (n_train, n_perms, sig)
[docs]def model_file_name(task_or_model, is_extended=False):
n_train = task_or_model['idxs_train'].shape[0]
n_perms = task_or_model['perms'].shape[0]
sig = np.squeeze(task_or_model['sig'])
if is_extended:
dataset = np.squeeze(task_or_model['dataset_name'])
theory_level_str = re.sub(
r'[^\w\-_\.]', '.', str(np.squeeze(task_or_model['dataset_theory']))
)
theory_level_str = re.sub(r'\.\.', '.', theory_level_str)
return '%s-%s-train%d-sym%d.npz' % (dataset, theory_level_str, n_train, n_perms)
return 'model-train%d-sym%d-sig%04d.npz' % (n_train, n_perms, sig)
[docs]def dataset_md5(dataset):
md5_hash = hashlib.md5()
keys = ['z', 'R']
if 'E' in dataset:
keys.append('E')
keys.append('F')
# only include new extra keys in fingerprint for 'modern' dataset files
# 'code_version' was included from 0.4.0.dev1
# opt_keys = ['lattice', 'e_unit', 'E_min', 'E_max', 'E_mean', 'E_var', 'f_unit', 'F_min', 'F_max', 'F_mean', 'F_var']
# for k in opt_keys:
# if k in dataset:
# keys.append(k)
for k in keys:
d = dataset[k]
if type(d) is np.ndarray:
d = d.ravel()
md5_hash.update(hashlib.md5(d).digest())
return md5_hash.hexdigest().encode('utf-8')
# ## FILES
# Read geometry file (xyz format).
# R: (n_geo,3*n_atoms)
# z: (3*n_atoms,)
[docs]def read_xyz(file_path):
with open(file_path, 'r') as f:
n_atoms = None
R, z = [], []
for i, line in enumerate(f):
line = line.strip()
if not n_atoms:
n_atoms = int(line)
cols = line.split()
file_i, line_i = divmod(i, n_atoms + 2)
if line_i >= 2:
R.append(list(map(float, cols[1:4])))
if file_i == 0: # first molecule
z.append(_z_str_to_z_dict[cols[0]])
R = np.array(R).reshape(-1, 3 * n_atoms)
z = np.array(z)
f.close()
return R, z
# Write geometry file (xyz format).
[docs]def write_geometry(filename, r, z, comment_str=''):
r = np.squeeze(r)
try:
with open(filename, 'w') as f:
f.write(str(len(r)) + '\n' + comment_str)
for i, atom in enumerate(r):
f.write('\n' + _z_to_z_str_dict[z[i]] + '\t')
f.write('\t'.join(str(x) for x in atom))
except IOError:
sys.exit("ERROR: Writing xyz file failed.")
# Write geometry file (xyz format).
[docs]def generate_xyz_str(r, z, e=None, f=None, lattice=None):
comment_str = ''
if lattice is not None:
comment_str += 'Lattice=\"{}\" '.format(
' '.join(['{:.12g}'.format(l) for l in lattice.T.ravel()])
)
if e is not None:
comment_str += 'Energy={:.12g} '.format(e)
comment_str += 'Properties=species:S:1:pos:R:3'
if f is not None:
comment_str += ':forces:R:3'
species_str = '\n'.join([_z_to_z_str_dict[z_i] for z_i in z])
r_f_str = ui.gen_mat_str(r)[0]
if f is not None:
r_f_str = ui.merge_col_str(r_f_str, ui.gen_mat_str(f)[0])
xyz_str = str(len(r)) + '\n' + comment_str + '\n'
xyz_str += ui.merge_col_str(species_str, r_f_str)
return xyz_str
[docs]def lattice_vec_to_par(lat):
lat = lat.T
lengths = [np.linalg.norm(v) for v in lat]
angles = []
for i in range(3):
j = i - 1
k = i - 2
ll = lengths[j] * lengths[k]
if ll > 1e-16:
x = np.dot(lat[j], lat[k]) / ll
angle = 180.0 / np.pi * np.arccos(x)
else:
angle = 90.0
angles.append(angle)
return lengths, angles
### FILE HANDLING
[docs]def is_file_type(arg, type):
"""
Validate file path and check if the file is of the specified type.
Parameters
----------
arg : :obj:`str`
File path.
type : {'dataset', 'task', 'model'}
Possible file types.
Returns
-------
(:obj:`str`, :obj:`dict`)
Tuple of file path (as provided) and data stored in the
file. The returned instance of NpzFile class must be
closed to avoid leaking file descriptors.
Raises
------
ArgumentTypeError
If the provided file path does not lead to a NpzFile.
ArgumentTypeError
If the file is not readable.
ArgumentTypeError
If the file is of wrong type.
ArgumentTypeError
If path/fingerprint is provided, but the path is not valid.
ArgumentTypeError
If fingerprint could not be resolved.
ArgumentTypeError
If multiple files with the same fingerprint exist.
"""
# Replace MD5 dataset fingerprint with file name, if necessary.
if type == 'dataset' and not arg.endswith('.npz') and not os.path.isdir(arg):
dir = '.'
if re.search(r'^[a-f0-9]{32}$', arg): # arg looks similar to MD5 hash string
md5_str = arg
else: # is it a path with a MD5 hash at the end?
md5_str = os.path.basename(os.path.normpath(arg))
dir = os.path.dirname(os.path.normpath(arg))
if dir == '': # it is only a filename after all, hence not the right type
raise argparse.ArgumentTypeError('{0} is not a .npz file'.format(arg))
if re.search(r'^[a-f0-9]{32}$', md5_str) and not os.path.isdir(
dir
): # path has MD5 hash string at the end, but directory is not valid
raise argparse.ArgumentTypeError('{0} is not a directory'.format(dir))
file_names = filter_file_type(dir, type, md5_match=md5_str)
if not len(file_names):
raise argparse.ArgumentTypeError(
"No {0} files with fingerprint '{1}' found in '{2}'".format(
type, md5_str, dir
)
)
elif len(file_names) > 1:
error_str = (
"Multiple {0} files with fingerprint '{1}' found in '{2}'".format(
type, md5_str, dir
)
)
for file_name in file_names:
error_str += '\n {0}'.format(file_name)
raise argparse.ArgumentTypeError(error_str)
else:
arg = os.path.join(dir, file_names[0])
if not arg.endswith('.npz'):
argparse.ArgumentTypeError('{0} is not a .npz file'.format(arg))
try:
file = np.load(arg, allow_pickle=True)
except Exception:
raise argparse.ArgumentTypeError('{0} is not readable'.format(arg))
if 'type' not in file or file['type'].astype(str) != type[0]:
raise argparse.ArgumentTypeError('{0} is not a {1} file'.format(arg, type))
return arg, file
[docs]def filter_file_type(dir, type, md5_match=None):
"""
Filters all files from a directory that match a given type and (optionally)
a given fingerprint.
Parameters
----------
arg : :obj:`str`
File path.
type : {'dataset', 'task', 'model'}
Possible file types.
md5_match : :obj:`str`, optional
Fingerprint string.
Returns
-------
:obj:`list` of :obj:`str`
List of file names that match the specified type and fingerprint
(if provided).
Raises
------
ArgumentTypeError
If the directory contains unreadable .npz files.
"""
file_names = []
for file_name in sorted(os.listdir(dir)):
if file_name.endswith('.npz'):
file_path = os.path.join(dir, file_name)
try:
file = np.load(file_path, allow_pickle=True)
except Exception:
raise argparse.ArgumentTypeError(
'{0} contains unreadable .npz files'.format(arg)
)
if 'type' in file and file['type'].astype(str) == type[0]:
if md5_match is None:
file_names.append(file_name)
elif 'md5' in file and file['md5'] == md5_match:
file_names.append(file_name)
file.close()
return file_names
[docs]def is_valid_file_type(arg_in):
"""
Check if file is either a valid dataset, task or model file.
Parameters
----------
arg_in : :obj:`str`
File path.
Returns
-------
(:obj:`str`, :obj:`dict`)
Tuple of file path (as provided) and data stored in the
file. The returned instance of NpzFile class must be
closed to avoid leaking file descriptors.
Raises
------
ArgumentTypeError
If the provided file path does not point to a supported
file type.
"""
arg, file = None, None
try:
arg, file = is_file_type(arg_in, 'dataset')
except argparse.ArgumentTypeError:
pass
if file is None:
try:
arg, file = is_file_type(arg_in, 'task')
except argparse.ArgumentTypeError:
pass
if file is None:
try:
arg, file = is_file_type(arg_in, 'model')
except argparse.ArgumentTypeError:
pass
if file is None:
raise argparse.ArgumentTypeError(
'{0} is neither a dataset, task, nor model file'.format(arg)
)
return arg, file
[docs]def is_dir_with_file_type(arg, type, or_file=False):
"""
Validate directory path and check if it contains files of the specified type.
Note
----
If a file path is provided, this function acts like its a directory with
just one file.
Parameters
----------
arg : :obj:`str`
File path.
type : {'dataset', 'task', 'model'}
Possible file types.
or_file : bool
If `arg` contains a file path, act like it's a directory
with just a single file inside.
Returns
-------
(:obj:`str`, :obj:`list` of :obj:`str`)
Tuple of directory path (as provided) and a list of
contained file names of the specified type.
Raises
------
ArgumentTypeError
If the provided directory path does not lead to a directory.
ArgumentTypeError
If directory contains unreadable files.
ArgumentTypeError
If directory contains no files of the specified type.
"""
if or_file and os.path.isfile(arg): # arg: file path
_, file = is_file_type(
arg, type
) # raises exception if there is a problem with the file
file.close()
file_name = os.path.basename(arg)
file_dir = os.path.dirname(arg)
return file_dir, [file_name]
else: # arg: dir
if not os.path.isdir(arg):
raise argparse.ArgumentTypeError('{0} is not a directory'.format(arg))
file_names = filter_file_type(arg, type)
# if not len(file_names):
# raise argparse.ArgumentTypeError(
# '{0} contains no {1} files'.format(arg, type)
# )
return arg, file_names
[docs]def is_task_dir_resumeable(
train_dir, train_dataset, test_dataset, n_train, n_test, sigs, gdml
):
r"""
Check if a directory contains `task` and/or `model` files that
match the configuration of a training process specified in the
remaining arguments.
Check if the training and test datasets in each task match
`train_dataset` and `test_dataset`, if the number of training and
test points matches and if the choices for the kernel
hyper-parameter :math:`\sigma` are contained in the list. Check
also, if the existing tasks/models contain symmetries and if
that's consistent with the flag `gdml`. This function is useful
for determining if a training process can be resumed using the
existing files or not.
Parameters
----------
train_dir : :obj:`str`
Path to training directory.
train_dataset : :obj:`dataset`
Dataset from which training points are sampled.
test_dataset : :obj:`test_dataset`
Dataset from which test points are sampled (may be the
same as `train_dataset`).
n_train : int
Number of training points to sample.
n_test : int
Number of test points to sample.
sigs : :obj:`list` of int
List of :math:`\sigma` kernel hyper-parameter choices
(usually: the hyper-parameter search grid)
gdml : bool
If `True`, don't include any symmetries in model (GDML),
otherwise do (sGDML).
Returns
-------
bool
False, if any of the files in the directory do not match
the training configuration.
"""
for file_name in sorted(os.listdir(train_dir)):
if file_name.endswith('.npz'):
file_path = os.path.join(train_dir, file_name)
file = np.load(file_path, allow_pickle=True)
if 'type' not in file:
continue
elif file['type'] == 't' or file['type'] == 'm':
if (
file['md5_train'] != train_dataset['md5']
or file['md5_valid'] != test_dataset['md5']
or len(file['idxs_train']) != n_train
or len(file['idxs_valid']) != n_test
or gdml
and file['perms'].shape[0] > 1
or file['sig'] not in sigs
):
return False
return True
### ARGUMENT VALIDATION
[docs]def is_strict_pos_int(arg):
"""
Validate strictly positive integer input.
Parameters
----------
arg : :obj:`str`
Integer as string.
Returns
-------
int
Parsed integer.
Raises
------
ArgumentTypeError
If integer is not > 0.
"""
x = int(arg)
if x <= 0:
raise argparse.ArgumentTypeError('must be strictly positive')
return x
[docs]def parse_list_or_range(arg):
"""
Parses a string that represents either an integer or a range in
the notation ``<start>:<step>:<stop>``.
Parameters
----------
arg : :obj:`str`
Integer or range string.
Returns
-------
int or :obj:`list` of int
Raises
------
ArgumentTypeError
If input can neither be interpreted as an integer nor a valid range.
"""
if re.match(r'^\d+:\d+:\d+$', arg) or re.match(r'^\d+:\d+$', arg):
rng_params = list(map(int, arg.split(':')))
step = 1
if len(rng_params) == 2: # start, stop
start, stop = rng_params
else: # start, step, stop
start, step, stop = rng_params
rng = list(range(start, stop + 1, step)) # include last stop-element in range
if len(rng) == 0:
raise argparse.ArgumentTypeError('{0} is an empty range'.format(arg))
return rng
elif re.match(r'^\d+$', arg):
return int(arg)
raise argparse.ArgumentTypeError(
'{0} is neither a integer list, nor valid range in the form <start>:[<step>:]<stop>'.format(
arg
)
)