如何检测在Python模块导入上执行的顶级print()和日志记录调用



我想检测对print()logging(例如logging.info()(的调用,这些调用是顶级可访问的,即在模块加载时执行,如果找到,则构建失败。

我维护一个其他团队经常承诺的服务,所以我希望这是CI中的一个lint检查。我该怎么做?

我不关心非顶级调用(例如函数内部的调用(。如果其他团队真的愿意,我想继续允许他们这样做,因为当他们执行自己的代码时。

到目前为止,我已经尝试/遇到了一些没有成功的事情,通常是我关心的所有python文件的动态import_module,然后是:

  • pytestcapsys/capfd功能,可能是因为bug?https://github.com/pytest-dev/pytest/issues/5997#issuecomment-1028710193
    例如:
# foo.py
print("hello")
from importlib import import_module
def test_does_not_print(capfd):
import_module('foo')
out, err = capfd.readouterr()
assert out == ""  # surprise: this will pass
  • 行走ast:https://stackoverflow.com/a/25249854/234593(很难判断是否可以到达顶层(
  • Mock Python';s内置打印功能

注意:以下是一个解决方法,因为capsys/capfd应该能够解决这个问题,但由于未知原因,它不适用于我的特定项目。

我已经能够通过在CI期间运行的独立脚本中对printlogging.info函数进行运行时Monkeypatch来实现这一点,例如:

import builtins
from contextlib import contextmanager
import functools as ft
from importlib import import_module
import logging
import os
import sys
orig_print = builtins.print
orig_info, orig_warning, orig_error, orig_critical = logging.info, logging.warning, logging.error, logging.critical
NO_ARG = object()
sys.path.insert(0, 'src')

def main():
orig_print("Checking files for print() & logging on import...")
for path in files_under_watch():
orig_print("  " + path)
output = detect_toplevel_output(path)
if output:
raise SyntaxWarning(f"Top-level output (print & logging) detected in {path}: {output}")

def files_under_watch():
for root, _, files in os.walk('src'):
for file in files:
if should_watch_file(file):  # your impl here
yield os.path.join(root, file)

def detect_toplevel_output(python_file_path):
with capture_print() as printed, capture_logging() as logged:
module_name = python_file_path[:-3].replace('/', '.')
import_module(module_name)
output = {'print': printed, 'logging': logged}
return {k: v for k, v in output.items() if v}

@contextmanager
def capture_print():
calls = []
@ft.wraps(orig_print)
def captured_print(*args, **kwargs):
calls.append((args, kwargs))
return orig_print(*args, **kwargs)
builtins.print = captured_print
yield calls
builtins.print = orig_print

@contextmanager
def capture_logging():
calls = []
@ft.wraps(orig_info)
def captured_info(*args, **kwargs):
calls.append(('info', args, kwargs))
return orig_info(*args, **kwargs)
@ft.wraps(orig_warning)
def captured_warning(*args, **kwargs):
calls.append(('warning', args, kwargs))
return orig_warning(*args, **kwargs)
@ft.wraps(orig_error)
def captured_error(*args, **kwargs):
calls.append(('error', args, kwargs))
return orig_error(*args, **kwargs)
@ft.wraps(orig_critical)
def captured_critical(*args, **kwargs):
calls.append(('critical', args, kwargs))
return orig_critical(*args, **kwargs)
logging.info, logging.warning, logging.error, logging.critical = captured_info, captured_warning, captured_error, captured_critical
yield calls
logging.info, logging.warning, logging.error, logging.critical = orig_info, orig_warning, orig_error, orig_critical

if __name__ == '__main__':
main()

最新更新