Creating Type Stubs for Scientific Python (Part 3)

Posted by Graham Wheeler on Tuesday, September 6, 2022

Generating Output for a Whole Package

The approach we took in the last post to finding the files for a package is not strictly correct. We imported the package, then looked at the file associated with the package, and if it was an __init__.py file, added all the other .py files in the same directory. This works in many cases but not all. It specifically did work for matplotlib.axes which is the example I have used until now. I suspect there is probably an elegant solution to finding out when that is appropriate and when it is not, but I don’t know what it is. Instead, I am going to treat each .py file as an independent module. The old get_module_and_files function gets replaced by a function get_module_and_children, which returns just a single file, but if that file is an __init__.py, also returns all other .py files and directories in the same folder in a list of submodules.

def get_module_and_children(m: str) -> tuple[ModuleType|None, str, list[str]]:
    try:
        mod = importlib.import_module(m)
        file = inspect.getfile(mod)
    except Exception as e:
        print(f'Could not import module {m}: {e}')
        return None, None, []

    submodules = []
    if file.endswith("/__init__.py"):
        # Get the parent directory and all the files in that directory
        folder = file[:-12]
        files = []
        for f in glob.glob(folder + "/*"):
            if f == file:
                continue
            if f.endswith('.py'):
                submodules.append(f'{m}.{f[f.rfind("/")+1:-3]}')
            elif os.path.isdir(f) and not f.endswith('__pycache__'):
                submodules.append(f'{m}.{f[f.rfind("/")+1:]}')
    return mod, file, submodules

We can then add a flag to the process_module function to control whether we want to include submodules:

def process_module(m: str, 
        state: object,
        processor: Callable, 
        targeter: Callable,
        post_processor: Callable|None = None, 
        include_submodules: bool = True,
        **kwargs):

    modules = [m]
    while modules:
        mod, file, submodules = get_module_and_children(modules.pop())
        if include_submodules:
            if not mod:
                continue
            modules.extend(submodules)
        else:
            if not mod:
                return

        result = None

        try:
            with open(file) as f:
                source = f.read()
        except Exception as e:
            print(f"Failed to read {file}: {e}")
            continue

        result = processor(mod, file, source, state, **kwargs)
        if post_processor is None:
            if result is None:
                print(f"Failed to process {file}")
                continue
            else:
                target = targeter(file)
                folder = target[: target.rfind("/")]
                os.makedirs(folder, exist_ok=True)
                with open(target, "w") as f:
                    f.write(result)
        print(f"Processed file {file}")

    if post_processor:
        result = post_processor(m, state)
        target = targeter(m)
        folder = target[: target.rfind("/")]
        os.makedirs(folder, exist_ok=True)
        with open(target, "w") as f:
            f.write(result)

Now, if we call process_module('matplotlib'), we process all modules in the package.

The Type Translation Function

We’ve assembled a fair number of components by now, but shouldn’t lose sight of the goal while in the weeds. As we generate the stubs, we want a function that essentially takes these inputs:

  • the type from the docstring
  • whether this is a parameter or a return value/assignment
  • the default value (if a parameter or assignment)

and returns the outputs:

  • the type annotation
  • the (modified) default value (which in most cases should be the same as the input, or ‘…’)
  • any imports that are needed to support the type annotation.

Later on we’ll use the import info together with the file path to determine what import statements we want to add (favoring relative imports for imports from the same package).

The reason for identifying if this is a parameter versus non-parameter is that we want to be less restrictive with parameters. Let’s say a docstring says list of int: for a parameter we may want to annotate with Sequence[int], because we just need something that is “list-like”, while for the return type it may be reasonable to return list[int], as it is a specific case.

For simplicity, we can split this desired function into separate cases for assignments, parameters and return values, but they will likely share much of their logic.

Collecting Classes and Import Information

As we analyze types, it is useful to collect what classes are defined in the module. This will also help later to make sure we have the necessary import statements in the stubs. We can add a second dictionary to the state object in the analyzer to keep track of this. We will key it on class names, with the values being the file paths:

    def __init__(self, mod: ModuleType, fname: str, counter: Counter, context: dict,
            imports: dict):
        super().__init__()
        self._mod = mod
        i = fname.find('site-packages')
        if i > 0:
            # Strip off the irrelevant part of the path
            self._fname = fname[i+14:]
        else:
            self._fname = fname
        self._classname = ''
        self._parser = NumpyDocstringParser()
        self._counter = counter
        self._context = context
        self._imports = imports
        
    ...

    def visit_ClassDef(self, node: cst.ClassDef) -> bool:
        if self.at_top_level():
            self._classname = node.name.value
            self._imports[self._classname] = self._fname
            obj = AnalyzingTransformer.get_top_level_obj(self._mod, self._fname, node.name.value)
            self._analyze_obj(obj, self._classname)
        return super().visit_ClassDef(node)

