Source code for golemflavor.param

# author : S. Mandalia
#          s.p.mandalia@qmul.ac.uk
#
# date   : April 19, 2018

"""
Param class and functions for the BSM flavor ratio analysis
"""

from __future__ import absolute_import, division
from six import string_types

import sys

from collections import Sequence
from copy import deepcopy

import numpy as np

from golemflavor.fr import fr_to_angles
from golemflavor.enums import DataType, Likelihood, ParamTag, PriorsCateg


[docs]class Param(object): """Parameter class to store parameters.""" def __init__(self, name, value, ranges, prior=None, seed=None, std=None, tex=None, tag=None): self._prior = None self._seed = None self._ranges = None self._tex = None self._tag = None self.name = name self.value = value self.nominal_value = deepcopy(value) self.prior = prior self.ranges = ranges self.seed = seed self.std = std self.tex = tex self.tag = tag @property def ranges(self): return tuple(self._ranges) @ranges.setter def ranges(self, values): self._ranges = [val for val in values] @property def prior(self): return self._prior @prior.setter def prior(self, value): if value is None: self._prior = PriorsCateg.UNIFORM else: assert value in PriorsCateg self._prior = value @property def seed(self): if self._seed is None: return self.ranges return tuple(self._seed) @seed.setter def seed(self, values): if values is None: return self._seed = [val for val in values] @property def tex(self): return r'{0}'.format(self._tex) @tex.setter def tex(self, t): self._tex = t if t is not None else r'{\rm %s}' % self.name @property def tag(self): return self._tag @tag.setter def tag(self, t): if t is None: self._tag = ParamTag.NONE else: assert t in ParamTag self._tag = t
[docs]class ParamSet(Sequence): """Container class for a set of parameters.""" def __init__(self, *args): param_sequence = [] for arg in args: try: param_sequence.extend(arg) except TypeError: param_sequence.append(arg) if len(param_sequence) != 0: # Disallow duplicated params all_names = [p.name for p in param_sequence] unique_names = set(all_names) if len(unique_names) != len(all_names): duplicates = set([x for x in all_names if all_names.count(x) > 1]) raise ValueError('Duplicate definitions found for param(s): ' + ', '.join(str(e) for e in duplicates)) # Elements of list must be Param type assert all([isinstance(x, Param) for x in param_sequence]), \ 'All params must be of type "Param"' self._params = param_sequence def __len__(self): return len(self._params) def __getitem__(self, i): if isinstance(i, int): return self._params[i] elif isinstance(i, string_types): return self._by_name[i] def __getattr__(self, attr): return super(ParamSet, self).__getattribute__(attr) def __iter__(self): return iter(self._params) def __str__(self): o = '\n' for obj in self._params: o += '== {0:<15} = {1:<15}, tag={2:<15}\n'.format( obj.name, obj.value, obj.tag ) return o @property def _by_name(self): return {obj.name: obj for obj in self._params} @property def names(self): return tuple([obj.name for obj in self._params]) @property def labels(self): return tuple([obj.tex for obj in self._params]) @property def values(self): return tuple([obj.value for obj in self._params]) @property def nominal_values(self): return tuple([obj.nominal_value for obj in self._params]) @property def seeds(self): return tuple([obj.seed for obj in self._params]) @property def ranges(self): return tuple([obj.ranges for obj in self._params]) @property def stds(self): return tuple([obj.std for obj in self._params]) @property def tags(self): return tuple([obj.tag for obj in self._params]) @property def params(self): return self._params def to_dict(self): return {obj.name: obj.value for obj in self._params} def from_tag(self, tag, values=False, index=False, invert=False): if values and index: assert 0 tag = np.atleast_1d(tag) if not invert: ps = [(idx, obj) for idx, obj in enumerate(self._params) if obj.tag in tag] else: ps = [(idx, obj) for idx, obj in enumerate(self._params) if obj.tag not in tag] if values: return tuple([io[1].value for io in ps]) elif index: return tuple([io[0] for io in ps]) else: return ParamSet([io[1] for io in ps]) def remove_params(self, params): rm_paramset = [] for parm in self.params: if parm.name not in params.names: rm_paramset.append(parm) return ParamSet(rm_paramset) def extend(self, p): param_sequence = self.params if isinstance(p, Param): param_sequence.append(p) elif isinstance(p, ParamSet): param_sequence.extend(p.params) return ParamSet(param_sequence)