#     Copyright 2007-8 Jim Bublitz <jbublitz@nwinternet.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the
# Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA

import keyword
from twineConfig import importFilter, includeFilter, copyright, header, license

scalarTypes = ["int", "long int", "short int", "signed int", "unsigned int", "unsigned long int",\
               "unsigned short",  "unsigned short int", "char", "unsigned char", "signed char",\
               "float", "double", "long", "long long", "unsigned long", "signed long", "unsigned long long",\
               "signed long long", "short", "unsigned", "signed", "bool", "void", "wchar_t"]
    
def eqlist (me, other):
    # compare two lists (enumerators, arguments) for equality
    if len (me) != len (other):
        return False
    for i in range (len (me)):
        if not me [i] == other [i]:
            return False
    return True

class Attributes:
    def __init__ (self):
        # auto, register, static, extern, mutable
        self.storageClass      = ''
        # inline, virtual, explicit, pure ( = pure virtual)
        self.functionQualifier = ''
        # const, volatile
        self.cv                = ''
        
        self.mcNeeded          = False
        self.mtNeeded          = False
        self.ioNeeded          = False
        self.sipSlot           = False    

    def __eq__ (self, other):
        return self.storageClass  == other.storageClass\
            and self.functionQualifier == other.functionQualifier\
            and self.cv == other.cv
                        
    def setStorage (self, candidate):
        s = candidate.split ()
        for item in s:
            item = item.strip ()
            if item == 'virtual':
                self.functionQualifier = 'virtual'
            else:
                self.storageClass = item
        

def low (version):
    if not version:
        return '-'
    return version
    
def high (version):
    if not version:
        return '~'
    return version

class Version:
    def __init__ (self, (versionLow, versionHigh, platform)):
        self.versionLow  = versionLow
        self.versionHigh = versionHigh
        self.platform    = platform

    def __eq__ (self, other):
        return self.versionHigh == other.versionHigh\
            and self.versionLow == other.versionLow\
            and self.platform == other.platform
        
    def inRange (self, other):
        if high (self.versionHigh) < low (other.versionLow)\
            or low (self.versionLow) >= high (other.versionHigh):
            return False
        return True
        
class Object(object):
    def __init__ (self, name, ordinal, stateInfo, objectType):
        # the base class for all objects
        self.name       = name
        self.objectType = objectType
        self.ordinal    = ordinal
        self.access     = stateInfo.access
        self.scope      = stateInfo.scope 
        self.version    = Version (stateInfo.versionStack [-1])
        self.ignore     = stateInfo.ignore
        self.force      = stateInfo.force
        self.filename   = stateInfo.filename
        self.filepath   = None
        self.module     = stateInfo.module
        self.pyName     = self.checkPyName ()
        self.doc        = stateInfo.getDoc ()
        self.sipDoc     = None
        self.blocks     = []        # this holds any sip blocks attached to (or appearing after) this object
        self.skip       = False
                    
    def checkPyName (self):
        if keyword.iskeyword (self.name):
            return self.name + '_'
        else:
            return None
            
            
class ScopeObject:
    def __init__ (self, object):
        # what's kept after a file has been processed
        # (mostly for computing scopes/verifying definition)
        self.name       = object.name
        self.objectType = object.objectType
        self.access     = object.access
        self.scope      = object.scope 
        self.version    = object.version
        self.filename   = object.filename
        self.filepath   = object.filepath
        self.module     = object.module
        if hasattr (object, 'template'):
            self.template   = object.template
            
class ScopedTemplateObject:
    def __init__ (self, object):
        pos = object.name.find ('<')
        if pos > 0:
            self.name = object.name [:pos]
        else:
            self.name = object.name
            
        self.objectType = 'template'
        if object.template:
            self.template   = object.template
        elif object.templateParams:
            self.template = Template (self.name, object.templateParams)
        else:
            self.template = None

        if self.template:
            self.deconstructTemplate (self.template)
        else:
            self.typelist = object.name
            
    def primaryType (self, argtype):
        # get the unmodified (almost) identifier
        all  = len (argtype)
        star = argtype.find ("*")
        amp  = argtype.find ("&")
        if star < 0:
            star = all
        if amp < 0:
            amp = all

        return argtype [:min (star, amp, all)]        

    def deconstructTemplate (self, template):
        self.typelist = []
        params = template.params.split (',')
        for param in params:
            if ' ' in param:
                param = param.split () [0].strip ()
            if self.primaryType (param) in scalarTypes:
                self.typelist.append (param)
            else:
                base = self.primaryType (param)
                self.typelist.append ('nonscalar' + param [len (base):])           
                