Discarding ‘Trivial’ Analysis Lines

As we generate the map files, there are some types that are obviously correct and need no additional processing or extra imports. Some of these are built-in types, like float. Another case is a restricted value that is just a set of possible string values (these are common in matplotlib at least). We can drop these when generating the map file.

I would have liked to be able to split up the types on " or " and then deal with each alternative as a separate entity. Unfortunately here is where we run into issues with poor specification. It’s possible to see types like list of bool or float. This is ambiguous, but it probably is meant to imply list of bool or list of float, as opposed to list of bool or float. If we simply split at the word or we would get the second interpretation, so it is not safe to do this.

Having said that, it is probably okay to do this when the alternatives are single words, such as in bool or float, or if they correspond to classes in the package that we have in the map file.

There are also many references to ‘array-like’ in matplotlib, which we can normalize to ArrayLike.

# Start with {, end with }, comma-separated quoted words
_single_restricted = re.compile(r'^{([ ]*[\"\'][A-Za-z0-9\-_]+[\"\'][,]?)+}$') 


def is_trivial(s):
    if s.lower() in ['float', 'int', 'bool', 'str', 'set', 'list', 'dict', 'tuple', 'callable', 'array-like', 'none']:
        return True

    if _single_restricted.match(s):
        return True

    if s.find(' or ') > 0:
        if all([is_redundant(c.strip()) for c in s.split(' or ')]):
            return True
        
    return False

Skipping such types reduces our output to 154 lines:

