import os, sys
from openeye.oechem import *

def AllowedAtoms(mol, ok_atoms):
  result = 'none'
  for atom in mol.GetAtoms():
    if OEGetAtomicSymbol(atom.GetAtomicNum()) not in ok_atoms:
      result = OEGetAtomicSymbol(atom.GetAtomicNum())
      return result
  return result

def CalcMolWT(mol):
  result = 0.0
  for atom in mol.GetAtoms():
    elem = atom.GetAtomicNum()
    mass = atom.GetIsotope()
    if (elem != 0 and mass != 0):
      result += OEGetIsotopicWeight(elem,mass)
    else:
      result += OEGetAverageWeight(elem)
  return result

def CountChiral(mol):
  result = 0
  for atom in mol.GetAtoms():
    if atom.IsChiral(): result += 1
  for bond in mol.GetBonds():
    if bond.IsChiral(): result += 1
  return result

##########################################################
#from flipper.py in OEChem examples
def EnumerateChirality(mol, ofs):
# build atom predicate
  isChiralAtom = IsChiralAtom()
  atomStereoSpecified = HasAtomStereoSpecified()
  notAtomStereoSpecified = OENotAtom(atomStereoSpecified)
  okAtom = OEAndAtom(isChiralAtom,notAtomStereoSpecified)
# build bond predicate
  isChiralBond = IsChiralBond()
  bondStereoSpecified = HasBondStereoSpecified()
  notBondStereoSpecified = OENotBond(bondStereoSpecified)
  okBond = OEAndBond(isChiralBond,notBondStereoSpecified)
# setup variables
  centerCt=0
  atomlist=[]
  bondlist=[]
  for atom in mol.GetAtoms(okAtom):
    centerCt+=1
    atomlist.append(atom)
  for bond in mol.GetBonds(okBond):
    centerCt+=1
    bondlist.append(bond)
  if flip(ofs,mol,atomlist,len(atomlist),bondlist,len(bondlist))==0:
    OEWriteMolecule(ofs,mol)

##########################################################
#from flipper.py in OEChem examples
def flip(ofs, mol, atomlist, asize, bondlist, bsize):
  if asize>0:
    a=atomlist[asize-1]
    nbrs=[]
    for nbr in a.GetAtoms(): nbrs.append(nbr)
    a.SetStereo(nbrs, OEAtomStereo_Tetra, OEAtomStereo_Right)
    if flip(ofs,mol,atomlist,asize-1,bondlist,bsize)==0:
      OEWriteMolecule(ofs,mol)
    a.SetStereo(nbrs, OEAtomStereo_Tetra, OEAtomStereo_Left)
    if flip(ofs,mol,atomlist,asize-1,bondlist,bsize)==0:
      OEWriteMolecule(ofs,mol)
    return 1
  if bsize>0:
    b=bondlist[bsize-1]
    bgn = b.GetBgn()
    end = b.GetEnd()
    nbrs=[]
    for nbr in bgn.GetAtoms():
      if nbr.GetIdx() != end.GetIdx():
        nbrs.append(nbr)
	break
    for nbr in end.GetAtoms():
      if nbr.GetIdx() != bgn.GetIdx():
        nbrs.append(nbr)
	break
    b.SetStereo(nbrs, OEBondStereo_CisTrans, OEBondStereo_Cis)
    if flip(ofs,mol,atomlist,asize,bondlist,bsize-1)==0:
      OEWriteMolecule(ofs,mol)
    b.SetStereo(nbrs, OEBondStereo_CisTrans, OEBondStereo_Trans)
    if flip(ofs,mol,atomlist,asize,bondlist,bsize-1)==0:
      OEWriteMolecule(ofs,mol)
    return 1
  return 0

def stripH(mol):
  pat = []
  smarts = ['c1nnnn1', '[n+][OH-]']
  for rule in range(len(smarts)):
    pat.append(OESubSearch())
    pat[rule].Init(smarts[rule])
    for match in pat[rule].Match(mol,1):
      for matchpair in match.GetAtoms():
        if matchpair.target.GetTotalHCount() == 1 and \
          not matchpair.target.IsCarbon():
          for bond in matchpair.target.GetBonds():
            nbor = bond.GetNbr(matchpair.target)
            if nbor.IsHydrogen(): mol.DeleteAtom(nbor)
          matchpair.target.SetFormalCharge(-1)

def addH(mol):
  pat = []
  smarts = ['[CX4][NH3]', '[CX4][NH2][CX4]', '[CX4][NH]([CX4])[CX4]']
  for rule in range(len(smarts)):
    pat.append(OESubSearch())
    pat[rule].Init(smarts[rule])
    for match in pat[rule].Match(mol,1):
      for matchpair in match.GetAtoms():
        if matchpair.target.IsNitrogen():
          matchpair.target.SetFormalCharge(1)

def set_formal_charge(mol):
  for atom in mol.GetAtoms():
    if atom.IsPolar():
      atom.SetImplicitHCount(0)
  OEAssignImplicitHydrogens(mol)
  OEAddExplicitHydrogens(mol)
  OETriposAtomTypes(mol)
# stripH deprotonates tetrazoles and oximes
  stripH(mol)
  addH(mol)
  OEAssignFormalCharges(mol)
  OEGasteigerPartialCharges(mol)
