Context

Context managers

context manager

  • context manager an object designed to be used in a with-statement

    with context-manager:
        body
    
    with context-manager:
        context-manager.begin()
        body
        context-manager.end()
    
    with context-manager:
        setup()
        body
        teardown()
    
    with context-manager:
        context-manager.begin()
        body
        context-manager.end()        
    
    with context-manager:
        allocation()
        body
        deallocation()
    
    with context-manager:
        enter()
        body
        exit()
    
    
  • A context-manager ensures that resources are properly and automatically managed

    • enter() prepares the manager for use
    • exit() cleans it up
  • Context-manager Protocol

__enter__(self)

__exit__(self,
        exc_type, ## exception type 
        exc_val,  ## exception object
        exc_tb)    ## exception traceback
  • __enter__()

    • called before entering with-statement body
    • return value bound to as variable
    • can return value of any type
    • commonly returns context-manager itself
  • __exit__() called when with-statement body exits

  • By default __exit__() propagates exceptions thrown from the with-statement’s code block

  • __exit__() should never explicitly re-raise exceptions

  • __exit__() should only raise exceptions if it fails itself

contextlib

  • contextlib : standard library module for working with context managers
  • contextmanager is a decorator you can use to create new context managers

    @contextlib.contextmanager
    def my_context_manager():
        ## <ENTER>
        try:
        yield [value]
        ## <NORMAL EXIT>
        except:
        ## <EXCEPTIONAL EXIT>
        raise
        with my_context_manager() as x:
        ## . . .
    
  • contextmanager lets you define context-managers with simple control flow It allows you to leverage the statefulness of generators

  • contextmanager uses standard exception handling to propagate exceptions

  • contextmanager explicitly re-raises – or doesn’t catch – to propagate exceptions

  • contextmanager swallows exceptions by not re-raising them

  • Exceptions propagated from inner context managers will be seen by outer context managers

  • multiple context managers with-statements can use as many context-managers as you need

  • Never pass list

    • Example – The code below is the same

      with cm1() as a, cm2() as b:
          BODY
      
      with cm1() as a:
          with cm2() as b:
              BODY
      
  • Examples

    – simple sample

    import contextlib
    
    class LoggingContextManager:
        def __enter__(self):
            print('LoggingContextManager.__enter__()')
            return "You're in a with-block!"
    
        def __exit__(self, exc_type, exc_val, exc_tb):
            if exc_type is None:
                print('LoggingContextManager.__exit__: '
                    'normal exit detected')
            else:
                print('LoggingContextManager.__exit__: '
                    'Exception detected! '
                    'type={}, value={}, traceback={}'.format(
                        exc_type, exc_val, exc_tb))                     
    
    @contextlib.contextmanager
    def logging_context_manager():
        print('logging_context_manager: enter')
        try:
            yield "You're in a with-block!"
            print('logging_context_manager: normal exit')
        except Exception:
            print('logging_context_manager: exceptional exit',
                sys.exc_info())
            raise
    
    
    if __name__ == "__main__":
        with LoggingContextManager() as log:
            pass
    
        with logging_context_manager() as clog:
            pass
    
    ## test result
    ## LoggingContextManager.__enter__()
    ## LoggingContextManager.__exit__: normal exit detected
    ## logging_context_manager: enter
    ## logging_context_manager: normal exit
    
    
    
    • nested
                
    @contextlib.contextmanager
    def nest_test(name):
        print('Entering', name)
        yield name
        print('Exiting', name)
    
    
    
    
    if __name__ == "__main__":
        with nest_test('outer') as n1, nest_test('inner nested in ' + n1):
            pass
    
    ## test result
    ## Entering outer
    ## Entering inner nested in outer
    ## Exiting inner nested in outer
    ## Exiting outer
                
    
    • db & transaction
    import contextlib
    
    class Connection:
        def __init__(self):
            self.xid = 0
    
        def _start_transaction(self, read_only=True):
            print('starting transaction', self.xid)
            rslt = self.xid
            self.xid = self.xid + 1
            return rslt
    
        def _commit_transaction(self, xid):
            print('committing transaction', xid)
    
        def _rollback_transaction(self, xid):
            print('rolling back transaction', xid)
    
    class Transaction:
        def __init__(self, conn, read_only=True):
            self.conn = conn
            self.xid = conn._start_transaction(read_only=read_only)
    
        def commit(self):
            self.conn._commit_transaction(self.xid)
    
        def rollback(self):
            self.conn._rollback_transaction(self.xid)
    
    @contextlib.contextmanager
    def start_transaction(connection):
        tx = Transaction(connection)
    
        try:
            yield tx
        except Exception:
            tx.rollback()
            raise
    
        tx.commit()
    

