import pymel.all as pm
from pymel.internal.plogging import pymelLogger

#select 3 joints:
#    first the roll's PARENT joint, already properly aligned
#    second the ROLL joint itself
#    third the roll's CHILD joint, already properly aligned
#script will put the roll joint into the proper place so that perfect alignment is maintained between the 3

def alignRollJoint():
    myObjects = pm.ls(sl=1) #get the selection to act on
    #only process joints
    for item in myObjects:
        if not pm.nodeType(item) == 'joint': myObjects.remove(item)
    #collect the appropriate joints
    if len(myObjects)==1: #if 1 joint w/ only 1 child joint, use it as parent
        parentJoint = myObjects[0]
        myChildren = pm.listRelatives(parentJoint, c=1, type='joint')
        if not myChildren:
            pymelLogger.error('This joint has no children, what are you rolling?')
            return
        if not len(myChildren) == 1:
            pymelLogger.error('Unspecified child joint (multiple found): please select the roll joint (optional), and the child joint')
            return
        childJoint = myChildren[0]
        rollJoint=''
    elif len(myObjects) == 2: #if 2 joints, 2nd is the child
        rollJoint=''
        childJoint = myObjects[1]
    elif len(myObjects) == 3: #if 3 joints, middle is an existing roll to align
        rollJoint = myObjects[1]
        childJoint = myObjects[2]
    else:
        pymelLogger.error('Please select the parent joint, the roll joint (optional), and the child joint (optional)')
        return
    parentJoint = myObjects[0] #get the parent
    if not parentJoint in pm.listRelatives(childJoint,p=1): #parent the child
        try: pm.parent(childJoint, parentJoint)
        except: #if we can't, then the child was actually the parent all along!
            tempJoint = parentJoint
            parentJoint = childJoint
            childJoint = tempJoint
            pymelLogger.warning('Parent and child selection out of order: please check results for accuracy')
    
    #create the roll joint if required
    if not rollJoint:
        pm.select(cl=1)
        rollJoint = pm.joint(n='%sRoll'%parentJoint)
    
    #remove roll joint from joint chain
    if pm.listRelatives(rollJoint,p=1):
        pm.parent(rollJoint, w=1)
    
    #get joint positions
    Pos1 = pm.xform(parentJoint, q=1, ws=1, rp=1)
    Pos2 = pm.xform(childJoint, q=1, ws=1, rp=1)
    
    #move roll to proper position
    pm.move(rollJoint, (((Pos1[0]+Pos2[0])/2), ((Pos1[1]+Pos2[1])/2), ((Pos1[2]+Pos2[2])/2)), rpr=1)
    
    #orient the roll joint
    tempConstraint = pm.orientConstraint(parentJoint, rollJoint, mo=0)
    pm.delete(tempConstraint)
    pm.makeIdentity(rollJoint, apply=1, r=1, n=0)
    
    pm.parent(rollJoint, parentJoint) #return roll joint to joint chain
    pm.select(rollJoint) #return the selection to the roll joint