class ScopeEnumeratorObject:
    def __init__ (self, enum, enumerator):
        # keep the enumerators as they appear in default values
        self.name       = enumerator.name
        self.objectType = 'enumerator'
        self.enum       = enum.name
        self.access     = enum.access
        self.scope      = enum.scope
        self.version    = enum.version
        self.module     = enum.module
        
class ScopeTypedefObject (ScopeObject):
    # keep extra typedef info so inheritance can be traced if needed
    def __init__ (self, object):
        ScopeObject.__init__ (self, object)
        self.argumentType = object.argumentType

class Argument(object):
    def __init__ (self, argumentType, argumentName = None, argumentValue = None, annotation = None, template = None, defaultTypes = None):
        self.argumentType = argumentType
        self.argumentName = argumentName
        self.defaultValue = argumentValue    # string (no leading '=') of default value/expression
        self.defaultTypes = defaultTypes     # any types pulled out of the default value expression
        self.functionPtr  = None
        self.annotation   = annotation       # a list of annotations
        self.attributes   = Attributes ()
        self.template     = template         # the parsed info from any template-type argument
        self.scope        = ''               # the scope element found with argumentType
        
        # array to pointer
        if self.argumentName and self.argumentName.endswith ('[]'):
            self.argumentType = self.argumentType + '*'
            self.argumentName = self.argumentName [:-2].strip ()
        
        # check for cv qualifier
        if self.argumentType.startswith ('const'):
            self.attributes.cv = 'const'
            self.argumentType = self.argumentType [5:].strip ()
        elif self.argumentType.startswith ('volatile'):
            self.attributes.cv = 'volatile'
            self.argumentType = self.argumentType [8:].strip ()
        elif self.argumentType.startswith ('$fp'):
            self.argumentType = self.argumentType [3:]
            self.defaultValue = None
            if self.argumentValue:
                self.functionPtr  = self.argumentValue.split (',')
            else:
                self.functionPtr = []
                
        # break out scope qualifiers if not a template argument
        if not '<' in self.argumentType and '::' in self.argumentType:
            parts = self.argumentType.split ('::')
            self.argumentType = parts [-1].strip ()
            self.scope = '::'.join (parts [:-1])
                        
        # check for ambiguous types or handwritten code needed
        if self.argumentType [-2:] in ['**', '*&']:
            self.attributes.mcNeeded = True
        elif self.argumentType [-1] in ['*', '&'] and self.argumentType [:-1] in scalarTypes:
            self.attributes.ioNeeded = self.argumentType != 'char*'
            
    def __eq__ (self, other):
        return self.argumentType == other.argumentType\
            and self.attributes == other.attributes
        # fix!! default values in the current sip file set aren't correct in all cases
##        return self.argumentType == other.argumentType\
##            and self.defaultValue == other.defaultValue\
##            and self.attributes == other.attributes
            
            
class Template(object):
    def __init__ (self, name, params, template = None):
        self.name     = name
        self.params   = params
        self.template = template

class Enumerator(object):
    def __init__ (self, name, value, stateInfo):
        self.name  = name
        self.value = value
        self.doc   = ''
        
    def __eq__ (self, other):
        return self.name == other.name
       
           
class NamespaceObject (Object):
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name, ordinal, stateInfo, "namespace")
        self.opaque = False     # opaque or forward declaration
        
    def __eq__ (self, other):
        return other.objectType == self.objectType\
            and self.name == other.name

class ClassObject (Object):
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name, ordinal, stateInfo, "class")
        self.abstract       = False
        self.bases          = []
        self.templateParams = []        # if this list isn't empty, it's a template class
        self.template       = None
        self.opaque         = False     # opaque or forward declaration
        self.annotation     = []
        self.undefBase      = False

    def __eq__ (self, other):
        return self.objectType == other.objectType\
            and self.name   == other.name\
            and self.bases  == other.bases\
            and self.scope  == other.scope\
            and self.access == other.access
            
class EndClassMarker (Object):            
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name, ordinal, stateInfo, "endclass")
        
    def __eq__ (self, other):
        return self.objectType == other.objectType
        
class EndNamespaceMarker (Object):            
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name, ordinal, stateInfo, "endnamespace")
        
    def __eq__ (self, other):
        return self.objectType == other.objectType

