Source code for src.rossmassey.fetch_leetcode_problem._func_parsing

"""
Functions for parsing leetcode code snippet
"""
import ast
from collections import defaultdict


[docs] def generate_function_ast(code: str) -> ast.FunctionDef | None: """ Generates abstract syntax tree (ast) from code Args: code (str): class source code Returns: ast.FunctionDef: function ast for the leetcode problem """ tree = ast.parse(_add_pass_to_functions(code)) nodes = defaultdict(list) for node in ast.walk(tree): nodes[node.__class__.__name__].append(node) if not nodes['ClassDef'] or not nodes['FunctionDef']: return None # should only have one class class_name = nodes['ClassDef'][0].name # skip class based solutions (i.e. 155 - MinStack) for now if class_name != 'Solution': return None # assume there is one function return nodes['FunctionDef'][0]
[docs] def get_params(function_ast: ast.FunctionDef) -> tuple: """ Gets function parameters and types Args: function_ast (ast.FunctionDef): function ast Returns: tuple: list of parameters and ist of types """ params = [] param_types = [] for arg in function_ast.args.args: params.append(arg.arg) param_types.append(_parse_annotation(arg.annotation)) return params, param_types
[docs] def get_rtype(function_ast: ast.FunctionDef) -> str | None: """ Gets function return type Args: function_ast (ast.FunctionDef): function ast Returns: str: the return type of the function """ return _parse_annotation(function_ast.returns)
[docs] def _add_pass_to_functions(class_src: str) -> str: """ Adds `pass` to each function in class source code to allow for ast parsing Args: class_src: source code of class Returns: source code with `pass` added to each function """ lines = class_src.split('\n') modified_lines = [] for line in lines: modified_lines.append(line) if line.strip().startswith('def'): # account for class offset indent = 4 + len(line) - len(line.lstrip()) indented_pass = ' ' * indent + 'pass' modified_lines.append(indented_pass) return '\n'.join(modified_lines)
[docs] def _parse_annotation(node: ast.arg.annotation) -> str: """ Converts the annotation binary tree to a string Travels down recursively, don't expect many levels... Args: ast.arg.annotation: the annotation object Returns: str: the annotation string """ if isinstance(node, ast.Name): return node.id elif isinstance(node, ast.Subscript): return f"{node.value.id}[{_parse_annotation(node.slice)}]"