19##
14#1-D array#1-D array
10#1D or 2D array-like#1D|2D array-like
9#(float, float)#tuple[float, float]
8#indexable object#indexable object
7#color#color
7#`~matplotlib.lines.Line2D`#matplotlib.lines.Line2D
7#callable or ndarray#callable|ndarray
6#list of `.Line2D`#Sequence[Line2D]
6#array (length N) or scalar#array|scalar|tuple[length N]
5#str or `~matplotlib.colors.Colormap`#str|matplotlib.colors.Colormap
5#`~matplotlib.colors.Normalize`#matplotlib.colors.Normalize
5#1-D array or sequence#1-D array|sequence
4#(M, N) array-like#ArrayLike|tuple[M, N]
4#Transform#Transform
3#`.Bbox`#Bbox
3#`.BarContainer`#BarContainer
3#float or array-like, shape (n, )#float|ArrayLike|shape|tuple[n, ]
3#color or color sequence#color|color sequence
3#array (length N)#array|tuple[length N]
3#array of bool (length N)#Sequence[bool]|tuple[length N]
3#`.PolyCollection`#PolyCollection
3#2D array-like#2D array-like
3#array-like, shape (n, )#ArrayLike|shape|tuple[n, ]
3#`~.axes.Axes`#axes.Axes
2#`.Text`#Text
2#list of str#Sequence[str]
2#[x0, y0, width, height]#tuple[x0, y0, width, height]
2#`.Transform`#Transform
2#`.Axes`#Axes
2#`.patches.Rectangle`#patches.Rectangle
2#4-tuple of `.patches.ConnectionPatch`#4-tuple of .patches.ConnectionPatch
2#2-tuple of func, or Transform with an inverse#2-tuple of func|Transform with an inverse
2#axes._secondary_axes.SecondaryAxis#axes._secondary_axes.SecondaryAxis
2#str or `.Artist` or `.Transform` or callable or (float, float)#str|Artist|Transform|callable|tuple[float, float]
2#`~matplotlib.patches.Polygon`#matplotlib.patches.Polygon
2#list of colors#Sequence[colors]
2#`~matplotlib.collections.LineCollection`#matplotlib.collections.LineCollection
2#array-like or scalar#ArrayLike|scalar
2#sequence#sequence
2#array (length ``2*maxlags+1``)#array|tuple[length ``2*maxlags+1``]
2#array  (length ``2*maxlags+1``)#array|tuple[length ``2*maxlags+1``]
2#`.LineCollection` or `.Line2D`#LineCollection|Line2D
2#`.Line2D` or None#Line2D|None
2#array-like of length n#array-like of length n
2#1D array-like#1D array-like
2#float or array-like, shape(N,) or shape(2, N)#float|ArrayLike|shape(N|)|shape|tuple[2, N]
2#int or (int, int)#int|tuple[int, int]
2#Array or a sequence of vectors.#Array|a sequence of vectors.
2#list of dicts#Sequence[dicts]
2#`~matplotlib.image.AxesImage`#matplotlib.image.AxesImage
2#{'none', None, 'face', color, color sequence}#Literal['none', None, 'face', color, color sequence]
2#`~.contour.QuadContourSet`#contour.QuadContourSet
2#2D array#2D array
2#1D array#1D array
2#1-D arrays or sequences#1-D arrays|sequences
2#float greater than -0.5#float greater than -0.5
2#bool or 'line'#bool|Literal['line']
2#{"linear", "log", "symlog", "logit", ...} or `.ScaleBase`#ScaleBase|Literal["linear", "log", "symlog", "logit", ...]
2#`.MouseButton`#MouseButton
2#Axes#Axes
1#{'center', 'left', 'right'}, str#str|Literal['center', 'left', 'right']
1#sequence of `.Artist`#Sequence[Artist]
1#`~matplotlib.legend.Legend`#matplotlib.legend.Legend
1#number#number
1#ax#ax
1#`.Annotation`#Annotation
1#`.Line2D`#Line2D
1#array-like or list of array-like#ArrayLike|list of array-like
1#color or list of colors#color|Sequence[colors]
1#str or tuple or list of such values#str|tuple|list of such values
1#list of `.EventCollection`#Sequence[EventCollection]
1#timezone string or `datetime.tzinfo`#timezone string|datetime.tzinfo
1#list of `.Text`#Sequence[Text]
1#sequence of tuples (*xmin*, *xwidth*)#Sequence[tuples]|tuple[*xmin*, *xwidth*]
1#(*ymin*, *yheight*)#tuple[*ymin*, *yheight*]
1#`~.collections.BrokenBarHCollection`#collections.BrokenBarHCollection
1#`.StemContainer`#StemContainer
1#`.ErrorbarContainer`#ErrorbarContainer
1#float or (float, float)#float|tuple[float, float]
1#color or sequence or sequence of color or None#color|sequence|Sequence[color]|None
1#color or sequence of color or {'face', 'none'} or None#color|Sequence[color]|None|Literal['face', 'none']
1#c#c
1#array(N, 4) or None#array|None|tuple[N, 4]
1#edgecolors#edgecolors
1#array-like or list of colors or color#ArrayLike|Sequence[colors]|color
1#`~.markers.MarkerStyle`#markers.MarkerStyle
1#{'face', 'none', *None*} or color or sequence of color#color|Sequence[color]|Literal['face', 'none', *None*]
1#`~matplotlib.collections.PathCollection`#matplotlib.collections.PathCollection
1#'log' or int or sequence#Literal['log']|int|sequence
1#int > 0#int > 0
1#4-tuple of float#4-tuple of float
1#`~matplotlib.collections.PolyCollection`#matplotlib.collections.PolyCollection
1#`.FancyArrow`#FancyArrow
1#`matplotlib.quiver.Quiver`#matplotlib.quiver.Quiver
1#`~matplotlib.quiver.Quiver`#matplotlib.quiver.Quiver
1#bool or array-like of bool#bool|array-like of bool
1#`~matplotlib.quiver.Barbs`#matplotlib.quiver.Barbs
1#sequence of x, y, [color]#Sequence[x]|y|tuple[color]
1#list of `~matplotlib.patches.Polygon`#Sequence[matplotlib.patches.Polygon]
1#{{'pre', 'post', 'mid'}}#{|Literal['pre', 'post', 'mid'}]
1#array-like or PIL image#ArrayLike|PIL image
1#floats (left, right, bottom, top)#floats|tuple[left, right, bottom, top]
1#float > 0#float > 0
1#`matplotlib.collections.Collection`#matplotlib.collections.Collection
1#`matplotlib.collections.QuadMesh`#matplotlib.collections.QuadMesh
1#`.AxesImage` or `.PcolorImage` or `.QuadMesh`#AxesImage|PcolorImage|QuadMesh
1#`.ContourSet` instance#ContourSet instance
1#(n,) array or sequence of (n,) arrays#(n|) array|sequence of  arrays|tuple[n,]
1#int or sequence or str#int|sequence|str
1#(n,) array-like or None#ArrayLike|None|tuple[n,]
1#bool or -1#bool|Literal[-1]
1#array-like, scalar, or None#ArrayLike|scalar|None
1#color or array-like of colors or None#color|array-like of colors|None
1#array or list of arrays#array|Sequence[arrays]
1#array#array
1#`.BarContainer` or list of a single `.Polygon` or list of such objects#BarContainer|list of a single .Polygon|list of such objects
1#float, array-like or None#float|ArrayLike|None
1#`matplotlib.patches.StepPatch`#matplotlib.patches.StepPatch
1#None or int or [int, int] or array-like or [array, array]#None|int|[int|int]|ArrayLike|tuple[array, array]
1#array-like shape(2, 2)#array-like shape|tuple[2, 2]
1#`~.matplotlib.collections.QuadMesh`#matplotlib.collections.QuadMesh
1#`.Colormap`#Colormap
1#*None* or (xmin, xmax)#*None*|tuple[xmin, xmax]
1#`.AxesImage`#AxesImage
1#float or 'present'#float|Literal['present']
1#{'equal', 'auto', None} or float#float|Literal['equal', 'auto', None]
1#`~matplotlib.image.AxesImage` or `.Line2D`#matplotlib.image.AxesImage|Line2D
1#str, scalar or callable#str|scalar|callable
1#result#result
1#`~matplotlib.figure.Figure`#matplotlib.figure.Figure
1#[left, bottom, width, height]#tuple[left, bottom, width, height]
1#`.Figure`#Figure
1#[left, bottom, width, height] or `~matplotlib.transforms.Bbox`#matplotlib.transforms.Bbox|tuple[left, bottom, width, height]
1#Callable[[Axes, Renderer], Bbox]#Callable[|tuple[Axes, Renderer], Bbox]
1#Patch#Patch
1#Cycler#Cycler
1#iterable#iterable
1#None or str or (float, float)#None|str|tuple[float, float]
1#(float, float) or {'C', 'SW', 'S', 'SE', 'E', 'NE', ...}#Literal['C', 'SW', 'S', 'SE', 'E', 'NE', ...]|tuple[float, float]
1#`.RendererBase` subclass.#RendererBase subclass.
1#`.Line2D` properties#Line2D properties
1#pair of ints (m, n)#pair of ints|tuple[m, n]
1#The limit value after call to convert(), or None if limit is None.#The limit value after call to convert|None if limit is None.|tuple[]
1#4-tuple or 3 tuple#4-tuple|3 tuple
1#`matplotlib.backend_bases.MouseEvent`#matplotlib.backend_bases.MouseEvent
1#`.RendererBase` subclass#RendererBase subclass
1#list of `.Artist` or ``None``#Sequence[Artist]|None
1#default: False#default: False
1#`.BboxBase`#BboxBase
1#`matplotlib.figure.Figure`#matplotlib.figure.Figure
1#tuple (*nrows*, *ncols*, *index*) or int#tuple|int|tuple[*nrows*, *ncols*, *index*]
1#list of floats#Sequence[floats]
1#2-tuple of func, or `Transform` with an inverse.#2-tuple of func|Transform with an inverse.