Introspection

  • type of object is type

  • Example

    • introspector

      import inspect
      import reprlib
      import itertools
      ## from sorted_set import SortedSet
      
      from bisect import bisect_left
      from collections.abc import Sequence, Set
      from itertools import chain
      
      
      class SortedSet(Sequence, Set):
      
          def __init__(self, items=None):
              self._items = sorted(set(items)) if items is not None else []
      
          def __contains__(self, item):
              try:
                  self.index(item)
                  return True
              except ValueError:
                  return False
      
          def __len__(self):
              return len(self._items)
      
          def __iter__(self):
              return iter(self._items)
      
          def __getitem__(self, index):
              result = self._items[index]
              return SortedSet(result) if isinstance(index, slice) else result
      
          def __repr__(self):
              return "SortedSet({})".format(
                  repr(self._items) if self._items else ''
              )
      
          def __eq__(self, rhs):
              if not isinstance(rhs, SortedSet):
                  return NotImplemented
              return self._items == rhs._items
      
          def __ne__(self, rhs):
              if not isinstance(rhs, SortedSet):
                  return NotImplemented
              return self._items != rhs._items
      
          def _is_unique_and_sorted(self):
              return all(self[i] < self[i + 1] for i in range(len(self) - 1))
      
          def index(self, item):
              assert self._is_unique_and_sorted()
              index = bisect_left(self._items, item)
              if (index != len(self._items)) and (self._items[index] == item):
                  return index
              raise ValueError("{} not found".format(repr(item)))
      
          def count(self, item):
              assert self._is_unique_and_sorted()
              return int(item in self)
      
          def __add__(self, rhs):
              return SortedSet(chain(self._items, rhs._items))
      
          def __mul__(self, rhs):
              return self if rhs > 0 else SortedSet()
      
          def __rmul__(self, lhs):
              return self * lhs
      
          def issubset(self, iterable):
              return self <= SortedSet(iterable)
      
          def issuperset(self, iterable):
              return self >= SortedSet(iterable)
      
          def intersection(self, iterable):
              return self & SortedSet(iterable)
      
          def union(self, iterable):
              return self | SortedSet(iterable)
      
          def symmetric_difference(self, iterable):
              return self ^ SortedSet(iterable)
      
          def difference(self, iterable):
              return self - SortedSet(iterable)
      
      def full_sig(method):
          try:
              return method.__name__ + str(inspect.signature(method))
          except ValueError:
              return method.__name__ + '(...)'
      
      
      def brief_doc(obj):
          doc = obj.__doc__
          if doc is not None:
              lines = doc.splitlines()
              if len(lines) > 0:
                  return lines[0]
          return ''
      
      
      def print_table(rows_of_columns, *headers):
          num_columns = len(rows_of_columns[0])
          num_headers = len(headers)
          if len(headers) != num_columns:
              raise TypeError("Expected {} header arguments, "
                              "got {}".format(num_columns, num_headers))
          rows_of_columns_with_header = itertools.chain([headers], rows_of_columns)
          columns_of_rows = list(zip(*rows_of_columns_with_header))
          column_widths = [max(map(len, column)) for column in columns_of_rows]
          column_specs = ('{{:{w}}}'.format(w=width) for width in column_widths)
          format_spec = ' '.join(column_specs)
          print(format_spec.format(*headers))
          rules = ('-' * width for width in column_widths)
          print(format_spec.format(*rules))
          for row in rows_of_columns:
              print(format_spec.format(*row))
      
      
      def dump(obj):
          print("Type")
          print("====")
          print(type(obj))
          print()
      
          print("Documentation")
          print("=============")
          print(inspect.getdoc(obj))
          print()
      
          print("Attributes")
          print("==========")
          all_attr_names = SortedSet(dir(obj))
          method_names = SortedSet(
              filter(lambda attr_name: callable(getattr(obj, attr_name)),
                  all_attr_names))
          assert method_names <= all_attr_names
          attr_names = all_attr_names - method_names
          attr_names_and_values = [(name, reprlib.repr(getattr(obj, name)))
                                  for name in attr_names]
          print_table(attr_names_and_values, "Name", "Value")
          print()
      
          print("Methods")
          print("=======")
          methods = (getattr(obj, method_name) for method_name in method_names)
          method_names_and_doc = [(full_sig(method), brief_doc(method))
                                  for method in methods]
          print_table(method_names_and_doc, "Name", "Description")
          print()
      
      if __name__ == "__main__":
          dump('a')
      
      ## test result
      ## Type
      ## ====
      ## <class 'str'>
      #
      ## Documentation
      ## =============
      ## str(object='') -> str
      ## str(bytes_or_buffer[, encoding[, errors]]) -> str
      ## Create a new string object from the given object. If encoding or
      ## errors is specified, then the object must expose a data buffer
      ## that will be decoded using the given encoding and error handler.
      ## Otherwise, returns the result of object.__str__() (if defined)
      ## or repr(object).
      ## encoding defaults to sys.getdefaultencoding().
      ## errors defaults to 'strict'.
      #
      ## Attributes
      ## ==========
      ## Name    Value                         
      ## ------- ------------------------------
      ## __doc__ "str(object='... to 'strict'."
      #
      ## Methods
      ## =======
      ## Name              Description                                                            
      ## ----------------- -----------------------------------------------------------------------
      ## __add__(value, /)               Return self+value.                                                     
      ## str(...)                        str(object='') -> str                                                  
      ## __contains__(key, /)            Return key in self.                                                    
      ## __delattr__(name, /)            Implement delattr(self, name).                                         
      ## __dir__(...)                    __dir__() -> list                                                      
      ## __eq__(value, /)                Return self==value.                                                    
      ## __format__(...)                 S.__format__(format_spec) -> str                                       
      ## __ge__(value, /)                Return self>=value.                                                    
      ## __getattribute__(name, /)       Return getattr(self, name).                                            
      ## __getitem__(key, /)             Return self[key].                                                      
      ## __getnewargs__(...)                                                                                    
      ## __gt__(value, /)                Return self>value.                                                     
      ## __hash__()                      Return hash(self).                                                     
      ## __init__(*args, **kwargs)       Initialize self.  See help(type(self)) for accurate signature.         
      ## __init_subclass__(...)          This method is called when a class is subclassed.                      
      ## __iter__()                      Implement iter(self).                                                  
      ## __le__(value, /)                Return self<=value.                                                    
      ## __len__()                       Return len(self).                                                      
      ## __lt__(value, /)                Return self<value.                                                     
      ## __mod__(value, /)               Return self%value.                                                     
      ## __mul__(value, /)               Return self*value.n                                                    
      ## __ne__(value, /)                Return self!=value.                                                    
      ## __new__(*args, **kwargs)        Create and return a new object.  See help(type) for accurate signature.
      ## __reduce__(...)                 helper for pickle                                                      
      ## __reduce_ex__(...)              helper for pickle                                                      
      ## __repr__()                      Return repr(self).                                                     
      ## __rmod__(value, /)              Return value%self.                                                     
      ## __rmul__(value, /)              Return self*value.                                                     
      ## __setattr__(name, value, /)     Implement setattr(self, name, value).                                  
      ## __sizeof__(...)                 S.__sizeof__() -> size of S in memory, in bytes                        
      ## __str__()                       Return str(self).                                                      
      ## __subclasshook__(...)           Abstract classes can override this to customize issubclass().          
      ## capitalize(...)                 S.capitalize() -> str                                                  
      ## casefold(...)                   S.casefold() -> str                                                    
      ## center(...)                     S.center(width[, fillchar]) -> str                                     
      ## count(...)                      S.count(sub[, start[, end]]) -> int                                    
      ## encode(...)                     S.encode(encoding='utf-8', errors='strict') -> bytes                   
      ## endswith(...)                   S.endswith(suffix[, start[, end]]) -> bool                             
      ## expandtabs(...)                 S.expandtabs(tabsize=8) -> str                                         
      ## find(...)                       S.find(sub[, start[, end]]) -> int                                     
      ## format(...)                     S.format(*args, **kwargs) -> str      
      ## format_map(...)                 S.format_map(mapping) -> str                                           
      ## index(...)                      S.index(sub[, start[, end]]) -> int                                    
      ## isalnum(...)                    S.isalnum() -> bool                                                    
      ## isalpha(...)                    S.isalpha() -> bool                                                    
      ## isdecimal(...)                  S.isdecimal() -> bool                                                  
      ## isdigit(...)                    S.isdigit() -> bool                                                    
      ## isidentifier(...)               S.isidentifier() -> bool                                               
      ## islower(...)                    S.islower() -> bool                                                    
      ## isnumeric(...)                  S.isnumeric() -> bool                                                  
      ## isprintable(...)                S.isprintable() -> bool                                                
      ## isspace(...)                    S.isspace() -> bool                                                    
      ## istitle(...)                    S.istitle() -> bool                                                    
      ## isupper(...)                    S.isupper() -> bool                                                    
      ## join(...)                       S.join(iterable) -> str                                                
      ## ljust(...)                      S.ljust(width[, fillchar]) -> str                                      
      ## lower(...)                      S.lower() -> str                                                       
      ## lstrip(...)                     S.lstrip([chars]) -> str                                               
      ## maketrans(x, y=None, z=None, /) Return a translation table usable for str.translate().                 
      ## partition(...)                  S.partition(sep) -> (head, sep, tail)                                  
      ## replace(...)                    S.replace(old, new[, count]) -> str                                    
      ## rfind(...)                      S.rfind(sub[, start[, end]]) -> int                                    
      ## rindex(...)                     S.rindex(sub[, start[, end]]) -> int                                   
      ## rjust(...)                      S.rjust(width[, fillchar]) -> str                                      
      ## rpartition(...)                 S.rpartition(sep) -> (head, sep, tail)                                 
      ## rsplit(...)                     S.rsplit(sep=None, maxsplit=-1) -> list of strings                     
      ## rstrip(...)                     S.rstrip([chars]) -> str                                               
      ## split(...)                      S.split(sep=None, maxsplit=-1) -> list of strings                      
      ## splitlines(...)                 S.splitlines([keepends]) -> list of strings                            
      ## startswith(...)                 S.startswith(prefix[, start[, end]]) -> bool                           
      ## strip(...)                      S.strip([chars]) -> str                                                
      ## swapcase(...)                   S.swapcase() -> str                                                    
      ## title(...)                      S.title() -> str                                                       
      ## translate(...)                  S.translate(table) -> str                                              
      ## upper(...)                      S.upper() -> str                                                       
      ## zfill(...)                      S.zfill(width) -> str