from Bio.PDB import PDBParser
from Bio.PDB import NeighborSearch
import numpy as np

def get_atoms(path, radius):
    results_for_all_files = []

    for pdb_path in path:
        parser = PDBParser()

        structure = parser.get_structure("proteins", pdb_path)
        results_for_one_file = []

        all_atoms = list(structure.get_atoms())

        for model in structure:
            for chain in model:
                for residue in chain:
                    for atom in residue:
                        if atom.get_name().startswith("MN"):  
                            mn_coordinates = np.array(atom.get_coord())

                            
                            ns = NeighborSearch(all_atoms)
                            neighbors = ns.search(mn_coordinates, radius, level='A')

                            # Exclude the MN atom itself and calculate distances
                            atoms_within_radius = [(neighbor, np.linalg.norm(np.array(neighbor.get_coord()) - mn_coordinates)) for neighbor in neighbors if neighbor != atom]

                            results_for_one_file.append(atoms_within_radius)

        results_for_all_files.append(results_for_one_file)

    return results_for_all_files

def calculate_average_distance(results_for_all_files):
    average_distances = []

    for results_for_one_file in results_for_all_files:
        distances = [distance for atom_list in results_for_one_file for _, distance in atom_list]

        if distances: 
            sum_distances = sum(distances)
            length = len(distances)
            average_distance = round(sum_distances / length, 3)
            average_distances.append(average_distance)
        else:
            average_distances.append(np.nan)  

    return average_distances


path = ["/content/drive/MyDrive/Internship/structures/5ws6.pdb"]
radius = 2.4


results_s3 = get_atoms(path, radius)
average_distances_s3 = calculate_average_distance(results_s3)

print(average_distances_s3)