#! /usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
#
# Copyright (C) 2018, ARM Limited and contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import contextlib
import fnmatch
import functools
import gc
import importlib
import inspect
import io
import itertools
import logging
import pathlib
import pickle
import subprocess
import sys
import tempfile
import traceback
import types
import uuid
import glob
import textwrap
import argparse
import time
import datetime
import copy
import warnings
import os.path
DB_FILENAME = 'VALUE_DB.pickle.xz'
class NotSerializableError(Exception):
pass
def get_class_from_name(cls_name, module_map=None):
"""
Get a class object from its full name (including the module name).
"""
# Avoid argument default value that would be huge in Sphinx doc
module_map = module_map if module_map is not None else sys.modules
possible_mod_set = {
mod_name
for mod_name in module_map.keys()
if cls_name.startswith(mod_name)
}
# Longest match in term of number of components
possible_mod_list = sorted(possible_mod_set, key=lambda name: len(name.split('.')))
if possible_mod_list:
mod_name = possible_mod_list[-1]
else:
return None
mod = module_map[mod_name]
cls_name = cls_name[len(mod_name) + 1:]
return _get_class_from_name(cls_name, mod)
def _get_class_from_name(cls_name, namespace):
if isinstance(namespace, collections.abc.Mapping):
namespace = types.SimpleNamespace(**namespace)
split = cls_name.split('.', 1)
try:
obj = getattr(namespace, split[0])
except AttributeError as e:
raise ValueError('Object not found') from e
if len(split) > 1:
return _get_class_from_name('.'.join(split[1:]), obj)
else:
return obj
def create_uuid():
"""
Creates a UUID.
"""
return uuid.uuid4().hex
def get_mro(cls):
"""
Wrapper on top of :func:`inspect.getmro` that recognizes ``None`` as a
type (treated like ``type(None)``).
"""
if cls is type(None) or cls is None:
return (type(None), object)
else:
assert isinstance(cls, type)
return inspect.getmro(cls)
def get_method_class(function):
"""
Get the class of a method by analyzing its name.
"""
cls_name = function.__qualname__.rsplit('.', 1)[0]
if '<locals>' in cls_name:
return None
return eval(cls_name, function.__globals__)
def get_name(obj, full_qual=True, qual=True, pretty=False):
"""
Get a name for ``obj`` (function or class) that can be used in a generated
script.
:param full_qual: Full name of the object, including its module name.
:type full_qual: bool
:param qual: Qualified name of the object
:type qual: bool
:param pretty: If ``True``, will show a prettier name for some types,
although it is not guarnateed it will actually be a type name. For example,
``type(None)`` will be shown as ``None`` instead of ``NoneType``.
:type pretty: bool
"""
# full_qual enabled implies qual enabled
_qual = qual or full_qual
# qual disabled implies full_qual disabled
full_qual = full_qual and qual
qual = _qual
if qual:
def _get_name(x): return x.__qualname__
else:
def _get_name(x): return x.__name__
if obj is None:
pretty = True
obj = type(obj)
if pretty:
for prettier_obj in {None, NoValue}:
if obj == type(prettier_obj):
# For these types, qual=False or qual=True makes no difference
obj = prettier_obj
def _get_name(x): return str(x)
break
# Add the module's name in front of the name to get a fully
# qualified name
if full_qual:
try:
module_name = obj.__module__
except AttributeError:
module_name = ''
else:
module_name = (
module_name + '.'
if module_name not in ('__main__', 'builtins', None)
else ''
)
else:
module_name = ''
# Classmethods appear as bound method of classes. Since each subclass will
# get a different bound method object, we want to reflect that in the
# name we use, instead of always using the same name that the method got
# when it was defined
if inspect.ismethod(obj):
name = _get_name(obj.__self__) + '.' + obj.__name__
else:
name = _get_name(obj)
return module_name + name
def get_toplevel_module(obj):
"""
Return the outermost module object in which ``obj`` is defined as a tuple
of source file path and line number.
This is usually a package.
"""
module = inspect.getmodule(obj)
toplevel_module_name = module.__name__.split('.')[0]
toplevel_module = sys.modules[toplevel_module_name]
return toplevel_module
def get_src_loc(obj, shorten=True):
"""
Get the source code location of ``obj``
:param shorten: Shorten the paths of the source files by only keeping the
part relative to the top-level package.
:type shorten: bool
"""
try:
src_line = inspect.getsourcelines(obj)[1]
src_file = inspect.getsourcefile(obj)
src_file = pathlib.Path(src_file).resolve()
except (OSError, TypeError):
src_line, src_file = None, None
if shorten and src_file:
mod = get_toplevel_module(obj)
try:
paths = mod.__path__
except AttributeError:
paths = [mod.__file__]
paths = [pathlib.Path(p) for p in paths if p]
paths = [
p.parent.resolve()
for p in paths
if p in src_file.parents
]
if paths:
src_file = src_file.relative_to(paths[0])
return (str(src_file), src_line)
class ExceptionPickler(pickle.Pickler):
"""
Fix pickling of exceptions so they can be reloaded without troubles, even
when called with keyword arguments:
https://bugs.python.org/issue40917
"""
class _DispatchTable:
"""
For Python < 3.8 (i.e. before Pickler.reducer_override is available)
"""
def __init__(self, pickler):
self.pickler = pickler
def __getitem__(self, cls):
try:
return self.pickler._get_reducer(cls)
except ValueError:
raise KeyError
def __init__(self, *args, **kwargs):
# Due to this issue, it is critical that "dispatch_table" is set after
# calling super().__init__
# https://bugs.python.org/issue45830
super().__init__(*args, **kwargs)
self.dispatch_table = self._DispatchTable(self)
@staticmethod
def _make_excep(excep_cls, dct):
new = excep_cls.__new__(excep_cls)
# "args" is a bit magical: it's not stored in __dict__, and any "arg"
# key __dict__ will basically be ignored, so it needs to be restored
# manually.
dct = copy.copy(dct)
try:
args = dct.pop('args')
except KeyError:
pass
else:
new.args = args
new.__dict__ = dct
return new
@classmethod
def _reduce_excep(cls, obj):
# The __dict__ of exceptions does not contain "args", so shoe-horn
# it in there so it's passed as a regular attribute.
dct = obj.__dict__.copy()
try:
dct['args'] = obj.args
except AttributeError:
pass
return (cls._make_excep, (obj.__class__, dct))
# Only for Python >= 3.8, otherwise self.dispatch_table is used
def reducer_override(self, obj):
try:
reducer = self._get_reducer(type(obj))
except ValueError:
# Fallback to default behaviour
return NotImplemented
else:
return reducer(obj)
def _get_reducer(self, cls):
# Workaround this bug:
# https://bugs.python.org/issue43460
if issubclass(cls, BaseException):
return self._reduce_excep
else:
raise ValueError('Class not handled')
@classmethod
def dump_bytestring(cls, obj, **kwargs):
f = io.BytesIO()
cls.dump_file(f, obj, **kwargs)
return f.getvalue()
@classmethod
def dump_file(cls, f, obj, **kwargs):
pickler = cls(f, **kwargs)
return pickler.dump(obj)
@classmethod
def is_serializable(cls, obj, raise_excep=False):
"""
Try to Pickle the object to see if that raises any exception.
"""
class DevNull:
def write(self, _):
pass
try:
# This may be slow for big objects but it is the only way to be
# sure it can actually be serialized
cls.dump_file(DevNull(), obj)
except (TypeError, pickle.PickleError, AttributeError) as e:
debug('Cannot serialize instance of {}: {}'.format(
type(obj).__qualname__, str(e)
))
if raise_excep:
raise NotSerializableError(obj) from e
return False
else:
return True
is_serializable = ExceptionPickler.is_serializable
def once(callable_):
"""
Call the given function at most once per set of parameters
"""
return functools.lru_cache(maxsize=None, typed=True)(callable_)
def remove_indices(iterable, ignored_indices):
"""
Filter the given ``iterable`` by removing listed in ``ignored_indices``.
"""
return [v for i, v in enumerate(iterable) if i not in ignored_indices]
def get_module_basename(path):
"""
Get the module name of the module defined in source ``path``.
"""
path = pathlib.Path(path)
module_name = inspect.getmodulename(str(path))
# This is either garbage or a package
if module_name is None:
module_name = path.name
return module_name
def iterate_cb(iterator, pre_hook=None, post_hook=None):
"""
Iterate over ``iterator``, and call some callbacks.
:param pre_hook: Callback called right before getting a new value.
:type pre_hook: collections.abc.Callable
:param post_hook: Callback called right after getting a new value.
:type post_hook: collections.abc.Callable
"""
with contextlib.suppress(StopIteration):
for i in itertools.count():
# Do not execute pre_hook on the first iteration
if pre_hook and i:
pre_hook()
val = next(iterator)
if post_hook:
post_hook()
yield val
def format_exception(e):
"""
Format the traceback of the exception ``e`` in a string.
"""
elements = traceback.format_exception(type(e), e, e.__traceback__)
return ''.join(elements)
# Logging level above CRITICAL that is always displayed and used for output
LOGGING_OUT_LEVEL = 60
"""
Log level used for the ``OUT`` level.
This allows sending all the output through the logging module instead of using
:func:`print`, so it can easily be recorded to a file
"""
class ExekallFormatter(logging.Formatter):
"""
Custom :class:`logging.Formatter` that takes care of ``OUT`` level.
This ``OUT`` level allows using :mod:`logging` instead of :func:`print` so
it can be redirected to a file easily.
"""
def __init__(self, fmt, *args, **kwargs):
self.default_fmt = logging.Formatter(fmt, *args, **kwargs)
self.out_fmt = logging.Formatter('%(message)s', *args, **kwargs)
def format(self, record):
# level above CRITICAL, so it is always displayed
if record.levelno == LOGGING_OUT_LEVEL:
return self.out_fmt.format(record)
# regular levels are logged with the regular formatter
else:
return self.default_fmt.format(record)
LOGGING_FOMATTER_MAP = {
'normal': ExekallFormatter('[%(asctime)s][%(name)s] %(levelname)s %(message)s'),
'verbose': ExekallFormatter('[%(asctime)s][%(name)s/%(filename)s:%(lineno)s] %(levelname)s %(message)s'),
}
def setup_logging(log_level, debug_log_file=None, info_log_file=None, verbose=0):
"""
Setup the :mod:`logging` module.
:param log_level: Lowest log level name to display.
:type log_level: str
:param debug_log_file: Path to a file where logs are collected at the
``DEBUG`` level.
:type debug_log_file: str
:param info_log_file: Path to a file where logs are collected at the
``INFO`` level.
:type info_log_file: str
:param verbose: Verbosity level. The format string for log entries will
contain more information when the level increases.`
:type verbose: int
"""
logging.addLevelName(LOGGING_OUT_LEVEL, 'OUT')
level = getattr(logging, log_level.upper())
logger = logging.getLogger()
# We do not filter anything at the logger level, only at the handler level
logger.setLevel(logging.NOTSET)
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
formatter = 'verbose' if verbose else 'normal'
formatter = LOGGING_FOMATTER_MAP[formatter]
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if debug_log_file:
file_handler = logging.FileHandler(str(debug_log_file), encoding='utf-8')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(LOGGING_FOMATTER_MAP['verbose'])
logger.addHandler(file_handler)
if info_log_file:
file_handler = logging.FileHandler(str(info_log_file), encoding='utf-8')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(LOGGING_FOMATTER_MAP['normal'])
logger.addHandler(file_handler)
# Redirect all warnings of the "warnings" module as log entries
logging.captureWarnings(True)
EXEKALL_LOGGER = logging.getLogger('EXEKALL')
def out(msg):
"""
To be used as a replacement of :func:`print`.
This allows easy redirection of the output to a file.
"""
EXEKALL_LOGGER.log(LOGGING_OUT_LEVEL, msg)
def info(msg):
"""Write a log message at the INFO level."""
EXEKALL_LOGGER.info(msg)
def debug(msg):
"""Write a log message at the DEBUG level."""
EXEKALL_LOGGER.debug(msg)
def warn(msg):
"""Write a log message at the WARNING level."""
EXEKALL_LOGGER.warning(msg)
def error(msg):
"""Write a log message at the ERROR level."""
EXEKALL_LOGGER.error(msg)
def infer_mod_name(python_src, package_roots=None, excep_handler=None):
"""
Compute the module name of a Python source file by inferring its top-level
package
"""
python_src = pathlib.Path(python_src).absolute()
module_path1 = None
module_path2 = None
# First look for the outermost package we find in the parent directories.
# If we were supplied a path, it will not try to go past its highest folder.
for folder in reversed(python_src.parents):
if pathlib.Path(folder, '__init__.py').exists():
package_root_parent = folder.parents[0]
module_path1 = python_src.relative_to(package_root_parent)
break
# If no package was found, we try to find it through sys.path in case it is
# only using namespace packages
else:
for package_root_parent in sys.path:
try:
module_path1 = python_src.relative_to(package_root_parent)
break
except ValueError:
continue
if package_roots:
for path in package_roots:
assert path
path = pathlib.Path(path).absolute()
package_root_parent = path.parents[0]
try:
module_path2 = python_src.relative_to(package_root_parent)
except ValueError:
continue
else:
break
else:
raise ValueError(f'Could not find {python_src} in any of the package roots: {package_roots}')
# Pick the longest path of both
paths = [module_path1, module_path2]
module_path, *_ = sorted(
filter(bool, paths),
key=lambda path: len(path.parents), reverse=True
)
# If we found the top-level package
if module_path is not None:
module_parents = list(module_path.parents)
module_basename = get_module_basename(module_path)
# Import all parent package_names before we import the module
for package_name in reversed(module_parents[:-1]):
import_file(
pathlib.Path(package_root_parent, package_name),
module_name='.'.join(package_name.parts),
package_roots=package_roots,
is_package=True,
excep_handler=excep_handler,
)
module_dotted_path = list(module_parents[0].parts) + [module_basename]
module_name = '.'.join(module_dotted_path)
else:
module_name = get_module_basename(python_src)
return module_name
def find_customization_module_set(module_set):
"""
Find all customization modules, where subclasses of
:class:`exekall.customization.AdaptorBase` are expected to be found.
It looks for modules named ``exekall_customize`` present in any enclosing
package of modules in ``module_set``.
"""
def build_full_names(l_l):
"""Explode list of lists, and build full package names."""
for l in l_l:
for i, _ in enumerate(l):
i += 1
yield '.'.join(l[:i])
# Exception raised changed in 3.7:
# https://docs.python.org/3/library/importlib.html#importlib.util.find_spec
if sys.version_info >= (3, 7):
import_excep = ModuleNotFoundError
else:
import_excep = AttributeError
package_names_list = [
module.__name__.split('.')
for module in module_set
]
package_name_set = set(build_full_names(package_names_list))
customization_module_set = set()
for name in package_name_set:
customize_name = name + '.exekall_customize'
# Only hide ModuleNotFoundError exceptions when looking up that
# specific module, we don't want to hide issues inside the module
# itself.
module_exists = False
with contextlib.suppress(import_excep, ValueError):
module_exists = importlib.util.find_spec(customize_name)
if module_exists:
# Importing that module is enough to make the adaptor visible
# to the Adaptor base class
customize_module = importlib.import_module(customize_name)
customization_module_set.add(customize_module)
return customization_module_set
def import_modules(paths_or_names, excep_handler=None):
"""
Import the modules in the given list of paths.
If a folder is passed, all Python sources are recursively imported.
"""
def import_it(path_or_name):
# Recursively import all modules when passed folders
if path_or_name.is_dir():
modules = list(import_folder(path_or_name, package_roots=[path_or_name], excep_handler=excep_handler))
yield from modules
# Import by name the top-level package corresponding to the imports
# we just did, and import that top-level package by name if it is a
# namespace package. This will ensure we find all the legs of the
# namespace package, wherever they are
if modules:
names = [m.__name__ for m in modules]
prefix = os.path.commonprefix(names)
toplevel = '.'.join(prefix.split('.')[:-1])
try:
toplevel = importlib.import_module(toplevel)
except Exception:
pass
else:
is_namespace = len(toplevel.__path__ or [])
if is_namespace:
# Best-effort attempt, we ignore exceptions here as a
# namespace package is open to the world and we don't
# want to prevent loading part of it simply because
# another leg is broken.
yield from import_name_recursively(toplevel.__name__, excep_handler=lambda *args, **kwargs: None)
# If passed a file, a symlink or something like that
elif path_or_name.exists():
mod = import_file(path_or_name, excep_handler=excep_handler)
# Could be an exception handler return value
if inspect.ismodule(mod):
yield mod
# Otherwise, assume it is just a module name
else:
yield from import_name_recursively(path_or_name, excep_handler=excep_handler)
return set(itertools.chain.from_iterable(
import_it(pathlib.Path(path))
for path in paths_or_names
))
def import_name_recursively(name, excep_handler=None):
"""
Import a module by its name.
:param name: Full name of the module.
:type name: str
If it's a package, import all submodules recursively.
"""
name_str = str(name)
try:
mod = importlib.import_module(name_str)
except Exception as e:
if excep_handler:
return excep_handler(name_str, e)
else:
raise
try:
paths = mod.__path__
# This is a plain module
except AttributeError:
yield mod
# This is a package, so we import all the submodules recursively
else:
root, *_ = str(name).split('.', 1)
package_roots = importlib.import_module(root).__path__
for path in paths:
yield from import_folder(pathlib.Path(path), package_roots=package_roots, excep_handler=excep_handler)
def import_folder(path, package_roots=None, excep_handler=None):
"""
Import all modules contained in the given folder, recurisvely.
"""
for python_src in glob.iglob(str(path / '**' / '*.py'), recursive=True):
mod = import_file(python_src, package_roots=package_roots, excep_handler=excep_handler)
# Could be an exception handler return value
if inspect.ismodule(mod):
yield mod
def import_file(python_src, *args, excep_handler=None, **kwargs):
"""
Import a module.
:param python_src: Path to a Python source file.
:type python_src: str or pathlib.Path
:param module_name: Name under which to import the module. If ``None``, the
name is inferred using :func:`infer_mod_name`
:type module_name: str
:param package_roots: Paths to the root of the package, used by
:func:`infer_mod_name`. A namespace package can have multiple roots.
:type package_roots: list(str)
:param is_package: ``True`` if the module is a package. If a folder or
``__init__.py`` is passed, this is forcefully set to ``True``.
:type is_package: bool
"""
try:
return _import_file(python_src, *args, **kwargs)
except Exception as e:
if excep_handler:
return excep_handler(str(python_src), e)
else:
raise
def _import_file(python_src, module_name=None, is_package=False, package_roots=None, excep_handler=None):
python_src = pathlib.Path(python_src).resolve()
# Directly importing __init__.py does not really make much sense and may
# even break, so just import its package instead.
if python_src.name == '__init__.py':
return import_file(
python_src=python_src.parent,
module_name=module_name,
package_roots=package_roots,
is_package=True,
excep_handler=excep_handler,
)
if python_src.is_dir():
is_package = True
if module_name is None:
module_name = infer_mod_name(python_src, package_roots=package_roots, excep_handler=excep_handler)
# Check if the module has already been imported
if module_name in sys.modules:
return sys.modules[module_name]
is_namespace_package = False
if is_package:
# Signify that it is a package to
# importlib.util.spec_from_file_location
init_py = pathlib.Path(python_src, '__init__.py')
# __init__.py does not exists for namespace packages
if init_py.exists():
python_src = init_py
else:
is_namespace_package = True
# We manually build a ModuleSpec for namespace packages, since
# spec_from_file_location apparently does not handle them
if is_namespace_package:
# If we get the module spec this way, we will automatically get the
# full submodule_search_location so that the namespace package
# contains all its legs
spec = importlib.util.find_spec(module_name)
if spec is None:
spec = importlib.machinery.ModuleSpec(
name=module_name,
# loader is None for namespace packages
loader=None,
is_package=True,
)
# Set __path__ for namespace packages
spec.submodule_search_locations = [str(python_src)]
else:
spec = importlib.util.spec_from_file_location(
module_name,
str(python_src),
)
if spec is None:
raise ModuleNotFoundError(
'Could not find module "{module}" at {path}'.format(
module=module_name,
path=python_src
),
name=module_name,
path=python_src,
)
module = importlib.util.module_from_spec(spec)
try:
# Register module before executing it so relative imports will
# work
sys.modules[module_name] = module
with warnings.catch_warnings():
warnings.simplefilter(action='ignore')
spec.loader.exec_module(module)
# If the module cannot be imported cleanly regardless of the reason,
# make sure we remove it from sys.modules since it's broken. Future
# attempt to import it should raise again, rather than returning the
# broken module
except BaseException:
with contextlib.suppress(KeyError):
del sys.modules[module_name]
raise
else:
# Set the attribute on the parent package, so that this works:
#
# import foo.bar
# print(foo.bar)
try:
parent_name, last = module_name.rsplit('.', 1)
except ValueError:
pass
else:
parent = sys.modules[parent_name]
setattr(parent, last, module)
# Allow the module to change its entry in sys.modules, and get that if
# it did. This is consistent with the behaviour of the import
# statement.
try:
module = sys.modules[module_name]
except KeyError:
sys.modules[module_name] = module
importlib.invalidate_caches()
return module
def flatten_seq(seq, levels=1):
"""
Flatten a nested sequence, up to ``levels`` levels.
"""
if levels == 0:
return seq
else:
seq = list(itertools.chain.from_iterable(seq))
return flatten_seq(seq, levels=levels - 1)
def take_first(iterable):
"""
Pick the first item of ``iterable``.
"""
for i in iterable:
return i
return NoValue
class _NoValueType:
"""
Type of the :attr:`NoValue` singleton.
This is mostly used like ``None``, in places where ``None`` may be an
acceptable value.
"""
# Use a singleton pattern to make sure that even deserialized instances
# will be the same object
def __new__(cls):
try:
return cls._instance
except AttributeError:
obj = super().__new__(cls)
cls._instance = obj
return obj
def __eq__(self, other):
return isinstance(other, _NoValueType)
def __hash__(self):
return 0
def __bool__(self):
return False
def __repr__(self):
return 'NoValue'
def __eq__(self, other):
return type(self) is type(other)
NoValue = _NoValueType()
"""
Singleton with similar purposes as ``None``.
"""
class RestartableIter:
"""
Wrap an iterator to give a new iterator that is restartable.
"""
def __init__(self, it):
self.values = []
# Wrap the iterator to update the memoized values
def wrapped(it):
for x in it:
self.values.append(x)
yield x
self.it = wrapped(it)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.it)
except StopIteration:
# Use the stored values the next time we try to get an
# itertor again
self.it = iter(self.values)
raise
def get_froz_val_set_set(db, uuid_seq=None, type_pattern_seq=None):
"""
Get a set of sets of :class:`exekall.engine.FrozenExprVal`.
:param db: :class:`exekall.engine.ValueDB` to look into
:type db: exekall.engine.ValueDB
:param uuid_seq: Sequence of UUIDs to select.
:type uuid_seq: list(str)
:param type_pattern_seq: Sequence of :func:`fnmatch.fnmatch` patterns
matching type names (including module name).
:type type_pattern_seq: list(str)
"""
def uuid_predicate(froz_val):
return froz_val.uuid in uuid_seq
def type_pattern_predicate(froz_val):
return match_base_cls(froz_val.type_, type_pattern_seq)
if type_pattern_seq and not uuid_seq:
predicate = type_pattern_predicate
elif uuid_seq and not type_pattern_seq:
predicate = uuid_predicate
elif not uuid_seq and not type_pattern_seq:
def predicate(froz_val): return True
else:
def predicate(froz_val):
return uuid_predicate(froz_val) and type_pattern_predicate(froz_val)
return db.get_by_predicate(predicate, flatten=False, deduplicate=True)
def match_base_cls(cls, pattern_list):
"""
Match the name of the class of the object and all its base classes.
"""
for base_cls in get_mro(cls):
base_cls_name = get_name(base_cls, full_qual=True)
if not base_cls_name:
continue
if match_name(base_cls_name, pattern_list):
return True
return False
def match_name(name, pattern_list):
"""
Return ``True`` if ``name`` is matched by any pattern in ``pattern_list``.
If a pattern starts with ``!``, it is taken as a negative pattern.
"""
if name is None:
return False
if not pattern_list:
return False
neg_patterns = {
pattern[1:]
for pattern in pattern_list
if pattern.startswith('!')
}
pos_patterns = {
pattern
for pattern in pattern_list
if not pattern.startswith('!')
}
def invert(x): return not x
def identity(x): return x
def check(pattern_set, f):
if pattern_set:
ok = any(
fnmatch.fnmatch(name, pattern)
for pattern in pattern_set
)
return f(ok)
else:
return True
return (check(pos_patterns, identity) and check(neg_patterns, invert))
def get_common_base(cls_list):
"""
Get the most derived common base class of classes in ``cls_list``.
"""
# MRO in which "object" will appear first
def rev_mro(cls):
return reversed(inspect.getmro(cls))
def common(cls1, cls2):
# Get the most derived class that is in common in the MRO of cls1 and
# cls2
for b1, b2 in itertools.takewhile(
lambda b1_b2: b1_b2[0] is b1_b2[1],
zip(rev_mro(cls1), rev_mro(cls2))
):
pass
return b1
return functools.reduce(common, cls_list)
def get_subclasses(cls):
"""
Get all the (direct and indirect) subclasses of ``cls``.
"""
subcls_set = {cls}
for subcls in cls.__subclasses__():
subcls_set.update(get_subclasses(subcls))
return subcls_set
def get_package(module):
"""
Find the package name of a module. If the module has no package, its own
name is taken.
"""
# We keep the search local to the packages these modules are defined in, to
# avoid getting a huge set of uninteresting callables.
package = module.__package__
if package:
name = package.split('.', 1)[0]
else:
name = None
# Standalone modules that are not inside a package will get "" as
# package name
return name or module.__name__
def get_recursive_module_set(module_set, package_set, visited_module_set=None):
"""
Retrieve the set of all modules recursively imported from the modules in
``module_set`` if they are (indirectly) part of one of the packages named
in ``package_set``.
"""
visited_modules = set(visited_module_set) if visited_module_set else set()
def _get_recursive_module_set(modules, module_set, package_set):
for module in modules:
if not isinstance(module, types.ModuleType):
continue
if module in visited_modules:
continue
else:
visited_modules.add(module)
# We only recurse into modules that are part of the given set
# of packages
if get_package(module) in package_set:
module_set.add(module)
_get_recursive_module_set(vars(module).values(), module_set, package_set)
recursive_module_set = set()
_get_recursive_module_set(module_set, recursive_module_set, package_set)
return recursive_module_set
@contextlib.contextmanager
def disable_gc():
"""
Context manager to disable garbage collection.
This can result in significant speed-up in code creating a lot of objects,
like :func:`pickle.load`.
"""
if not gc.isenabled():
yield
return
gc.disable()
try:
yield
finally:
gc.enable()
def render_graphviz(expr):
"""
Render the structure of an expression as a graphviz description or SVG.
:returns: A tuple(bool, content) where the boolean is ``True`` if SVG could
be rendered or ``False`` if it still a graphviz description.
:param expr: Expression to render
:type expr: exekall.engine.ExpressionBase
"""
graphviz = expr.format_structure(graphviz=True)
with tempfile.NamedTemporaryFile('wt') as f:
f.write(graphviz)
f.flush()
try:
svg = subprocess.check_output(
['dot', f.name, '-Tsvg'],
stderr=subprocess.DEVNULL,
).decode('utf-8')
# If "dot" is not installed
except FileNotFoundError:
pass
except subprocess.CalledProcessError as e:
debug(f'dot failed to execute: {e}')
else:
return (True, svg)
return (False, graphviz)
def add_argument(parser, *args, help, **kwargs):
"""
Equivalent to :meth:`argparse.ArgumentParser.add_argument`, with ``help``
formatting.
This allows using parsers setup using raw formatters.
"""
if help is not argparse.SUPPRESS:
help = textwrap.dedent(help)
# Preserve all new lines where there are, and only wrap the other lines.
help = '\n'.join(textwrap.fill(line) for line in help.splitlines())
return parser.add_argument(*args, **kwargs, help=help)
def create_adaptor_parser_group(parser, adaptor_cls):
description = f'{adaptor_cls.name} custom options.\nCan only be specified *after* positional parameters.'
return parser.add_argument_group(adaptor_cls.name, description)
def powerset(iterable):
"""
Powerset of the given iterable ::
powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
"""
s = list(iterable)
return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1))
def measure_time(iterator):
"""
Measure how long it took to yield each value of the given iterator.
"""
while True:
begin = time.monotonic()
try:
val = next(iterator)
except StopIteration:
return
else:
end = time.monotonic()
yield (end - begin, val)
def capture_log(iterator):
logger = logging.getLogger()
def make_handler(level):
formatter = LOGGING_FOMATTER_MAP['normal']
string = io.StringIO()
handler = logging.StreamHandler(string)
handler.setLevel(level)
handler.setFormatter(formatter)
return (string, handler)
def setup():
handler_map = {
logging.getLevelName(level): make_handler(level)
for level in range(logging.NOTSET, logging.CRITICAL, 10)
}
for string, handler in handler_map.values():
logger.addHandler(handler)
return handler_map
def teardown(handler_map):
def extract(string, handler):
logger.removeHandler(handler)
return string.getvalue().rstrip()
return {
name: extract(string, handler)
for name, (string, handler) in handler_map.items()
}
while True:
handler_map = setup()
utc = utc_datetime()
try:
val = next(iterator)
except StopIteration:
return
else:
log_map = teardown(handler_map)
yield (utc, log_map, val)
class OrderedSetBase:
"""
Base class for custom ordered sets.
"""
def __init__(self, items=[]):
# Make sure to iterate over items only once, in case it's a generator
self._list = list(items)
self._set = set(self._list)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._set == other._set
elif isinstance(other, collections.abc.Sequence):
self._set == set(other)
else:
return False
def __contains__(self, item):
return item in self._set
def __iter__(self):
return iter(self._list)
def __len__(self):
return len(self._list)
class FrozenOrderedSet(OrderedSetBase, collections.abc.Set):
"""
Like a regular ``frozenset``, but iterating over it will yield items in insertion
order.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._hash = functools.reduce(
lambda hash_, item: hash_ ^ hash(item),
self._list,
0
)
def __hash__(self):
return self._hash
def __eq__(self, other):
if isinstance(other, self.__class__):
if hash(self) != hash(other):
return False
return super().__eq__(other)
class OrderedSet(OrderedSetBase, collections.abc.MutableSet):
"""
Like a regular ``set``, but iterating over it will yield items in insertion
order.
"""
def add(self, item):
if item in self._set:
return
else:
self._set.add(item)
self._list.append(item)
def update(self, *sets):
for s in sets:
for x in s:
self.add(x)
def discard(self, item):
self._set.discard(item)
with contextlib.suppress(ValueError):
self._list.remove(item)
def utc_datetime():
"""
Return a UTC :class:`datetime.datetime`.
"""
return datetime.datetime.now(datetime.timezone.utc)