class FunctionObject (Object):
    def __init__ (self, name, ordinal, stateInfo, ctor = False):
        Object.__init__ (self, name, ordinal, stateInfo, "function")
        if ctor:
            if '::' in self.scope:
                parts = self.scope.split ('::')
                self.scope = '::'. join (parts [:-1])
            else:
                self.scope = ''              
        self.returns        = []
        self.cppReturns     = []
        self.arguments      = []
        self.cppArguments   = []
        self.templateParams = []        # if this list isn't empty, it's template function (and usually tossed)
        self.methodCode     = False
        self.attributes     = Attributes () 
        self.annotation     = []
        self.exceptions     = ''
        self.inline         = stateInfo.inline
        stateInfo.inline    = False
 
    def setArguments (self, args):        
        for arg in args:
            self.arguments.append (Argument (arg [0], arg [1], arg [2], arg [3], arg [4], arg [5]))
            
    def setCppArgs (self, args, ctor = False):
        # this stores the cpp arg signature for functions/methods with %MethodCode and different
        # sip types than the C++ version
        if not args:
            return 
            
        if not ctor:
            arg = args [0]
            self.cppReturns.append (Argument (arg [0], arg [1], arg [2], arg [3], arg [4], arg [5]))
            args = args [1:]
        else:
            self.cppReturns.append (Argument ('ctor'))
        for arg in args:
            self.cppArguments.append (Argument (arg [0], arg [1], arg [2], arg [3]))        
        
            
    def __eq__ (self, other):
        equal = self.objectType == other.objectType\
            and self.name  == other.name\
            and self.attributes == other.attributes\
            and self.access == other.access\
            and self.scope == other.scope\
            and self.pyName == other.pyName
#            and eqlist (self.templateParams, other.templateParams)\
            
        if not equal:
            return False
        
        # if the argument list for the sip version differs from
        # the h file version, use the cppArgument data to compare
        if self.cppReturns or self.cppArguments:
            returns   = self.cppReturns
            arguments = self.cppArguments
        else:
            returns   = self.returns
            arguments = self.arguments
            
        if other.cppReturns or other.cppArguments:
            otherReturns   = other.cppReturns
            otherArguments = other.cppArguments
        else:
            otherReturns   = other.returns
            otherArguments = other.arguments

        return eqlist (returns, otherReturns)\
            and eqlist (arguments, otherArguments)\
            
class EnumObject (Object):
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name, ordinal, stateInfo, "enum")
        self.enumerators = []
        stateInfo.setDoc (self.doc)
                        
    def __eq__ (self, other):
        return self.objectType == other.objectType\
            and self.name == other.name\
            and self.access == other.access\
            and self.scope == other.scope\
            and eqlist (self.enumerators, other.enumerators)

class TypedefObject (Object):
    # 'name' is the type being defined
    # 'typeName' is the type being aliased
    # if a pointer-to-function, 'functionPtr' holds the function arguments,
    #    'name' is the function name and 'typeName' is the return type
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name,  ordinal, stateInfo, "typedef")
        self.argumentType = None
        self.functionPtr  = None
        self.template     = None
        self.attributes   = Attributes ()
        
    def setArgumentType (self, aType):
        self.argumentType = aType
        
        global scalarTypes
        if not self.functionPtr and self.argumentType in scalarTypes:
            scalarTypes.append (self.name)
        
    def __eq__ (self, other):
        return self.objectType == other.objectType\
            and self.name == other.name\
            and self.access == other.access\
            and self.scope == other.scope\
            and self.argumentType == other.argumentType
        
class VariableObject (Object):        
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name, ordinal, stateInfo, "variable")
        self.variable    = Argument ('unknown')
        self.bitmap      = None
        self.attributes  = Attributes ()
        self.functionPtr = None
        self.annotation  = []
        
    def __eq__ (self, other):
        return self.objectType == other.objectType\
            and self.name == other.name\
            and self.access == other.access\
            and self.scope == other.scope\
            and self.variable == other.variable\
            and self.functionPtr == other.functionPtr

class SipBlockObject (Object):
    # a sip block - %MethodCode, %ModuleCode, etc.
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name, ordinal, stateInfo, "sipBlock")
        self.block = None
        
class SipDirectiveObject (Object):
    # a %If, %End, or other directive
    def __init__ (self, name, ordinal, stateInfo):
        Object.__init__ (self, name, ordinal, stateInfo, "sipDirective")
        self.argument = None

class SipTypeObject (Object):
    # currently, either a %MappedType or %Exception
    def __init__ (self, name, ordinal, stateInfo, typeName):
        Object.__init__ (self, name, ordinal, stateInfo, typeName.lower ())
        self.block = ''
        self.base  = ''
        self.templateParams = []
        self.template = None
        
    def __eq__ (self, other):
       return self.objectType == other.objectType\
            and self.name == other.name\
            and self.access == other.access\
            and self.scope == other.scope    
            