We can extend this further by dropping any types that are just classes we collected in our earlier section. This makes the most sense when processing a whole package (so we gather all the possible classes from the package). If I run the analyzer on all of matplotlib, I reduce the output from about 690 lines down to about 530.

def is_trivial(s, m: str, classes: set = None):

    if s.find(' or ') > 0:
        if all([is_trivial(c.strip(), m, classes) for c in s.split(' or ')]):
            return True

    if _single_restricted.match(s):
        return True

    nt = normalize_type(s)

    if nt.lower() in ['float', 'int', 'bool', 'str', 'set', 'list', 'dict', 'tuple', 'array-like', 
                     'callable', 'none']:
        return True

    if classes:
        # Check unqualified classname

        if nt in classes: # 
            return True

        # Check full qualified classname
        if nt.startswith(m + '.'):
            if nt[nt.rfind('.')+1:] in classes:
                return True

    return False
    
...

def _post_process(m: str, state: tuple):
    result = ''
    freq: Counter = state[0]
    imports: dict = state[1]
    classes: set = set(imports.keys())
    for typ, cnt in freq.most_common():
        if not is_trivial(typ, m, classes):
            result += f'{cnt}#{typ}#{normalize_type(typ)}\n'
    return result

Persistent Maps

We could press on with reducing the output from analysis and there’s some value in doing that, because all the remaining entries are going to need human inspection to come up with an equivalent type annotation. But we are at the point of diminishing returns now, where the risk is we will exclude ambiguous lines and end up with bad annotations (there’s already a small risk of this from our last step if there are classes with identical names).

Instead, let’s put the various pieces together to create the type translation function. As a precursor to calling this function, we would run the analysis phase to get the classes and collect all the ’non-trivial’ types. We can then load a persistent map file (which would be a human-edited output from an earlier analysis), to find even more annotations. Anything that is either non-trivial or missing from the persistent map file would be unhandled, and we can output those entries so they can be added to the map file for the next iteration.

There’s a lot of changes needed for all of this, and rather than do them iteratively I’ll show the end result.

Note: I store a lot of metadata from the analysis transformer for use by the stubbing transformer. LibCST has a metadata mechanism that would probably make all of this simpler, but frankly, I find the documentation next to useless, so I implemented my own. I mostly just rely on dictionaries that are keyed on ‘contexts’.

