summaryrefslogtreecommitdiff
path: root/python_agent/to_ast.py
blob: 3f5aadbf2382d9700a85f24794be9f0b04aa3bb2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import hashlib
from os.path import exists
from pathlib import Path
from ast import parse, unparse, FunctionDef, ClassDef

def rglob(directory, extension):
    """
    Recursively glob all files with the given extension in the specified directory.

    Args:
    directory (str or Path): The root directory to start the search.
    extension (str): The file extension (e.g., '.txt', '.py').
                     The leading dot is optional but good practice

    Returns:
        list: A list of Path objects for the matching files
    """
    # Ensure the extension starts with a dot if not already provided
    if not extension.startswith("."):
        extension = "." + extension

    # Use rglob to recursively find files matching the pattern
    # The pattern should be the extension itself as rglob operates recursively
    files = list(Path(directory).rglob(f"*{extension}"))
    return files

for file in rglob("src/", "py"):
    source = ""
    with open(file, "r") as fhandle:
        source = fhandle.read()
    tree = parse(source)
    nodes = [node for node in tree.body if isinstance(node, (FunctionDef, ClassDef))]
    for node in nodes:
        src = unparse(node)
        src = f"# Function derived from {file}\n" + src
        if not src.endswith("\n"):
            src = src + "\n"
        fout = ""
        while True:
            srce = src.encode('utf-8')
            src_hobj = hashlib.sha256(srce)
            src_dig = src_hobj.hexdigest()
            fout = f"agent/ast/{src_dig}.py"
            if exists(fout):
                # Check hash to see if it is the same
                old_src = ""
                with open(fout, "r") as fhandle:
                    old_src = fhandle.read()
                old_srce = old_src.encode('utf-8')
                osrc_hobj = hashlib.sha256(old_srce)
                osrc_dig = osrc_hobj.hexdigest()
                if osrc_dig == src_dig:
                    break
                # Add something to clear up collision
                src = src + "# Added due to hash collision\n"
                print("Collision detected!")
                continue
            break
        if exists(fout):
            continue
        with open(fout, "w+") as fhandle:
            fhandle.write(src)