class ScopeIndex(dict):
    # This holds information about any object seen so far, and maps symbols to scope objects.
    def __init__ (self):
        dict.__init__ (self)
        
    def add (self, name, object):
        if object.objectType in ['endclass', 'endnamespace', 'sipBlock', 'sipDirective']\
            or object.name.startswith ('operator'):
            return

        if name in self:
            self [name].append (object)
        else:
            self [name] = [object]
        
    def getMatchingObjects (self, name, objectType = []):
        # find objects that match by name and type (can specify
        # more than one objectType in list)
        matches = []
        namespace = None
        if '::' in name:
            parts = name.split ('::')
            name = parts[-1]
            namespace = parts[0]
        if '.' in name:
            parts = name.split ('.')
            name = parts[-1]
            namespace = parts[0]
            
        if name in self:
            if not objectType:
                return self [name]
                
            for object in self [name]:
                if object.objectType in objectType and (namespace is None or object.scope==namespace):
                    matches.append (object)
        
        return matches
        
    def delete (self, object):
        name = object.name
        objList = self [name]
        objList.remove (object)
        if not objList:
            del self [name]

class Hierarchy (dict):
    # A class hierarchy (mostly for computing %ConvertToSubClass blocks)
    # Maps subclass names to their parent class. (str -> str). Doesn't support
    # multiple inheriance.
    def __init__ (self):
        dict.__init__ (self)
        self.subclasses = {}
        self.clearCurrentMod ()
        
    def clearCurrentMod (self):
        self.currentModClasses = []
        
    def add (self, name, object):
        self.currentModClasses.append (object)
#         if name in self:
#             self [name].append (object)
#         else:
#             self [name] = [object]
        self [object] = name
        
        if name in self.subclasses:
            self.subclasses[name].add(object)
        else:
            newset = set()
            newset.add(object)
            self.subclasses[name] = newset
            
    def getSubclasses(self,classname):
        return self.subclasses.get(classname,set())
            
    def getSuperclasses(self,classname):
        superclass = self.get(classname)
        if superclass is not None and len(superclass)!=0:
            result = [superclass]
            result.extend(self.getSuperclasses(superclass))
            return result
        return []

            
class SipIndex (dict):
    # the sip file object list converted to a dict for lookups
    def __init__ (self):
        dict.__init__ (self)
            
    def add (self, name, object):            
        if name in self:
            self [name].append (object)
        else:
            self [name] = [object]
            
    def getMatchingObjects (self, name):
        if not name in self:
            return None
        return self [name]
                
    def delete (self, object):
        name = object.name
        objList = self [name]
        objList.remove (object)
        if not objList:
            del self [name]
                
class Data(object):
    def __init__ (self):
        # the various structures that hold module data
        self.index         = ScopeIndex ()
        self.hierarchy     = Hierarchy ()
        self.definedScalar = []
        self.parents       = {}
        self.sipIndex      = SipIndex ()
        self.sipIndex.parent = self
        self.clear ()
        
        # the header info for every sip file
        self.copyright = copyright
        self.header    = header
        self.license   = license
        
    def addFile (self, filename, objects, modData, imported = False, versionRange = None):
        if imported:
            filteredObjects = importFilter (objects, versionRange)
        else:
            filteredObjects = includeFilter (objects, modData)
            self.files.append (filename)
            self.fileData [filename] = filteredObjects
        
        self.addToIndex (filteredObjects)
        
    def addToIndex (self, objects):
        for obj in objects:
            if obj.objectType == 'typedef':
                self.index.add (obj.name, ScopeTypedefObject (obj))
            else:
                self.index.add (obj.name, ScopeObject (obj))
                
            if obj.objectType == 'enum':
                for enumerator in obj.enumerators:
                    self.index.add (enumerator.name, ScopeEnumeratorObject (obj, enumerator))
            
            elif obj.objectType == 'class':
                self.parents [obj.name] = obj.bases
                if not obj.bases:
                    self.hierarchy [obj.name] = []
                else:
                    for base in obj.bases:
                        self.hierarchy.add (obj.bases [0], obj.name)
                if obj.templateParams:
                    tObj = ScopedTemplateObject (obj)
                    self.index.add (tObj.name, tObj)

            elif obj.objectType == 'mappedtype':
                if obj.templateParams or '<' in obj.name:
                    tObj = ScopedTemplateObject (obj)
                    self.index.add (tObj.name, tObj)                   
            
            elif obj.objectType == 'typedef':
                if obj.argumentType in scalarTypes and not obj.name in self.definedScalar:
                    self.definedScalar.append (obj.name)

    def clear (self):
        self.files      = []
        self.fileData   = {}
        self.objectList = []
        
    def createSipIndex (self, clear = True):
        if clear:
            self.sipIndex = SipIndex ()
            
        for object in self.objectList:
            self.sipIndex.add (object.name, object)
            if object.objectType == 'mappedtype':
                self.index.add (object.name, ScopeObject (object))
        return self.sipIndex
            
    