First, utils.py:

from genericpath import isdir
import glob
import importlib
import inspect
import os
import re
from types import ModuleType
from typing import Callable
from .normalize import normalize_type


def load_map(m: str):
    map = {}
    mapfile = f"analysis/{m}.map"
    if os.path.exists(mapfile):
        with open(mapfile) as f:
            for line in f:
                parts = line.strip().split('#')
                map[parts[0]] = parts[1]
    return map


def get_module_and_children(m: str) -> tuple[ModuleType|None, str, list[str]]:
    try:
        mod = importlib.import_module(m)
        file = inspect.getfile(mod)
    except Exception as e:
        print(f'Could not import module {m}: {e}')
        return None, None, []

    submodules = []
    if file.endswith("/__init__.py"):
        # Get the parent directory and all the files in that directory
        folder = file[:-12]
        files = []
        for f in glob.glob(folder + "/*"):
            if f == file:
                continue
            if f.endswith('.py'):
                submodules.append(f'{m}.{f[f.rfind("/")+1:-3]}')
            elif os.path.isdir(f) and not f.endswith('__pycache__'):
                submodules.append(f'{m}.{f[f.rfind("/")+1:]}')
    return mod, file, submodules


def process_module(m: str, 
        state: object,
        processor: Callable, 
        targeter: Callable,
        post_processor: Callable|None = None, 
        include_submodules: bool = True,
        **kwargs):

    modules = [m]
    while modules:
        m = modules.pop()
        mod, file, submodules = get_module_and_children(m)
        if include_submodules:
            if not mod:
                continue
            modules.extend(submodules)
        else:
            if not mod:
                return

        result = None

        try:
            with open(file) as f:
                source = f.read()
        except Exception as e:
            print(f"Failed to read {file}: {e}")
            continue

        result = processor(mod, m, file, source, state, **kwargs)
        if post_processor is None:
            if result is None:
                print(f"Failed to process {file}")
                continue
            else:
                target = targeter(file)
                folder = target[: target.rfind("/")]
                os.makedirs(folder, exist_ok=True)
                with open(target, "w") as f:
                    f.write(result)
        print(f"Processed file {file}")

    if post_processor:
        result, rtn = post_processor(m, state)
        if result:
            target = targeter(m)
            folder = target[: target.rfind("/")]
            os.makedirs(folder, exist_ok=True)
            with open(target, "w") as f:
                f.write(result)
        return rtn
    return None


# Start with {, end with }, comma-separated quoted words
_single_restricted = re.compile(r'^{([ ]*[\"\'][A-Za-z0-9\-_]+[\"\'][,]?)+}$') 


def is_trivial(s, m: str, classes: set|dict = None):
    """
    s - the type docstring to check
    m - the module name
    classes - a set of class names or dictionary keyed on classnames 
    """

    if s.find(' or ') > 0:
        if all([is_trivial(c.strip(), m, classes) for c in s.split(' or ')]):
            return True

    if _single_restricted.match(s):
        return True

    nt = normalize_type(s)

    if nt.lower() in ['float', 'int', 'bool', 'str', 'set', 'list', 'dict', 'tuple', 'array-like', 
                     'callable', 'none']:
        return True

    if classes:
        # Check unqualified classname

        if nt in classes: # 
            return True

        # Check full qualified classname
        if nt.startswith(m + '.'):
            if nt[nt.rfind('.')+1:] in classes:
                return True

    return False


_generic_type_map = {
    'float': 'float',
    'int': 'int',
    'bool': 'bool',
    'str': 'str',
    'dict': 'dict',
    'list': 'list',
}

_generic_import_map = {

}

Then, the analyzer.py:

from ast import Num
from collections import Counter
import inspect
import json
import os
from types import ModuleType
from xml.etree.ElementInclude import include
import libcst as cst
from .basetransformer import BaseTransformer
from .utils import process_module, is_trivial, load_map
from .parser import NumpyDocstringParser
from .normalize import normalize_type

