#!/usr/bin/env python3

import asyncio
import argparse
import os
import re
import signal
import sys

DEFAULT_TIMEOUT = 30

# TODO: signal handler to terminate spawned process group when wait-for is killed
# TODO: better return codes esp. when matches are found
# TODO: multiple patterns (multiple out, err, both)
# TODO: print unmatched patterns


async def terminate_group(p: asyncio.subprocess.Process):
    """
    Terminate the process group (shell, crowdsec plugins)
    """
    try:
        os.killpg(os.getpgid(p.pid), signal.SIGTERM)
    except ProcessLookupError:
        pass


async def monitor(
        cmd: str,
        args: list[str],
        out_regex: re.Pattern[str] | None,
        err_regex: re.Pattern[str] | None,
        timeout: float
) -> int:
    """
    Run a subprocess, monitor its stdout/stderr for matches, and handle timeouts or pattern hits.

    Args:
        cmd: The command to run.
        args: A list of arguments to pass to the command.
        out_regex: A compiled regular expression to search for in stdout.
        err_regex: A compiled regular expression to search for in stderr.
        timeout: The maximum number of seconds to wait for the process to terminate.

    Returns:
        The exit code of the process.
    """

    status = None

    async def read_stream(stream: asyncio.StreamReader | None, out, pattern: re.Pattern[str] | None):
        nonlocal status
        if stream is None:
            return

        while True:
            line = await stream.readline()
            if line:
                line = line.decode('utf-8')
                out.write(line)
                if pattern and pattern.search(line):
                    await terminate_group(process)
                    # this is nasty.
                    # if we timeout, we want to return a different exit code
                    # in case of a match, so that the caller can tell
                    # if the application was still running.
                    # XXX: still not good for match found, but return code != 0
                    if timeout != DEFAULT_TIMEOUT:
                        status = 128
                    else:
                        status = 0
                    break
            else:
                break

    process = await asyncio.create_subprocess_exec(
        cmd,
        *args,
        # capture stdout
        stdout=asyncio.subprocess.PIPE,
        # capture stderr
        stderr=asyncio.subprocess.PIPE,
        # disable buffering
        bufsize=0,
        # create a new process group
        # (required to kill child processes when cmd is a shell)
        preexec_fn=os.setsid)

    # Apply a timeout
    try:
        await asyncio.wait_for(
            asyncio.wait([
                asyncio.create_task(process.wait()),
                asyncio.create_task(read_stream(process.stdout, sys.stdout, out_regex)),
                asyncio.create_task(read_stream(process.stderr, sys.stderr, err_regex))
            ]), timeout)
        if status is None:
            status = process.returncode
    except asyncio.TimeoutError:
        await terminate_group(process)
        status = 241

    # Return the same exit code, stdout and stderr as the spawned process
    return status or 0


class Args(argparse.Namespace):
    cmd: str = ''
    args: list[str] = []
    out: str = ''
    err: str = ''
    timeout: float = DEFAULT_TIMEOUT


async def main():
    parser = argparse.ArgumentParser(
        description='Monitor a process and terminate it if a pattern is matched in stdout or stderr.')
    _ = parser.add_argument('cmd', help='The command to run.')
    _ = parser.add_argument('args', nargs=argparse.REMAINDER, help='A list of arguments to pass to the command.')
    _ = parser.add_argument('--out', help='A regular expression pattern to search for in stdout.')
    _ = parser.add_argument('--err', help='A regular expression pattern to search for in stderr.')
    _ = parser.add_argument('--timeout', type=float, default=DEFAULT_TIMEOUT)
    args: Args = parser.parse_args(namespace=Args())

    out_regex = re.compile(args.out) if args.out else None
    err_regex = re.compile(args.err) if args.err else None

    exit_code = await monitor(args.cmd, args.args, out_regex, err_regex, args.timeout)

    return exit_code


if __name__ == '__main__':
    sys.exit(asyncio.run(main()))
