Source code for exekall.utils

#! /usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
#
# Copyright (C) 2018, ARM Limited and contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# This module is allowed to import engine, as engine does not import it (so
# there is no circular dependency). Most utils should go in _utils unless they
# really need to depend on the engine, without being part of it.

import inspect
import typing
import warnings

import exekall.engine as engine

# Re-export all _utils here
from exekall._utils import *


[docs] def get_callable_set(module_set, verbose=False): """ Get the set of callables defined in all modules of ``module_set``. We ignore any callable that is defined outside of the modules' package. :param module_set: Set of modules to scan. :type module_set: set(types.ModuleType) """ package_set = set(map(get_package, module_set)) callable_set = set() visited_obj_set = set() visited_module_set = set() def get(module_set): new_module_set = set() callable_set = set() for module in get_recursive_module_set(module_set, package_set, visited_module_set): visited_module_set.add(module) callable_set_ = _get_callable_set( namespace=module, module=module, visited_obj_set=visited_obj_set, verbose=verbose, package_set=package_set, ) def is_class_attr(obj): try: name = obj.__qualname__ except AttributeError: return False else: return '.' in name and '<locals>' not in name class_attr = { callable_ for callable_ in callable_set_ if is_class_attr(callable_) } # We might have included callables that are defined in another # module via inherited methods, so recurse in these modules too new_module_set.update( map(inspect.getmodule, class_attr) ) callable_set.update(callable_set_) new_module_set.discard(None) new_module_set -= visited_module_set return (callable_set, new_module_set) while module_set: callable_set_, module_set = get(module_set) callable_set.update(callable_set_) return callable_set
def _get_members(*args, **kwargs): """ Same as :func:`inspect.getmembers` except that it will silence warnings, which avoids triggering deprecation warnings just because we are scanning a class """ with warnings.catch_warnings(): warnings.simplefilter(action='ignore') return inspect.getmembers(*args, **kwargs) def _get_callable_set(namespace, module, package_set, visited_obj_set, verbose): """ :param namespace: Module or class :param module: Module the namespace was defined in, or ``None`` to be ignored. """ log_f = info if verbose else debug callable_pool = set() if id(namespace) in visited_obj_set: return callable_pool else: visited_obj_set.add(id(namespace)) attributes = [ callable_ for name, callable_ in _get_members( namespace, predicate=callable ) ] if isinstance(namespace, type): attributes.append(namespace) def select(attr): _module = inspect.getmodule(attr) return ( # Module of builtins is None _module is None or # skip internal classes that may end up being exposed as a global _module is not engine and (True if module is None else _module is module) and get_package(_module) in package_set ) visited_obj_set.update(attributes) attributes = [ attr for attr in attributes if id(attr) not in visited_obj_set and select(attr) ] for callable_ in attributes: # Explore the class attributes as well for nested types if ( isinstance(callable_, type) and # Do not recurse into the "_type" class attribute of namedtuple, as # it has broken globals Python 3.6. The crazy checks are a # workaround the fact that there is no direct way to detect if a # class is a namedtuple. not ( isinstance(namespace, type) and callable_.__name__ == '_type' and issubclass(callable_, tuple) and hasattr(callable_, '_make') and hasattr(callable_, '_asdict') and hasattr(callable_, '_replace') ) ): callable_pool.update( _get_callable_set( namespace=callable_, visited_obj_set=visited_obj_set, verbose=verbose, # We want to get all the attributes in classes, regardless # on what module owns them. For example, we want to select # a method inherited from a base class even if that base # class and the method definition lives somewhere else. module=None, package_set=package_set, ) ) # Functions defined in a class are methods, and have to be wrapped so # engine.Operator() can correctly resolve annotations using the class # attributes as a context in which the function was defined. # # Note that we also wrap things like classmethod, as they have the same # needs if ( isinstance(namespace, type) and isinstance( callable_, ( # Instance and static methods types.FunctionType, # Class methods types.MethodType, ) ) ): callable_ = engine.UnboundMethod(callable_, namespace) try: op = engine.Operator(callable_) # Trigger exceptions if they have to be raised op.prototype # If the callable is partially annotated, warn about it since it is # likely to be a mistake. except engine.PartialAnnotationError as e: log_f('Partially-annotated callable "{callable}" will not be used: {e}'.format( callable=get_name(callable_), e=e, )) continue # If some annotations fail to resolve except NameError as e: log_f('callable "{callable}" with unresolvable annotations will not be used: {e}'.format( callable=get_name(callable_), e=e, )) continue # If something goes wrong, that means it is not properly annotated # so we just ignore it except (SyntaxError, AttributeError, ValueError, KeyError, engine.AnnotationError): continue def has_typevar(op): return any( isinstance(x, typing.TypeVar) for x in {op.value_type, *op.prototype[0].values()} ) def check_typevar_name(cls, name, var): if name != var.__name__: log_f('__name__ of {cls}.{var.__name__} typing.TypeVar differs from the name it is bound to "{name}", which will prevent using it for polymorphic parameters or return annotation'.format( cls=get_name(cls, full_qual=True), var=var, name=name, )) return name type_vars = sorted( check_typevar_name(op.value_type, name, attr) for name, attr in _get_members( op.value_type, lambda x: isinstance(x, typing.TypeVar) ) ) # Also make sure we don't accidentally get callables that will # return a abstract base class instance, since that would not work # anyway. if inspect.isabstract(op.value_type): log_f('Instances of {} will not be created since it has non-implemented abstract methods'.format( get_name(op.value_type, full_qual=True) )) elif type_vars: log_f('Instances of {} will not be created since it has non-overridden TypeVar class attributes: {}'.format( get_name(op.value_type, full_qual=True), ', '.join(type_vars) )) elif has_typevar(op): log_f('callable "{callable}" with non-resolved associated TypeVar annotations will not be used'.format( callable=get_name(callable_), )) else: callable_pool.add(callable_) return callable_pool
[docs] def sweep_param(callable_, param, start, stop, step=1): """ Used to generate a stream of numbers or strings to feed to a callable. :param callable_: Callable the numbers will be used by. :type callable_: collections.abc.Callable :param param: Name of the parameter of the callable the numbers will be providing values for. :type param: str :param start: Starting value. :type start: str :param stop: End value (inclusive) :type stop: str :param step: Increment step. :type step: str If ``start == stop``, only that value will be yielded, and it can be of any type. The type used will either be one that is annotated on the callable, or the one from the default value if no annotation is available, or float if no default value is found. The value will then be built by passing the string to the type as only parameter. """ op = engine.Operator(callable_) annot = op.prototype[0] try: type_ = annot[param] except KeyError: sig = op.signature default = sig.parameters[param].default if default is not inspect.Parameter.default and default is not None: type_ = type(default) else: type_ = str if start == stop: yield type_(start) else: i = type_(start) step = type_(step) stop = type_(stop) while i <= stop: yield type_(i) i += step
[docs] def get_method_class(function): """ Get the class in which a ``function`` is defined. """ # Unbound instance methods if isinstance(function, engine.UnboundMethod): return function.cls try: obj = function.__self__ except AttributeError: cls_name = function.__qualname__.rsplit('.', 1)[0] if '<locals>' in cls_name: return None else: return eval(cls_name, function.__globals__) else: # bound class methods if isinstance(obj, type): return obj # bound instance method else: return type(obj)