class AnalyzingTransformer(BaseTransformer):

    def __init__(self, 
            mod: ModuleType, 
            modname: str,
            fname: str, 
            counter: Counter,
            classes: dict,
            docs: dict):
        super().__init__(modname, fname)
        self._mod = mod
        self._parser = NumpyDocstringParser()
        self._counter = counter
        self._classes = classes
        self._docs = {}
        docs[modname] = self._docs
        self._classname = None
        

    def _analyze_obj(self, obj, context: str) -> tuple[dict[str, str]|None, ...]:
        doc = None
        if obj:
            doc = inspect.getdoc(obj)
        if not doc:
            return
        rtn = self._parser.parse(doc)
        for section in rtn:
            if section:
                for typ in section.values():
                    self._counter[typ] += 1
        return rtn

    @staticmethod
    def get_top_level_obj(mod: ModuleType, fname: str, oname: str):
        try:
            return mod.__dict__[oname]
        except KeyError as e:
            try:
                submod = fname[fname.rfind('/')+1:-3]
                return mod.__dict__[submod].__dict__[oname]
            except Exception:
                print(f'{fname}: Could not get obj for {oname}')
                return None

    def visit_ClassDef(self, node: cst.ClassDef) -> bool:
        rtn = super().visit_ClassDef(node)
        if self.at_top_level_class_level():
            self._classname = node.name.value
            self._classes[self._classname] = self._modname
            obj = AnalyzingTransformer.get_top_level_obj(self._mod, self._fname, node.name.value)
            self._docs[self.context()] = self._analyze_obj(obj, self._classname)
        return rtn

    def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
        outer_context = self.context()
        rtn = super().visit_FunctionDef(node)
        name = node.name.value
        obj = None
        context = self.context()
        if self.at_top_level_function_level():
            #context = name
            obj = AnalyzingTransformer.get_top_level_obj(self._mod, self._fname, name)
        elif self.at_top_level_class_method_level():
            #context = f'{self._classname}.{name}'
            parent = AnalyzingTransformer.get_top_level_obj(self._mod, self._fname, self._classname)
            if parent:
                if name in parent.__dict__:
                    obj = parent.__dict__[name]
                else:
                    print(f'{self._fname}: Could not get obj for {context}')
        docs = self._analyze_obj(obj, context)
        self._docs[context] = docs

        if name == '__init__':
            # If we actually had a docstring with params section, we're done
            if docs and docs[0]:
                return rtn
            # Else use the class docstring for __init__
            self._docs[context] = self._docs.get(outer_context)

        return rtn

    def leave_FunctionDef(
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> cst.CSTNode:
        # Add a special entry for the return type
        context = self.context()
        doc = self._docs[context]
        if doc:
            self._docs[context + '->'] = doc[1]
        return super().leave_FunctionDef(original_node, updated_node)

    def visit_Param(self, node: cst.Param) -> bool:
        parent_context = self.context()
        parent_doc = self._docs.get(parent_context)
        rtn = super().visit_Param(node)
        if parent_doc and not isinstance(parent_doc, str):
             # The string check makes sure it's not a parameter of a lambda or function that was 
             # assigned as a default value of some other parameter
            param_docs = parent_doc[0]
            if param_docs:
                try:
                    self._docs[self.context()] = param_docs.get(node.name.value)
                except Exception as e:
                    print(e)
        return rtn


def _analyze(mod: ModuleType, m: str, fname: str, source: str, state: tuple, **kwargs):
    try:
        cstree = cst.parse_module(source)
    except Exception as e:
        return None
    try:
        patcher = AnalyzingTransformer(mod, m, fname, 
            counter=state[0], 
            classes = state[1],
            docs = state[2])
        cstree.visit(patcher)
    except:  # Exception as e:
        # Note: I know that e is undefined below; this actually lets me
        # successfully see the stack trace from the original exception
        # as traceback.print_exc() was not working for me.
        print(f"Failed to analyze file: {e}")
        return None
    return state


def _post_process(m: str, state: tuple):
    map = load_map(m)
    result = ''
    freq: Counter = state[0]
    classes: dict = state[1]
    docs: dict = state[2]
    for typ, cnt in freq.most_common():
        if typ not in map and not is_trivial(typ, m, classes):
            result += f'{typ}#{normalize_type(typ)}\n'
    return result, (map, classes, docs)


def _targeter(m: str) -> str:
    """ Turn module name into map file name """
    return f"analysis/{m}.map.missing"


def analyze_module(m: str, include_submodules: bool = True):
    return process_module(m, (Counter(), {}, {}), _analyze, _targeter, post_processor=_post_process,
        include_submodules=include_submodules)

And then, the updated stubber.py:

from __future__ import annotations
from asyncio.proactor_events import _ProactorBaseWritePipeTransport
import glob
import inspect
import os
import re
from types import ModuleType
import libcst as cst

from docs2stubs.analyzer import analyze_module
from docs2stubs.normalize import normalize_type
from .basetransformer import BaseTransformer
from .utils import is_trivial, process_module


class StubbingTransformer(BaseTransformer):
    def __init__(self, modname: str, fname: str, map: dict, classes: dict, docs: dict, 
        strip_defaults=False, infer_types_from_defaults=False):
        super().__init__(modname, fname)
        self._map = map
        self._classes = classes
        self._docs = docs[modname]
        self._strip_defaults = strip_defaults
        self._infer_types = infer_types_from_defaults
        self._method_names = set()
        self._local_class_names = set()
        self._need_imports = {}
        self._ident_re = re.compile(r'([A-Za-z_][A-Za-z0-9_]*)')

    @staticmethod
    def get_value_type(node: cst.CSTNode) -> str|None:
        typ: str|None= None
        if isinstance(node, cst.Name):
            if node.value in [ 'True', 'False']:
                typ = 'bool'
            elif node.value == 'None':
                typ = 'None'
        else:
            for k, v in {
                cst.Integer: 'int',
                cst.Float: 'float',
                cst.Imaginary: 'complex',
                cst.BaseString: 'str',
                cst.BaseDict: 'dict',
                cst.BaseList: 'list',
                cst.BaseSlice: 'slice',
                cst.BaseSet: 'set',
                # TODO: check the next two
                cst.Lambda: 'Callable',
                cst.MatchPattern: 'pattern'
            }.items():
                if isinstance(node, k):
                    typ = v
                    break
        return typ

    def get_assign_value(self, node: cst.Assign) -> cst.CSTNode:
        # See if this is an alias, in which case we want to
        # preserve the value; else we set the new value to ...
        new_value = None
        if isinstance(node.value, cst.Name) and not self.in_function():
            check = set()
            if self.at_top_level():
                check = self._local_class_names
            elif self.at_top_level_class_level(): # Class level
                check = self._method_names
            if node.value.value in check:
                new_value = node.value
        if new_value is None:
            new_value = cst.parse_expression("...")  
        return new_value

    def get_assign_props(self, node: cst.Assign) -> tuple(str|None, cst.CSTNode):
         typ = StubbingTransformer.get_value_type(node.value)
         value=self.get_assign_value(node)
         return typ, value

    def leave_Assign(
        self, original_node: cst.Assign, updated_node: cst.Assign
    ) -> cst.CSTNode:
        typ, value = self.get_assign_props(original_node)
        typ = StubbingTransformer.get_value_type(original_node.value)
        # Make sure the assignment was not to a tuple before
        # changing to AnnAssign
        # TODO: if this is an attribute, see if it had an annotation in 
        # the class docstring and use that
        if typ is not None and len(original_node.targets) == 1:
            return cst.AnnAssign(target=original_node.targets[0].target,
                annotation=cst.Annotation(annotation=cst.Name(typ)),
                value=value)
        else:
            return updated_node.with_changes(value=value)

    def leave_AnnAssign(
        self, original_node: cst.Assign, updated_node: cst.Assign
    ) -> cst.CSTNode:
        value=self.get_assign_value(original_node)
        return updated_node.with_changes(value=value)

    def leave_Param(
        self, original_node: cst.Param, updated_node: cst.Param
    ) -> cst.CSTNode:
        doctyp = self._docs.get(self.context())
        super().leave_Param(original_node, updated_node)
        annotation = original_node.annotation
        default = original_node.default
        valtyp = None
        is_optional = False

        if default:
            valtyp = StubbingTransformer.get_value_type(default) # Inferred type from default
            if (not valtyp or self._strip_defaults):
                # Default is something too complex for a stub or should be stripped; replace with '...'
                default = cst.parse_expression("...")

        if doctyp and not annotation:
            typ = None
            if doctyp in self._map:
                typ = self._map[doctyp]
            elif is_trivial(doctyp, self._modname, self._classes):
                typ = normalize_type(doctyp)
            if typ:
                if typ.find('list') >= 0:
                    # Make this more robust
                    typ = typ.replace('list', 'Sequence')
                    self._need_imports['Sequence'] = 'typing'
                # Figure out other needed imports. A crude but maybe good
                # enough approach is to search for identifiers with a regexp, and
                # then add those if they are in the imports dict.
                for m in self._ident_re.findall(typ):
                    if m in ['Any', 'Callable', 'Iterable', 'Literal', 'Sequence']:
                        self._need_imports[m] = 'typing'
                    elif m in self._classes and m not in self._local_class_names:
                        self._need_imports[m] = self._classes[m]

                # If the default value is None, make sure we include it in the type
                is_optional = 'None' in typ.split('|')
                if not is_optional and valtyp == 'None':
                    typ = typ + '|None'

                print(f'Annotated {self.context()} with {typ} from {doctyp}')
                annotation = cst.Annotation(annotation=cst.parse_expression(typ))
            else:
                print(f'Could not annotate {self.context()} from {doctyp}')

        if self._infer_types and valtyp and not annotation and valtyp != 'None':
            # Use the inferred type from default value as long as it is not None
            annotation = cst.Annotation(annotation=cst.Name(valtyp))
            
        return updated_node.with_changes(annotation=annotation, default=default)

    def visit_ClassDef(self, node: cst.ClassDef) -> bool:
        # Record the names of top-level classes
        if not self.in_class():
            self._local_class_names.add(node.name.value)
        return super().visit_ClassDef(node)

    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode:
        super().leave_ClassDef(original_node, updated_node)
        if not self.in_class():
            # Clear the method name set
            self._method_names = set()
            return updated_node
        else:
            # Nested class; return ...
            return cst.parse_statement('...')

    def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
        if self.at_top_level_class_level():
            # Record the method name
            self._method_names.add(node.name.value)
        return super().visit_FunctionDef(node)

    def leave_FunctionDef(
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> cst.CSTNode:
        """Remove function bodies"""
        doctyp = self._docs.get(self.context() + '->')
        annotation = original_node.returns
        super().leave_FunctionDef(original_node, updated_node)
        if self.in_function(): 
            # Nested function; return ...
            return cst.parse_statement('...')

        if not annotation and doctyp:
            if all([t in self._map or is_trivial(t, self._modname, self._classes) for t in doctyp.values()]):
                v = [self._map[t] if t in self._map else normalize_type(t) for t in doctyp.values()]
                if len(v) > 1:
                    rtntyp = 'tuple[' + ', '.join(v) + ']'
                else:
                    rtntyp = v[0]
                print(f'Annotating {self.context()}-> as {rtntyp}')   
                return updated_node.with_changes(body=cst.parse_statement("..."), 
                    returns=cst.Annotation(annotation=cst.parse_expression(rtntyp)))    
            else:
                print(f'Could not annotate {self.context()}-> from {doctyp}') 

        # Remove the body only
        return updated_node.with_changes(body=cst.parse_statement("..."))

    def leave_SimpleStatementLine(
        self,
        original_node: cst.SimpleStatementLine,
        updated_node: cst.SimpleStatementLine,
    ) -> cst.CSTNode:
        newbody = [
            node
            for node in updated_node.body
            if any(
                isinstance(node, cls)
                for cls in [cst.Assign, cst.AnnAssign, cst.Import, cst.ImportFrom]
            )
        ]
        return updated_node.with_changes(body=newbody)

    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        """Remove everything from the body that is not an import,
        class def, function def, or assignment.
        """
        newbody = [
            node
            for node in updated_node.body
            if any(
                isinstance(node, cls)
                for cls in [cst.ClassDef, cst.FunctionDef, cst.SimpleStatementLine]
            )
        ]
        return updated_node.with_changes(body=newbody)


def patch_source(m: str, fname: str, source: str, map: dict, imports: dict, docs: dict, strip_defaults: bool = False) -> str|None:
    try:
        cstree = cst.parse_module(source)
    except Exception as e:
        return None

    patcher = StubbingTransformer(m, fname, map, imports, docs, strip_defaults=strip_defaults)
    modified = cstree.visit(patcher)

    imports = ''
    for module in set(patcher._need_imports.values()):
        typs = []
        for k, v in patcher._need_imports.items():
            if v == module:
                typs.append(k)
        # TODO: make these relative imports if appropriate
        imports += f'from {module} import {",".join(typs)}\n'
    if imports:
        return imports + '\n\n' + modified.code

    return modified.code


def _stub(mod: ModuleType, m: str, fname: str, source: str, state: tuple, **kwargs):
    return patch_source(m, fname, source, state[0], state[1], state[2], **kwargs)

def _targeter(fname: str) -> str:
    return "typings/" + fname[fname.find("/site-packages/") + 15 :] + "i"

def stub_module(m: str, include_submodules: bool = True, strip_defaults: bool = False):
    map, imports, docs = analyze_module(m, include_submodules=include_submodules)
    process_module(m, (map, imports, docs), _stub, _targeter, include_submodules=include_submodules,
        strip_defaults=strip_defaults)

At some point I want to write some code that will populate the map file for matplotlib from the stubs I have already created. However, the process above will never match what I did earlier, because it still relies on their being appropriate docstrings with types before it will add annotations. At best, this code will annotate every parameter and return type that is ‘properly’ documented, but nothing more. When I created the matplotlib stubs, I started using context that I learned over time. For example, I recognized that a parameter named gc represented a GraphicsContext even if this wasn’t otherwise documented. We could enhance this code to do something similar, but its still going to be quite dumb compared to a human that can apply judgement. So it’s unlikely I will ever use this code to matplotlib stubs. However, it could be a great tool for other libraries for which we have no stubs. I’ll be applying it next to scikit-learn and seeing how that goes, and will describe that in my next post.